一篇文章,带你明白什么是过拟合,欠拟合以及交叉验证

误差模型:过拟合,交叉验证,偏差-方差权衡

作者Natasha Latysheva;Charles Ravarani

发表于cambridgecoding


介绍

??在本文中也许你会掌握机器学习中最核心的概念:偏差-方差权衡.其主要想法是,你想创建尽可能预测准确并且仍能适用于新数据的模型(这是泛化).危险的是,你可以轻松的在你制定的数据中创建过度拟合本地噪音的模型,这样的模型是无用的,并且导致弱泛化能力,因为噪声是随机的,故而在每个数据集中是不同的.从本质上讲,你希望创建仅捕获数据集中有用成份的模型.另一方面,泛化能力很好但是对于产生良好预测过于僵化的模型是另一个极端(这称之为欠拟合).

??我们使用k-近邻算法讨论并展示这些概念,k-近邻带有一个简单的参数k,可以用不同的参数清楚的展示欠拟合,过拟合以及泛化能力的思想.同时,平衡欠拟合和过拟合之间的相关概念称为偏差-方差权衡.这里有一个表格概括了无论是过拟合或者欠拟合模型中一些不同但相同

??我们将解释这些术语的意思,以及他们如何关联的.同样也会讨论交叉验证,这是评估模型准确率和泛化能力的优秀指标.

??你会在未来的所有博文中遇到这些概念,将涵盖模型优化,随机森林,朴素贝叶斯,逻辑回归以及如何将不同模型组合成为集成元模型.

产生数据

??让我们从建立人工数据集开始.你可以轻松的使用sklearn.datasets中的make_classification()函数做到这一点.具体来说,你会生成相对简单的二元分类问题.为了让它更有趣一点,让我们的数据呈现月牙型并加入一些随机噪声.这应该能让其更真实并提高分类观测的难度.

“`

Creating the dataset

e.g. make_moons generates crescent-shaped data

Check out make_classification, which generates linearly-separable data

from sklearn.datasets import make_moons

X, y = make_moons(

n_samples=500, # the number of observations

random_state=1,

noise=0.3

)

Take a peek

print(X[:10,])

print(y[:10])

“`

[[ 0.50316464 0.11135559]

[ 1.06597837 -0.63035547]

[ 0.95663377 0.58199637]

[ 0.33961202 0.40713937]

[ 2.17952333 -0.08488181]

[ 2.00520942 0.7817976 ]

[ 0.12531776 -0.14925731]

[ 1.06990641 0.36447753]

[-0.76391099 -0.6136396 ]

[ 0.55678871 0.8810501 ]]

[1 1 0 0 1 1 1 0 0 0]

??你刚生成的数据集如下图所示:

“`

import matplotlib.pyplot as plt

from matplotlib.colors import ListedColorma

%matplotlib inline # for the plots to appear inline in jupyter notebooks

Plot the first feature against the other, color by class

plt.scatter(X[y == 1, 0], X[y == 1, 1], color=”#EE3D34”, marker=”x”)

plt.scatter(X[y == 0, 0], X[y == 0, 1], color=”#4458A7”, marker=”o”)

“`

<\center>

??接下来,让我们将数据且分为训练集测试集 .训练集用于开发和优化模型.测试集完全分离,直到最后在此运行完成的模型.拥有测试集允许你在之前看不到的数据之外,模型运行良好的估计.

“`

from sklearn.cross_validation import train_test_split

Split into training and test sets

XTrain, XTest, yTrain, yTest = train_test_split(X, y, random_state=1, test_size=0.5)

“`

??使用K近邻(KNN)分类器预测数据集类别.Introduction to Statistical Learning第二章提供了关于KNN理论非常好介绍.我是ISLR书的脑残粉.你同样可以看看之前文章 how to implement the algorithm from scratch in Python.

介绍KNN中的超参数K

??KNN算法的工作原理是,对新数据点利用K近邻信息分配类别标签.只注重于和它最相似数据点的类,并分配这个新数据点到这些近邻中最常见的类.当使用KNN,你需要设定希望算法使用的K值.

??如果K很高(k=99),模型在对未知数据点类别做决策是会考虑大量近邻.这意味着模型是相当受限的,因为它分类实例时,考虑了大量信息.换句话说,一个大的k值导致相当”刚性”的模型行为.

??相反,如果k很低(k=1,或k=2),在做分类决策时只考虑少量近邻,这是非常灵活并且非常复杂的模型,它能完美拟合数据的精确形式.因此模型预测更依赖于数据的局部趋势(关键的是,包含噪声).

??让我们看一看k=99与k=1时KNN算法分类数据的情况.绿色的线是训练数据的决策边界(算法中的阈值决定一个数据点是否属于蓝或红类).

??在本文最后你会学会如何生成这些图像,但是先让我们先深入理论.

??当k=99(左),看起来模型拟合有点太平滑,对于有点接近的数据可以忍受.模型具有低灵活性低复杂度 .它描绘了一个笼统的决策边界.它具有比较高的偏差 ,因为对数据建模并不好,模型化数据的底层生成过程太过简单,并且偏离了事实.但是,如果你扔到另一个稍微不同的数据集,决策边界可能看起来非常相似.这是不会有非常大差异的稳定模型–它具有低方差.

??当k=1(右侧),你可以看到模型过度拟合噪声.从技术上来说,在训练集生成非常完美的预测结果(在右下角的错误等于0.0),但是希望你可以看到这样的拟合方式对于单独数据点过于敏感.牢记你在数据集中添加了噪声.看起来模型拟合对噪声太过重视并且拟合的非常紧密.你可以说,k=1的模型具有高灵活性高复杂度 ,因为它对数据调优非常紧密.同样具有低偏差,如果不出意外,决策边界肯定适合你观测数据的趋势.但是,在稍微改变的数据上,拟合的边界会大大改变,这将是非常显著的.K=1的模型具有高方差 .

??但是模型的泛化能力如何?在新数据上表现如何?

??目前你只能看到训练数据,但是量化训练误差没多大用处.对模型概括刚学习的训练集性能有多好,你不感兴趣.让我们看看在测试集表现如何,因为这会对模型好坏给你一个更直观的印象.试着使用不同的K值:

from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics
knn99 = KNeighborsClassifier(n_neighbors = 99)
knn99.fit(XTrain, yTrain)
yPredK99 = knn99.predict(XTest)
print "Overall Error of k=99 Model:", 1 - round(metrics.accuracy_score(yTest, yPredK99), 2)
knn1 = KNeighborsClassifier(n_neighbors = 1)
knn1.fit(XTrain, yTrain)
yPredK1 = knn1.predict(XTest)
print "Overall Error of k=1 Model:", 1 - round(metrics.accuracy_score(yTest, yPredK1), 2)

Overall Error of k=99 Model: 0.15

Overall Error of k=1 Model: 0.15

??实际上,看起来这些模型对测试集表现的大约同样出色.下面是通过训练集学习到的决策边界应用于测试集.看能否找出两个模型错误的预测.

??两个模型出错有不同的原因.看起来k=99的模型对捕获月牙形数据特征方面表现不是很好(这是欠拟合),而k=1的模型是对噪声严重的过拟合.记住,过拟合的特点是良好的训练表现和糟糕的测试表现,你能在这里观察到这些.

??也许k在1到99的中间值是你想要的?

knn50 = KNeighborsClassifier(n_neighbors = 50)
knn50.fit(XTrain, yTrain)
yPredK50 = knn50.predict(XTest)
print "Overall Error of k=50 Model:", 1 - round(metrics.accuracy_score(yTest, yPredK50), 2)

Overall Error of k=50 Model: 0.11

??看起来好了点.让我们检查k=50时模型的决策边界.

??不错!模型拟合类似数据集的实际趋势,这种改善体现在较低的测试误差.

偏差-方差权衡:结论意见

??希望你现在对模型的欠拟合和过拟合有良好的理解.看现在是否理解本文开头的所有术语.基本上,发现过拟合和欠拟合之间正确的平衡关系相当于偏差-方差权衡.

??总的来说,当你对一个数据集训练机器学习算法,关注模型在一个独立数据模型的表现如何.对于训练集做好分类是不够的.本质上来讲,只关心构建可泛化的模型–对于训练集获得100%的准确率并不令人印象深刻,仅仅是过拟合的指标.过拟合是紧密拟合模型,并且调优噪声而不是信号的情况.

??更清楚的讲,你不是建模数据集中的趋势.而是尝试建模真实世界过程,引导我们研究数据.你恰好使用的具体数据集只是基础事实的一小部分实例,其中包含噪声和自身的特点.

??下列汇总图片展示在训练集和测试集上欠拟合(高偏差,低方差),正确拟合,以及过拟合(低偏差,高方差)模型如何表现:

??建立泛化模型这种想法背后的动机是切分数据集为为一个训练集和测试集(在你分析的最后提供模型性能的准确测量).

??但是,它也有可能过拟合测试数据.如果你对测试集尝试许多不同模型,并为了追求精度不断改变它们,然后测试集的信息可能不经意地渗入到模型创建阶段.你需要一个办法解决.

使用K折交叉验证评估模型性能

??输入K折交叉验证,这是仅使用训练集衡量模型性能的一个方便技术.该过程如下:你随机划分训练集为k等份;然后,我们在k-1/k的训练集上训练数据;对剩下的一部分评估性能.这给你一些模型性能的指标(如,整体精度).接下来训练在不同的k-1/k训练集训练算法,并在剩下的1部分评估.你重复这个过程k次,得到k个不同的模型性能度量,利用这些值的平均值得到整体性能的度量.继续例子,10折交叉验证背后如下:

??你可以使用k折交叉验证获得模型精度的评估,同样可以利用这些估计调整你的模型直到令你满意.这使得你不用最后测试数据,因此避免了过拟合的危险.换句话说,交叉验证提供一种方式模拟比你实际拥有更多的数据,因此你不用建模最后才使用测试集.k折交叉验证以及其变种是非常流行并且非常有用,尤其你尝试许多不同的模型(如果你想测试不同参数模型性能如何).

比较训练误差,交叉验证误差和测试误差

??那么,什么k是最佳的?对训练数据构建模形式尝试不同K值,看对训练集本身和测试集预测类别的结果模型如何.最后看K折交叉验证如何支出最好的K.

??注:实践中,当扫描这样的参数,使用训练集测试模型是以个糟糕的主意.相同的方式,你不能使用测试集多次浏览一个参数(每个参数值一次).接下来,你是用这些计算只是作为例子.实践中,只有K折交叉验证是一种安全的方法!

import numpy as np
from sklearn.cross_validation import train_test_split, cross_val_score

knn = KNeighborsClassifier()

# the range of number of neighbors you want to test

n_neighbors = np.arange(1, 141, 2)

# here you store the models for each dataset used

train_scores = list()
test_scores = list()
cv_scores = list()

# loop through possible n_neighbors and try them out

for n in n_neighbors:
knn.n_neighbors = n
knn.fit(XTrain, yTrain)
train_scores.append(1 - metrics.accuracy_score(yTrain, knn.predict(XTrain))) # this will over-estimate the accuracy
test_scores.append(1 - metrics.accuracy_score(yTest, knn.predict(XTest)))
cv_scores.append(1 - cross_val_score(knn, XTrain, yTrain, cv = 10).mean()) # you take the mean of the CV scores

??那么最优的k是多少?当多个同样的预测误差,你随便挑一个最小的作为k值.


# what do these different datasets think is the best value of k?

print(
‘The best values of k are: n‘
‘{} according to the Training Setn‘
‘{} according to the Test Set andn‘
‘{} according to Cross-Validation‘.format(
min(n_neighbors[train_scores == min(train_scores)]),
min(n_neighbors[test_scores == min(test_scores)]),
min(n_neighbors[cv_scores == min(cv_scores)])
)
)

最优K是:

1 according to the Training Set

23 according to the Test Set and

11 according to Cross-Validation

??不仅仅是收集最优的k,还需要对一系列测试的K看看预测误差.


# let‘s plot the error you get with different values of k

plt.figure(figsize=(10,7.5))
plt.plot(n_neighbors, train_scores, c="black", label="Training Set")
plt.plot(n_neighbors, test_scores, c="black", linestyle="--", label="Test Set")
plt.plot(n_neighbors, cv_scores, c="green", label="Cross-Validation")
plt.xlabel(‘Number of K Nearest Neighbors‘)
plt.ylabel(‘Classification Error‘)
plt.gca().invert_xaxis()
plt.legend(loc = "lower left")
plt.show()

??让我们谈谈训练集的分类错误.你考虑少量近邻,训练集会得到低的预测误差.这是有道理的,因为在做新的分类是,逼近每个点只考虑它本身的情况.测试误差遵循类似的轨迹,但是在某个点后由于过拟合而增长.这种现象表明,构建的训练集模型拟合在指定测试集样本上建模效果不好.

??在该图中可以看到,尤其是对于k的低值,采用k折交叉验证突出参数空间的区域(即k的非常低的值),这是非常容易出现过拟合的。尽管交叉验证和测试集的评估导致一些不同的最优解,它们都是相当不错的,并且大致正确。你也可以看到,交叉验证是测试误差的合理估计。这种类型的情节是好的,以获得确定参数如何影响模型表现的良好感觉,并帮助建立数据集的直觉来学习。

代码展示

??这是生成以上所有图片,训练测试不同kNN算法的代码.代码是scikit-learn样例改编的代码,主要处理决策边界的计算并让图片好看.

包含机器学习中拆分数据集,算法拟合以及测试的部分。

def detect_plot_dimension(X, h=0.02, b=0.05):
x_min, x_max = X[:, 0].min() - b, X[:, 0].max() + b
y_min, y_max = X[:, 1].min() - b, X[:, 1].max() + b
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
dimension = xx, yy
return dimension

def detect_decision_boundary(dimension, model):
xx, yy = dimension # unpack the dimensions
boundary = model.predict(np.c_[xx.ravel(), yy.ravel()])
boundary = boundary.reshape(xx.shape) # Put the result into a color plot
return boundary

def plot_decision_boundary(panel, dimension, boundary, colors=[‘#DADDED‘, ‘#FBD8D8‘]):
xx, yy = dimension # unpack the dimensions
panel.contourf(xx, yy, boundary, cmap=ListedColormap(colors), alpha=1)
panel.contour(xx, yy, boundary, colors="g", alpha=1, linewidths=0.5) # the decision boundary in green

def plot_dataset(panel, X, y, colors=["#EE3D34", "#4458A7"], markers=["x", "o"]):
panel.scatter(X[y == 1, 0], X[y == 1, 1], color=colors[0], marker=markers[0])
panel.scatter(X[y == 0, 0], X[y == 0, 1], color=colors[1], marker=markers[1])

def calculate_prediction_error(model, X, y):
yPred = model.predict(X)
score = 1 - round(metrics.accuracy_score(y, yPred), 2)
return score

def plot_prediction_error(panel, dimension, score, b=.3):
xx, yy = dimension # unpack the dimensions
panel.text(xx.max() - b, yy.min() + b, (‘%.2f‘ % score).lstrip(‘0‘), size=15, horizontalalignment=‘right‘)

def explore_fitting_boundaries(model, n_neighbors, datasets, width):

# determine the height of the plot given the aspect ration of each panel should be equal

height = float(width)/len(n_neighbors) * len(datasets.keys())

nrows = len(datasets.keys())
ncols = len(n_neighbors)

# set up the plot

figure, axes = plt.subplots(
nrows,
ncols,
figsize=(width, height),
sharex=True,
sharey=True
)

dimension = detect_plot_dimension(X, h=0.02) # the dimension each subplot based on the data

# Plotting the dataset and decision boundaries

i = 0
for n in n_neighbors:
model.n_neighbors = n
model.fit(datasets["Training Set"][0], datasets["Training Set"][1])
boundary = detect_decision_boundary(dimension, model)
j = 0
for d in datasets.keys():
try:
panel = axes[j, i]
except (TypeError, IndexError):
if (nrows * ncols) == 1:
panel = axes
elif nrows == 1: # if you only have one dataset
panel = axes[i]
elif ncols == 1: # if you only try one number of neighbors
panel = axes[j]
plot_decision_boundary(panel, dimension, boundary) # plot the decision boundary
plot_dataset(panel, X=datasets[d][0], y=datasets[d][1]) # plot the observations
score = calculate_prediction_error(model, X=datasets[d][0], y=datasets[d][1])
plot_prediction_error(panel, dimension, score, b=0.2) # plot the score

# make compacted layout

panel.set_frame_on(False)
panel.set_xticks([])
panel.set_yticks([])

# format the axis labels

if i == 0:
panel.set_ylabel(d)
if j == 0:
panel.set_title(‘k={}‘.format(n))
j += 1
i += 1

plt.subplots_adjust(hspace=0, wspace=0) # make compacted layout

??然后,你可以这样运行代码:


# specify the model and settings

model = KNeighborsClassifier()
n_neighbors = [200, 99, 50, 23, 11, 1]
datasets = {
"Training Set": [XTrain, yTrain],
"Test Set": [XTest, yTest]
}
width = 20

# explore_fitting_boundaries(model, n_neighbors, datasets, width)

explore_fitting_boundaries(model=model, n_neighbors=n_neighbors, datasets=datasets, width=width)

结论

??偏差-方差权衡出现在机器学习的不同领域.所有算法都可以认为具有一定弹性,而且不仅仅是KNN.发现描述良好数据模式并且可以泛化新数据,这样灵活的最佳点的目标适用于基本上所有算法.

时间: 2024-10-10 06:52:11

一篇文章,带你明白什么是过拟合,欠拟合以及交叉验证的相关文章

PDF怎么拆分成多个PDF,看完这篇文章你就明白了

PDF文件对于每一个经常在职场上工作的人来说,是特别常见的一个文档格式,PDF格式深受人们的喜爱,因为是特别好用的,但同时也是比较难进行编辑和修改的,特别是遇到PDF文档过长,为了方便浏览和及时查找对我们有用的内容,这就需要将PDF文档拆分成多个PDF,那么PDF怎么拆分成多个PDF?通过今天的文章就来告诉大家PDF文档拆分的方法,看完这篇文章你就明白了,那么我们就一起来看看吧.?方法一:软件拆分法借助软件:如果想要将PDF文档拆分成多个PDF,那就需要借助迅捷PDF转换器来实现,这个软件有着丰

两篇文章带你走入.NET Core 世界:Kestrel+Nginx+Supervisor 部署上云服务器(二)

背景: 上一篇:两篇文章带你走入.NET Core 世界:CentOS+Kestrel+Ngnix 虚拟机先走一遍(一) 已经交待了背景,这篇就省下背景了,这是第二篇文章了,看完就木有下篇了. 直接进入主题: 1.购买云服务器 之前在虚拟机跑了一下,感觉还不够真实,于是,准备买台服务器,认真的跑一下. 有阿里云,腾讯云,华为云,还有好多云,去哪买一个? 之前做为华为云的云享专家去参加了一下活动,本来也准备写篇文章,不过相同游记文太多, 这里就转一篇了:让华为云MVP告诉你——在华为的一天可以做什

三篇文章带你极速入门php(三)之php原生实现登陆注册

看下成果 ps:纯天然h5,绝不添加任何添加剂(css)以及化学成分(js)(<( ̄ ﹌  ̄)我就是喜欢纯天然,不接受任何反驳) 关于本文 用原生的php和html做了一个登陆注册,大概是可以窥见一般php开发的样子了.不过,low的地方区别提前说一下: 这个是多入口,一般程序都是单入口,单入口就是统一通过index.php进入,然后再引入其他文件,调用其代码,多入口就是每次通过不同文件进入(比如一会展示的login.php和register.php) 保留登陆信息用的是session,现在普遍

三篇文章带你极速入门php(一)之语法

本文适合阅读用户 有其他语言基础的童鞋 看完w3cschool语法教程来回顾一下的童鞋(传送门,想全面看一下php语法推荐这里) 毫无基础然而天资聪慧颇有慧根(不要左顾右看说的就是你,老夫这里有一本<php从入门到放弃>,观你根骨清奇10两银子卖给你如何) 看完本文后你会收获到什么 php的变量的定义,使用 函数的定义,使用,传递参数 数组的定义,调用,常用方法,使用场景 php中循环,判断,选择结构的语法 类的定义,成员变量和成员函数的定义和使用 相信我,认真看完本文,你就已经掌握了php常

还不知道事务消息吗?这篇文章带你全面扫盲!

在分布式系统中,为了保证数据一致性是必须使用分布式事务.分布式事务实现方式就很多种,今天主要介绍一下使用 RocketMQ 事务消息,实现分布事务. 文末有彩蛋,看完再走 为什么需要事务消息? 很多同学可能不知道事务消息是什么,没关系,举一个真实业务场景,先来带你了解一下普通的消息存在问题. 上面业务场景中,当用户支付成功,将会更新支付订单,然后发送 MQ 消息.手续费系统将会通过拉取消息,计算手续费然后保存到另外一个手续费数据库中. 由于计算手续费这个步骤可以离线计算,所以这里采用 MQ 解耦

这篇文章带你彻底理解synchronized

本人免费整理了Java高级资料,涵盖了Java.Redis.MongoDB.MySQL.Zookeeper.Spring Cloud.Dubbo高并发分布式等教程,一共30G,需要自己领取.传送门:https://mp.weixin.qq.com/s/JzddfH-7yNudmkjT0IRL8Q 1. synchronized简介在学习知识前,我们先来看一个现象: public class SynchronizedDemo implements Runnable { private static

我们工作到底为了什么(这篇文章很重要)

我们工作到底为了什么(这篇文章很重要) HP大中华区总裁孙振耀退休感言 : 如果这篇文章没有分享给你,那是我的错. 如果这篇文章分享给你了,你却没有读,继续走弯路的你不要怪我. 如果你看了这篇文章,只读了一半你就说没时间了,说明你已经是个"茫"人了. 如果你看完了,你觉得这篇文章只是讲讲大道理,说明你的人生阅历还不够,需要你把这篇文章珍藏,走出去碰几年壁,头破血流后再回来,再读,你就会感叹自己的年少无知. 如果你看完了,觉得很有道理,然后束之高阁,继续走进拥挤的地铁,依然用着自己昨日的

【转载】如果有人问你数据库的原理,叫他看这篇文章

原文:如果有人问你数据库的原理,叫他看这篇文章 本文由 伯乐在线 - Panblack 翻译,黄利民 校稿.未经许可,禁止转载!英文出处:Christophe Kalenzaga.欢迎加入翻译组. 一提到关系型数据库,我禁不住想:有些东西被忽视了.关系型数据库无处不在,而且种类繁多,从小巧实用的 SQLite 到强大的 Teradata .但很少有文章讲解数据库是如何工作的.你可以自己谷歌/百度一下『关系型数据库原理』,看看结果多么的稀少[译者注:百度为您找到相关结果约1,850,000个…] 

十年后2023年再读这篇文章,看看我将会怎么样?

http://blog.csdn.net/wojiushiwo987/article/details/8453881看到一篇文章不错[清华差生10年奋斗经历] ,写给将要工作的自己,十年后2023年再读这篇文章,看看我将会怎么样? 在2012年收关时刻,看到如此激励的文章,实在是我的幸运.文章讲述了所谓清华差生的奋斗史,从毕业.各种工作经历.与同事.领导关系细致入微的剖析了实战的职场及人和人差距拉开的原因等.正如文中作者指出的那样,这也是我的心灵导师俞敏洪一直教导的,”人生是跑马拉松的过程,不在