简述

线性回归模型是机器学习里面最基础的一种模型,是为了解决回归问题,学习机器学习从线性回归开始最好,网上关于机器学习的概述有很多,这里不再详细说明,本博文主要关注初学者常见的一些问题以及本人的一些思考和心得,然后会用Python代码实现线性回归,并对比sklearn实现的线性回归,会以实例的方式展现出来。

假设函数、损失函数和梯度下降法

首先,我们利用sklearn包来生成一组一元回归数据集

import numpy as np

import pandas as pd

from sklearn import datasets #sklearn生成数据集都在这里

from matplotlib import pyplot as plt

#生成一个特征的回归数据集

x,y=datasets.make_regression(n_features=1,noise=15,random_state=2020) plt.scatter(x,y)

plt.show()1

2

3

4

5

6

7

8

9

10

11

12

make_regression用于生成回归数据集,在jupyter里面是用Shift+Tab查看参数,大家如果想查什么资料,强烈建议大家多去看看官网的说明文档。

如上所示,我们用肉眼大概觉得下面的这条红色回归线比较合适

那是怎么求得的呢?

本样本集属于一元线性回归问题,我们假设(markdown实在是耗费时间,关键是我还不太会用,o(╥﹏╥)o,只能写在纸上贴出来)

问题一:为什么要用均方误差,而不用平均绝对误差?

回答:其实平均绝对误差也可以代表损失,只不过后面我们要用梯度下降法求参数k,b的偏导,而平均绝对误差带有绝对值,不方便求偏导(在0处不可导),因此选用均方误差,也好理解。

在机器学习中,损失函数要求可微可导,某些损失函数还要求二阶可导,例如xgboost,后面讲到xgboost时再展开。

问题二:为什么损失函数前面是1/2m,m个样本不是除以m就可以了吗?

回答:主要是为了求梯度时比较好看,抵消平方提出来的2,其实不影响最终的参数,因为加上2,只是相当于学习率(步长)变小了,不加上2,学习率就不用除以2,但是对于这个凸函数优化而言,最终都可以得到最小值,参数不会变化。这个问题在后面讲标准方程法时,我会具体证明:这个2对参数没有影响。

思考:使用均方误差作为线性回归的损失函数,有什么特点?

回答:对异常值非常敏感,由于带平方,如果有1个异常数据远离样本点,在平方的作用下,这个点和正常的回归线之间误差很大,而均方误差是基于整体误差最小,可能因为这一个异常点,而导致线性回归模型参数的改变。1

2

3

4

5

6

7

8

9

还有一种为什么用均方误差的解释,看这里

链接: 为什么用均方误差.

问题二的解释,请看:

链接: 为什么1/2m不会影响最终的参数

我们都知道这个损失函数是一个关于系数k,b的平方函数(因为只有一个特征),平方函数也是凸函数,因此采用梯度下降法,沿着负梯度改变,一定可以取到最小值,不存在局部最小值的问题(关于什么是凸函数,不懂的同学还请自行了解,这个比较重要,可以简单理解成单调函数)。

梯度下降法这块,相关的博文也很多(了解什么是随机梯度下降法,批量梯度下降法,小批量梯度下降法,这里用的是批量梯度下降法),这里只讲一点,为什么沿着负梯度的方向,可以取到最小值?

对系数求偏导,就是链式法则求复合函数导数,忘记的同学自己复习一下,这里不再展开。

直接贴图,写了这么久,才写了这么点,想详细点有心无力哇,/(ㄒoㄒ)/~~

这里面的θ0和θ1就是上面的b,k

Python实现一元线性回归

根据上面的推导公式,现在用Python来实现一元线性回归

class one_variable_linear(): #初始化参数,k为斜率,b为截距,a为学习率,n为迭代次数 def __init__(self,k,b,a,n): self.k =k self.b=b self.a=a self.n = n #梯度下降法迭代训练模型参数 def fit(self,x,y): #计算总数据量 m=len(x) #循环n次 for i in range(self.n): b_grad=0 k_grad=0 #计算梯度的总和再求平均 for j in range(m): b_grad += (1/m)*((self.k*x[j]+self.b)-y[j]) k_grad += (1/m)*((self.k*x[j]+self.b)-y[j])*x[j] #更新k,b self.b=self.b-(self.a*b_grad) self.k=self.k-(self.a*k_grad) #每迭代10次,就输出一次图像 if i%10==0: print('迭代{0}'.format(i)+'次') plt.plot(x,y,'b.') plt.plot(x,self.k*x+self.b,'r') plt.show() self.params= {'k':self.k,'b':self.b} #输出系数 return self.params #预测函数 def predict(self,x): y_pred =self.k * x + self.b return y_pred

lr=one_variable_linear(k=1,b=1,a=0.1,n=60)

lr.fit(x,y)1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

下面是迭代过程:

便得到了最开始的回归线,其中k=19.2369,b=0.58201

对比sklearn实现的一元线性回归

下面使用sklearn来实现一元线性回归

from sklearn.linear_model import LinearRegression

model = LinearRegression()

model.fit(x,y)

print(model.intercept_) #输出截距

print(model.coef_) #输出斜率1

2

3

4

5

0.5820048693454326

[19.2371827]

#sklearn实现的一元线性回归画图

plt.plot(x,y,'b.')

plt.plot(x,model.predict(x),'r')

plt.show()1

2

3

4

咦,和自己用Python实现的一元线性回归得到的参数,虽然很接近了,但还是不一样!

问题一:为什么不一样?

回答:其实我们的Python代码,里面参数都是比较随意的,比如迭代次数为60,很多情况下这个迭代次数并不能使模型收敛,只不过今晚对于这个数据集,我试了下,还可以;

用最大迭代次数来终止参数迭代,其实是不太好的方法,这里之所以用这个办法,是为了直观展示梯度下降法的迭代是怎么做的,比如:一般可以选择用△k、△b都小于0.001之类,来判断收敛,

if np.all(△θ) < 0.001:

stop iteration

但是,只要是梯度下降法,基本上不能得到代价函数最小值的参数,只能无限逼近,这个大家应该可以理解。

问题二:

那sklearn里面的参数到底是用什么办法计算得到的?

回答:矩阵法,标准方程法,这个下一篇再写,还是会用实例来写,毕竟语言能力不行;

sklearn毕竟是标准包,里面的代码都经过大量优化,平时直接调包就好。

思考:这个一元回归类Python代码可以优化吗?

回答:优化的点还有很多,比如没有推广到多元线性回归、多项式回归、带正则项的回归等等,大家有兴趣自己修改一下,加参数,加函数就行1

2

3

4

5

6

7

8

9

10

11

12

13

14

今天就写到这里,下篇介绍多元线性回归以及标准方程法。

文章来源: blog.csdn.net,作者:Dream-YH,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/weixin_44700798/article/details/110405473

python一元线性回归算法_手写算法—Python代码实现一元线性回归相关推荐

  1. python实现tomasulo算法_手写算法-python代码实现KNN

    本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理 原理解析 KNN-全称K-Nearest Neighbor,最近邻算法,可以做分类任务,也可以做回归 ...

  2. 多元线性回归算法python实现_手写算法-Python代码推广多元线性回归

    1.梯度下降-矩阵形式 上篇文章介绍了一元线性回归,包括Python实现和sklearn实现的实例.对比,以及一些问题点,详情可以看这里: 链接: 手写算法-Python代码实现一元线性回归 里面封装 ...

  3. 前端算法及手写算法JavaScript

    一.手写算法 1.获取url中参数列表,保存为对象 function getUrlParam(){ //获取url中参数列表,保存为对象 var url="http://jjhs/dddh? ...

  4. 扫描线填充算法代码_手写算法并记住它:计数排序

    对于经典算法,你是否也遇到这样的情形:学时觉得很清楚,可过阵子就忘了? 本系列文章就尝试解决这个问题. 研读那些排序算法,细品它们的名字,其实都很贴切. 比如计数排序,所谓"计数" ...

  5. python多元非线性拟合csdn_手写算法-Python代码实现非线性回归

    生成非线性数据集 前面我们介绍了Python代码实现线性回归,今天,我们来聊一聊当数据呈现非线性时,这时我们继续用线性表达式去拟合,显然效果会很差,那我们该怎么处理?继续上实例(我们的代码里用到的数据 ...

  6. python机器学习手写算法系列——线性回归

    本系列另一篇文章<决策树> https://blog.csdn.net/juwikuang/article/details/89333344 本文源代码: https://github.c ...

  7. python手写字母识别_机器学习--kNN算法识别手写字母

    本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...

  8. 手写算法-python代码实现Ridge(L2正则项)回归

    手写算法-python代码实现Ridge回归 Ridge简介 Ridge回归分析与python代码实现 方法一:梯度下降法求解Ridge回归参数 方法二:标准方程法实现Ridge回归 调用sklear ...

  9. 【Python】基于kNN算法的手写识别系统的实现与分类器测试

    基于kNN算法的手写识别系统 1.      数据准备 使用windows画图工具,手写0-9共10个数字,每个数字写20遍,共200个BMP文件. 方法如下,使用画图工具,打开网格线,调整像素为32 ...

  10. python机器学习手写算法系列——逻辑回归

    从机器学习到逻辑回归 今天,我们只关注机器学习到线性回归这条线上的概念.别的以后再说.为了让大家听懂,我这次也不查维基百科了,直接按照自己的理解用大白话说,可能不是很严谨. 机器学习就是机器可以自己学 ...

最新文章

  1. R语言ggplot2可视化:ggplot2可视化直方图(histogram)并在直方图的顶部外侧(top upper)或者直方图内部添加数值标签
  2. Python爬虫(十三)_JSON模块与JsonPath
  3. ios Swift 中文学习手册
  4. 分享memcache和memcached安装过程(转)
  5. 云原生时代,底层性能如何调优?
  6. org.apache.hadoop.security.AccessControlException: Permission denied: user=anonymous, access=EXECUTE
  7. html设置json请求头,当我想在zf2客户端代码中使用“application/json”时,接受请求标头是“text/html,application/xhtm ...(etc)”...
  8. AttributeError: module 'pymysql' has no attribute 'escape' 错误的出现以及解决
  9. 采用FTP协议实现文件的上传
  10. Python与Java曝漏洞,黑客利用FTP注入攻击可绕过防火墙
  11. 阶乘末尾蓝桥杯java_Java实现第九届蓝桥杯阶乘位数
  12. 批量梯度下降,随机梯度下降和小批量梯度下降的区别
  13. c++学习 -- #program once
  14. VisualStudio移动开发(C#、VB.NET)Smobiler开发平台——AlbumView相册控件的使用方式...
  15. mysql的读写分离配置
  16. 用户调研的操作步骤与过程模板
  17. QT 车牌号正则验证
  18. STM32F103 与 STM32F207/407编程的区别自我总结
  19. 计算机无法进入pe系统,u盘启动盘无法进入pe解决方法
  20. opencv convertTo函数

热门文章

  1. S4 HANA BP 维护客户信贷管理数据
  2. 汽车用组合仪表设计规范
  3. python中的数据存储-json
  4. idea weblogic 部署慢_IDEA+weblogic部署运行项目
  5. RTX2009管理器服务运行状态空白
  6. 华为数通HCNA学习资料
  7. 无刷滑环全面分析大全
  8. 来自吉普赛人祖传的神奇读心术.它能测算出你的内心感应
  9. 金蝶K3供应链-采购系统选项功能描述
  10. 转载:肖知兴:管理到底是个什么鬼,以及怎么破