scikit-learn的线性回归模型都是通过最小化成本函数来计算参数的,通过矩阵乘法和求逆运算来计算参数。当变量很多的时候计算量会非常大,因此我们改用梯度下降法,批量梯度下降法每次迭代都用所有样本,快速收敛但性能不高,随机梯度下降法每次用一个样本调整参数,逐渐逼近,效率高,本节我们来利用随机梯度下降法做拟合

梯度下降法

梯度下降就好比从一个凹凸不平的山顶快速下到山脚下,每一步都会根据当前的坡度来找一个能最快下来的方向。随机梯度下降英文是Stochastic gradient descend(SGD),在scikit-learn中叫做SGDRegressor。

样本实验

依然用上一节的房价样本

X = [[50],[100],[150],[200],[250],[300]]

y = [[150],[200],[250],[280],[310],[330]]

  1. import matplotlib.pyplot as plt

  2. from sklearn.linear_model import SGDRegressor

  3. from sklearn.preprocessing import StandardScaler

  4. plt.figure()#实例化作图变量

  5. plt.title("SGD")

  6. plt.xlabel('x')

  7. plt.ylabel('y')

  8. plt.grid(True)

  9. X_scaler = StandardScaler()

  10. y_scaler = StandardScaler()

  11. X = [[50],[100],[150],[200],[250],[300]]

  12. y = [[150],[200],[250],[280],[310],[330]]

  13. X = X_scaler.fit_transform(X)

  14. y = y_scaler.fit_transform(y)

  15. X_test = [[40],[400]]#用来做最终效果测试

  16. X_test = X_scaler.transform(X_test)

  17. plt.plot(X,y,'k.')

  18. model = SGDRegressor()

  19. model.fit(X,y.ravel())

  20. y_result = model.predict(X_test)

  21. plt.plot(X_test,y_result,'g-')

  22. plt.show()

效果图如下:

这里需要用StandardScaler来对样本数据做正规化,同时对测试数据也要做正规化

我们发现拟合出的直线和样本之间还是有一定偏差的,这是因为随机梯度是随着样本数量的增加不断逼近最优解的,也就是样本数量越多就越准确

优化效果

既然样本数多拟合的好,那么我们把已有的样本重复多次试一下,修改成如下

  1. import matplotlib.pyplot as plt

  2. from sklearn.linear_model import SGDRegressor

  3. from sklearn.preprocessing import StandardScaler

  4. plt.figure()#实例化作图变量

  5. plt.title("SGD")

  6. plt.xlabel('x')

  7. plt.ylabel('y')

  8. plt.grid(True)

  9. X_scaler = StandardScaler()

  10. y_scaler = StandardScaler()

  11. X = [[50],[100],[150],[200],[250],[300],[50],[100],[150],[200],

  12. [250],[300],[50],[100],[150],[200],[250],[300],[50],[100],

  13. [150],[200],[250],[300],[50],[100],[150],[200],[250],[300],

  14. [50],[100],[150],[200],[250],[300],[50],[100],[150],[200],

  15. [250],[300],[50],[100]]

  16. y = [[150],[200],[250],[280],[310],[330],[150],[200],[250],[280],

  17. [310],[330],[150],[200],[250],[280],[310],[330],[150],[200],

  18. [250],[280],[310],[330],[150],[200],[250],[280],[310],[330],

  19. [150],[200],[250],[280],[310],[330],[150],[200],[250],[280],

  20. [310],[330],[150],[200]]

  21. X = X_scaler.fit_transform(X)

  22. y = y_scaler.fit_transform(y)

  23. X_test = [[40],[400]]#用来做最终效果测试

  24. X_test = X_scaler.transform(X_test)

  25. plt.plot(X,y,'k.')

  26. model = SGDRegressor()

  27. model.fit(X,y.ravel())

  28. y_result = model.predict(X_test)

  29. plt.plot(X_test,y_result,'g-')

  30. plt.show()


这回靠谱了许多,实事上,如果再继续重复样本,效果会更逼近

SGDRegressor相关推荐

  1. 梯度下降回归SGDRegressor、岭回归(Ridge)和套索(Lasso)回归、套索最小角回归、ElasticNet回归、正交匹配追踪回归

    梯度下降回归SGDRegressor.岭回归(Ridge)和套索(Lasso)回归.套索最小角回归.ElasticNet回归.正交匹配追踪回归 目录

  2. 机器学习(13)岭回归(线性回归的改进)

    目录 一.基础理论 API 二.岭回归:预测波士顿房价 总代码 一.基础理论 岭回归:带有L2正则化的线性回归.(为了解决过拟合) 对病态数据的拟合要强于最小二乘法 API sklearn.linea ...

  3. 机器学习(11)线性回归(2)实战 -- 正规方程优化、梯度下降优化(波士顿房价预测)

    目录 一.波士顿房价预测(正规方程优化) API 1.获取数据集 2.划分数据集 3.标准化 4. 创建预估器,得到模型 5.模型评估(均方差评估) 代码 二.波士顿房价预测(正规方程优化) API ...

  4. 残差平方和ssr的计算公式为_如何为你的回归问题选择最合适的机器学习方法?...

    文章发布于公号[数智物语] (ID:decision_engine),关注公号不错过每一篇干货. 转自 | AI算法之心(公众号ID:AIHeartForYou) 作者 | 何从庆 什么是回归呢?回归 ...

  5. 如何为回归问题选择最合适的机器学习方法?

    作者 | 何从庆 本文经授权转载自 AI算法之心(id:AIHeartForYou) 在目前的机器学习领域中,最常见的三种任务就是:回归分析.分类分析.聚类分析.在之前的文章中,我曾写过一篇<1 ...

  6. Scikit-learn 发布 0.20版本!新增处理缺失值、合并Pandas等亮点功能

    整理 | Jane 出品 | AI科技大本营 之前一直预告 Scikit-learn 的新版本会在 9 月发布,在马上就要结束的 9 月,我们终于迎来了 Scikit-learn  0.20. 此版本 ...

  7. 基于LightGBM算法实现数据挖掘!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄雨龙,中国科学技术大学 对于回归问题,Datawhale已经梳理 ...

  8. 线性回归的改进-岭回归

    线性回归的改进-岭回归 1 API sklearn.linear_model.Ridge(alpha=1.0, fit_intercept=True,solver="auto", ...

  9. 线性回归之案例:波士顿房价预测

    线性回归之案例:波士顿房价预测 数据介绍   [13个特征值,1个目标值] 给定的这些特征,是专家们得出的影响房价的结果属性.此阶段不需要自己去探究特征是否有用,只需要使用这些特征.到后面量化很多特征 ...

最新文章

  1. py2exe——.py文件转换成exe
  2. pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构
  3. VS2010 定位文件在solution中的位置
  4. Istio 1.7——进击的追风少年
  5. VIM 正则表达式搜索字符串
  6. python课程设计
  7. Python+django网页设计入门(6):文件上传与数据导入
  8. 黑帽大会:苹果网络服务器比微软易入侵
  9. HTML中select的option设置selected=“selected“无效的解决方案
  10. 虚拟机命令里面的光标不动了怎么办_Linux Sever简单笔记(第四堂课)之Linux下的文本编辑器vim(vim中常用的操作方式命令) - 我杨晓东太难了...
  11. 一共81个,开源大数据处理工具汇总(上)
  12. bootstrap实现导航栏的响应式布局,当在小屏幕、手机屏幕浏览时自动折叠隐藏
  13. 12个开源的后台管理系统
  14. 解决libpng warning: iCCP: known incorrect sRGB profile
  15. 你想要创建一个属于自己的网站吗?十大免费网站
  16. Web APls 阶段——第四节——案例:关闭淘宝二维码案例
  17. 企业信息安全需要做到的三点,可以有效的规避大部分风险
  18. 中国移动研究院人工智能中心前端面试题目整理
  19. 字符串匹配KMP算法讲解
  20. javascript 编码转换

热门文章

  1. Python自动发送邮件-smtplib和email库
  2. Python:向函数传递任意数量的实参
  3. delphi base64 java_Base64以及delphi、Java实现[转]
  4. C语言strstr()函数(在主字符串里查找子字符串,返回第一次找到的子字符串以及后面的字符串)
  5. Intel Realsense D435 在windows系统下运行时请修改相机隐私设置以确保摄像头正常运行(没啥子用,还是掉线)
  6. 利用最小二乘法,用直线拟合点时,为什么计算竖直距离而非垂直距离?为什么在线性回归分析中,求的是距离平方和最小,而不是距离之和最小?
  7. BeautifulSoup中的.text方法和get_text()方法的区别
  8. 搭建FastDFS分布式文件方式一(Docker版本)
  9. Spring 使用注解@DependsOn控制Bean加载顺序
  10. Logstash7.6.0同步MySQL到Elasticsearch