随机梯度下降算法(Stochastic gradient descent,SGD)在神经网络模型训练中,是一种很常见的优化算法。这种算法是基于梯度下降算法产生的,所以要理解随机梯度下降算法,必须要对梯度下降算法有一个全面的理解。

梯度下降:

这个算法我在之前的博文Logistic Regression的数学推导过程以及Python实现 中有详细的说明介绍,这里我们再来简单回顾一下梯度下降算法:假设在逻辑斯蒂回归中,预测函数为 hθ(x)=θ0+θ1x1+θ2x2+...+θnxn{h_\theta }(x) = {\theta _0} + {\theta _1}{x_1} + {\theta _2}{x_2} + ... + {\theta _n}{x_n}hθ​(x)=θ0​+θ1​x1​+θ2​x2​+...+θn​xn​,我们用平方损失函数可以得到这个函数的损失函数:J(θ)=12∑i=1m(hθ(x)−y)2J(\theta ) = \frac{1}{2}\sum\limits_{i = 1}^m {{{({h_\theta }(x) - y)}^2}} J(θ)=21​i=1∑m​(hθ​(x)−y)2
我们的目标就是最小化函数的损失函数,我们就对每一个 θi{\theta _i}θi​ 超参数求偏导,就可以得到当前这一轮的梯度,然后损失函数向梯度的反方向进行更新,不断这样进行迭代更新,就可以得到超参数的全局最优解。运用链式求导方法,数学过程可以表示成:
∂∂θjJ(θ)=∂∂θj12(hθ(x)−y)2=(hθ(x)−y)⋅∂∂θj(hθ(x)−y)=(hθ(x)−y)⋅∂∂θj(∑i=0nθixi−y)=(hθ(x)−y)xj\frac{\partial }{{\partial {\theta _j}}}J(\theta ) = \frac{\partial }{{\partial {\theta _j}}}\frac{1}{2}{({h_\theta }(x) - y)^2} = ({h_\theta }(x) - y) \cdot \frac{\partial }{{\partial {\theta _j}}}({h_\theta }(x) - y) \\= ({h_\theta }(x) - y) \cdot \frac{\partial }{{\partial {\theta _j}}}(\sum\limits_{i = 0}^n {{\theta _i}{x_i}} - y) = ({h_\theta }(x) - y){x_j}∂θj​∂​J(θ)=∂θj​∂​21​(hθ​(x)−y)2=(hθ​(x)−y)⋅∂θj​∂​(hθ​(x)−y)=(hθ​(x)−y)⋅∂θj​∂​(i=0∑n​θi​xi​−y)=(hθ​(x)−y)xj​
这是每一轮迭代的梯度,我们加上 learninglearninglearning rateraterate α\alphaα,就可以得到完整的梯度下降的公式:
θj:=θj−α(hθ(x)−y)xj{\theta _j}: = {\theta _j} - \alpha ({h_\theta }(x) - y){x_j}θj​:=θj​−α(hθ​(x)−y)xj​
这个过程就像是在一个山脉中寻找一个最低的山谷,我们用学习率为 α\alphaα 的步长一步步地向山谷的大致方向移动,我们每一步只能向山谷的方向靠近,每一步都在进步,整个过程可以参考下面的动图(图片来自知乎-量子位):

在图中可见,小球从山顶从不同的方向梯度滚下山,这就是梯度下降的过程。但是梯度下降算法每一步的更新都需要计算所有超参数的梯度,迭代速度必然会很慢,我们有没有比较快速的梯度下降算法呢,这里就可以用到我们的随机梯度下降算法
这个算法的流程就是在每次更新的时候使用一个样本进行梯度下降,所谓的随机二字,就是说我们可以随机用一个样本来表示所有的样本,来调整超参数 θ\thetaθ ,算法的公式如下所示:
LoopLoopLoop {
forforfor iii ininin range(m):range(m):range(m): {
θj:=θj+ α(y(i)−hθ(x(i)))xj(i){\theta _j}: = {\theta _j}{\text{ + }}\alpha ({y^{(i)}} - {h_\theta }({x^{(i)}}))x_j^{(i)}θj​:=θj​ + α(y(i)−hθ​(x(i)))xj(i)​
}
}
因为这个样本是随机的,所以每次迭代没有办法得到一个准确的梯度,这样一来虽然每一次迭代得到的损失函数不一定是朝着全局最优方向,但是大体的方向还是朝着全局最优解的方向靠近,直到最后,得到的结果通常就会在全局最优解的附近。这种算法相比普通的梯度下降算法,收敛的速度更快,所以在一般神经网络模型训练中,随机梯度下降算法 SGD 是一种非常常见的优化算法。
这就是大家在训练神经网络模型中常见的随机梯度下降SGD算法,希望可以帮助大家在理解优化算法上有所帮助,谢谢。

随机梯度下降SGD算法理解相关推荐

  1. 批量梯度下降(BGD)、随机梯度下降(SGD)以及小批量梯度下降(MBGD)的理解

    批量梯度下降(BGD).随机梯度下降(SGD)以及小批量梯度下降(MBGD)的理解 </h1><div class="clear"></div> ...

  2. 随机梯度下降(SGD)与经典的梯度下降法的区别

    随机梯度下降(SGD)与经典的梯度下降法的区别 经典的优化方法,例如梯度下降法,在每次迭代过程中需要使用所有的训练数据,这就给求解大规模数据优化问题带来挑战. 知识点:随机梯度下降法(SGD).小批量 ...

  3. 批量梯度下降(BGD)、随机梯度下降(SGD)以及小批量梯度下降(MBGD)及 batch、epoch、iteration

    先介绍一下梯度下降:梯度下降是一种用于机器学习训练参数的一种优化方法.对损失函数进行梯度下降,"梯度"指误差梯度或误差斜率,"下降"指沿着误差斜率移动到误差较小 ...

  4. 深度学习中的随机梯度下降(SGD)简介

    随机梯度下降(Stochastic Gradient Descent, SGD)是梯度下降算法的一个扩展. 机器学习中反复出现的一个问题是好的泛化需要大的训练集,但大的训练集的计算代价也更大.机器学习 ...

  5. 使用随机梯度下降SGD的BP反向传播算法的PyTorch代码实现

    Index 目录索引 写在前面 PyTorch的 .data() PyTorch的 .item() BP with SGD的PyTorch代码实现 参考文章 写在前面 本文将用一个完整的例子,借助Py ...

  6. python实现随机梯度下降(SGD)

    使用神经网络进行样本训练,要实现随机梯度下降算法.这里我根据麦子学院彭亮老师的讲解,总结如下,(神经网络的结构在另一篇博客中已经定义): def SGD(self, training_data, ep ...

  7. 几种优化算法的读书笔记——梯度下降、牛顿法、拟牛顿法、随机梯度下降、AdaGrad、RMSProp、Adam及选择优化算法的建议

    文章目录 1 梯度下降 1.1 特点 1.2 思想 1.3 数学基础 1.4 具体算法 2 牛顿法和拟牛顿法 2.1 特点 2.2 牛顿法 2.2.1 数学基础 2.2.2 思想 2.2.3 具体算法 ...

  8. 神经网络优化算法:随机梯度下降

    什么是优化算法? 优化算法的功能,是通过改善训练方式,来最小化(或最大化)损失函数E(x). 损失函数E(x): 模型内部有些参数,是用来计算测试集中目标值Y的真实值和预测值的偏差程度的,基于这些参数 ...

  9. 梯度下降法的三种形式BGD(批量梯度下降)、SGD(随机梯度下降)以及MBGD(小批量梯度下降)

    在应用机器学习算法时,我们通常采用梯度下降法来对采用的算法进行训练.其实,常用的梯度下降法还具体包含有三种不同的形式,它们也各自有着不同的优缺点. 下面我们以线性回归算法来对三种梯度下降法进行比较. ...

  10. NeurIPS'18 | 种群进化随机梯度下降深度神经网络优化算法框架

    ,欢迎关注公众号:论文收割机(paper_reader) 因为排版问题,很多图片和公式无法直接显示,欢迎关注我们的公众号点击目录来阅读原文. 原文点击 ↓ 种群进化随机梯度下降深度神经网络优化算法框架 ...

最新文章

  1. 需要恢复中断状态的一个场景
  2. php的Traits属性以及基本用法
  3. (五)Cisco dhcp snooping实例3-多交换机环境(DHCP服务器和DHCP客户端位于同VLAN)...
  4. Nginx 优化详解
  5. PHP点击跳转QQ,thinkphp3.2 获取QQconnect Login 跳转到的地址值
  6. python 提升效率_@Python 程序员,如何最大化提升编码效率?
  7. java比赛题目_【蓝桥杯2016第七届比赛题目】JAVA A组
  8. 如何生成 Flink 作业的交互式火焰图?
  9. HttpClient4.3.x的连接管理
  10. python读usb_使用Python来操作Microchip安全芯片
  11. php 密匙加密解密,带密匙的php加密解密示例分享
  12. c语言多位数加减,两个超长正整数的加法
  13. IAR for STM8安装教程
  14. C语言 Linux网络编程(C/S架构) 在线词典
  15. 【统计学习方法】朴素贝叶斯
  16. DT时代商业革命,大数据金融行业应用发展分析
  17. 集成XXL-JOB任务调度中心
  18. 【年终盘点之二】2020 区块链创新项目里程碑事件
  19. 计算机配置及性能测试,硬件配置及性能测试_联想笔记本电脑_笔记本评测-中关村在线...
  20. 神经网络优化图片大全,神经网络优化图片下载

热门文章

  1. 一文吃透strcmp函数
  2. 用英语介绍计算机系统,如何用英语介绍计算机系统
  3. 介绍几款免费APP在线制作社开发生成工具
  4. Spring代码实例系列-绪论
  5. 使用adblock plus浏览器插件屏蔽广告
  6. ROS机器人操作系统学习记录
  7. 如何撰写专利说明书摘要?
  8. ucgui 字体生成与字体个性化编辑
  9. html ios视频播放器,iOS 视频播放器(整理)
  10. J2EE框架学习经典总结