这两天造的轮子每次训练上来几个epoch就爆,试过lr降到1e-5;将所有分母/log都加了极小值;加了梯度裁剪,改了损失函数,lr衰减改成指数形式,加载预训练权重,修改权重初始化方式等等全失败,一度怀疑是我代码逻辑哪里写的有问题,检查数遍无果。就在这时随手加了个warmup,卧槽居然不飞了,个人觉得原因是加载的预训练模型最后几个输出层类别个数和自己的任务类别个数存在差异,直接采用温和的lr衰减方式导致梯度飘了,需要在训练初期快速下降是模型稳定后再采用普通的衰减训练。被爆多次干脆这篇把梯度爆炸和弥散过一遍。

什么是梯度弥散和梯度爆炸?

直观上说深度学习的优化是基于反向传播和链式求导,每层的梯度会进行连乘,如果层数太深就容易造成>1的梯度连乘后变得很大,<1的梯度连乘后变得很小。所以每层中相乘的两个数:一个是初始化权重的值,一个的激活函数的导数就会影响传播值。

梯度弥散通常出现在以下两种情况:

  1. 神经网络层次太深;
  2. 采用了不合适的激活函数;

梯度爆炸通常出现在以下情况:

  1. 神经网络层次太深;
  2. 采用了不合适的初始化权重方式

在深层网络中,不同的层学习的速度差异很大,靠近输出的层学习的情况通常很好,靠近输入的层学习的通常很慢,有时甚至训练了很久,前几层的权值和刚开始随机初始化的值差不多。因此,梯度消失和梯度爆炸的根本原因在于反向传播训练法则,本质在于反向传播这个方法问题。

如何解决梯度弥散和爆炸?

方法一:恒等映射,根据上面提到的问题,无论是梯度弥散还是梯度爆炸都是由于网络层次太深造成连乘的缩小/放大。造成越靠近输入端的权重更新越缓慢,参考残差结构的思想采用恒等映射的方式设计神经网络跳跃连接,使反向传播的梯度值能直接传递到早期的层更新权重,残差结构如下图所示:

方法二:激活函数,每层的梯度值可以选择合适的激活函数来缓解梯度弥散问题;比如下图sigmoid函数的梯度随着x的增大或减小会进入饱和区,公式如下:

relu相比sigmoid属于非饱和激活函数,其导数在正数部分是恒等于1的,因此在深层网络中使用relu激活函数就不会导致梯度消失和爆炸的问题,每层的网络都可以得到相同的更新速度,公式如下:

但是relu激活函数由于负数部分为0,会导致神经元死亡,所以后来出现了很多变种形式,比如:leakrelu,elu等解决了relu的0区间带来的影响。

方法三:BatchNorm,把每层神经网络任意神经元这个输入值的分布强行拉回到接近均值为0方差为1的标准正太分布,即严重偏离的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,使得让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度,提升训练稳定性。

BN应作用在非线性映射前,即对x=Wu+b做规范化。将激活规范为均值和方差一致的手段使得原本会减小的activation的scale变大。

方法四:梯度裁剪,检查误差梯度的值是否超过阈值,如果超过,则截断梯度,将梯度设置为阈值,可以一定程度上缓解梯度爆炸问题。TF代码如下:

gvs = optimizer.compute_gradients(loss[0] + l2_loss, var_list=update_vars)
clip_grad_var = [gv if gv[0] is None else [tf.clip_by_norm(gv[0], 100.), gv[1]] for gv in gvs]
train_op = optimizer.apply_gradients(clip_grad_var, global_step=global_step)
####### tensorflow 提供的 API函数 ######
tf.clip_by_value(t, clip_value_min, clip_value_max)
tf.clip_by_norm(t, clip_norm)
tf.clip_by_average_norm(t, clip_norm)
tf.clip_by_global_norm(t_list, clip_norm)

方法五:权重初始化,上面提到如果初始化权重太大,进过多层的连续相乘回传到输入端会造成梯度值指数级增长,所以选择合适的初始化方式至关重要,一般推荐He初始化和Xavier初始化方式。也可以采用预训练的方法先寻找局部最优,站在一个较好的位置进行微调。如果使用 relu,推荐采用 he_initialization, 即 tf.contrib.layers.variance_scaling_initializer( ),在 relu 网络中,假定每一层有一半的神经元被激活,另一半为 0 ,所以要保持 variance 不变,只需要在 xavier 的基础上再除以 2 。如果激活函数使用 sigmoid 和 tanh,则最好使用 xavier initialization, 即 tf.contrib.layers.xavier_initializer_conv2d( ),保持输入和输出的方差一致,避免了所有的输出值都趋向于0。

方法六:权重正则化,检查网络权重的大小,并惩罚产生较大权重值的损失函数。该过程被称为权重正则化,通常使用的是 L1 惩罚项(权重绝对值)或 L2 惩罚项(权重平方)。TF代码如下:

l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in tf.trainable_variables() if 'weights' in var.name])

方法七:长短时记忆网络,在循环神经网络中,梯度爆炸的发生可能是因为某种网络的训练本身就存在不稳定性,如随时间的反向传播本质上将循环网络转换成深度多层感知机神经网络。使用长短期记忆(LSTM)单元和相关的门类型神经元结构可以减少梯度爆炸问题。采用 LSTM 单元是适合循环神经网络的序列预测的较好方式。

炼丹手册——梯度弥散和爆炸相关推荐

  1. tensorflow中的梯度弥散与梯度爆炸

    定义:把梯度接近于0的现象称为梯度弥散:把梯度值远大于1的现象称为梯度爆炸. 例子1:梯度爆炸 import tensorflow as tf import numpy as np import ma ...

  2. 深度学习之循环神经网络(6)梯度弥散和梯度爆炸

    深度学习之循环神经网络(6)梯度弥散和梯度爆炸  循环神经网络的训练并不稳定,网络的善妒也不能任意加深.那么,为什么循环神经网络会出现训练困难的问题呢?简单回顾梯度推导中的关键表达式: ∂ht∂hi= ...

  3. 深度学习基础知识---梯度弥散 梯度爆炸

    目录 1 梯度弥散.梯度爆炸的成因 2  解决方式 2.1.pretrain+finetune 2.2 梯度裁剪 2.3 权重正则化 2.5 Batch Normalization正则化 2.6 残差 ...

  4. 梯度弥散与梯度爆炸及其解决方法

    梯度弥散(梯度消失): 通常神经网络所用的激活函数是sigmoid函数,sigmod函数容易引起梯度弥散.这个函数能将负无穷到正无穷的数映射到0和1之间,并且对这个函数求导的结果是f′(x)=f(x) ...

  5. LSTM如何解决梯度消失或爆炸的?

    from:https://zhuanlan.zhihu.com/p/44163528 哪些问题? 梯度消失会导致我们的神经网络中前面层的网络权重无法得到更新,也就停止了学习. 梯度爆炸会使得学习不稳定 ...

  6. ztree在刷新时第一个父节点消失_从反向传播推导到梯度消失and爆炸的原因及解决方案(从DNN到RNN,内附详细反向传播公式推导)...

    引言:参加了一家公司的面试和另一家公司的笔试,都问到了这个题!看来很有必要好好准备一下,自己动手推了公式,果然理解更深入了!持续准备面试中... 一. 概述: 想要真正了解梯度爆炸和消失问题,必须手推 ...

  7. alexnet实验偶遇:loss nan, train acc 0.100, test acc 0.100情况,通过bn层加快收敛速度,防止过拟合,防止梯度消失、爆炸

    场景:数据集:官方的fashionminst + 网络:alexnet+pytroch+relu激活函数 源代码:https://zh-v2.d2l.ai/chapter_convolutional- ...

  8. 梯度消失和梯度爆炸_梯度消失、爆炸的原因及解决办法

    一.引入:梯度更新规则 目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过梯度反向传播的方式,更新优化深度网络的权值.这样做是有一定原因的,首先,深层网络由许多非线性层堆叠而来 ...

  9. RNN梯度消失和爆炸的原因 以及 LSTM如何解决梯度消失问题

    RNN梯度消失和爆炸的原因 经典的RNN结构如下图所示: 假设我们的时间序列只有三段,  为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下: 假设在t=3时刻,损失函数为  . 则对于一 ...

最新文章

  1. 前端使用 Nginx 反向代理彻底解决跨域问题
  2. java 字符串用法_java中字符串的用法
  3. Python3基础-分数运算
  4. Spring注入方法
  5. pycharm python3区别_1.安装Python3和PyCharm
  6. iOS下数据存储的方式
  7. @Scheduled(cron=) spring定时任务时间设置
  8. Linux设备驱动中的异步通知
  9. 猿人学第二题,手撕OB混淆给你看(Step1-开篇)
  10. 74HC595使用方法
  11. 笔记本连不上网(IPV4和IPV6无网络访问权限)解决方法
  12. 3ds Max小白入门小案例|旋转楼梯
  13. 百胜中国拟2025年前开1000家Lavazza咖啡店​;别样肉客开始陆续进驻中国山姆会员商店 | 知消...
  14. 随便画一张,奥古斯丁的世界观 及 lambda
  15. 天使投资人刘峻:腾讯的七条命 |捕手志
  16. 3D游戏之父--John Carmack连载系列(四)
  17. 一文搞懂什么是遗传算法Genetic Algorithm【附应用举例】
  18. 如何在“运行”里打开软件
  19. 渲染吃显卡还是CPU,如何高效3D渲染?
  20. python语音转文字库_有没有语音转文字的APP?

热门文章

  1. 牛客网暑期ACM多校训练营(第三场): A. Ternary String(欧拉降幂+递推)
  2. Codeforces Round #164 (Div. 2):B. Buttons
  3. bzoj 1609: [Usaco2008 Feb]Eating Together麻烦的聚餐(DP)
  4. opencv kmeans聚类 实现图像色彩量化
  5. javascript手机号码、电子邮件正则表达式 一种解决方案
  6. Multisium里如何使用多个不同的VCC
  7. 获取虚拟账号列表失败啥意思_「图」Windows 10 Build 18963发布:可显GPU温度 支持重命名虚拟桌面...
  8. Mysql中外键的 Cascade ,NO ACTION ,Restrict ,SET NULL
  9. 第二次作业——Python基础和软件工程
  10. Vue、J2ee - 001 : Vue项目的创建过程