继续之前的写。
三、对单个样本进行分类。
''' function: classify the input sample by voting from its K nearest neighbor input: 1. the input feature vector 2. the feature matrix 3. the label list 4. the value of k return: the result label ''' def ClassifySampleByKNN(featureVectorIn, featureMatrix, labelList, kValue): # calculate the distance between feature input vector and the feature matrix disValArray = CalcEucDistance(featureVectorIn,featureMatrix) # sort and return the index theIndexListOfSortedDist = disValArray.argsort() # consider the first k index, vote for the label labelAndCount = {} for i in range(kValue): theLabelIndex = theIndexListOfSortedDist[i] theLabel = labelList[theLabelIndex] labelAndCount[theLabel] = labelAndCount.get(theLabel,0) + 1 sortedLabelAndCount = sorted(labelAndCount.iteritems(), key=lambda x:x[1], reverse=True) return sortedLabelAndCount[0][0]
基本思路就是,首先计算输入样本和训练样本集合的欧氏距离,然后根据距离进行排序,选择距离最小的k个样本,用这些样本对应的标签进行投票,票数最多的标签就是输入样本所对应的标签。
比较有特色的写法是这一句:
# sort and return the index theIndexListOfSortedDist = disValArray.argsort()
disValArray是numpy的一维数组,存储的仅仅是欧式距离的值。argsort直接对这些值进行排序,并且把排序结果所对应的原索引返回回来。很方便。另外一句是sorted函数的调用,按照value来对字典进行排序,用到了函数式编程的lambda表达式。这个用operator也能达到同样的目的。
四、对测试样本文件进行分类,并统计错误率
''' function: classify the samples in test file by KNN algorithm input: 1. the name of training sample file 2. the name of testing sample file 3. the K value for KNN 4. the name of log file ''' def ClassifySampleFileByKNN(sampleFileNameForTrain, sampleFileNameForTest, kValue, logFileName): logFile = open(logFileName,'w') # load the feature matrix and normailize them feaMatTrain, labelListTrain = LoadFeatureMatrixAndLabels(sampleFileNameForTrain) norFeaMatTrain = AutoNormalizeFeatureMatrix(feaMatTrain) feaMatTest, labelListTest = LoadFeatureMatrixAndLabels(sampleFileNameForTest) norFeaMatTest = AutoNormalizeFeatureMatrix(feaMatTest) # classify the test sample and write the result into log errorNumber = 0.0 testSampleNum = norFeaMatTest.shape[0] for i in range(testSampleNum): label = ClassifySampleByKNN(norFeaMatTest[i,:],norFeaMatTrain,labelListTrain,kValue) if label == labelListTest[i]: logFile.write("%d:right\n"%i) else: logFile.write("%d:wrong\n"%i) errorNumber += 1 errorRate = errorNumber / testSampleNum logFile.write("the error rate: %f" %errorRate) logFile.close() return
代码挺多,不过逻辑上就很简单了。没什么好说的。另外,不知道python中的命名是什么习惯?我发现如果完全把变量名字展开,太长了——我的macbook pro显示起来太难看。这里就沿用c/c++的变量简写命名方式了。
五、入口调用函数
类似c/c++的main函数。只要运行kNN.py这个脚本,就会先执行这一段代码:
if __name__ == '__main__': print "You are running KNN.py" ClassifySampleFileByKNN('datingSetOne.txt','datingSetTwo.txt',3,'log.txt')
kNN中的k值我选择的是3。
未完,待续。
时间: 2024-10-31 05:00:47