sklearn中LinearRegression使用及源码解读

sklearn中的LinearRegression

  • 函数原型:class sklearn.linear_model.LinearRegression(fit_intercept=True,normalize=False,copy_X=True,n_jobs=1)
  • fit_intercept:模型是否存在截距
  • normalize:模型是否对数据进行标准化(在回归之前,对X减去平均值再除以二范数),如果fit_intercept被设置为False时,该参数将忽略。

    该函数有属性:coef_可供查看模型训练后得到的估计系数,如果获取的估计系数太大,说明模型有可能过拟合。

    使用样例:

      >>>from sklearn import linear_model
      >>>clf = linear_model.LinearRegression()
      X = [[0,0],[1,1],[2,2]]
      y = [0,1,2]
      >>>clf.fit(X,y)
      >>>print(clf.coef_)
      [ 0.5 0.5]
      >>>print(clf.intercept_)
      1.11022302463e-16

源码分析

在github可以找到LinearRegression的源码:LinearRegression

  • 主要思想:sklearn.linear_model.LinearRegression求解线性回归方程参数时,首先判断训练集X是否是稀疏矩阵,如果是,就用Golub&Kanlan双对角线化过程方法来求解;否则调用C库中LAPACK中的用基于分治法的奇异值分解来求解。在sklearn中并不是使用梯度下降法求解线性回归,而是使用最小二乘法求解。

    sklearn.LinearRegression的fit()方法:

      if sp.issparse(X):#如果X是稀疏矩阵
          if y.ndim < 2:
              out = sparse_lsqr(X, y)
              self.coef_ = out[0]
              self._residues = out[3]
          else:
              # sparse_lstsq cannot handle y with shape (M, K)
              outs = Parallel(n_jobs=n_jobs_)(
                  delayed(sparse_lsqr)(X, y[:, j].ravel())
                  for j in range(y.shape[1]))
              self.coef_ = np.vstack(out[0] for out in outs)
              self._residues = np.vstack(out[3] for out in outs)
      else:
          self.coef_, self._residues, self.rank_, self.singular_ =           linalg.lstsq(X, y)
          self.coef_ = self.coef_.T

几个有趣的点:

  • 如果y的维度小于2,并没有并行操作。
  • 如果训练集X是稀疏矩阵,就用sparse_lsqr()求解,否则使用linalg.lstsq()

linalg.lstsq()

scipy.linalg.lstsq()方法就是用来计算X为非稀疏矩阵时的模型系数。这是使用普通的最小二乘OLS法来求解线性回归参数的。

  • scipy.linalg.lstsq()方法源码

    scipy提供了三种方法来求解least-squres problem最小均方问题,即模型优化目标。其提供了三个选项gelsd,gelsy,geless,这些参数传入了get_lapack_funcs()。这三个参数实际上是C函数名,函数是从LAPACK(Linear Algebra PACKage)中获得的。

    gelsd:它是用singular value decomposition of A and a divide and conquer method方法来求解线性回归方程参数的。

    gelsy:computes the minimum-norm solution to a real/complex linear least squares problem

    gelss:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A.

    scipy.linalg.lstsq()方法使用gelsd求解(并没有为用户提供选项)。

sparse_lsqr()方法源码

sqarse_lsqr()方法用来计算X是稀疏矩阵时的模型系数。sparse_lsqr()就是不同版本的scipy.sparse.linalg.lsqr(),参考自论文C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS实现。

相关源码如下:

    if sp_version < (0, 15):
        # Backport fix for scikit-learn/scikit-learn#2986 / scipy/scipy#4142
        from ._scipy_sparse_lsqr_backport import lsqr as sparse_lsqr
    else:
        from scipy.sparse.linalg import lsqr as sparse_lsqr

原文地址:https://www.cnblogs.com/mengnan/p/9307642.html

时间: 2024-11-13 05:38:58

sklearn中LinearRegression使用及源码解读的相关文章

JAVA中ArrayList的扩增源码解读

今天某位大佬问了一下我关于ArrayList扩增的大小,本人甚是愚昧,用记忆之中的答案回复了一下,大佬大手一挥,去看源码再来回答我,所以就有了这篇观后感,个人愚见,共同进步吧. 然后先二话不说,上关于ArrayList的源码: /* * Copyright (c) 1997, 2013, Oracle and/or its affiliates. All rights reserved. * ORACLE PROPRIETARY/CONFIDENTIAL. Use is subject to l

sklearn中LinearRegression关键源码解读

问题的引入 我们知道,线性回归方程的参数,可以用梯度下降法求解,或者用正规方程求解. 那sklearn.linear_model.LinearRegression中,是不是可以指定求解方式呢?能不能从中获取梯度相关信息呢? 下面是线性回归最简单的用法. from sklearn import linear_model # Create linear regression object regr = linear_model.LinearRegression() # Train the model

关于cocos2dx lua中的clone函数的源码解读

cocos2dx的clone函数,是深拷贝,完全的拷贝,这段代码内容不多,不过第一次看还是有点晕,我把解读记下来分享一下. 源码解读: function clone(object)--clone函数 local lookup_table = {}--新建table用于记录 local function _copy(object)--_copy(object)函数用于实现复制 if type(object) ~= "table" then return object ---如果内容不是t

QCustomplot使用分享(二) 源码解读

一.头文件概述 从这篇文章开始,我们将正式的进入到QCustomPlot的实践学习中来,首先我们先来学习下QCustomPlot的类图,如果下载了QCustomPlot源码的同学可以自己去QCustomPlot的目录下documentation/qcustomplot下寻找一个名字叫做index.html的文件,将其在浏览器中打开,也是可以找到这个库的类图.如图1所示,是组成一个QCustomPlot类图的可能组成形式. 一个图表(QCustomPlot):包含一个或者多个图层.一个或多个ite

vue源码解读预热-0

vueJS的源码解读 vue源码总共包含约一万行代码量(包括注释)特别感谢作者Evan You开放的源代码,访问地址为Github 代码整体介绍与函数介绍预览 代码模块分析 代码整体思路 总体的分析 从图片中可以看出的为采用IIFE(Immediately-Invoked Function Expression)立即执行的函数表达式的形式进行的代码的编写 常见的几种插件方式: (function(,){}(,))或(function(,){})(,)或!function(){}()等等,其中必有

SpringMVC源码解读 - RequestMapping注解实现解读 - RequestCondition体系

一般我们开发时,使用最多的还是@RequestMapping注解方式. @RequestMapping(value = "/", param = "role=guest", consumes = "!application/json") public void myHtmlService() { // ... } 台前的是RequestMapping ,正经干活的却是RequestCondition,根据配置的不同条件匹配request. @Re

jdk1.8.0_45源码解读——HashMap的实现

jdk1.8.0_45源码解读——HashMap的实现 一.HashMap概述 HashMap是基于哈希表的Map接口实现的,此实现提供所有可选的映射操作.存储的是<key,value>对的映射,允许多个null值和一个null键.但此类不保证映射的顺序,特别是它不保证该顺序恒久不变.  除了HashMap是非同步以及允许使用null外,HashMap 类与 Hashtable大致相同. 此实现假定哈希函数将元素适当地分布在各桶之间,可为基本操作(get 和 put)提供稳定的性能.迭代col

15、Spark Streaming源码解读之No Receivers彻底思考

在前几期文章里讲了带Receiver的Spark Streaming 应用的相关源码解读,但是现在开发Spark Streaming的应用越来越多的采用No Receivers(Direct Approach)的方式,No Receiver的方式的优势: 1. 更强的控制自由度 2. 语义一致性 其实No Receivers的方式更符合我们读取数据,操作数据的思路的.因为Spark 本身是一个计算框架,他底层会有数据来源,如果没有Receivers,我们直接操作数据来源,这其实是一种更自然的方式

jdk1.8.0_45源码解读——Set接口和AbstractSet抽象类的实现

jdk1.8.0_45源码解读——Set接口和AbstractSet抽象类的实现 一. Set架构 如上图: (01) Set 是继承于Collection的接口.它是一个不允许有重复元素的集合.(02) AbstractSet 是一个抽象类,它继承于AbstractCollection.AbstractCollection实现了Set中的绝大部分函数,为Set的实现类提供了便利.(03) HastSet 和 TreeSet 是Set的两个实现类.        HashSet依赖于HashMa