在线性回归模型中使用梯度下降法

In [1]:

import numpy as np
import matplotlib.pyplot as plt
import datetime;print ('Run by CYJ,',datetime.datetime.now())
Run by CYJ, 2022-01-26 12:10:39.645986

In [2]:

np.random.seed(666)
x = 2 * np.random.random(size=100)
y = x * 3. + 4. + np.random.normal(size=100)

In [3]:

X = x.reshape(-1, 1)

In [4]:

X[:20]

Out[4]:

array([[1.40087424],[1.68837329],[1.35302867],[1.45571611],[1.90291591],[0.02540639],[0.8271754 ],[0.09762559],[0.19985712],[1.01613261],[0.40049508],[1.48830834],[0.38578401],[1.4016895 ],[0.58645621],[1.54895891],[0.01021768],[0.22571531],[0.22190734],[0.49533646]])

In [5]:

y[:20]

Out[5]:

array([8.91412688, 8.89446981, 8.85921604, 9.04490343, 8.75831915,4.01914255, 6.84103696, 4.81582242, 3.68561238, 6.46344854,4.61756153, 8.45774339, 3.21438541, 7.98486624, 4.18885101,8.46060979, 4.29706975, 4.06803046, 3.58490782, 7.0558176 ])

In [6]:

plt.scatter(x, y)
plt.show()

使用梯度下降法训练

In [7]:

def J(theta, X_b, y):try:return np.sum((y - X_b.dot(theta))**2) / len(X_b)except:return float('inf')

In [8]:

def dJ(theta, X_b, y):res = np.empty(len(theta))res[0] = np.sum(X_b.dot(theta) - y)for i in range(1, len(theta)):res[i] = (X_b.dot(theta) - y).dot(X_b[:,i])return res * 2 / len(X_b)

In [9]:

def gradient_descent(X_b, y, initial_theta, eta, n_iters = 1e4, epsilon=1e-8):theta = initial_thetacur_iter = 0while cur_iter < n_iters:gradient = dJ(theta, X_b, y)last_theta = thetatheta = theta - eta * gradientif(abs(J(theta, X_b, y) - J(last_theta, X_b, y)) < epsilon):breakcur_iter += 1return theta

In [10]:

X_b = np.hstack([np.ones((len(x), 1)), x.reshape(-1,1)])
initial_theta = np.zeros(X_b.shape[1])
eta = 0.01theta = gradient_descent(X_b, y, initial_theta, eta)

In [11]:

theta

Out[11]:

array([4.02145786, 3.00706277])

封装我们的线性回归算法

In [12]:

from playML.LinearRegression import LinearRegressionlin_reg = LinearRegression()
lin_reg.fit_gd(X, y)

Out[12]:

LinearRegression()

In [13]:

lin_reg.coef_

Out[13]:

array([3.00706277])

In [14]:

lin_reg.intercept_

Out[14]:

4.021457858204859

In [ ]:

												

[云炬python3玩转机器学习] 6-4 在线性回归模型中使用梯度下降法相关推荐

  1. [云炬python3玩转机器学习]5-10 更多关于线性回归的讨论

    线性回归的系数:有正有负,正负代表我们预测的特征与房价是正相关还是负相关 为正就是正相关,换句话就是这个特征越大,房价越高,系数绝对值的大小就决定了影响的程度 技巧:即使数据用线性回归算法进行预测,得 ...

  2. [云炬python3玩转机器学习] 5-7,8 多元线性回归正规解及其实现

    08 实现我们自己的 Linear Regression import numpy as np import matplotlib.pyplot as plt from sklearn import ...

  3. [云炬python3玩转机器学习笔记] 1-3课程所使用的主要技术栈

    课程环境 语言:Python3 框架:Scikit-learn 其他框架:numpy,matplotlib... IDE:Jupyter Notebook,PyCharm,ANACONDA 课程学习基 ...

  4. [云炬python3玩转机器学习笔记] 3-2 Jupter Notebook魔法命令

    xxxxxxxxxx### %run %run¶ In [1]:%run myscript/hello.py hello Machine Learning ! . . .In [2]:xxxxxxxx ...

  5. [云炬python3玩转机器学习笔记] 3-1 Jupyter Notebook

    1+2for _ in range(5):print("Hello, Machine Learning!")5+5*29+9print("天津云炬网络科技有限公司&quo ...

  6. [云炬python3玩转机器学习] 5-6最好的衡量线性回归法的指标: R Squared

    RMSE MAE 无法解决 不同模型之间的误差值比较,如预测房产数据误差为5万元,而预测学生成绩误差是10分,无法进行比较算法是用在哪个问题上好 R Squared就可以解决这种问题 y = y均值这 ...

  7. [云炬python3玩转机器学习笔记] 2-6关于回归和分类

    在这一章,我们了解到了,机器学习主要可以处理的两大类问题,是回归和分类.看起来,似乎有些局限,但是,非常出人意料的,在我们现实生活中,很多问题,都可以通过化简,或者转换的手段,转换成分类问题或者回归问 ...

  8. [云炬python3玩转机器学习笔记] 2-5机器学习相关的哲学思考

    2-5机器学习相关的哲学思考

  9. [云炬python3玩转机器学习笔记] 2-4批量学习、咋西安学习、参数学习和非参数学习

    机器学习的其他分类: 在线学习(online learining)和批量学习(离线学习 batch learning/offline learning): 批量学习(之前没有具体说明的话,都可以用批量 ...

  10. [云炬python3玩转机器学习笔记] 2-2机器学习主要任务

    机器学习(监督学习)的主要任务 一.分类:将给定的数据进行分类- 二分类任务:二选一的方式,yes/no- 多分类任务:结果不仅仅在两个结果中,而是很多结果,获得的结果很明确- 数字识别- 图像识别- ...

最新文章

  1. fcm算法的MATLAB实现,FCM算法的matlab程序(初步)
  2. Windows系统运维转linux系统运维的经历
  3. BZOJ 2287 【POJ Challenge】消失之物
  4. Java编程思想(第4版)读书笔记——01
  5. pytorch 半精度,提升pytorch推理性能
  6. ECSHOP二次开发文档【文件结构和数据库表分析】
  7. 《程序员代码面试指南第二版》Python实现(个人读书笔记)
  8. 为RecyclerView添加下拉刷新(PullToRefresh)功能
  9. mysql的日期和时间函数
  10. 【BZOJ4818】序列计数(动态规划,生成函数)
  11. iPhone升级iOS 15卡在请求更新上怎么办?
  12. android5.0刷机,真快!努比亚手机更新安卓5.0(附刷机方法)
  13. 怎么把vob格式的视频转换成mp4格式
  14. Aho-Corasick算法学习
  15. start-stop-daemon 用法
  16. MySQL基础 - 简单查询
  17. Linux 下的IP/子网计算器:ipcalc
  18. 十一假期,分享几个好玩儿的GitHub项目
  19. Log4j简单xml配置
  20. 【Mind】角膜上皮脱落康复记录

热门文章

  1. springmvc 实现javamail
  2. Android中Spinner的使用
  3. ARM的批量加载/存储指令
  4. poj 2079(旋转卡壳)
  5. hdu 1723(简单dp)
  6. hdu-2209 翻纸牌游戏
  7. 动态规划之——拦截导弹(nyoj79)
  8. NYOJ 655 光棍的yy
  9. hdu 1050 Moving Tables
  10. JSON 常量详情参考 (内含对中文不转义的参数)