参考视频:03.梯度下降算法_哔哩哔哩_bilibili

显然使用穷举法效率太低了,如果权重多一些,时间复杂度将是指数级的增长。所以我们需要使用梯度下降算法来优化。

梯度Gradient: ∂ c o s t ∂ w \frac{\partial{cost}}{\partial{w}} ∂w∂cost​

用梯度来更新权重w:

w = w − α ∂ c o s t ∂ w w=w-\alpha\frac{\partial{cost}}{\partial{w}} w=w−α∂w∂cost​

(如果梯度大于0,则表示当前方向是误差增大的方向,需要往方向更新w;反之,如果梯度小于0,则表示沿当前方向是误差减小)

1 梯度下降算法

参考视频中 y = w x y=wx y=wx的样例,我们尝试一下 y = w x + b y=wx+b y=wx+b

还是以 y = 3 x + 2 y=3x+2 y=3x+2为例(事先不知道)

学习率设为0.01,训练2000次

代码如下:

import numpy as np
import matplotlib.pyplot as pltx_data = [1.0, 2.0, 3.0]
y_data = [5.0, 8.0, 11.0]w = 2.0
b = 1.0
lr = 0.01def forward(x):return w * x + b# 计算误差
def cost(xs, ys):cost = 0for x, y in zip(xs, ys):y_pred = forward(x)cost += (y_pred - y) ** 2return cost / len(xs)# 计算w的梯度
def gradient_w(xs, ys):grad = 0for x, y in zip(xs, ys):grad += 2 * x * (x * w + b - y)return grad / len(xs)# 计算b的梯度
def gradient_b(xs, ys):grad = 0for x, y in zip(xs, ys):grad += 2 * (x * w + b - y)return grad / len(xs)shuffle_array = [0, 1, 2]
cost_list = []for epoch in range(2000):   # 训练2000次cost_val = cost(x_data, y_data)grad_w_val = gradient_w(x_data, y_data)grad_b_val = gradient_b(x_data, y_data)w -= lr * grad_w_val    # 更新wb -= lr * grad_b_val    # 更新bprint('Epoch:', epoch, 'w=', w, ' b=', b, ' loss=', cost_val)# print('gw:', grad_w_val, 'gb:', grad_b_val)cost_list.append(cost_val)
print('Predict(after training', 4, forward(4))plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Cost')
plt.plot(np.arange(0, 2000, 1), np.array(cost_list))
plt.show()

2000次训练后结果如下,很逼近我们的答案w=3,b=2


2 随机梯度下降

用随机梯度下降的思想进一步优化,学习率调整为0.05,训练次数调整为100次(实际为更新300次)

和上面不同的主要是,每一次从训练集中不放回地抽取一组(x,y)计算梯度,更新我们的权重w和偏置b

我们可以定义一个shuffle数组存数据索引,每个epoch开始前打乱数组顺序,再根据打乱的索引取数据

shuffle_array = [0,1,2]
random.shuffle(shuffle_array) # 打乱顺序

完整代码如下:

import numpy as np
import matplotlib.pyplot as plt
import randomx_data = [1.0, 2.0, 3.0]
y_data = [5.0, 8.0, 11.0]w = 2.0
b = 1.0
lr = 0.05def forward(x):return w * x + bdef loss(x, y):y_pred = forward(x)return (y_pred - y) ** 2def gradient_w(x, y):return 2 * x * (x * w + b - y)def gradient_b(x, y):return 2 * (x * w + b - y)shuffle_array = [0,1,2]
cost_list = []for epoch in range(100):random.shuffle(shuffle_array)print(shuffle_array)for i in shuffle_array:x, y = x_data[i], y_data[i]cost_val = loss(x, y)grad_w_val = gradient_w(x, y)grad_b_val = gradient_b(x, y)w -= lr * grad_w_valb -= lr * grad_b_valprint('Epoch:', epoch, 'w=', w, ' b=', b, ' loss=', cost_val)cost_list.append(cost_val)
print('Predict(after training', 4, forward(4))plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.plot(np.arange(0, 100, 1), np.array(cost_list))
plt.show()

可以发现在进一步优化后,我们只用了300次的训练就达到了预期的效果

线性模型(梯度下降随机梯度下降)相关推荐

  1. 坐标下降+随机梯度下降

    坐标下降+随机梯度下降 坐标轴下降法(Coordinate Descent, CD)是一种迭代法,通过启发式的方法一步步的迭代求解函数的最小值,和梯度下降法(GD)不同的时候,坐标轴下降法是沿着坐标轴 ...

  2. 【数据挖掘】神经网络 后向传播算法 ( 梯度下降过程 | 梯度方向说明 | 梯度下降原理 | 损失函数 | 损失函数求导 | 批量梯度下降法 | 随机梯度下降法 | 小批量梯度下降法 )

    文章目录 I . 梯度下降 Gradient Descent 简介 ( 梯度下降过程 | 梯度下降方向 ) II . 梯度下降 示例说明 ( 单个参数 ) III . 梯度下降 示例说明 ( 多个参数 ...

  3. 批梯度下降 随机梯度下降_梯度下降及其变体快速指南

    批梯度下降 随机梯度下降 In this article, I am going to discuss the Gradient Descent algorithm. The next article ...

  4. 线性回归随机梯度下降_线性回归的批次梯度与随机梯度下降

    线性回归随机梯度下降 In this article, we will introduce about batch gradient and stochastic gradient descent m ...

  5. 批量梯度下降 | 随机梯度下降 | 小批度梯度下降

    文章目录 1. 什么是梯度?求梯度有什么公式? 2. 批量梯度下降 | 随机梯度下降 | 小批度梯度下降 区别 3. 随机梯度下降的两种方式:原始形式 和 对偶形式 1. 什么是梯度?求梯度有什么公式 ...

  6. 梯度下降法—随机梯度下降

    1.算法描述 批量梯度下降的主要问题是它要用整个训练集来计算每一步的梯度,训练集大时算法特别慢.相反,随机梯度下降,每一步在训练集中随机选择一个实例,并且仅基于该单个实例来计算梯度. 与使用批量梯度下 ...

  7. 一般梯度、随机梯度、相对梯度和自然梯度

    一般梯度 也称常规梯度,就是 f ( w ⃗ ) f(\vec w) f(w ) 对 w ⃗ \vec w w 的偏导,即 ∂ f ( w ⃗ ) ∂ w ⃗ \frac{\partial f(\ve ...

  8. Sklearn官方文档中文整理4——随机梯度下降和最近邻篇

    Sklearn官方文档中文整理4--随机梯度下降和最近邻篇 1. 监督学习 1.5. 随机梯度下降 1.5.1. 分类[linear_model.SGDClassifier] 1.5.2. 回归[li ...

  9. 【转载】深度学习数学基础(二)~随机梯度下降(Stochastic Gradient Descent, SGD)

    Source: 作者:Evan 链接:https://www.zhihu.com/question/264189719/answer/291167114 来源:知乎 著作权归作者所有.商业转载请联系作 ...

最新文章

  1. 网站用户登录验证:Servlet+JSP VS Struts书剑恩仇录
  2. 用Xwt构建跨平台应用程序[转载]
  3. 签名工具 signtool.exe 参数简介
  4. C++ Primer 5th笔记(chap 17 标准库特殊设施)smatch
  5. python中可选参数_带可选参数的Python函数
  6. Spring4.x集成xfire1.26 问题汇总
  7. 爬虫职业道德----查看Robots.txt
  8. url即统一资源定位符
  9. 详解Android动画之Tween Animation
  10. JQueryDOM之CSS操作
  11. Linux 安装Nginx详细图解教程
  12. Mysql中的delimiter详解
  13. HDU 2234 无题I
  14. python 3模块导入(import)问题一则
  15. 信噪比 香农公式_「香农公式」信噪比/香农公式 - seo实验室
  16. @@@Blog总目录@@@
  17. 广告SDK平台中的CPA、CPS、CPM、CPT、CPC 是什么
  18. 1149:最长单词2
  19. SAP message TK 248 solved
  20. 发送短信并存入短信库

热门文章

  1. c语言实现24位彩色图像二值化
  2. 工业物联网创新方案亮相2018云栖大会
  3. Android高手秘笈之View的挂载
  4. 申请CVE的姿势总结
  5. 数学建模所需计算机知识
  6. Python编程——数字
  7. 微信提现报证书已过期
  8. 央视解说之韩乔生巅峰之作--夏普
  9. 计算机派位志愿填报技巧,海淀小升初哪些入学途径采取电脑派位 志愿又怎么填报 2021家长了解...
  10. 题解 SP2916 【GSS5 - Can you answer these queries V】