本文主要利用k-近邻分类器实现手写识别系统,训练数据集大约2000个样本,每个数字大约有200个样本,每个样本保存在一个txt文件中,手写体图像本身是32X32的二值图像,如下图所示:
手写数字识别系统的测试代码:
from numpy import *
import operator
from os import listdir
#inX 要检测的数据
#dataSet 数据集
#labels 结果集
#k 要对比的长度
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] #计算有多少行
# tile(inX, (dataSetSize,1))生成对应inX维度的矩阵,方便做差
diffMat = tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat**2 #差求平方
sqDistances = sqDiffMat.sum(axis=1) # axis=0, 表示列 axis=1, 表示行。
distances = sqDistances**0.5 #开方
sortedDistIndicies = distances.argsort() #argsort()排序,求下标
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]] #通过下标索引分类
# 通过构造字典,记录分类频数
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
# 对字段按值排序(从大到小)
sortedClassCount = sorted(classCount.items(),key=lambda classCount:classCount[1], reverse=True)
return sortedClassCount[0][0]
#手写字体识别
#首先,我们需要将图像格式化处理为一个向量,
# 把一个32X32的二进制图像矩阵通过img2vector()函数转换为1X1024的向量:
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32): #图片矩阵为32*32
lineStr = fr.readline() #数据量大,所以使用readline
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
#手写字体识别
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir(r‘trainingDigits‘) #指定文件夹
m = len(trainingFileList) #获取文件夹个数
trainingMat = zeros((m,1024)) #构造m个1024比较矩阵
for i in range(m):
fileNameStr = trainingFileList[i] #获取文件名
fileStr = fileNameStr.split(‘.‘)[0] #按点把文件名字分割
classNumStr = int(fileStr.split(‘_‘)[0]) #按下划线把文件名字分割
hwLabels.append(classNumStr) #实际值添加保存
trainingMat[i,:] = img2vector(r‘trainingDigits/%s‘ % fileNameStr)
testFileList = listdir(‘testDigits‘) #测试数据
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):#同上,处理测试数据
fileNameStr = testFileList[i]
fileStr = fileNameStr.split(‘.‘)[0] #take off .txt
classNumStr = int(fileStr.split(‘_‘)[0])
vectorUnderTest = img2vector(r‘testDigits/%s‘ % fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print ("计算值: %d, 实际值: %d" % (classifierResult, classNumStr))
if (classifierResult != classNumStr): errorCount += 1.0
print ("\n错误出现次数: %d" % errorCount)
print ("\n错误率: %f" % (errorCount/float(mTest)))
handwritingClassTest()
结果:
计算值: 9, 实际值: 9
计算值: 9, 实际值: 9
计算值: 9, 实际值: 9
计算值: 9, 实际值: 9
计算值: 9, 实际值: 9
计算值: 9, 实际值: 9
错误出现次数: 10
错误率: 0.010571
可以看到KNN算法对内存消耗很大(本人12G),中文环境识别不敢想象。