摘要

本文给出 Batch Normalization 函数的定义, 并求解其在反向传播中的梯度

相关

配套代码, 请参考文章 :

Python和PyTorch对比实现批标准化Batch Normalization函数及反向传播

本文仅介绍Batch Normalization的训练过程, 测试或推理过程请参考 :

Batch Normalization的测试或推理过程及样本参数更新方法

系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981

正文

1. 概念

批标准化 (Batch Normalization) 的思想来自传统的机器学习, 主要为了处理数据取值范围相差过大的问题.
比如, 正常成年人每升血液中所含血细胞的数量:

项目 数量
红细胞计数 RBC 3.5×1012∼5.5×10123.5×10^{12} \sim 5.5×10^{12}3.5×1012∼5.5×1012 个
白细胞计数 WBC 5.0×109∼10.0×1095.0×10^9 \sim 10.0×10^95.0×109∼10.0×109 个
血小板计数 PLT 1.5×1011∼3.5×10111.5×10^{11} \sim 3.5×10^{11}1.5×1011∼3.5×1011 个
血红蛋白 Hb 110∼160g/L110 \sim 160g/L110∼160g/L

如果这些指标发生异常变化, 人体就可能患病.
这些数据不仅量级差别非常大, 血红蛋白的单位还和其他项目不一样, 不可能直接用于机器学习.
传统的标准化方法 (Normalization) 是将这些数据统一缩放为 0 ~ 1 之间的数据.

深度神经网络学习过程中的 Batch Normalization 与之类似, 不同点在于数据规模非常大, 只能分批处理, 故称为批标准化.

2. 定义

批标准化是对同一个指标下的数据进行处理的, 与其他指标无关.
将同一个项目下的数据用向量 x 表示:
x=(x1,x2,x2,⋯ ,xk)x = (x_1,x_2,x_2,\cdots,x_k) x=(x1​,x2​,x2​,⋯,xk​)

均值 mmm 及方差 vvv 是标量 :
m=∑t=1kxt/n  v=∑t=1k(xt−m)2/nm=\sum_{t=1}^{k}x_{t}/n\\ \;\\ v =\sum_{t=1}^{k} (x_{t} - m)^2/n m=t=1∑k​xt​/nv=t=1∑k​(xt​−m)2/n

为防止分母为零, 设一个极小数 ε\varepsilonε, 如 ε=10−5\varepsilon=10^{-5}ε=10−5, 则数据标准化为 :
si=xi−mv+εs_{i} = \frac{x_{i} - m}{\sqrt{v + \varepsilon}} si​=v+ε​xi​−m​

为了增强数据的表征力, 添加一个线性变换, 得 :
yi=w⋅si+b  yi  为  xi  经过  BatchNormalization  转换后的数据  w  和  b  是标量,对本批次本指标中所有si是相同的y_i =w \cdot s_i + b\\ \;\\ y_i \;为\;x_i\;经过\;Batch Normalization\;转换后的数据\\ \;\\ w \;和\;b\;是标量, 对本批次本指标中所有 s_i 是相同的 yi​=w⋅si​+byi​为xi​经过BatchNormalization转换后的数据w和b是标量,对本批次本指标中所有si​是相同的

3. 训练过程中的反向传播的梯度

3.1 误差 e 对 x 的梯度

考虑一个 k 维输入向量 x , 经 Batch Normalization 得到向量 y, 往前 forward 传播得到误差值 error (标量 e ). 上游的误差梯度向量 ∇e(y)\nabla e_{(y)}∇e(y)​ 已在反向传播时得到, 求 e 对 x 的梯度.

已知 :
e=forward(y)  ∇e(y)=dedy=(∂ey1,∂ey2,∂ey3,⋯ ,∂eyk)  m=∑t=1kxt/k  v=∑t=1k(xt−m)2/k  si=xi−mv+ε  yi=w⋅si+be=forward(y)\\ \;\\ \nabla e_{(y)}=\frac{de}{dy}=(\frac{\partial e}{y_1}, \frac{\partial e}{y_2}, \frac{\partial e}{y_3}, \cdots, \frac{\partial e}{y_k} )\\ \;\\ m=\sum_{t=1}^{k}x_{t}/k\\ \;\\ v =\sum_{t=1}^{k} (x_{t} - m)^2/k\\ \;\\ s_{i} = \frac{x_{i} - m}{\sqrt{v + \varepsilon}}\\ \;\\ y_i =w \cdot s_i + b\\ e=forward(y)∇e(y)​=dyde​=(y1​∂e​,y2​∂e​,y3​∂e​,⋯,yk​∂e​)m=t=1∑k​xt​/kv=t=1∑k​(xt​−m)2/ksi​=v+ε​xi​−m​yi​=w⋅si​+b
求解过程 :

均值 mmm 和方差 vvv 是标量 :
dmdxi=1/k  dvdxi=2k∑t=1k(xt−m)(dxtdxi−dmdxi)=2k∑t=1k(xt−m)(dxtdxi−1k)\frac{dm}{dx_i} = 1/k \;\\ \frac{dv}{dx_i}=\frac{2}{k}\sum_{t=1}^{k}(x_t-m)(\frac{dx_t}{dx_i}-\frac{dm}{dx_i})=\frac{2}{k}\sum_{t=1}^{k}(x_t-m)(\frac{dx_t}{dx_i}-\frac{1}{k}) dxi​dm​=1/kdxi​dv​=k2​t=1∑k​(xt​−m)(dxi​dxt​​−dxi​dm​)=k2​t=1∑k​(xt​−m)(dxi​dxt​​−k1​)

=2k∑t=1k(xt−m)dxtdxi−2k∑t=1k(xt−m)1k  ∑t=1k(xt−m)=0=\frac{2}{k}\sum_{t=1}^{k}(x_t-m)\frac{dx_t}{dx_i}-\frac{2}{k}\sum_{t=1}^{k}(x_t-m)\frac{1}{k}\\ \;\\ \sum_{t=1}^{k}(x_t-m)=0 =k2​t=1∑k​(xt​−m)dxi​dxt​​−k2​t=1∑k​(xt​−m)k1​t=1∑k​(xt​−m)=0

dxtdxi={1,t=i0,t≠i\frac{dx_t}{dx_i}=\left\{ \begin{array}{rr} 1, & t = i\\ 0, & t \neq i \end{array} \right. dxi​dxt​​={1,0,​t=it̸​=i​

dvdxi=2(xi−m)/k\frac{dv}{dx_i} = 2(x_i-m)/k dxi​dv​=2(xi​−m)/k

向量 sss 对向量 xxx 求导是一个雅可比矩阵 :
∇s(x)=dsdx=(∂s1/∂x1∂s1/∂x2∂s1/∂x3⋯∂s1/∂xk∂s2/∂x1∂s2/∂x2∂s2/∂x3⋯∂s2/∂xk∂s3/∂x1∂s3/∂x2∂s3/∂x3⋯∂s3/∂xk⋮⋮⋮⋱⋮∂sk/∂x1∂sk/∂x2∂sk/∂x3⋯∂sk/∂xk)\nabla s_{(x)}=\frac{ds}{dx}= \begin{pmatrix} \partial s_1/\partial x_1 & \partial s_1/\partial x_2 & \partial s_1/\partial x_3 &\cdots & \partial s_1/\partial x_k\\ \partial s_2/\partial x_1 & \partial s_2/\partial x_2 & \partial s_2/\partial x_3 &\cdots & \partial s_2/\partial x_k\\ \partial s_3/\partial x_1 & \partial s_3/\partial x_2 & \partial s_3/\partial x_3 &\cdots & \partial s_3/\partial x_k\\ \vdots& \vdots & \vdots &\ddots & \vdots\\ \partial s_k/\partial x_1 & \partial s_k/\partial x_2 & \partial s_k/\partial x_3 &\cdots & \partial s_k/\partial x_k\\ \end{pmatrix} ∇s(x)​=dxds​=⎝⎜⎜⎜⎜⎜⎛​∂s1​/∂x1​∂s2​/∂x1​∂s3​/∂x1​⋮∂sk​/∂x1​​∂s1​/∂x2​∂s2​/∂x2​∂s3​/∂x2​⋮∂sk​/∂x2​​∂s1​/∂x3​∂s2​/∂x3​∂s3​/∂x3​⋮∂sk​/∂x3​​⋯⋯⋯⋱⋯​∂s1​/∂xk​∂s2​/∂xk​∂s3​/∂xk​⋮∂sk​/∂xk​​⎠⎟⎟⎟⎟⎟⎞​

当 i=ji = ji=j 时,

∂si∂xj=(1−1/k)(v+ε)−(xi−m)(v+ε)−0.5(xj−m)/kv+ε  =k−1−sisjkv+ε\frac{\partial s_i}{\partial x_j}=\frac{(1-1/k)(\sqrt{v + \varepsilon}) - (x_i - m)(v + \varepsilon)^{-0.5}(x_j-m)/k}{v + \varepsilon}\\ \;\\ =\frac{k-1 - s_is_j}{k\sqrt{v + \varepsilon}} ∂xj​∂si​​=v+ε(1−1/k)(v+ε​)−(xi​−m)(v+ε)−0.5(xj​−m)/k​=kv+ε​k−1−si​sj​​

当 i≠ji \neq ji̸​=j 时,
∂si∂xj=(−1/k)(v+ε)−(xi−m)(v+ε)−0.5(xj−m)/kv+ε  =−1−sisjkv+ε\frac{\partial s_i}{\partial x_j}=\frac{(-1/k)(\sqrt{v + \varepsilon}) - (x_i - m)(v + \varepsilon)^{-0.5}(x_j-m)/k}{v + \varepsilon}\\ \;\\ =\frac{-1 - s_is_j}{k\sqrt{v + \varepsilon}} ∂xj​∂si​​=v+ε(−1/k)(v+ε​)−(xi​−m)(v+ε)−0.5(xj​−m)/k​=kv+ε​−1−si​sj​​
代入上式可得矩阵 ∇s(x)\nabla s_{(x)}∇s(x)​.
dyidsi=w  dyidxj={w(k−1−sisj)/(kv+ε),i=jw(−1−sisj)/(kv+ε),i≠j\frac{dy_i}{ds_i}=w\\ \;\\ \frac{dy_i}{dx_j}= \left\{ \begin{array}{rr} w(k-1 - s_is_j)/({k\sqrt{v + \varepsilon}}), & i = j\\ w(-1 - s_is_j)/({k\sqrt{v + \varepsilon}}), & i \neq j \end{array} \right. dsi​dyi​​=wdxj​dyi​​={w(k−1−si​sj​)/(kv+ε​),w(−1−si​sj​)/(kv+ε​),​i=ji̸​=j​

∇e(x)=∇e(y)∇y(x)\nabla e_{(x)}=\nabla e_{(y)}\nabla y_{(x)} ∇e(x)​=∇e(y)​∇y(x)​

其中, ∇e(y)\nabla e_{(y)}∇e(y)​ 是一个向量, ∇y(x)\nabla y_{(x)}∇y(x)​ 是一个雅克比矩阵, 最后的结果 ∇e(x)\nabla e_{(x)}∇e(x)​ 是一个向量.

为了方便编程实现, 定义一个标量 uuu 和矩阵 RRR, 其中:
u=wkv+ε  rij={k−1−sisj,i=j−1−sisj,i≠ju = \frac{w}{k \sqrt{v + \varepsilon}}\\ \;\\ r_{ij}= \left\{ \begin{array}{rr} k-1 - s_is_j, & i = j\\ -1 - s_is_j, & i \neq j \end{array} \right. u=kv+ε​w​rij​={k−1−si​sj​,−1−si​sj​,​i=ji̸​=j​
则 :
∇y(x)=uR\nabla y_{(x)}=uR ∇y(x)​=uR

3.2 误差 e 对 w 或 b 的梯度

∇e(w)=dedy1dy1dw+dedy2dy2dw+⋯+dedykdykdw=∇e(y)⋅s  ∇e(b)=dedy1dy1db+dedy2dy2db+⋯+dedykdykdb=∑i=1k∇e(y)\nabla e_{(w)}=\frac{de}{dy_1}\frac{dy_1}{dw}+\frac{de}{dy_2}\frac{dy_2}{dw}+ \cdots +\frac{de}{dy_k}\frac{dy_k}{dw}=\nabla e_{(y)} \cdot s\\ \;\\ \nabla e_{(b)}=\frac{de}{dy_1}\frac{dy_1}{db}+\frac{de}{dy_2}\frac{dy_2}{db}+ \cdots +\frac{de}{dy_k}\frac{dy_k}{db}=\sum_{i=1}^{k} \nabla e_{(y)} ∇e(w)​=dy1​de​dwdy1​​+dy2​de​dwdy2​​+⋯+dyk​de​dwdyk​​=∇e(y)​⋅s∇e(b)​=dy1​de​dbdy1​​+dy2​de​dbdy2​​+⋯+dyk​de​dbdyk​​=i=1∑k​∇e(y)​

其中, ∇e(w)\nabla e_{(w)}∇e(w)​ 是向量点积得到的标量, ∇e(b)\nabla e_{(b)}∇e(b)​ 是求和得到的标量.

3.3 小提示

如果输入的是一个 k 行矩阵 X, 每一行对应一条包含 n 个项目的数据, 批标准化是逐列处理的, 编程实现时需要注意这一点.

全文完.

Batch Normalization函数详解及反向传播中的梯度求导相关推荐

  1. 机器学习--多标签softmax + cross-entropy交叉熵损失函数详解及反向传播中的梯度求导

    https://blog.csdn.net/oBrightLamp/article/details/84069835 正文 在大多数教程中, softmax 和 cross-entropy 总是一起出 ...

  2. L2正则化Regularization详解及反向传播的梯度求导

    摘要 本文解释L2正则化Regularization, 求解其在反向传播中的梯度, 并使用TensorFlow和PyTorch验证. 相关 系列文章索引 : https://blog.csdn.net ...

  3. pythonpandas函数详解_对pandas中Series的map函数详解

    Series的map方法可以接受一个函数或含有映射关系的字典型对象. 使用map是一种实现元素级转换以及其他数据清理工作的便捷方式. (DataFrame中对应的是applymap()函数,当然Dat ...

  4. python input函数详解_对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函 ...

  5. 【机器学习】详解 BackPropagation 反向传播算法!

    首先介绍一下链式法则 假如我们要求z对x1的偏导数,那么势必得先求z对t1的偏导数,这就是链式法则,一环扣一环 BackPropagation(BP)正是基于链式法则的,接下来用简单的前向传播网络为例 ...

  6. CBN(Cross-Iteration Batch Normalization)论文详解

    原文链接:Cross-Iteration Batch Normalization 代码链接:https://github.com/Howal/Cross-iterationBatchNorm 随着BN ...

  7. 神经网络反向传播的矩阵复合求导计算

    以前一直以为矩阵的复合求导和可微函数的链式求导是一样的,但是在推导神经网络梯度公式的时候往往会出现一些符号次序不对的问题,我这里借用吴恩达编程作业中给出的浅层神经网络的案例来计算反向传播的梯度.关于矩 ...

  8. 转载:汇总详解:矩阵的迹以及迹对矩阵求导

    矩阵的迹概念 矩阵的迹 就是 矩阵的主对角线上所有元素的和. 矩阵A的迹,记作tr(A),可知tra(A)=∑aii,1<=i<=n. 定理:tr(AB) = tr(BA) 证明 定理:t ...

  9. 【机器学习】汇总详解:矩阵基本知识以及矩阵求导

    1.矩阵的基本概念 1.1矩阵的迹(matrix trace) 存在方阵A=(aij)n×n,其主对角线上的所有元素的和,称为此方阵的迹,记作tr(A) tr(A)=a11+a22+--+ann tr ...

最新文章

  1. Visual Studio 2010 Ultimate开发与测试敏捷特性
  2. html写出日出,描写日出优美句子
  3. ppt扇形图怎么显示数据_前方高能!多维数据分析的神器雷达图PPT制作教程来啦!...
  4. C++类中protected访问权限问题
  5. c++用牛顿法开多次根_望远镜的历史之三:大神出世,改变望远镜历史的竟然是牛顿...
  6. javascript之继承
  7. php中正则表达式中的特殊符号
  8. 11.合并两个有序数组
  9. 影刀RPA实操指南丨90%用户都会陷入的excel自动化误区
  10. linux 安装TeamViewer
  11. pygame--图片随键盘移动
  12. 解决mkimage command not found - U-Boot images will not be buil
  13. 51nod 1243 排船的问题
  14. 【计算机网络】第三部分 数据链路层(15) 连接局域网、主干网和虚拟局域网
  15. 监控系统相关的常见面试问题
  16. vdsm:vdsm-client 命令行使用演示
  17. 数字孪生在制造业的7种应用
  18. 批处理 rewriteBatchedStatements=true
  19. 数字化的终局:赛博朋克?社会主义?
  20. Flutter黑马头条项目开发(二.底部切换导航和新闻页面开发)

热门文章

  1. 由“你”而生的公司危机【网络新生媒体的力量】
  2. shell中的数字比较符-eg,-ne, -gt, -It, -ge, -le
  3. 查看 MySQL 数据库中每个表占用的空间大小
  4. 如何在NS2中产生和使用Poisson Traffic
  5. 第一期:栈的经典例题
  6. 网上书城(登录、注册、权限管理)
  7. c语言微信抢红包的随机算法,微信红包的随机算法是怎样实现的?
  8. Qemu,KVM,Virsh傻傻的分不清
  9. Bootstrap 框架响应式网页开发
  10. 深度学习和dqn_深度Q学习方面的改进:双重DQN决斗,优先体验重播和固定…