《统计学习方法》第三章,k 近邻法

? k 近邻法来分类,用到了 kd 树的建立和搜索

● 代码

  1 import numpy as np
  2 import matplotlib.pyplot as plt
  3 from mpl_toolkits.mplot3d import Axes3D
  4 from mpl_toolkits.mplot3d.art3d import Poly3DCollection
  5 from matplotlib.patches import Rectangle
  6 import operator
  7 import warnings
  8
  9 warnings.filterwarnings("ignore")
 10 dataSize = 10000
 11 trainRatio = 0.3
 12
 13 def dataSplit(x, y, part):                                                          # 将数据集按给定索引分为两段
 14     return x[:part], y[:part],x[part:],y[part:]
 15
 16 def myColor(x):                                                                     # 颜色函数,用于对散点染色
 17     r = np.select([x < 1/2, x < 3/4, x <= 1, True],[0, 4 * x - 2, 1, 0])
 18     g = np.select([x < 1/4, x < 3/4, x <= 1, True],[4 * x, 1, 4 - 4 * x, 0])
 19     b = np.select([x < 1/4, x < 1/2, x <= 1, True],[1, 2 - 4 * x, 0, 0])
 20     return [r**2,g**2,b**2]
 21
 22 def mold(x, y):                                                                     # 距离采用欧氏距离的平方
 23     return np.sum((x - y)**2)
 24
 25 def createData(dim, kind, count = dataSize):                                        # 创建数据集
 26     np.random.seed(103)
 27     X = np.random.rand(count, dim)
 28     center = np.random.rand(kind, dim)
 29     Y = [ chr(65 + np.argmin(np.sum((X[i] - center)**2, 1))) for i in range(count) ]
 30     #print(output)
 31     classCount = dict([ [chr(65 + i),0] for i in range(kind) ])
 32     for i in range(count):
 33         classCount[Y[i]] +=1
 34     print("dim = %d, kind = %d, dataSize = %d,"%(dim, kind, count))
 35     for i in range(kind):
 36         print("kind %c -> %4d"%(chr(65+i), classCount[chr(65+i)]))
 37     return X, np.array(Y)
 38
 39 def buildKdTree(dataX, dataY, dividDim):                            # 建立 kd 树,每个节点具有的成员有:
 40     count, dim = np.shape(dataX)                                    # count 总结点数,dividDim 根节点用来划分空间的坐标的序号
 41     if count == 0:                                                  # point 根节点坐标,kind 根节点类别
 42         return {‘count‘: 0}                                         # leftChild rightChild 左右子节点
 43     if count == 1:
 44         return {‘count‘: 1, ‘point‘: dataX[0], ‘kind‘: dataY[0]}    # 总结点只有 0 或者 1 时只有部分成员就够了
 45
 46     #print(count)                                                    # 调试用,显示当前节点情况
 47     index = np.lexsort((np.ones(count),dataX[:,dividDim]))          # 用 dataX 的值大小来给 dataX 和 dataY 排序,以便查找中位数、切割数据
 48     childDataX = dataX[index]
 49     childDataY = dataY[index]
 50     return {‘count‘: count, ‘index‘: dividDim, ‘point‘: childDataX[count>>1], ‘kind‘: dataY[count>>1],  51             ‘leftChild‘: buildKdTree(childDataX[:count>>1], childDataY[:count>>1], (dividDim + 1) % dim),  52             ‘rightChild‘: buildKdTree(childDataX[(count>>1) + 1:], childDataY[(count>>1) + 1:], (dividDim + 1) % dim)}
 53
 54 def findNearest(origin, nowTree, dividDim):                         # 搜索 kd 树,寻找最近邻点
 55     if nowTree[‘count‘] == 0:                                       # 空子树,返回一个极大的距离
 56         return np.inf, ‘?‘
 57     if nowTree[‘count‘] == 1:                                       # 单点子树,返回距离和类别
 58         return mold(origin, nowTree[‘point‘]), nowTree[‘kind‘]
 59
 60     dim = len(origin)
 61     moldCenter = mold(origin, nowTree[‘point‘])                                 # 母节点距离
 62
 63     if origin[dividDim] < nowTree[‘point‘][dividDim]:                           # 左支
 64         temp = findNearest(origin, nowTree[‘leftChild‘], (dividDim+1)%dim)
 65         if origin[dividDim] + temp[0] > nowTree[‘point‘][dividDim]:             # 穿透分界线,要算右边,最近点为母节点或新子节点
 66             temp = findNearest(origin, nowTree[‘rightChild‘], (dividDim+1)%dim) # 没穿分界线,不算右边,最近点在母节点或旧子节点
 67     else:                                                                       # 右支
 68         temp = findNearest(origin, nowTree[‘rightChild‘], (dividDim+1)%dim)
 69         if origin[dividDim] - temp[0] < nowTree[‘point‘][dividDim]:             # 穿透分界线,要算左边
 70             temp = findNearest(origin, nowTree[‘leftChild‘], (dividDim+1)%dim)  # 没穿分界线,不算左边
 71
 72     if moldCenter < temp[0]:                                                    # 所有分支的比较集中在母节点和挑出来的子节点之间
 73         return moldCenter, nowTree[‘kind‘]
 74     else:
 75         return temp
 76
 77 def vote(point, k, trainX, trainY):                                             # 计算所有距离,选取
 78     distance = np.sum((point - trainX)**2, 1)                                   # 计算
 79     queue = sorted(list(zip(distance[:k], trainY[:k])))                         # 取出前 k 项排好序
 80     for j in range(k, len(distance)):
 81         if distance[j] < queue[-1][0]:                                          # 每次有更优的点就把 queue 中最差的点替换掉,然后排序
 82             queue[-1] = (distance[j], trainY[j])
 83             queue.sort()
 84     kindCount = {}                                                              # 投票阶段
 85     for line in queue:
 86         if line[1] not in kindCount.keys():
 87             kindCount[line[1]] = 0
 88         kindCount[line[1]] += 1
 89     output = sorted(kindCount.items(),key = operator.itemgetter(1),reverse = True)
 90     return output[0][0]
 91
 92 def test(dim, kind, k):
 93     allX, allY = createData(dim, kind)
 94     trainX, trainY, testX, testY = dataSplit(allX, allY, int(dataSize * trainRatio))
 95     myResult = np.array([ ‘?‘ for i in range(len(testX)) ])         # 存放测试结果
 96
 97     if k == 1:                                                      # 一个最近邻时使用 kd 树,否则用正常的的计算距离排序
 98         tree = buildKdTree(trainX, trainY, 0)
 99         for i in range(len(testX)):                                 # 每次循环解决一个测试样本
100             myResult[i] = findNearest(testX[i], tree, 0)[1]
101     else:
102         if k > len(testX):
103             return None
104         for i in range(len(testX)):                                 # 每次循环解决一个测试样本
105             myResult[i] = vote(testX[i], k, trainX, trainY)
106
107     errorRatio = np.sum((myResult != np.array(testY)).astype(int)**2) / (dataSize * (1 - trainRatio))
108     print("k = %d, errorRatio = %4f\n"%(k, errorRatio))
109     if dim >= 4:                                                    # 4维以上不画图,只输出测试错误率
110         return
111
112     errorP = []                                                     # 分类错误的点
113     classP = [ [] for i in range(kind) ]                            # 正确分到各类的的点
114     for i in range(len(testX)):
115         if myResult[i] != testY[i]:
116             errorP.append(testX[i])
117         else:
118             classP[ord(myResult[i]) - 65].append(testX[i])
119     errorP = np.array(errorP)
120     classP = [ np.array(classP[i]) for i in range(kind) ]
121
122     fig = plt.figure(figsize=(10, 8))
123
124     if dim == 1:                                                    # 分不同属性维度画图
125         plt.xlim(-0.1, 1.1)
126         plt.ylim(-0.1, 1.1)
127         for i in range(kind):
128             plt.scatter(classP[i][:,0], np.ones(len(classP[i]))*i, color = myColor(i/kind), s = 8, label = "class" + str(i))
129         if len(errorP) != 0:
130             plt.scatter(errorP[:,0], (errorP[:,0] > 0.5).astype(int), color = myColor(1), s = 16, label = "errorData")
131         R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
132         plt.legend(R, [ "class" + chr(i+65) for i in range(kind) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1)
133
134     if dim == 2:
135         plt.xlim(-0.1, 1.1)
136         plt.ylim(-0.1, 1.1)
137         for i in range(kind):
138             plt.scatter(classP[i][:,0], classP[i][:,1], color = myColor(i/kind), s = 8, label = "class" + str(i))
139         if len(errorP) != 0:
140             plt.scatter(errorP[:,0], errorP[:,1], color = myColor(1), s = 16, label = "errorData")
141         R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
142         plt.legend(R, [ "class" + chr(i+65) for i in range(kind) ] + ["errorData"], loc=[0.84, 0.012], ncol=1, numpoints=1, framealpha = 1)
143
144     if dim == 3:
145         ax = Axes3D(fig)
146         ax.set_xlim3d(-0.1, 1.1)
147         ax.set_ylim3d(-0.1, 1.1)
148         ax.set_zlim3d(-0.1, 1.1)
149         ax.set_xlabel(‘X‘, fontdict={‘size‘: 15, ‘color‘: ‘k‘})
150         ax.set_ylabel(‘Y‘, fontdict={‘size‘: 15, ‘color‘: ‘k‘})
151         ax.set_zlabel(‘Z‘, fontdict={‘size‘: 15, ‘color‘: ‘k‘})
152         for i in range(kind):
153             ax.scatter(classP[i][:,0], classP[i][:,1],classP[i][:,2], color = myColor(i/kind), s = 8, label = "class" + str(i))
154         if len(errorP) != 0:
155             ax.scatter(errorP[:,0], errorP[:,1],errorP[:,2], color = myColor(1), s = 16, label = "errorData")
156         R = [ Rectangle((0,0),0,0, color = myColor(i/kind)) for i in range(kind) ] + [ Rectangle((0,0),0,0, color = myColor(1)) ]
157         plt.legend(R, [ "class" + chr(i+65) for i in range(kind) ] + ["errorData"], loc=[0.85, 0.02], ncol=1, numpoints=1, framealpha = 1)
158
159     fig.savefig("R:\\dim" + str(dim) + "kind" + str(kind) + ".png")
160     plt.close()
161
162 if __name__ == ‘__main__‘:
163     test(2, 2, 1)
164     test(2, 3, 1)
165     test(3, 3, 1)
166     test(4, 3, 1)
167     test(2, 3, 2)
168     test(2, 4, 3)
169     test(3, 3, 2)
170     test(3, 4, 3)
171     test(4, 3, 2)
172     test(4, 4, 4)

● 输出结果

dim = 2, kind = 2, dataSize = 10000,
kind A -> 5301
kind B -> 4699
k = 1, errorRatio = 0.011143

dim = 2, kind = 3, dataSize = 10000,
kind A -> 2740
kind B -> 3197
kind C -> 4063
k = 1, errorRatio = 0.024714

dim = 3, kind = 3, dataSize = 10000,
kind A -> 3693
kind B -> 4232
kind C -> 2075
k = 1, errorRatio = 0.052571

dim = 4, kind = 3, dataSize = 10000,
kind A -> 2640
kind B -> 1765
kind C -> 5595
k = 1, errorRatio = 0.121000

dim = 2, kind = 3, dataSize = 10000,
kind A -> 2740
kind B -> 3197
kind C -> 4063
k = 2, errorRatio = 0.009857

dim = 2, kind = 4, dataSize = 10000,
kind A -> 2740
kind B -> 3000
kind C -> 2387
kind D -> 1873
k = 3, errorRatio = 0.013571

dim = 3, kind = 3, dataSize = 10000,
kind A -> 3693
kind B -> 4232
kind C -> 2075
k = 2, errorRatio = 0.028571

dim = 3, kind = 4, dataSize = 10000,
kind A -> 3029
kind B -> 3379
kind C ->  917
kind D -> 2675
k = 3, errorRatio = 0.038000

dim = 4, kind = 3, dataSize = 10000,
kind A -> 2640
kind B -> 1765
kind C -> 5595
k = 2, errorRatio = 0.062286

dim = 4, kind = 4, dataSize = 10000,
kind A -> 2472
kind B -> 1752
kind C -> 3365
kind D -> 2411
k = 4, errorRatio = 0.079429

● 画图

原文地址:https://www.cnblogs.com/cuancuancuanhao/p/11160291.html

时间: 2024-07-28 22:40:19

《统计学习方法》第三章,k 近邻法的相关文章

《统计学习方法》:第三章 K 近邻算法

k -- NN k--NN 是一种基本分类和回归方法.对新实例进行分类时,通过已经训练的数据求出 k 个最近实例,通过多数表决进行分类.故 k 邻近算法具有不显式的学习过程. 三个基本要素:k 值选择,距离度量,分类决策规则. 1. k 近邻算法 原理:给定一个训练集,对于新输入的实例,在训练集中找到与其相似的 k 个实例,这 k 个实例的多数属于某一类,就将该实例归属到这一类. 输入:训练数据集 \(T = \{(x_1,y_1),(x_2,y_2),...,(x_3,y_3)\}\) 其中,

第3章 K近邻法

参考: http://www.cnblogs.com/juefan/p/3807713.html http://blog.csdn.net/v_july_v/article/details/8203674/ http://www.cnblogs.com/imczxj/p/3941703.html

统计学习方法 笔记&lt;第一章&gt;

第一章 统计学习方法概述 1.1 统计学习 统计学习(statistical learning)是关于计算机基于数据概率模型并运用模型进行预测和分析的学科.统计学习也称为统计机器学习,现在人们提及的机器学习一般都是指统计机器学习. 统计学习的对象是数据(data),关于数据的基本假设是同类数据具有一定的统计规律性(前提):比如可以用随机变量描述数据中的特征,用概率分布描述数据的统计规律等. 统计学习的目的:对现有的数据进行分析,构建概率统计模型,分析和预测未知新数据,同时也需要考虑模型的复杂度以

统计学习方法 (第3章)K近邻法 学习笔记

第3章 K近邻法 k近邻算法简单.直观:给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例的多数属于某个类,就把该输入实例分为这个类.当K=1时,又称为最近邻算法,这时候就是将训练数据集中与x最邻近点作为x的类. 3.1 k近邻模型 模型由三个基本要素--距离度量.k值得选择.和分类决策规则决定. 3.1.1 距离度量 p=2时,称为欧式距离,p=1时,称为曼哈顿距离. 3.1.2 k值的选择 k 值的选择会对k 近邻法的结果产生重大影响.如果选择较小的k

统计学习方法与Python实现(二)——k近邻法

统计学习方法与Python实现(二)——k近邻法 iwehdio的博客园:https://www.cnblogs.com/iwehdio/ 1.定义 k近邻法假设给定一个训练数据集,其中的实例类别已定.分类时,对新的实例,根据其k个最近邻的训练实例的类别,通过多数表决的方式进行预测.k近邻法不具有显式的学习过程,而实际上是利用训练数据集对特征空间进行划分,并作为其分类的模型.k近邻法的三个基本要素是 k值的选择.距离度量和分类决策规则. k近邻法的模型是将特征空间划分成一些称为单元的子空间,并且

李航统计学习方法——算法2——k近邻法

一.K近邻算法 k近邻法(k-nearest neighbor,k-NN)是一种基本分类与回归方法,输入实例的特征向量,输出实例的类别,其中类别可取多类 二.k近邻模型 2.1 距离度量 距离定义: (1)当p=1,称为曼哈顿距离 (2)当p=2,称为欧式距离 (3)当p取无穷大时,它是各个坐标距离的最大值 max|xi-xj| 注意:p值的选择会影响分类结果,例如二维空间的三个点 x1=(1,1),x2=(5,1), x3=(4,4) 由于x1和x2只有第二维上不同,不管p值如何变化,Lp始终

K近邻法(KNN)原理小结

K近邻法(k-nearst neighbors,KNN)是一种很基本的机器学习方法了,在我们平常的生活中也会不自主的应用.比如,我们判断一个人的人品,只需要观察他来往最密切的几个人的人品好坏就可以得出了.这里就运用了KNN的思想.KNN方法既可以做分类,也可以做回归,这点和决策树算法相同. KNN做回归和分类的主要区别在于最后做预测时候的决策方式不同.KNN做分类预测时,一般是选择多数表决法,即训练集里和预测的样本特征最近的K个样本,预测为里面有最多类别数的类别.而KNN做回归时,一般是选择平均

scikit-learn K近邻法类库使用小结

在K近邻法(KNN)原理小结这篇文章,我们讨论了KNN的原理和优缺点,这里我们就从实践出发,对scikit-learn 中KNN相关的类库使用做一个小结.主要关注于类库调参时的一个经验总结. 一.scikit-learn 中KNN相关的类库概述 在scikit-learn 中,与近邻法这一大类相关的类库都在sklearn.neighbors包之中.KNN分类树的类是KNeighborsClassifier,KNN回归树的类是KNeighborsRegressor.除此之外,还有KNN的扩展,即限

3.K近邻法

1. k 近邻算法k近邻法(k-nearest neighbor, k-NN) 是一种基本分类与回归方法.  k近邻法的输入为实例的特征向量, 对应于特征空间的点: 输出为实例的类别, 可以取多类. k近邻法假设给定一个训练数据集, 其中的实例类别已定. 分类时, 对新的实例, 根据其k个最近邻的训练实例的类别, 通过多数表决等方式进行预测.因此, k近邻法不具有显式的学习过程. k近邻法实际上利用训练数据集对特征向量空间进行划分, 并作为其分类的“模型”. k值的选择. 距离度量及分类决策规则