在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现。但是在看到源码的一瞬间突然有种怀疑人生的感觉,我是谁?我在哪?果然大佬的代码只能让我膜拜。

在一目十行地看完代码之后,我发现了一个问题,梯度的单词是gradient,一般在代码中会使用缩写grad 来表示梯度,而在这个代码中除了Gram 之外竟然没有一个以'g' 开头的单词,更不用说gradient 了。那么代码中包括注释压根没提到过梯度,是不是说明这里根本没有使用梯度下降算法呢,换言之就是是否还有其他方法来实现最小二乘法的线性回归呢?带着这个疑问,我开始仔细阅读LinearRegression.fit() 函数。

首先都是参数处理,这里虽然看不太懂但是也能大概知道他在做什么,所以可以跳过。

然后来到了核心代码,核心代码中使用的几个判断:

 1 if self.positive:
 2     if y.ndim < 2:
 3         pass
 4     else:
 5         pass
 6 elif sp.issparse(X):
 7     if y.ndim < 2:
 8         pass
 9     else:
10         pass
11 else:
12     pass

self.positive 是在使用密集矩阵的时候设置的参数,y.ndim 表示y 的维度,简单来说就是y 中有几个[],所以大概能知道代码将密集矩阵与稀疏矩阵区分开,并且将一维矩阵与多维矩阵区分开,意味着不同的类别使用不同的方法。

if self.positive 分段解析

购物返利 https://m.cpa5.cn/

1 if self.positive:
2     if y.ndim < 2:
3         self.coef_, self._residues = optimize.nnls(X, y)
4     else:
5         # scipy.optimize.nnls cannot handle y with shape (M, K)
6         outs = Parallel(n_jobs=n_jobs_)(
7             delayed(optimize.nnls)(X, y[:, j])
8             for j in range(y.shape[1]))
9         self.coef_, self._residues = map(np.vstack, zip(*outs))

可以看出y 的维度小于2 的话使用optimize.nnls() 方法,否则进行其他处理,因为“scipy.optimize.nnls cannot handle y with shape (M, K)”,但看到之后也调用了optimize.nnls,所以应该是将矩阵处理成可以使用的样子。

并且值得注意的是这里使用了Parallel(n_jobs=n_jobs_)(delayed(optimize.nnls)(X, y[:, j])for j in range(y.shape[1])) 的调用方式,也就是形如fun(x)(y) 的方式,这意味着函数内定义了另一个函数,第一个括号是fun 的参数,第二个括号是给fun 函数内定义的函数的参数。

elif sp.issparse(X) 分段解析

 1 elif sp.issparse(X):
 2 X_offset_scale = X_offset / X_scale
 3
 4
 5 def matvec(b):
 6     return X.dot(b) - b.dot(X_offset_scale)
 7
 8
 9 def rmatvec(b):
10     return X.T.dot(b) - X_offset_scale * np.sum(b)
11
12
13 X_centered = sparse.linalg.LinearOperator(shape=X.shape,
14                                           matvec=matvec,
15                                           rmatvec=rmatvec)
16
17 if y.ndim < 2:
18     out = sparse_lsqr(X_centered, y)
19     self.coef_ = out[0]
20     self._residues = out[3]
21 else:
22     # sparse_lstsq cannot handle y with shape (M, K)
23     outs = Parallel(n_jobs=n_jobs_)(
24         delayed(sparse_lsqr)(X_centered, y[:, j].ravel())
25         for j in range(y.shape[1]))
26     self.coef_ = np.vstack([out[0] for out in outs])
27     self._residues = np.vstack([out[3] for out in outs])

可以看到先是对数据进行了处理,然后调用了sparse_lsqr() 函数。

剩余分段解析

1 else:
2 self.coef_, self._residues, self.rank_, self.singular_ = \n3     linalg.lstsq(X, y)
4 self.coef_ = self.coef_.T
5
6 if y.ndim == 1:
7     self.coef_ = np.ravel(self.coef_)
8 self._set_intercept(X_offset, y_offset, X_scale)

使用了linalg.lstsq() 函数。

以上我们可以看到代码中一共使用了3 个方法来实现线性回归:optimize.nnls()、sparse_lsqr()、linalg.lstsq()

optimize.nnls() 分析

NNLS 即非负正则化最小二乘法,代码实现由scipy.optimize.nnls 提供,这里只是将其封装起来,在源码的注释中提到该算法的FORTRAN 代码在Charles L. Lawson 与Richard J. Hanson 两位教授于1987 年所著的《Solving Least Squares Problems》中发布。“The algorithm is an active set method. It solves the KKT (Karush-Kuhn-Tucker) conditions for the non-negative least squares problem.” 可惜由于本人水平有限,并不能从书中或者此处的代码中学到该算法的精髓,只能先挖一个坑,以后有所提高了再来研究该算法。

sparse_lsqr() 分析

LSQR 即最小二乘QR分解算法,代码实现由scipy.sparse.linalg.lsqr 类提供,这里只是将其封装起来,在文档中可以看到:

LSQR uses an iterative method to approximate the solution.  The number of iterations required to reach a certain accuracy depends strongly on the scaling of the problem.  Poor scaling of the rows or columns of A should therefore be avoided where possible.

同时也给出了参考文献:

[1] C. C. Paige and M. A. Saunders (1982a). "LSQR: An algorithm for sparse linear equations and sparse least squares", ACM TOMS 8(1), 43-71.
[2] C. C. Paige and M. A. Saunders (1982b). "Algorithm 583.  LSQR: Sparse linear equations and least squares problems", ACM TOMS 8(2), 195-209.
[3] M. A. Saunders (1995).  "Solution of sparse rectangular systems using LSQR and CRAIG", BIT 35, 588-604.

可以看到LSQR 算法是Paige 和Saunders 于1982 年提出的一种方法,但水平有限,暂时并不清楚其中原理。

linalg.lstsq() 分析

LSTSQ 是 LeaST SQuare (最小二乘)的意思,也就是普通的最小二乘法。代码实现由scipy.linalg.lstsq 提供,这里只是将其封装起来。通过官方文档提供的源代码链接,我找到了lstsq 函数的源代码,注释中提到了:

Which LAPACK driver is used to solve the least-squares problem.
Options are ``'gelsd'``, ``'gelsy'``, ``'gelss'``. Default(``'gelsd'``) is a good choice.  However, ``'gelsy'`` can be slightly faster on many problems.  ``'gelss'`` was used historically.  It is generally slow but uses less memory.

也就是说有三个选项:gelsd(默认推荐)、gelsy(可能稍快)、gelss(使用内存少),那么来看看他们分别使用什么方法来解决最小二乘法吧。

 1 if driver in ('gelss', 'gelsd'):
 2     if driver == 'gelss':
 3         lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
 4         v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
 5                                                 overwrite_a=overwrite_a,
 6                                                 overwrite_b=overwrite_b)
 7
 8     elif driver == 'gelsd':
 9         if real_data:
10             lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
11             x, s, rank, info = lapack_func(a1, b1, lwork,
12                                            iwork, cond, False, False)
13         else:  # complex data
14             lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
15                                                  nrhs, cond)
16             x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
17                                            cond, False, False)
18 elif driver == 'gelsy':
19         lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
20         jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
21         v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
22                                           lwork, False, False)

看到调用了lapack_func,但是找了一下spicy 并没有发现这个函数,于是搜索lapack_func 找到了定义:

lapack_func, lapack_lwork = get_lapack_funcs((driver,'%s_lwork' % driver),(a1, b1))

原来它调用了get_lapack_funcs 函数,于是去找该函数的文档,从文档中看到了该方法的用处:“This routine automatically chooses between Fortran/C interfaces. Fortran code is used whenever possible for arrays with column major order. In all other cases, C code is preferred.”,原来这是一个Fortran 语言和C 语言的接口且首选C 语言,并且从源码中我看到了它是调用_get_funcs 实现,那3 个单词是C 语言的函数名,所幸我对C 语言的熟悉程度打过Python,于是去找C 语言API 进行学习。

gelsd:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A and a divide and conquer method. 分治法的奇异值分解找出最优解。

gelss:Computes the minimum-norm solution to a linear least squares problem using the singular value decomposition of A. 使用奇异值分解找出最优解。

gelsy:Computes the minimum-norm solution to a linear least squares problem using a complete orthogonal factorization of A. 使用完全正交分解的方式找出最优解。

至此对于所有函数的分析基本上是结束了,密集矩阵使用的是NNLS 算法,稀疏矩阵使用的是LSQR 算法,其他的使用的则是最常用的最小二乘法算法。回到开头的问题,为什么没有使用梯度下降呢?难道梯度下降并不是解决最小二乘或者说是线性回归的算法吗?在网上查阅了很多材料之后我发现自己的问题:我陷入了误区。

先来说说最小二乘法,最小二乘法在我初高中的时候学习简单线性回归的时候就接触了,根据百度百科的词条,最小二乘法(又称最小平方法)是一种数学优化技术。它通过最小化误差的平方和寻找数据的最佳函数匹配。利用最小二乘法可以简便地求得未知的数据,并使得这些求得的数据与实际数据之间误差的平方和为最小,也就是说最小二乘法的公式是目标函数 = MIN( ∑ (预测值 – 实际值)² )。这就意味着如果人工计算的话就需要穷举所有的函数,计算他们的损失然后找出最小的那个函数。

再来说说梯度下降,梯度下降是迭代法的一种,通过更新梯度来找到损失最小的那个函数,不知你发现问题了吗?最小二乘法与梯度下降是在做同一件事情,也就是最优化问题,两个是并行的关系,并不存在谁解决谁。百度百科中关于梯度下降的词条中提到:“梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以)。在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的方法是最小二乘法。”这里很清晰的指出了最小二乘法和梯度下降法的关系。

再来说说另一个我之前并不知道的概念:最小二乘准则。百度百科中提到这是一种对于偏差程度的评估准则,与上两者不同,上述的算法都是基于最小二乘准则提出的对于最小二乘法优化问题的解决方案,也就是如果不穷举的话如何找到最小二乘法的最优解。

列出我查阅的资料:

1、知乎-最小二乘法和梯度下降法有哪些区别?

2、百度百科-梯度下降

3、百度百科-最小二乘法

总结

1、虽然找到了sklearn.LinearRegression 类中对于线性回归的算法及实现,但发现并没有使用到梯度下降法,而是使用最小二乘法找到最优解,解开了我对最小二乘法与梯度下降到误解,但由于之前并未从事过算法研究与数学分析,对相应的算法一知半解,所以这里的代码难以看懂,只能就此作罢,学习了相应的算法之后再来学习代码实现。

2、在学习源码的过程中以及写这篇文章的过程中发现对于python的有些概念还是不太清晰,比如函数和方法还有fun(x)(y) 的调用方式,所以能看到上文中有些使用“函数”有些使用“方法”,可能并不对应,但之后熟悉了才能修改。

3、源码中的注释充斥着许多数学词汇,读起来让我异常头疼,几乎都要使用翻译软件才能理解,同时有些平常使用的词汇我也不懂,这个时候英语的作用就十分必要了。

4、总的来说,基础不扎实,水平不高,所以对于其中的精髓难以理解,同时可能文章中错漏百出,但我并未发现,这就是目前的问题或者说困境,勤加学习才能脱离。

深度学习03-sklearn.LinearRegression 源码学习相关推荐

  1. Vuex 4源码学习笔记 - 通过Vuex源码学习E2E测试(十一)

    在上一篇笔记中:Vuex 4源码学习笔记 - 做好changelog更新日志很重要(十) 我们学到了通过conventional-changelog来生成项目的Changelog更新日志,通过更新日志 ...

  2. Python Sklearn库源码学习--kmeans

    前言: 分析体检数据希望不拘泥于Sklearn库中已有的聚类算法,想着改一下Kmeans算法.本着学习的目的,现在开始查看sklearn的源代码.希望能够写成一个通用的包. 有必要先交代一下我使用的p ...

  3. springmvc项目在启动完成之后执行一次方法_学习笔记21-springmvc部分源码学习

    SpringMVC:请求处理流程 这几张图讲的大致都是一个东西,就是springmvc的请求处理流程: ① 用户发送请求到springmvc框架提供的DispatcherServlet 这个前端控制器 ...

  4. RocketMQ源码学习

    RocketMQ源码学习 文章目录 RocketMQ源码学习 Producer 是怎么将消息发送至 Broker 的? 同步发送 异步发送 队列选择器 事务消息 原理 Broker 是怎么处理客户端发 ...

  5. clickhouse-jdbc 源码学习

    clickhouse-jdbc 源码学习 文章目录 clickhouse-jdbc 源码学习 包介绍 依赖版本 搭建环境版本如下 QA 1.LocalDate/LocalDateTime不兼容 2.一 ...

  6. 【博学谷学习记录】超强总结,用心分享 | 架构师 Mybatis源码学习总结

    Mybatis源码学习 文章目录 Mybatis源码学习 一.Mybatis架构设计 二.源码剖析 1.如何解析的全局配置文件 解析配置文件源码流程 2.如何解析的映射配置文件 Select inse ...

  7. Hadoop HDFS源码学习之NameNode部分

    NameNode源码学习 文章目录 NameNode源码学习 一.文件系统目录树(第一关系) 2.1 INode相关类 2.2 快照特性的实现 2.3 FSEditLog类 2.4 FSImage类 ...

  8. hystrix 源码 线程池隔离_Spring Cloud Hystrix 源码学习合集

    # Spring Cloud Hystrix 源码学习合集 **Hystrix: Latency and Fault Tolerance for Distributed Systems** ![](h ...

  9. postgresql源码学习(57)—— pg中的四种动态库加载方法

    一. 基础知识 1. 什么是库 库其实就是一些通用代码,可以在程序中重复使用,比如一些数学函数,可以不需要自己编写,直接调用相关函数即可实现,避免重复造轮子. 在linux中,支持两种类型的库: 1. ...

最新文章

  1. Nginx Web服务应用
  2. 如何做流数据分析,Byron Ellis来告诉你...
  3. 2019-03-18-算法-进化(实现strStr())
  4. 《细胞》突破性成果!北京林业大学钮世辉等解开“中国松”基因密码
  5. table表格边框样式_如何在CAD创建、导入表格?原来CAD的表格功能这么强大
  6. py2neo database
  7. python学习第二天标准输入输出和注释用法
  8. [学习笔记-SLAM篇]视觉SLAM十四讲ch3
  9. 地理中经纬度的英文名称
  10. 五险一金及个税计算器
  11. ADSL上网常见故障解答
  12. 机器学习入门——简单线性回归
  13. Android 11.0 12.0关机界面全屏显示(UI全屏显示)
  14. 深度学习工作站装机指南
  15. [WDS] Disconnected!问题解决
  16. slqdbx mysql_sqldbx下载 v4.3 附sqldbx使用教程
  17. 华为发展鸿蒙系统再出奇招,为了留存现有手机用户可谓费尽脑汁
  18. 【机器学习】Python 3.0 简单实现K-邻近法
  19. 清明节全网灰色主题实现原理
  20. git初始化本地仓库,远程提交远端代码教程

热门文章

  1. 利用百度tts 实现文字转语音(node)
  2. Android 程序后台运行和锁屏运行
  3. 旗舰手机正在向技术“深水区”挺进
  4. CAD垂直标注出现拐角的问题
  5. 聚合支付的优势有哪些?
  6. 关于seo写作内容的一些探讨
  7. Kubernetes 「驾驶舱」 kubectl 知多少?
  8. STM32F103C8T6使用modbus协议读取温湿度传感器
  9. CAD文件(dwg)的加载-ArcEngine
  10. HTML练习--做一个在线简历