参数初始化就是这么一个容易被忽视的重要因素,因为不仅使用者对其重要性缺乏概念,而且这些操作都被TF、pytorch这些框架封装了,你可能不知道的是,糟糕的参数初始化是会阻碍复杂非线性系统的训练的。

本文以MNIST手写体数字识别模型为例来演示参数初始化对模型训练的影响。点击这里查看源码。

Xavier Initialization

早期的参数初始化方法普遍是将数据和参数normalize为高斯分布(均值0方差1),但随着神经网络深度的增加,这方法并不能解决梯度消失问题。

Figure 1: XavierInitialisation.pdf

Xavier初始化的作者,Xavier Glorot,在Understanding the difficulty of training deep feedforward neural networks论文中提出一个洞见:激活值的方差是逐层递减的,这导致反向传播中的梯度也逐层递减。要解决梯度消失,就要避免激活值方差的衰减,最理想的情况是,每层的输出值(激活值)保持高斯分布。

因此,他提出了Xavier初始化:bias初始化为0,为Normalize后的参数乘以一个rescale系数:1/√n,n是输入参数的个数。

公式的推导过程大致如下:

如果上述这段公式你看晕了,也没关系,只要记住结果就好。

接下来,我们要做实验来验证Xavier的高见。

def linear(x, w, b): return x @ w + bdef relu(x): return x.clamp_min(0.)nh = 50
W1 = torch.randn(784, nh)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1)
b2 = torch.zeros(1)z1 = linear(x_train, W1, b1)
print(z1.mean(), z1.std())tensor(-0.8809) tensor(26.9281)

这是个简单的线性回归模型:y=ax+b,(W1, b1)和(W2, b2)分别是隐层和输出层的参数,W1/W2初始化为高斯分布,b1/b2初始为0。果然,第一个linear层的输出值(z1)的均值和标准差就已经发生了很大的变化。如果后续使用sigmoid作为激活函数,那梯度消失就会很明显。

现在我们按照Xavier的方法来初始化参数:

W1 = torch.randn(784, nh) * math.sqrt(1 / 784)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1) * math.sqrt(1 / nh)
b2 = torch.zeros(1)z1 = linear(x_train, W1, b1)
print(z1.mean(), z1.std())tensor(0.1031) tensor(0.9458)a1 = relu(z1)
a1.mean(), a1.std()(tensor(0.4272), tensor(0.5915))

参数经过Xavier初始化后,linear层的输出值的分布没有大的变化(U[0.1031,0.6458]),依旧接近高斯分布,但是好景不长,relu的激活值分布就开始跑偏了(U[0.4272,0.5915])。

Kaiming Initialization

Xavier初始化的问题在于,它只适用于线性激活函数,但实际上,对于深层神经网络来说,线性激活函数是没有价值,神经网络需要非线性激活函数来构建复杂的非线性系统。今天的神经网络普遍使用relu激活函数。

Kaiming初始化的发明人kaiming he,在Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification论文中提出了针对relu的kaiming初始化。

因为relu会抛弃掉小于0的值,对于一个均值为0的data来说,这就相当于砍掉了一半的值,这样一来,均值就会变大,前面Xavier初始化公式中E(x)=mean=0的情况就不成立了。根据新公式的推导,最终得到新的rescale系数:√(2/n)。更多细节请看论文的section 2.2。

W1 = torch.randn(784, nh) * math.sqrt(2 / 784)
b1 = torch.zeros(nh)
W2 = torch.randn(nh, 1) * math.sqrt(2 / nh)
b2 = torch.zeros(1)z1 = linear(x_train, W1, b1)
a1 = relu(z1)
a1.mean(), a1.std()(tensor(0.4553), tensor(0.7339))

可以看到,Kaiming初始化的表现要优于Xavier初始化,relu之后的输出值标准差还有0.7339(浮动可以达到0.8+)。

实际上,Kaiming初始化已经被Pytorch用作默认的参数初始化函数。

import torch.nn.init as initW1 = torch.zeros(784, nh)
b1 = torch.zeros(nh)
W2 = torch.zeros(nh, 1)
b2 = torch.zeros(1)init.kaiming_normal_(W1, mode='fan_out', nonlinearity='relu')
init.kaiming_normal_(W2, mode='fan_out')
z1 = linear(x_train, W1, b1)
a1 = relu(z1)
print("layer1: ", a1.mean(), a1.std())
z2 = linear(a1, W2, b2)layer1:  tensor(0.5583) tensor(0.8157)
tensor(1.1784) tensor(1.3209)

现在,方差的问题已经解决了,接下来就是均值不为0的问题。因为在x轴上平移data并不会影响data的方差,因此,如果把relu的激活值左移5,结果会如何?

def linear(x, w, b):return x @ w + bdef relu(x):return x.clamp_min(0.) - 0.5def model(x):x = relu(linear(x, W1, b1))print("layer1: ", x.mean(), x.std())x = relu(linear(x, W2, b2))print("layer2: ", x.mean(), x.std())x = linear(x, W3, b3)print("layer3: ", x.mean(), x.std())return xnh = [100, 50]
W1 = torch.zeros(784, nh[0])
b1 = torch.zeros(nh[0])
W2 = torch.zeros(nh[0], nh[1])
b2 = torch.zeros(nh[1])
W3 = torch.zeros(nh[1], 1)
b3 = torch.zeros(1)init.kaiming_normal_(W1, mode='fan_out')
init.kaiming_normal_(W2, mode='fan_out')
init.kaiming_normal_(W3, mode='fan_out')
_ = model(x_train)layer1:  tensor(0.0383) tensor(0.7993)
layer2:  tensor(0.0075) tensor(0.7048)
layer3:  tensor(-0.2149) tensor(0.4493)

结果出乎意料的好,这个三层的模型在没有添加batchnorm的情况下,每层的输入值和输出值都接近高斯分布,虽然数据方差是会逐层递减,但相比normalize初始化和Xavier初始化要好很多。

最后,因为Kaiming初始化是pytorch的默认初始化函数,因此我又用pytorch提供的nn.Linear()和nn.Relu()来构建相同的模型对比测试,结果是大跌眼镜。

class Model(nn.Module):def __init__(self):super().__init__()self.lin1 = nn.Linear(784, nh[0])self.lin2 = nn.Linear(nh[0], nh[1])self.lin3 = nn.Linear(nh[1], 1)self.relu = nn.ReLU()def forward(self, x):x = self.relu(self.lin1(x))print("layer 1: ", x.mean().item(), x.std().item())x = self.relu(self.lin2(x))print("layer 2: ", x.mean().item(), x.std().item())x = self.relu(self.lin3(x))print("layer 3: ", x.mean().item(), x.std().item())return xm = Model()
_ = m(x_train)layer 1:  0.2270725518465042 0.32707411050796
layer 2:  0.033514849841594696 0.23475737869739532
layer 3:  0.013271240517497063 0.09185370802879333

可以看到,第三层的输出已经均值为0、方差为0。去看nn.Linear()类的代码时会看到,它在做初始化时会传入参数a=math.sqrt(5)。我们知道,当输入为负数时,leaky relu的梯度为[0,∞),x = λx,参数a就是这个λ。虽然kaiming_uniform_()的默认网络要使用的激活函数是leaky relu,但a默认值为0,此时leaky relu就等于relu。但现在数据存在负数,因此,mean相比relu模型更接近于0,甚至E(x) > 0的假设都不成立了,因此,rescale系数就不准确了,nn.Linear()才会有这样的表现。

    def reset_parameters(self):init.kaiming_uniform_(self.weight, a=math.sqrt(5))

END

本文通过Xavier和Kaiming初始化来展现了参数初始化的重要性,因为糟糕的初始化容易让神经网络陷入梯度消失的陷阱中。

References

  • Understanding the difficulty of training deep feedforward neural networks
  • Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification
  • understanding-xavier-initialization-in-deep-neural-networks
  • https://github.com/fastai/course-v3/blob/master/nbs/dl2/02_fully_connected.ipynb

搞懂深度网络初始化(Xavier and Kaiming initialization)相关推荐

  1. 太强了! 李宏毅:1 天搞懂深度学习,我总结了 300 页 PPT

    <1 天搞懂深度学习>,300 多页的 ppt,台湾李宏毅教授写的,非常棒.不夸张地说,是我看过最系统,也最通俗易懂的,关于深度学习的文章. 这份 300 页的 PPT,被搬运到了 Sli ...

  2. 下载 | 李宏毅:1 天搞懂深度学习,我总结了 300 页 PPT

    <1 天搞懂深度学习>,300 多页的 ppt,台湾李宏毅教授写的,非常棒.不夸张地说,是我看过最系统,也最通俗易懂的,关于深度学习的文章. 这份 300 页的 PPT,被搬运到了 Sli ...

  3. 【深度学习】李宏毅:1 天搞懂深度学习,我总结了 300 页 PPT(附思维导图)...

    转载自:机器学习算法那些事 ID:Charlotte77 公众号:Charlotte数据挖掘 By    Charlotte77 前言:李宏毅的教材,非常经典,B站有配套视频,文末附下载链接!     ...

  4. 性能优化|解读面试题,彻底搞懂类加载和初始化顺序

    解读面试题,彻底搞懂类加载和初始化顺序 在高级面试过程中,始终逃不过面试官的追问三连: 你知道jvm是怎么加载类的么? 类的初始化顺序你有了解么? 我出一个面试题,你能答出来么? 三连问下来,恐怕自己 ...

  5. 干货 | 台大“一天搞懂深度学习”课程PPT(下载方式见文末!!)

    微信公众号 关键字全网搜索最新排名 [机器学习算法]:排名第一 [机器学习]:排名第一 [Python]:排名第三 [算法]:排名第四 Deep Learing Tutorial 本篇文章我们给出了一 ...

  6. 李宏毅——一天搞懂深度学习PPT学习笔记

    李宏毅一天搞懂机器学习PPT,SildeShare链接:https://www.slideshare.net/tw_dsconf/ss-62245351?qid=108adce3-2c3d-4758- ...

  7. python2.7爬虫实例-用案例让你一文搞懂python网络爬虫

    声明:本文来自于微信公众号  数据EDTA(ID:livandata),作者: livan,授权站长之家转载发布. 很久以前写了一篇爬虫的文章,把它放在CSDN上(livan1234)没想到点击量竟然 ...

  8. [1天搞懂深度学习] 读书笔记 lecture I:Introduction of deep learning

    - 通常机器学习,目的是,找到一个函数,针对任何输入:语音,图片,文字,都能够自动输出正确的结果. - 而我们可以弄一个函数集合,这个集合针对同一个猫的图片的输入,可能有多种输出,比如猫,狗,猴子等, ...

  9. 一文搞懂深度学习正则化的L2范数

    想要彻底弄明白L2范数,必须要有一定的矩阵论知识,L2范数涉及了很多的矩阵变换.在我们进行数学公式的推到之前,我们先对L2范数有一个感性的认识. L2范数是什么? L2范数的定义其实是一个数学概念,其 ...

最新文章

  1. 37. 两个链表的第一个公共结点
  2. 68 Centos7安装Zabbix 5.0 版本
  3. 全卷积神经网路【U-net项目实战】Unet++
  4. Java语言学习思维导图
  5. QT绘制饼状图,自定义切片。
  6. word2vec应用场景_word2vec有什么应用?
  7. 按季度分类汇总_2019年纯碱行业相关上市公司季报 与半年报情况汇总
  8. 配置opencv cmake
  9. jenkins构建后脚本不执行_接口管理工具ApiPost-预(后)执行脚本常用方法集合
  10. 【numpy】数组增加一维(升维)小结
  11. 数据结构------递归+迷宫问题+最短路径问题解决思路
  12. matlab 的 legend 用法
  13. 斜齿轮重合度计算公式_斜齿轮重合度计算
  14. tesseract-ocr验证码识别
  15. 邓凡平WIFI学习笔记4:WiFi Simple configuration
  16. 使用阿里云PCDN降低内容分发成本
  17. Zuul动态路由及动态Filter实现
  18. C++中局部变量和全局变量的存储位置和内存回收机制
  19. hadoop 3.x 启动过程中 Permission denied (publickey,gssapi-keyex,gssapi-with-mic,password).
  20. DSF框架使用(DAO、序列化、注解、服务接口、服务代理)

热门文章

  1. 蒟蒻君的数学学习之路1:斐波那契数列的n种解法
  2. 吉大计算机如何本科进实验室,实验室简介-吉林大学理论化学计算实验室
  3. SqlSession和SqlSessionTemplate的不解姻缘系列之一(总体阐述)
  4. 如何把20秒熊本熊GIF图发送给微信好友
  5. Oracle数据库面试题 精选 Oracle 面试题
  6. 个人计算机组装主板,组装一台电脑需要哪些配件【详细列举】
  7. 《后宫》成明朝女人帮 安以轩称冯绍峰“小强”
  8. 什么是PHP以及PHP的特性
  9. service层中注入conroller_springMVC中controller层调用service层的方式
  10. css 上下左右居中5种方法