台大李宏毅Machine Learning 2017Fall学习笔记 (8)Backpropagation
台大李宏毅Machine Learning 2017Fall学习笔记 (8)Backpropagation
当网络结构很复杂时,会有大量的参数。∇L(θ)\nabla L(\theta)是百万维的向量。如何高效地计算百万维的参数,使用反向传播算法来计算。BP并非是一个和GD不同的训练方法,BP就是GD,只是是一种比较有效率的计算方法。
数学知识铺垫:微积分中的链式法则,很简单。
还是以上节中手写数字识别为例。
xnx^n是一张输入图片,yny^n是网络的输出labellabel向量,y^n\hat y^n是该图片的真值labellabel向量。CnC^n是输出值和真实值的交叉熵损失。定义L(θ)L(\theta)为损失函数。
L(\theta)=\sum_{n=1}^NC^n(\theta)
损失函数对参数的导数为:
\frac{\partial L(\theta)}{\partial w}=\sum_{n=1}^N\frac{\partial C^n(\theta)}{\partial w}
如下图所示: ∂C∂w=∂z∂w∂C∂z\frac{\partial C}{\partial w}=\frac{\partial z}{\partial w}\frac{\partial C}{\partial z}, BackpropagationBackpropagation算法分为两个过程。
Forward pass
首先计算前向传播中的∂z∂w\frac{\partial z}{\partial w}。以上图为例。
\frac{\partial z}{\partial w_1}=x_1
\frac{\partial z}{\partial w_2}=x_2
显然这一步比较简单,某一参数的微分值就是其对应的输入值。注意要把所有 ∂z∂w\frac{\partial z}{\partial w}的值计算出来。
Backward pass
然后计算反向传播中损失函数对于激活函数输入值的偏微分∂C∂z\frac{\partial C}{\partial z}。
如下图中所示:∂C∂z=∂a∂z∂C∂a\frac{\partial C}{\partial z}=\frac{\partial a}{\partial z}\frac{\partial C}{\partial a},∂a∂z=σ′(z)\frac{\partial a}{\partial z}=\sigma'(z)。
利用链式法则计算∂C∂a\frac{\partial C}{\partial a}.
稍微整理一下,成为下图这样。
下图中很形象地展示了反向传播的概念,σ′(z)\sigma'(z)类似模拟电路中的放大器。
最后一步是计算∂C∂z′\frac{\partial C}{\partial z'}和∂C∂z′′\frac{\partial C}{\partial z''}。这分两种情况:1)z′z'和z′′z''的下一层是输出层;2)z′z'和z′′z''的下一层不是输出层。
Case1:Case1:输出层
Case2:Case2:非输出层
不断地递归计算∂C∂z\frac{\partial C}{\partial z},直至输出层,如下图。
注意:在backward pass过程中也需要对所有的zz,计算出∂C∂z\frac{\partial C}{\partial z}.
Summary
一图胜千言。
台大李宏毅Machine Learning 2017Fall学习笔记 (8)Backpropagation相关推荐
- 台大李宏毅Machine Learning 2017Fall学习笔记 (16)Unsupervised Learning:Neighbor Embedding
台大李宏毅Machine Learning 2017Fall学习笔记 (16)Unsupervised Learning:Neighbor Embedding
- 台大李宏毅Machine Learning 2017Fall学习笔记 (14)Unsupervised Learning:Linear Dimension Reduction
台大李宏毅Machine Learning 2017Fall学习笔记 (14)Unsupervised Learning:Linear Dimension Reduction 本博客整理自: http ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (13)Semi-supervised Learning
台大李宏毅Machine Learning 2017Fall学习笔记 (13)Semi-supervised Learning 本博客参考整理自: http://blog.csdn.net/xzy_t ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (12)Why Deep?
台大李宏毅Machine Learning 2017Fall学习笔记 (12)Why Deep? 本博客整理自: http://blog.csdn.net/xzy_thu/article/detail ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (11)Convolutional Neural Network
台大李宏毅Machine Learning 2017Fall学习笔记 (11)Convolutional Neural Network 本博客主要整理自: http://blog.csdn.net/x ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (10)Tips for Deep Learning
台大李宏毅Machine Learning 2017Fall学习笔记 (10)Tips for Deep Learning 注:本博客主要参照 http://blog.csdn.net/xzy_thu ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (9)Keras
台大李宏毅Machine Learning 2017Fall学习笔记 (9)Keras 本节课主要讲述了如何利用Keras搭建深度学习模型.Keras是基于TensorFlow封装的上层API,看上去 ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (7)Introduction of Deep Learning
台大李宏毅Machine Learning 2017Fall学习笔记 (7)Introduction of Deep Learning 最近几年,deep learning发展的越来越快,其应用也越来 ...
- 台大李宏毅Machine Learning 2017Fall学习笔记 (6)Logistic Regression
台大李宏毅Machine Learning 2017Fall学习笔记 (6)Logistic Regression 做Logistic Regression回归,需要3步. Step 1: Funct ...
最新文章
- appium-DesiredCapability详解与实战
- icinga安装介绍,监控软件
- buffer pool mysql_MySQL 5.7版本新特性(修改buffer pool,无需重启服务)
- [云炬Python学习笔记] Python读取指定文件夹下的文件
- 模仿u-boot的makefile结构
- 137_Power BI 自定义矩阵复刻Beyondsoft Calendar
- DOTNET零碎总结---VB.NET修改数据存在多个txtbox时,SQL语句的操作
- 花花的礼物 (huahua)
- 面试题解:输入一个数A,找到大于A的一个最小数B,且B中不存在连续相等的两个数字...
- Linux 应用市场易受RCE和供应链攻击,多个0day未修复
- 结合上下文和篇章特征的多标签情绪分类
- Helm 3 完整教程(五):Helm 内置对象详解
- 【渝粤教育】国家开放大学2018年春季 8647-21T工程经济与管理 参考试题
- BMS 项目过程中遇到的问题
- 几个linux中有趣的游戏
- 深度学习计算机视觉的简介_商业用途计算机视觉简介
- 树莓派4B引脚定义及运行实例
- 热血江湖游戏中断开服务器,为什么最近老是一进去游戏就提示与服务器断开 – 手机爱问...
- 西数硬盘刷新固件_机械硬盘选购:SMR避坑指南
- 基于java开发的网上商城系统
热门文章
- Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
- 【OpenCV】IplImage和char *的相互转换,以及极易忽视的细节
- mysql if join_如何在MySQL中使用JOIN编写正确的If … Else语句?
- oracle性能优化求生指南_Vue项目性能优化--实践指南,网上最全最详细
- linux删除静态arp,Linux如何清理ARP缓存?
- 在C ++中将String转换为Integer并将Integer转换为String
- winform布局、控件
- 为VS2010添加背景图
- ROS的学习(十九)用rosserial创建一个subscriber
- C++的学习(十)类和对象