项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
欢迎大家star,留言,一起学习进步

1.前言

最小二乘法线性回归作为最基础的线性回归,在统计和机器学习中都有重要的地位。在机器学习中,线性回归用来从数据中获得启示来帮助预测,因此如何得到最拟合数据的函数和防止过拟合是研究重点。
假设我们的拟合函数是y=ax+by = ax + by=ax+b,标准的线性最小二乘采用MSE做为loss function。那么在用梯度下降求解的时候,参数a,b对应的梯度分别为:

∂∂aL(a,b)=∂∂a(12m∑i=1m(axi+b−yi)2)=1m∑i=1m(axi+b−yi)xi\frac{\partial}{\partial a} L(a, b) = \frac{\partial}{\partial a} \left(\frac{1}{2m} \sum_{i=1}^m(ax^i + b - y^i)^2 \right) = \frac{1}{m}\sum_{i=1}^m(ax^i + b - y^i) x^i∂a∂​L(a,b)=∂a∂​(2m1​i=1∑m​(axi+b−yi)2)=m1​i=1∑m​(axi+b−yi)xi
∂∂bL(a,b)=∂∂b(12m∑i=1m(axi+b−yi)2)=1m∑i=1m(axi+b−yi)\frac{\partial}{\partial b} L(a, b) = \frac{\partial}{\partial b} \left(\frac{1}{2m} \sum_{i=1}^m(ax^i + b - y^i)^2 \right) = \frac{1}{m}\sum_{i=1}^m(ax^i + b - y^i)∂b∂​L(a,b)=∂b∂​(2m1​i=1∑m​(axi+b−yi)2)=m1​i=1∑m​(axi+b−yi)

2.梯度下降求解最小二乘

import numpy as npalpha = 0.01
eps = 1e-6                                    x = [1., 2., 3., 4., 5., 6., 7., 8., 9.]
y = [3., 5., 7., 9., 11., 13., 15., 17., 19.] def solve_by_gradient():# m is sample numm = len(x)a, b= 0, 0sse, sse_new = 0, 0grad_a, grad_b = 0, 0count = 0for step in range(100000):count += 1for i in range(m):base = a * x[i] + b - y[i]grad_a += x[i] * basegrad_b += basegrad_a = grad_a / mgrad_b = grad_b / ma -= alpha * grad_ab -= alpha * grad_b# loss function: Mean Squared Error, MSE# because 2m is a const, so 1/2m can be ignoredfor j in range(m):sse_new += (a * x[j] + b - y[j]) ** 2if abs(sse_new - sse) < eps:breakelse:sse = sse_newprint('{0} * x + {1}'.format(a, b))print "count is: " , countsolve_by_gradient()

上面的代码严格按照梯度进行迭代而来,最后输出的结果为:

2.00003690546 * x + 0.999758438895
count is:  3386

由结果可知,最后还是比好好的拟合出了数据反应的y=2x+1的规律。
不过问题也比较明显,也比较好的证明了梯度下降的一个缺点:收敛速度很慢。像我们这个简单的例子,用了3386次迭代才最终收敛。

3.用矩阵求解的方式直接计算

在参考文献1中,我们给出了最小二乘矩阵求解的直接计算方式:
θ=(ATA)−1ATY\theta = (A^TA)^{-1}A^TYθ=(ATA)−1ATY

利用这个公式我们来实现一下:

x = [1., 2., 3., 4., 5., 6., 7., 8., 9.]
y = [3., 5., 7., 9., 11., 13., 15., 17., 19.]def solve_by_gd_matrix():x0 = [1.0 for i in range(9)]xarray = np.column_stack((x, x0))xmatrix = np.mat(xarray, float)yarray = np.array(y)ymatrix = np.mat(yarray, float)theta = (xmatrix.T * xmatrix).I * xmatrix.T * ymatrix.Tprint(theta)

代码中的x0,就是相当于偏置项b。θ\thetaθ求解公式直接套用上面的公式,最后代码运行的结果为:

[[2.][1.]]

直接精确求出a=2, b=1。

最小二乘法矩阵求解的推导过程如下,内容来自参考文献2.

参考文献:
1.https://mp.csdn.net/mdeditor/51589143#
2.https://zhuanlan.zhihu.com/p/33899560

用梯度下降求解最小二乘线性回归python实现相关推荐

  1. 梯度下降及一元线性回归[python代码](二)

    第一章.AI人工智能入门之梯度下降及一元线性回归(2) 目录 第一章.AI人工智能入门之梯度下降及一元线性回归(2) 一.线性回归是什么? 二.线性回归的应用 三.线性回归的一般形式 四.一元线性回归 ...

  2. 机器学习基础:理解梯度下降本质「附Python代码」

    https://www.toutiao.com/a6646958932096975373/ 2019-01-16 13:15:26 今天我们尝试用最简单的方式来理解梯度下降,在之后我们会尝试理解更复杂 ...

  3. 唐宇迪​​机器学习实战——梯度下降求解逻辑回归(理论基础+源代码实现)

    问题的提出 符号问题,这里的lg就是指log2,你的理解是正确的!在计算机科学中有些符号的使用跟我们在数学中使用的有区别.比如有时候log用来表示自然对数(以e为底数).希望对你有帮助! 首先计算机科 ...

  4. 【视频】什么是梯度下降?用线性回归解释和R语言估计GARCH实例

    全文链接:http://tecdat.cn/?p=23606 梯度下降是一种优化算法,能够为各种问题找到最佳解决方案(点击文末"阅读原文"获取完整代码数据). 梯度下降是什么? 梯 ...

  5. 梯度下降法求解多元线性回归 — NumPy

    梯度下降法求解多元线性回归问题 使用梯度下降法求解一元线性回归的方法也可以被推广到求解多元线性回归问题. 这是多元线性回归的模型: 其中的 X 和 W 都是 m+1 维的向量. 下图为它的损失函数: ...

  6. (MATLAB)使用梯度下降进行一元线性回归

    使用梯度下降进行一元线性回归 Step1:选择函数模型:y=wx+b Step2:选择损失函数: Step3: Step4:根据公式: 来更新w和b,最终得到最优解. 梯度下降函数代码如下: func ...

  7. python实现梯度下降求解方程斜率和截距,线性回归算法LinearRegression

    使用sklearn可以实现线性回归算法,新手练习写一个类似于线性回归的算法,进行比较. 1.调用sklearn中的LinearRegression: 导包: import numpy as np im ...

  8. 梯度下降原理及线性回归代码实现(python/java/c++)

    "梯度下降"顾名思义通过一步一步迭代逼近理想结果,当达到一定的精度或者超过迭代次数才退出,所以所获得的结果是一个近似值.在其他博客上面基本都有一个通俗的比喻:从山顶一步步下山.下面 ...

  9. 梯度下降python编程实现_【机器学习】线性回归——单变量梯度下降的实现(Python版)...

    [线性回归] 如果要用一句话来解释线性回归是什么的话,那么我的理解是这样子的:**线性回归,是从大量的数据中找出最优的线性(y=ax+b)拟合函数,通过数据确定函数中的未知参数,进而进行后续操作(预测 ...

  10. 梯度下降算法和牛顿算法原理以及使用python用梯度下降和最小二乘算法求回归系数

    梯度下降算法 以下内容参考 微信公众号 AI学习与实践平台 SIGAI 导度和梯度的问题 因为我们做的是多元函数的极值求解问题,所以我们直接讨论多元函数.多元函数的梯度定义为: 其中称为梯度算子,它作 ...

最新文章

  1. 大数据时代,谁的眼神锁定你?
  2. 百度分拆金融业务,融资19亿美元
  3. 后门技术(HOOK篇)之DT_RPATH
  4. ubuntu下配置交叉编译环境
  5. ckks方案优化最好的_站群如何优化才能提高SEO效果?站群优化方案有哪些?
  6. NumPy 矩阵库(Matrix)
  7. kali如何制作php字典_Kali Linux安装搜狗输入法
  8. 控制器及其中$scope
  9. ggbiplot设置分组_比PCA更好用的监督排序—LDA分析、作图及添加置信-ggord
  10. 曾在美国生产的苹果电脑,如今却败得一塌糊涂
  11. Ajax提交表单数据(包含文件)
  12. 213.打家劫舍II(力扣leetcode) 博主可答疑该问题
  13. 九大ICT企业年中业绩大比拼
  14. ms10_002(极光漏洞)渗透步骤——MSF搭建钓鱼网站
  15. 智慧路灯网关下的校园智慧路灯照明解决方案
  16. android平台数字看板,数据看板
  17. 关于Chrome浏览器主页被2345篡改
  18. 史上五大最光明的白帽黑客介绍 都有很大技术贡献
  19. BUGS 小胡的学习日志
  20. 把握出租车行驶的数据脉搏 :出租车轨迹数据给你答案!

热门文章

  1. UML中依赖(Dependency)和关联(Association)之间的区别
  2. 今天遇到的一个诡异的core和解决 std::sort
  3. 正则表达式案例分析 (二)
  4. C基础(41——45)
  5. “一夜成名”需要多久?他花了20年!
  6. Monkey学习笔记三:Monkey脚本编写
  7. [ERROR]-Error: failure: repodata/filelists.xml.gz from addons: [Errno 256] No more mirrors to try.
  8. 信息技术手册查重错误比对分析程序开发记录3
  9. ImportError: No module named 'requests.packages.urllib3'
  10. Python chapter 2amp;3 learning notes