sklearn交叉验证2-【老鱼学sklearn】

过拟合

过拟合相当于一个人只会读书,却不知如何利用知识进行变通。

相当于他把考试题目背得滚瓜烂熟,但一旦环境稍微有些变化,就死得很惨。

从图形上看,类似下图的最右图:

从数学公式上来看,这个曲线应该是阶数太高的函数,因为一般任意的曲线都能由高阶函数来拟合,它拟合得太好了,因此丧失了泛化的能力。

用Learning curve 检视过拟合

首先加载digits数据集,其包含的是手写体的数字,从0到9:

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target

然后用SVC(支持向量机)对手写体数字进行分类,当然,这里要介绍的核心函数是learning_curve,先上代码看看:

# 导入支持向量机
from sklearn.svm import SVC
model = SVC(gamma=0.001)

train_sizes, train_loss, test_loss = learning_curve(model, X, y, cv=10, scoring=‘neg_mean_squared_error‘, train_sizes=[0.1, 0.25, 0.5, 0.75, 1])
# 平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

在learning_curve中设置了十一法的数据,同时在打分时使用了neg_mean_squared_error方式,也就是方差值,因此这个最后的得分值是负数;train_sizes指定了5轮检视学习曲线(10%, 25%, 50%, 75%, 100%):

最后,我们把根据每轮的训练集大小对于最终得分的影响程度画个图,相当于做题数量的多少跟最终考试成绩的关系图:

# 可视化图形
import matplotlib.pyplot as plt
plt.plot(train_sizes, train_loss_mean, label="Train")
plt.plot(train_sizes, test_loss_mean, label="Test")
plt.legend()
plt.show()

显示图形为:

看起来随着做题数量的增加,考试成绩越来越好了(损失函数的值越来越小了),并且考试成绩在慢慢逼近平常的训练成绩。

完整的代码如下:

from sklearn.datasets import load_digits

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target

# 加载学习曲线模块
from sklearn.model_selection import learning_curve
import numpy as np

# 导入支持向量机
from sklearn.svm import SVC
model = SVC(gamma=0.001)

train_sizes, train_loss, test_loss = learning_curve(model, X, y, cv=10, scoring=‘neg_mean_squared_error‘, train_sizes=[0.1, 0.25, 0.5, 0.75, 1])
# 平均每一轮所得到的平均方差(共5轮,分别为样本10%、25%、50%、75%、100%)
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

# 可视化图形
import matplotlib.pyplot as plt
plt.plot(train_sizes, train_loss_mean, label="Train")
plt.plot(train_sizes, test_loss_mean, label="Test")
plt.legend()
plt.show()

如果我们把上面代码中SVC参数的gamma值设置为0.1,显示出的图形为:

在上面的图形中,我们会发现再增加训练集的数据并没有使测试集的损失值下降,相当于一个人按照他的学习方式做训练题已经够多了,你做更多的训练题都无法提高你的考试成绩了,这时可能需要考虑的是稍微改变一下你的学习方法说不定就能提高考试成绩呢。

这也说明了,在某些情况下题海战术不一定奏效了。

在机器学习中表示为所学到的模型可能太复杂了,产生了过拟合(过拟合表现为训练集的损失函数在下降,但测试集的损失函数不降反升),不具备泛化能力,例如下图中绿色曲线就是一个过拟合的表现:

相应的损失函数曲线显示如下所示:

因此如果我们想要查看是否有过拟合,可以通过learning_curve函数来进行并以可视化的方式来查看结果。

时间: 2024-07-31 10:17:45

sklearn交叉验证2-【老鱼学sklearn】的相关文章

sklearn交叉验证-【老鱼学sklearn】

交叉验证(Cross validation),有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法.于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证. 一开始的子集被称为训练集.而其它的子集则被称为验证集或测试集.交叉验证是一种评估统计分析.机器学习算法对独立于训练数据的数据集的泛化能力(generalize). 我们以分类花的例子来看下: # 加载iris数据集 from sklearn.datasets import load_iris from s

sklearn数据库-【老鱼学sklearn】

在做机器学习时需要有数据进行训练,幸好sklearn提供了很多已经标注好的数据集供我们进行训练. 本节就来看看sklearn提供了哪些可供训练的数据集. 这些数据位于datasets中,网址为:http://scikit-learn.org/stable/modules/classes.html#module-sklearn.datasets 房价数据 加载波士顿房价数据,可以用于线性回归用: sklearn.datasets.load_boston:http://scikit-learn.or

为何学习matplotlib-【老鱼学matplotlib】

这次老鱼开始学习matplotlib了. 在上个pandas最后一篇博文中,我们已经看到了用matplotlib进行绘图的功能,这次更加系统性地多学习一下关于matplotlib的功能. 在matlab中,其拥有非常强大的显示图表的功能. 在python中,就提供了一个类似matlab软件中的画图库matplotlib,其基本上是模仿matlab中的画图函数. 官网中介绍的显示图表的例子见:http://matplotlib.org/gallery/index.html 要使用,就必须先进行安装

Python 之 sklearn 交叉验证 数据拆分

本文K折验证拟采用的是 Python 中 sklearn 包中的 StratifiedKFold 方法. 方法思想详见:http://scikit-learn.org/stable/modules/cross_validation.html StratifiedKFold is a variation of k-fold which returns stratified folds: each set contains approximately the same percentage of s

tensorflow用dropout解决over fitting-【老鱼学tensorflow】

在机器学习中可能会存在过拟合的问题,表现为在训练集上表现很好,但在测试集中表现不如训练集中的那么好. 图中黑色曲线是正常模型,绿色曲线就是overfitting模型.尽管绿色曲线很精确的区分了所有的训练数据,但是并没有描述数据的整体特征,对新测试数据的适应性较差. 一般用于解决过拟合的方法有增加权重的惩罚机制,比如L2正规化,但在本处我们使用tensorflow提供的dropout方法,在训练的时候, 我们随机忽略掉一些神经元和神经联结 , 是这个神经网络变得"不完整". 用一个不完整

tensorflow分类-【老鱼学tensorflow】

前面我们学习过回归问题,比如对于房价的预测,因为其预测值是个连续的值,因此属于回归问题. 但还有一类问题属于分类的问题,比如我们根据一张图片来辨别它是一只猫还是一只狗.某篇文章的内容是属于体育新闻还是经济新闻等,这个结果是有一个全集的离散值,这类问题就是分类问题. 我有时会把回归问题看成是分类问题,比如对于房价值的预测,在实际的应用中,一般不需要把房价精确到元为单位的,比如对于均价,以上海房价为例,可以分为:5000-10万这样的一个范围段,并且以1000为单位就可以了,尽管这样分出了很多类,但

pandas基本介绍-【老鱼学pandas】

前面我们学习了numpy,现在我们来学习一下pandas. Python Data Analysis Library 或 pandas 主要用于处理类似excel一样的数据格式,其中有表头.数据序列号以及实际的数据,而numpy就仅仅包含了实际的数据. 安装 直接输入: pip3 install pandas 最基本用法 import pandas as pd s = pd.Series([1, 2, 5, 6]) print(s) 输出: 0 1 1 2 2 5 3 6 dtype: int6

pandas设置值-【老鱼学pandas】

本节主要讲述如何根据上篇博客中选择出相应的数据之后,对其中的数据进行修改. 对某个值进行修改 例如,我们想对数据集中第2行第2列的数据进行修改: import pandas as pd import numpy as np dates = pd.date_range("2017-01-08", periods=6) data = pd.DataFrame(np.arange(24).reshape(6, 4), index=dates, columns=["A",

pandas处理丢失数据-【老鱼学pandas】

假设我们的数据集中有缺失值,该如何进行处理呢? 丢弃缺失值的行或列 首先我们定义了数据集的缺失值: import pandas as pd import numpy as np dates = pd.date_range("2017-01-08", periods=6) data = pd.DataFrame(np.arange(24).reshape(6, 4), index=dates, columns=["A", "B", "C&