点击“小詹学Python”,选择“置顶”公众号

重磅干货,第一时间送达

本文转载自数据分析挖掘与算法,禁二次转载

阅读本文需要的知识储备:

  • 高等数学

  • 运筹学

  • Python基础

引出梯度下降


对于,线性回归问题,上一篇我们用的是最小二乘法,很多人听到这个,或许会说:天杀的最小二乘法,因为很多人对它太敏感了。是的,从小到大,天天最小二乘法,能不能来点新花样。这里就用数学算法——梯度下降,来解决,寻优问题。

当然了,我们的目标函数还是:

在开始之前,我还是上大家熟知常见的图片。

梯度下山图片(来源:百度图片)

找了好久,我选了这张图片,因为我觉得这张图片很形象:天气骤变,一个人需要快速下山回家,但是他迷路了,不知道怎么回家,他知道自己家位于这座山海拔最低处。环顾四周,怎么样最快下山回家呢。他个子一定(假设1.8m大个子吧),每次迈步子为平时走路最大步长了,哈哈!(假设保持不变),要往哪个方向走才能使得:每迈出一步,自己下降的高度最大呢?只要每步有效下降高度最大,我们完全有理由相信,他会最快下山回家。

所以:他会告诉自己,我每次要找一个最好的下山方向(有点像“贪心”)。

其实,这个图还反映了另外一个问题,对于有多个极值点的情况,不同的初始出发点,梯度下降可能会陷入局部极小值点。就像一句古诗:不识庐山真面目,只缘身在此山中!这时候,就需要多点随机下山解决。当然了,解决线性回归问题的梯度下降是基于误差平方和,只有二次项,不存在多峰问题。

梯度下降的理论基础

我们都现在都知道这个人的任务是什么了:每次要找一个最好的下山方向。数学微分学告诉我们:其实这里的方向就是我们平时所说的:方向导数,它可以衡量函数值沿着某个方向变化的快慢,只要选择了好的方向(导数),就能快速达到(最大/最小值)。

(1)、梯度的定义

这里还是摆一个公式吧,否则看着不符合我的风格。方向导数定义就不扯远了,那是个关于极限的定义。这里给出三元函数梯度定义公式:

显然,让沿着与梯度方向,夹角为0或者180°时函数值增减最快。

其实,每个多元函数在任一点会有一个梯度。函数在某一点沿着梯度方向,函数值是变化最快的。这里就不过多证明了。

(2)、步长的求法

其实,我们可以设定一个指定步长。但是,这个指定步长到底设为多大合适。众所周知,过大会导致越过极小值点;过小在数据量大时会导致迭代次数过多。所以我们需要一套理论可以来科学得计算步长。保证在步长变换过程中,尽管有时可能会走回头路,但总体趋势是向驻点逼近。

这里简单说一下,假设在图中一点沿着梯度方向存在二阶偏导数,就可以泰勒展开到平方项,进而对这个关于步长的函数求导数,导函数零点就是此时最佳步长。详细可以参见运筹学推导。我尽量少写公式,多说明,哈哈。

用到的公式主要是步长lambda公式如下:

说明:下三角f表示梯度,海赛矩阵,X1,X2这里表示各个变量(这里是两个),对于连续函数,偏导数不分先后,所以不要算两遍,后面写程序会用到!这样每走一步,都会重新设置步长,与定步长相比,是不是更加智能了?

下降停止标志:梯度趋于0,或者小于给定的eps。

有了这些理论基础后,编程实现就容易多了,下面就编程实现了。

线性关系呢。最著名的当数最小二乘法了,很多人都知道。

梯度下降的Python实现


这里用的与上一片一样的数据。

(1)、用到的函数:

不同点的梯度函数,海赛矩阵函数,迭代主函数

这里用到的比如点乘函数,在第一篇《基于最小二乘法的——线性回归拟合(一)》里面有我是放在一个脚本里面的,所以这里没有写两次,你们可以把两个脚本放在一起是没有问题的。

程序代码:

 1#-----------------梯度下降法---------------- 2#返回梯度向量 3def dif(alpha,beta,x,y): 4   dif_alpha = -2*sum(err(alpha,beta,x,y)) 5   dif_beta = -2*dot(err(alpha,beta,x,y),x) 6   return(dif_alpha,dif_beta) 7#返回海赛矩阵 8def hesse(x): 9   return([[2*len(x),2*sum(x)],[2*sum(x),2*dot(x,x)]])10#计算lambda11def lam(x1,x2):12   s1 = dot(x1,[x2[0][0],x2[1][0]])13   s2 = dot(x1,[x2[0][1],x2[1][1]])14   return(dot(x1,x1)/dot([s1,s2],x1))15#导入数学、随机数模块16import math17import random18def grad(x,y):19   #设置最大计算次数20   n_max = 10021   k = 022   error_ = 0.00123   alpha,beta = random.random(),random.random()24   #计算梯度向量25   d_f = dif(alpha,beta,x,y)26   while(math.sqrt(dot(d_f,d_f))>error_ and k<n_max):27      h = hesse(x)28      lamb = lam(d_f,h)29      alpha,beta = [alpha-lamb*d_f[0],beta-lamb*d_f[1]]30      d_f = dif(alpha,beta,x,y)31      k+=132   else:33      return(alpha,beta,k,math.sqrt(dot(d_f,d_f)))34alpha,beta,k,error = grad(x,y)35print('*------------梯度下降-----------*')36print('系数为:',alpha,beta)37print('梯度下降拟合次数为:',k)38print('梯度为:',error)39print('误差为:',error_total(alpha,beta,x,y))40R_square = r_square(alpha,beta,x,y)41print('R方:',R_square)42if(R_square>0.95):43   print('在0.05置信水平下,该线性拟合不错!')44else:45   print('在0.05置信水平下,该线性拟合效果不佳!')46#可视化47plt.figure(2)48plt.scatter(x,y,marker = '*',color = 'b')49plt.xlabel('x label')50plt.ylabel('y label')51plt.title('Linear Fit')52plt.plot(x,[alpha+beta*x_i for x_i in x],color = 'r')53plt.show()5455print('#-------------多个初始点下山---------------#')56for i in range(10):57   alpha,beta,k,error = grad(x,y)58   R_square = r_square(alpha,beta,x,y)59  print('系数为:',alpha,beta,'误差为:',error_total(alpha,beta,x,y),'R方:',R_square)60   if(R_square>0.95):61      print('在0.05置信水平下,该线性拟合不错!')62   else:63      print('在0.05置信水平下,该线性拟合效果不佳!')64   print('*********************************************')

(2)、结果

*------------梯度下降-----------*
系数为:2.1672851935 2.5282506525292012
梯度下降拟合次数为:5
梯度为:1.2745428915606112e-05
误差为:9.898083702910405
R方:0.9558599578256541
在0.05置信水平下,该线性拟合不错!

拟合图如下

 1#-------------多个初始点下山---------------# 2系数为:2.167285891989479 2.528250598680116 3误差为:9.898083702904094 4R方:0.9558599578256822 5在0.05置信水平下,该线性拟合不错! 6********************************************* 7系数为:2.167282336941068 2.5282508727544775 8误差为:9.898083702990858 9R方:0.955859957825295310在0.05置信水平下,该线性拟合不错!11*********************************************12系数为:2.167285928067579 2.528250595898777313误差为:9.89808370290390514R方:0.955859957825683115在0.05置信水平下,该线性拟合不错!16*********************************************17系数为:2.1672811054772247 2.52825096769474818误差为:9.89808370305263519R方:0.955859957825019920在0.05置信水平下,该线性拟合不错!21*********************************************22系数为:2.1672836911979947 2.52825076834759323误差为:9.89808370294174724R方:0.955859957825514425在0.05置信水平下,该线性拟合不错!26*********************************************27系数为:2.1672838440861364 2.528250756561491628误差为:9.89808370293745629R方:0.955859957825533530在0.05置信水平下,该线性拟合不错!31*********************************************32系数为:2.1672853294236947 2.528250642050225333误差为:9.89808370290875134R方:0.955859957825661535在0.05置信水平下,该线性拟合不错!36*********************************************37系数为:2.1672857750441694 2.528250607695918438误差为:9.89808370290477839R方:0.955859957825679240在0.05置信水平下,该线性拟合不错!41*********************************************42系数为:2.16728609101821 2.528250583336422643误差为:9.8980837029032744R方:0.955859957825685945在0.05置信水平下,该线性拟合不错!46*********************************************47系数为:2.1672842715049874 2.52825072360983348误差为:9.89808370292675749R方:0.955859957825581250在0.05置信水平下,该线性拟合不错!51*********************************************

当然了,这里多个初始点随机梯度下降不需要,以后对于多元多峰函数这是有必要的

结果分析

1*----------梯度下降----------*2系数为:2.1672851935 2.52825065252920123梯度下降拟合次数为:54梯度为:1.2745428915606112e-055误差为:9.8980837029104056R方:0.95585995782565417在0.05置信水平下,该线性拟合不错!

可以对比最小二乘法与梯度下降误差,我们猜测肯定是梯度下降误差大一些,因为最小二乘法基于函数极值点求法肯定是全局最优的,梯度下降由于随机原因与步长可能是靠近最优,哈哈!在有多个极值点的情况下可能是局部最优解。

 1*----------最小二乘法-------* 2 3系数为:2.6786542252575067 2.538861110659364 4 5误差为:6.8591175428159215 6 7R方:0.9696451619135048 8 9在0.05置信水平下,该线性拟合不错!1011*------------梯度下降-----------*1213系数为:2.1672851935 2.52825065252920121415梯度下降拟合次数为:51617梯度为:1.2745428915606112e-051819误差为:9.8980837029104052021R方:0.95585995782565412223在0.05置信水平下,该线性拟合不错!

可以看出,梯度为:1.2745428915606112e-05,已经接近0了,跟据实际精度会有不同。显然,梯度下降这里不存在局部极值点问题,只能是步长迈过去了,但这个点一定是靠近最优解的,误差非常小。

推荐阅读:某坤学学吴亦凡,Python可视化分析「大碗宽面」b站弹幕和网易云音乐评论

基于梯度下降法的——线性回归拟合相关推荐

  1. 基于matlab的梯度下降法实现线性回归

    基于matlab的梯度下降法实现线性回归 1 绪论 1.1线性回归的定义 1.2单变量线性回归 1.3多变量线性回归 2 梯度下降 2.1 cost function 2.2 梯度下降:解决线性回归的 ...

  2. 梯度下降法求解线性回归

    梯度下降法求解线性回归 通过梯度下降法求解简单的一元线性回归 分别通过梯度下降算法和sklearn的线性回归模型(即基于最小二乘法)解决简单的一元线性回归实际案例,通过结果对比两个算法的优缺. 通过最 ...

  3. 线性回归介绍及分别使用最小二乘法和梯度下降法对线性回归C++实现

    回归:在这类任务中,计算机程序需要对给定输入预测数值.为了解决这个任务,学习算法需要输出函数f:Rn→R.除了返回结果的形式不一样外,这类问题和分类问题是很像的.这类任务的一个示例是预测投保人的索赔金 ...

  4. python梯度下降法实现线性回归_梯度下降法的python代码实现(多元线性回归)

    梯度下降法的python代码实现(多元线性回归最小化损失函数) 1.梯度下降法主要用来最小化损失函数,是一种比较常用的最优化方法,其具体包含了以下两种不同的方式:批量梯度下降法(沿着梯度变化最快的方向 ...

  5. 机器学习:用梯度下降法实现线性回归

    之前在 机器学习算法数学基础之 -- 线性代数篇 中,总结过求解线性回归的两种方法: 最小二乘法 梯度下降法 这篇文章重点总结一下梯度下降法中的一些细节和需要注意的地方. 梯度下降法是什么 假设有一个 ...

  6. 利用梯度下降法实现线性回归--python实现

    利用梯度下降法代替最小二乘法,求线性回归方程. 首先引用库 import numpy as np import matplotlib.pyplot as plt 定义相应的x和y np.random. ...

  7. python梯度下降法实现线性回归_【机器学习】线性回归——多变量向量化梯度下降算法实现(Python版)...

    [向量化] 单一变量的线性回归函数,我们将其假设为:hθ(χ)=θ0+θ1χh_\theta(\chi)=\theta_0+\theta_1\chihθ​(χ)=θ0​+θ1​χ但是如果我们的变量个数 ...

  8. 机器学习:批量梯度下降法(线性回归中的使用)

    一.推导目标函数 1)基础概念 多元线性回归模型: 多元线性回归的损失函数: 参数 theta:θ = (θ0, θ1, θ3, ..., θn) n:表示模型中有 n 个特征参数: θ1:表示 梯度 ...

  9. 梯度下降法实现线性回归, 实例---预测波士顿房价

    本文先手动实现一个线性回归模型, 然后用sklearn的线性回归模型作对比 import pandas as pd df = pd.read_csv('house_data.csv') #数据集可到网 ...

最新文章

  1. 中国毛纺织行业竞争分析与发展前景展望报告2022-2028年
  2. 一位产品总监打算这样管国家:首先得让大家交得起税。
  3. linux系统编程之使用C++(1)-打开关闭文件
  4. 史上最全python字符串操作指南
  5. C++ 常用集合算法
  6. 制造业数字化转型的困难_智能制造如何助力企业转型升级?百家制造业企业共谋数字化转型路...
  7. 二维数组和指针(包含交换二维数组行列)
  8. bzoj1051 [HAOI2006]受欢迎的牛
  9. Guice集成Properties配置
  10. Oracle全局临时表和私有临时表
  11. 家庭整理-《家庭断舍离》书中的精髓:如何通过家庭断舍离,来消除家居环境和家庭关系间的堆积物,从而拥有自由舒适的家庭关系。
  12. win10误删系统变量path恢复方法
  13. echarts饼图中间默认内容显示与data数据显示切换
  14. 感恩陪伴,链接未来 | Conflux杭州应用开发运营中心成立
  15. plsql导入excel数据
  16. 架构师培训:aop是什么
  17. 结构设计模式:复合模式
  18. 百度奖学金获得者徐立恒:执着创造价值
  19. 2021年中式烹调师(中级)考试题及中式烹调师(中级)找解析
  20. 如何解决element ui 表单resetFields 的indexOf报错

热门文章

  1. PHPStorm无法保存个人设置 ctrl左键无法找到类
  2. 用友老是显示服务器错误,客户端连服务器出现这样的错误框
  3. 进程和线程的概念、区别和联系
  4. mysql内部_使用mysql中的内部加入
  5. 计算两个向量间的欧氏距离_用Numpy实现常见距离度量
  6. 计算机科学与技术属于教育技术学么,计算机教育-计算机教育与教育技术学有什么不同吗?我学的专业是计算机教育,我想 爱问知识人...
  7. java printwrite_Java PrintWriter write(int)用法及代码示例
  8. java forEach使用
  9. rust灯灭了怎么办_Rust Rc 方法整理
  10. python qcut_Python之Pandas库学习(三):数据处理