程序主体:
- 以kNN算法为基础
- 增加了文件数据导入函数
- 增加了可视化操作
- 增加了算法错误率判定
1 # -*- coding:utf-8 -*- 2 from numpy import * 3 import operator 4 import math 5 import matplotlib 6 import matplotlib.pyplot as plt 7 import numpy as np 8 import random 9 import collections 10 11 def classify0(inX,dataSet,labels,k): 12 dataSetSize = dataSet.shape[0] 13 diffMat = tile(inX,(dataSetSize,1)) - dataSet 14 sqDiffMat = diffMat**2 15 print sqDiffMat 16 sqDistances = sqDiffMat.sum(axis = 1) 17 distances = sqDistances.argsort() 18 # sortedDistIndicies = distances.argsort() 19 classCount=collections.OrderedDict() 20 for i in range(k): 21 voteIlabel = labels[distances[i]] 22 classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 23 print classCount 24 sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True) 25 print sortedClassCount 26 return sortedClassCount[0][0] 27 28 def file2matrix(filename): 29 fr = open(filename) 30 arrayOLines = fr.readlines() 31 numberOfLines = len(arrayOLines) 32 returnMat = zeros((numberOfLines,3)) 33 classLabelVector = [] 34 index = 0 35 for line in arrayOLines: 36 line = line.strip() 37 listFromLine = line.split(‘ ‘) 38 returnMat[index,:] = listFromLine[0:3] 39 classLabelVector.append(int(listFromLine[0])) 40 index+=1 41 return returnMat,classLabelVector 42 43 def autoNorm(dataSet): 44 minVals = dataSet.min(0) 45 maxVals = dataSet.max(0) 46 ranges = maxVals - minVals 47 normDataSet = zeros(shape(dataSet)) 48 m = dataSet.shape[0] 49 normDataSet = dataSet - tile(minVals,(m,1)) 50 normDataSet = normDataSet/tile(ranges,(m,1)) 51 return normDataSet,ranges,minVals 52 53 datingdataMat,datingLabels = file2matrix(‘/Users/tiemuer/PycharmProjects/kNN/data.txt‘) 54 fig = plt.figure() 55 ax = fig.add_subplot(111) 56 ax.scatter(datingdataMat[:,0],datingdataMat[:,1],15.0*array(datingLabels),15.0*array(datingLabels)) 57 plt.show() 58 59 def datingClassTest(): 60 hoRatio = 0.10 61 datingdataMat,datingLabels = file2matrix(‘/Users/tiemuer/PycharmProjects/kNN/data.txt‘) 62 normMat,ranges,minVals = autoNorm(datingdataMat) 63 m = normMat.shape[0] 64 numTestVecs = int(m*hoRatio) 65 errorCount = 0.0 66 n = [] 67 for i in range(numTestVecs): 68 n.append(int(random.randint(0,m))) 69 for i in n: 70 classifierResult = classify0(normMat[i,:],normMat[:m,:],datingLabels[:m],3) 71 print "the classifier cameback with %d,the real answer is:%d"%(classifierResult,datingLabels[i]) 72 if(classifierResult!=datingLabels[i]):errorCount+=1.0 73 print "the total error rate is %f"%(errorCount/float(numTestVecs)) 74 75 76 datingClassTest()
算法改进
- 由于需要不同,书籍上的代码并不能很好契合程序,故作出一些改进。
- 文件数据导入函数中classLabelVector.append(int(listFromLine[0])),其中把listFromLine[-1]变为[0],因为本程序用每组数据第一个位置来做label,这样才能与后面寻找label契合。
- 本程序选取数据集中随机10%的数据来与原数据集进行匹配,因为数据集不够,且数据集完全随机生成,进而发现问题:对于小数据来说很难做到错误率低。主要原因在于对于相近邻的数据统计次数过少,且普通字典是无序排放,导致不能选出最契合的label,本人针对此问题选用有序字典进行label存储,可以进一步改善算法准确率。
- 上一篇文章的kNN算法需要改动,删掉本程序中注释掉的代码(多进行了一次排序导致程序混乱,得不到期望结果)。
- 由于各个数据项特征值大小不同,而它们在程序中的地位同等,故用autoNorm()函数来调整大小:(data-mindata)/(maxdata-mindata)把特征值统一到0-1之间,更具参考性。
初学python,记录学习笔记
- 字典访问需要key,而key的获取可以用dic.keys()来得到一个list
- 对字典进行排序时,利用迭代器方法sortedClassCount = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse = True),reverse = True 进行升序排序,返回的不是字典,而可以通过下标访问
- open(filename)打开文件,读文件用readlines(),下面记录下read(),readline(),readlines()函数的区别:
- read():读取整个文件,文件内容存放于字符串变量中
- readline():读取文件一行,存放于字符串对象中,比readlines()慢许多
- readlines():读取整个文件,自动把文件分成一个行的列表
- 普通字典是无序的,故不能利用存放次序来寻找字典元素,可以借助collections库中的collections.OrderedDict()来创建一个有序字典。
- strip() 方法用于移除字符串头尾指定的字符(默认为空格)。例如:str.strip([chars])
- split()通过指定分隔符对字符串进行切片,如果参数num 有指定值,则仅分隔 num 个子字符串。例如:split(‘ ‘,1)仅对于空格分割一次。
- 这里不整理可视化作图操作,回头找个时间专门做个整理。
时间: 2024-10-20 05:35:04