Python代码实现一元线性回归

  • 简述
  • 假设函数、损失函数和梯度下降法
  • Python实现一元线性回归
  • 对比sklearn实现的一元线性回归

简述

线性回归模型是机器学习里面最基础的一种模型,是为了解决回归问题,学习机器学习从线性回归开始最好,网上关于机器学习的概述有很多,这里不再详细说明,本博文主要关注初学者常见的一些问题以及本人的一些思考和心得,然后会用Python代码实现线性回归,并对比sklearn实现的线性回归,会以实例的方式展现出来。

假设函数、损失函数和梯度下降法

首先,我们利用sklearn包来生成一组一元回归数据集

import numpy as np
import pandas as pd
from sklearn import datasets   #sklearn生成数据集都在这里
from matplotlib import pyplot as plt#生成一个特征的回归数据集
x,y=datasets.make_regression(n_features=1,noise=15,random_state=2020)  plt.scatter(x,y)
plt.show()

make_regression用于生成回归数据集,在jupyter里面是用Shift+Tab查看参数,大家如果想查什么资料,强烈建议大家多去看看官网的说明文档。

如上所示,我们用肉眼大概觉得下面的这条红色回归线比较合适

那是怎么求得的呢?

本样本集属于一元线性回归问题,我们假设(markdown实在是耗费时间,关键是我还不太会用,o(╥﹏╥)o,只能写在纸上贴出来)

问题一:为什么要用均方误差,而不用平均绝对误差?
回答:其实平均绝对误差也可以代表损失,只不过后面我们要用梯度下降法求参数k,b的偏导,而平均绝对误差带有绝对值,不方便求偏导(在0处不可导),因此选用均方误差,也好理解。
在机器学习中,损失函数要求可微可导,某些损失函数还要求二阶可导,例如xgboost,后面讲到xgboost时再展开。问题二:为什么损失函数前面是1/2m,m个样本不是除以m就可以了吗?
回答:主要是为了求梯度时比较好看,抵消平方提出来的2,其实不影响最终的参数,因为加上2,只是相当于学习率(步长)变小了,不加上2,学习率就不用除以2,但是对于这个凸函数优化而言,最终都可以得到最小值,参数不会变化。这个问题在后面讲标准方程法时,我会具体证明:这个2对参数没有影响。思考:使用均方误差作为线性回归的损失函数,有什么特点?
回答:对异常值非常敏感,由于带平方,如果有1个异常数据远离样本点,在平方的作用下,这个点和正常的回归线之间误差很大,而均方误差是基于整体误差最小,可能因为这一个异常点,而导致线性回归模型参数的改变。

还有一种为什么用均方误差的解释,看这里
链接: 为什么用均方误差.
问题二的解释,请看:
链接: 为什么1/2m不会影响最终的参数
我们都知道这个损失函数是一个关于系数k,b的平方函数(因为只有一个特征),平方函数也是凸函数,因此采用梯度下降法,沿着负梯度改变,一定可以取到最小值,不存在局部最小值的问题(关于什么是凸函数,不懂的同学还请自行了解,这个比较重要,可以简单理解成单调函数)。

梯度下降法这块,相关的博文也很多(了解什么是随机梯度下降法,批量梯度下降法,小批量梯度下降法,这里用的是批量梯度下降法),这里只讲一点,为什么沿着负梯度的方向,可以取到最小值?

对系数求偏导,就是链式法则求复合函数导数,忘记的同学自己复习一下,这里不再展开。
直接贴图,写了这么久,才写了这么点,想详细点有心无力哇,/(ㄒoㄒ)/~~


这里面的θ0和θ1就是上面的b,k

Python实现一元线性回归

根据上面的推导公式,现在用Python来实现一元线性回归

class one_variable_linear():#初始化参数,k为斜率,b为截距,a为学习率,n为迭代次数def __init__(self,k,b,a,n):self.k =k self.b=bself.a=aself.n = n#梯度下降法迭代训练模型参数def fit(self,x,y):#计算总数据量m=len(x)#循环n次for i in range(self.n):b_grad=0k_grad=0#计算梯度的总和再求平均for j in range(m):b_grad += (1/m)*((self.k*x[j]+self.b)-y[j])k_grad += (1/m)*((self.k*x[j]+self.b)-y[j])*x[j]#更新k,bself.b=self.b-(self.a*b_grad)self.k=self.k-(self.a*k_grad)#每迭代10次,就输出一次图像if i%10==0:print('迭代{0}'.format(i)+'次')plt.plot(x,y,'b.')plt.plot(x,self.k*x+self.b,'r')plt.show()self.params= {'k':self.k,'b':self.b}#输出系数return self.params#预测函数def predict(self,x):y_pred =self.k * x + self.breturn y_predlr=one_variable_linear(k=1,b=1,a=0.1,n=60)
lr.fit(x,y)

下面是迭代过程:




便得到了最开始的回归线,其中k=19.2369,b=0.58201

对比sklearn实现的一元线性回归

下面使用sklearn来实现一元线性回归

from sklearn.linear_model import LinearRegression
model = LinearRegression()
model.fit(x,y)
print(model.intercept_)   #输出截距
print(model.coef_)   #输出斜率

0.5820048693454326
[19.2371827]

#sklearn实现的一元线性回归画图
plt.plot(x,y,'b.')
plt.plot(x,model.predict(x),'r')
plt.show()

咦,和自己用Python实现的一元线性回归得到的参数,虽然很接近了,但还是不一样!

问题一:为什么不一样?
回答:其实我们的Python代码,里面参数都是比较随意的,比如迭代次数为60,很多情况下这个迭代次数并不能使模型收敛,只不过今晚对于这个数据集,我试了下,还可以;
用最大迭代次数来终止参数迭代,其实是不太好的方法,这里之所以用这个办法,是为了直观展示梯度下降法的迭代是怎么做的,比如:一般可以选择用△k、△b都小于0.001之类,来判断收敛,
if np.all(△θ) < 0.001:stop iteration
但是,只要是梯度下降法,基本上不能得到代价函数最小值的参数,只能无限逼近,这个大家应该可以理解。问题二:
那sklearn里面的参数到底是用什么办法计算得到的?
回答:矩阵法,标准方程法,这个下一篇再写,还是会用实例来写,毕竟语言能力不行;
sklearn毕竟是标准包,里面的代码都经过大量优化,平时直接调包就好。思考:这个一元回归类Python代码可以优化吗?
回答:优化的点还有很多,比如没有推广到多元线性回归、多项式回归、带正则项的回归等等,大家有兴趣自己修改一下,加参数,加函数就行

今天就写到这里,下篇介绍多元线性回归以及标准方程法。

手写算法—Python代码实现一元线性回归相关推荐

  1. 多元线性回归算法python实现_手写算法-Python代码推广多元线性回归

    1.梯度下降-矩阵形式 上篇文章介绍了一元线性回归,包括Python实现和sklearn实现的实例.对比,以及一些问题点,详情可以看这里: 链接: 手写算法-Python代码实现一元线性回归 里面封装 ...

  2. 手写算法-python代码实现Ridge(L2正则项)回归

    手写算法-python代码实现Ridge回归 Ridge简介 Ridge回归分析与python代码实现 方法一:梯度下降法求解Ridge回归参数 方法二:标准方程法实现Ridge回归 调用sklear ...

  3. python实现tomasulo算法_手写算法-python代码实现KNN

    本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理 原理解析 KNN-全称K-Nearest Neighbor,最近邻算法,可以做分类任务,也可以做回归 ...

  4. python多元非线性拟合csdn_手写算法-Python代码实现非线性回归

    生成非线性数据集 前面我们介绍了Python代码实现线性回归,今天,我们来聊一聊当数据呈现非线性时,这时我们继续用线性表达式去拟合,显然效果会很差,那我们该怎么处理?继续上实例(我们的代码里用到的数据 ...

  5. python机器学习手写算法系列——线性回归

    本系列另一篇文章<决策树> https://blog.csdn.net/juwikuang/article/details/89333344 本文源代码: https://github.c ...

  6. python机器学习手写算法系列——kmeans聚类

    从机器学习到kmeans 聚类是一种非监督学习,他和监督学习里的分类有相似之处,两者都是把样本分布到不同的组里去.区别在于,分类分析是有标签的,聚类是没有标签的.或者说,分类是有y的,聚类是没有y的, ...

  7. 【机器学习与算法】python手写算法:Cart树

    [机器学习与算法]python手写算法:Cart树 背景 代码 输出示例 背景 Cart树算法原理即遍历每个变量的每个分裂节点,找到增益(gini或entropy)最大的分裂节点进行二叉分割. 这里只 ...

  8. python机器学习手写算法系列——逻辑回归

    从机器学习到逻辑回归 今天,我们只关注机器学习到线性回归这条线上的概念.别的以后再说.为了让大家听懂,我这次也不查维基百科了,直接按照自己的理解用大白话说,可能不是很严谨. 机器学习就是机器可以自己学 ...

  9. python识别数字程序_python实现识别手写数字 python图像识别算法

    写在前面 这一段的内容可以说是最难的一部分之一了,因为是识别图像,所以涉及到的算法会相比之前的来说比较困难,所以我尽量会讲得清楚一点. 而且因为在编写的过程中,把前面的一些逻辑也修改了一些,将其变得更 ...

最新文章

  1. php排除无效字查询,php删除无效的字符
  2. 平台电商类的增长策略:从用户激励到养成类游戏
  3. 云幸福–如何在几分钟内安装新的OpenShift Container Platform 3.7
  4. C语言排序方法-----冒泡排序法
  5. Android Camera 编程从入门到精通
  6. jquery.rotate.js 转盘抽奖示例
  7. Hyper-V使用手记(一):无法引导安装FreeBSD7
  8. 根据ip查询所在国家地区(国内外ip均适用)
  9. 技术图文:基于“科比投篮”数据集学Pandas
  10. 计算机为什么不能新建文档,电脑无法新建word文档怎么办
  11. 谷梁科技多元化一卡通系统应用
  12. STM32与DS1302设计时钟芯片,超详细
  13. win7 install solution for intel SKL and BSW platform
  14. 测试鼠标传感器的软件,鼠标该怎么选择?给大家说下鼠标传感器的差距
  15. 我,27岁,程序员,9月无情被辞:想给做开发的提个醒…
  16. 五剑同辉 聚力安全:绿盟科技五大安全实验室震撼发布
  17. ERP 系统的核心是什么?有什么作用?
  18. TeamViewe远程桌面软件连接移动设备的控制面板教程
  19. python高级练习题:多米诺平铺 - 5×2N局【难度:4级】--景越Python编程实例训练营,不同难度Python习题,适合自学Python的新手进阶
  20. PHP 一句话木马 @eval($_POST[‘hack‘]); 语句解析及靶机演示

热门文章

  1. 硬件设计——信号放大的电路
  2. 网易视频云:HBase —— RegionServer宕机案件侦查
  3. MFC vc_mbcsmfc.exe安装失败
  4. 《kaldi语音识别实战》阅读笔记:三音素模型训练—train_deltas.sh解析
  5. Powershell如何修改组策略(group policy)
  6. echarts 梯形柱状图
  7. 广联达java开发_2019年广联达JAVA开发实习面试经验
  8. 【面经】广联达-C++软件开发工程师
  9. 2019年创业亏 800 万元,可以给到你的一些建议
  10. 对计算机病毒防治最科学的方法,常见的计算机病毒防范方法有哪些