6.2 梯度爆炸实验

造成简单循环网络较难建模长程依赖问题的原因有两个:梯度爆炸和梯度消失。一般来讲,循环网络的梯度爆炸问题比较容易解决,一般通过权重衰减或梯度截断可以较好地来避免;对于梯度消失问题,更加有效的方式是改变模型,比如通过长短期记忆网络LSTM来进行缓解。

本节将首先进行复现简单循环网络中的梯度爆炸问题,然后尝试使用梯度截断的方式进行解决。这里采用长度为20的数据集进行实验,训练过程中将进行输出 ,和的梯度向量的范数,以此来衡量梯度的变化情况。

6.2.1 梯度打印函数

使用custom_print_log实现了在训练过程中打印梯度的功能,custom_print_log需要接收runner的实例,并通过model.named_parameters()获取该模型中的参数名和参数值. 这里我们分别定义W_list, U_list和b_list,用于分别存储训练过程中参数的梯度范数。

import torchW_list = []
U_list = []
b_list = []# 计算梯度范数
def custom_print_log(runner):model = runner.modelW_grad_l2, U_grad_l2, b_grad_l2 = 0, 0, 0for name, param in model.named_parameters():if name == "rnn_model.W":W_grad_l2 = torch.norm(param.grad, p=2).numpy()if name == "rnn_model.U":U_grad_l2 = torch.norm(param.grad, p=2).numpy()if name == "rnn_model.b":b_grad_l2 = torch.norm(param.grad, p=2).numpy()print(f"[Training] W_grad_l2: {W_grad_l2:.5f}, U_grad_l2: {U_grad_l2:.5f}, b_grad_l2: {b_grad_l2:.5f} ")W_list.append(W_grad_l2)U_list.append(U_grad_l2)b_list.append(b_grad_l2)

【思考】什么是范数,什么是L2范数,这里为什么要打印梯度范数?

范数,是具有“距离”概念的函数。我们知道距离的定义是一个宽泛的概念,只要满足非负、自反、三角不等式就可以称之为距离。范数是一种强化了的距离概念,它在定义上比距离多了一条数乘的运算法则。有时候为了便于理解,我们可以把范数当作距离来理解。

在数学上,范数包括向量范数和矩阵范数,向量范数表征向量空间中向量的大小,矩阵范数表征矩阵引起变化的大小。一种非严密的解释就是,对应向量范数,向量空间中的向量都是有大小的,这个大小如何度量,就是用范数来度量的,不同的范数都可以来度量这个大小,就好比米和尺都可以来度量远近一样;对于矩阵范数,学过线性代数,我们知道,通过运算AX=B,可以将向量X变化为B,矩阵范数就是来度量这个变化大小的。

L2范数
L2范数是我们最常见最常用的范数了,我们用的最多的度量距离欧氏距离就是一种L2范数,它的定义如下:

对于L2范数,它的优化问题如下:

L2范数通常会被用来做优化目标函数的正则化项,防止模型为了迎合训练集而过于复杂造成过拟合的情况,从而提高模型的泛化能力

为什么要打印梯度范数:

函数在某一点处的方向导数在其梯度方向上达到最大值,此最大值即梯度的范数。 而模型的学习过程是通过使用训练数据来最小化损失函数,从而确定参数的值。而最小化损失函数,即通过求导求损失函数的极值。打印梯度范数值可以帮助我们更直观地了解模型训练情况的好坏,梯度过大或过小都有可能导致模型的训练效果变差,因此打印梯度范数有利于我们更快地对模型作出修改。

6.2.2 复现梯度爆炸现象

为了更好地复现梯度爆炸问题,使用SGD优化器将批大小和学习率调大,学习率为0.2,同时在计算交叉熵损失时,将reduction设置为sum,表示将损失进行累加。 代码实现如下:

import os
import random
import torch
import numpy as npnp.random.seed(0)
random.seed(0)
torch.manual_seed(0)# 训练轮次
num_epochs = 50
# 学习率
lr = 0.2
# 输入数字的类别数
num_digits = 10
# 将数字映射为向量的维度
input_size = 32
# 隐状态向量的维度
hidden_size = 32
# 预测数字的类别数
num_classes = 19
# 批大小
batch_size = 64
# 模型保存目录
save_dir = "./checkpoints"# 可以设置不同的length进行不同长度数据的预测实验
length = 20
print(f"\n====> Training SRN with data of length {length}.")# 加载长度为length的数据
data_path = f"D:/datasets/{length}"
train_examples, dev_examples, test_examples = load_data(data_path)
train_set, dev_set, test_set = DigitSumDataset(train_examples), DigitSumDataset(dev_examples),DigitSumDataset(test_examples)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size)
# 实例化模型
base_model = SRN(input_size, hidden_size)
model = Model_RNN4SeqClass(base_model, num_digits, input_size, hidden_size, num_classes)
# 指定优化器
optimizer = torch.optim.SGD(model.parameters(),lr)
# 定义评价指标
metric = Accuracy()
# 定义损失函数
loss_fn = nn.CrossEntropyLoss(reduction="sum")# 基于以上组件,实例化Runner
runner = RunnerV3(model, optimizer, loss_fn, metric)# 进行模型训练
model_save_path = os.path.join(save_dir, f"srn_explosion_model_{length}.pdparams")
runner.train(train_loader, dev_loader, num_epochs=num_epochs, eval_steps=100, log_steps=1,save_path=model_save_path, custom_print_log=custom_print_log)

接下来,可以获取训练过程中关于参数梯度的L2范数,并将其绘制为图片以便展示,相应代码如下:

import matplotlib.pyplot as plt
def plot_grad(W_list, U_list, b_list, save_path, keep_steps=40):# 开始绘制图片plt.figure()# 默认保留前40步的结果steps = list(range(keep_steps))plt.plot(steps, W_list[:keep_steps], "r-", color="#e4007f", label="W_grad_l2")plt.plot(steps, U_list[:keep_steps], "-.", color="#f19ec2", label="U_grad_l2")plt.plot(steps, b_list[:keep_steps], "--", color="#000000", label="b_grad_l2")plt.xlabel("step")plt.ylabel("L2 Norm")plt.legend(loc="upper right")plt.show()plt.savefig(save_path)print("image has been saved to: ", save_path)save_path = f"./images/6.8.pdf"
plot_grad(W_list, U_list, b_list, save_path)

此图展示了在训练过程中关于 ,和 参数梯度的L2范数,可以看到经过学习率等方式的调整,梯度范数急剧变大,而后梯度范数几乎为0. 这是因为TanhTanh为SigmoidSigmoid型函数,其饱和区的导数接近于0,由于梯度的急剧变化,参数数值变的较大或较小,容易落入梯度饱和区,导致梯度为0,模型很难继续训练。

接下来,使用该模型在测试集上进行测试。

print(f"Evaluate SRN with data length {length}.")
# 加载训练过程中效果最好的模型
model_path = os.path.join(save_dir, "srn_explosion_model_20.pdparams")
torch.load(model_path)# 使用测试集评价模型,获取测试集上的预测准确率
score, _ = runner.evaluate(test_loader)
print(f"[SRN] length:{length}, Score: {score: .5f}")

6.2.3 使用梯度截断解决梯度爆炸问题

梯度截断是一种可以有效解决梯度爆炸问题的启发式方法,当梯度的模大于一定阈值时,就将它截断成为一个较小的数。一般有两种截断方式:按值截断和按模截断.本实验使用按模截断的方式解决梯度爆炸问题。

按模截断是按照梯度向量的模进行截断,保证梯度向量的模值不大于阈值b,裁剪后的梯度为:

当梯度向量的模不大于阈值时,数值不变,否则对进行数值缩放。

问:在飞桨中,可以使用paddle.nn.ClipGradByNorm进行按模截断.--- pytorch中用什么?

 nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=20, norm_type=2)

在引入梯度截断之后,将重新观察模型的训练情况。这里我们重新实例化一下:模型和优化器,然后组装runner,进行训练。代码实现如下:

# 清空梯度列表
W_list.clear()
U_list.clear()
b_list.clear()
# 实例化模型
base_model = SRN(input_size, hidden_size)
model = Model_RNN4SeqClass(base_model, num_digits, input_size, hidden_size, num_classes)# 定义clip,并实例化优化器optimizer = torch.optim.SGD(lr=lr, params=model.parameters())
# 定义评价指标
metric = Accuracy()
# 定义损失函数
loss_fn = nn.CrossEntropyLoss(reduction="sum")# 实例化Runner
runner = RunnerV3(model, optimizer, loss_fn, metric)# 训练模型
model_save_path = os.path.join(save_dir, f"srn_fix_explosion_model_{length}.pdparams")
runner.train(train_loader, dev_loader, num_epochs=num_epochs, eval_steps=100, log_steps=1, save_path=model_save_path, custom_print_log=custom_print_log)
# 进行模型训练
model_save_path = os.path.join(save_dir, f"srn_explosion_model_{length}.pdparams")

在引入梯度截断后,获取训练过程中关于参数梯度的L2范数,并将其绘制为图片以便展示,相应代码如下:

save_path = f"./images/6.9.pdf"
plot_grad(W_list, U_list, b_list, save_path, keep_steps=100)

展示了引入按模截断的策略之后,模型训练时参数梯度的变化情况。可以看到,随着迭代步骤的进行,梯度始终保持在一个有值的状态,表明按模截断能够很好地解决梯度爆炸的问题.

接下来,使用梯度截断策略的模型在测试集上进行测试。

print(f"Evaluate SRN with data length {length}.")# 加载训练过程中效果最好的模型
model_path = os.path.join(save_dir, f"srn_fix_explosion_model_{length}.pdparams")
runner.load_model(model_path)# 使用测试集评价模型,获取测试集上的预测准确率
score, _ = runner.evaluate(test_loader)
print(f"[SRN] length:{length}, Score: {score: .5f}")

由于为复现梯度爆炸现象,改变了学习率,优化器等,因此准确率相对比较低。但由于采用梯度截断策略后,在后续训练过程中,模型参数能够被更新优化,因此准确率有一定的提升。

【思考题】梯度截断解决梯度爆炸问题的原理是什么?

由于梯度太大会产生梯度爆炸的现象,太小会产生梯度消失的现象(参数不更新),所以为梯度提供一个范围[a,b],

  • 如果梯度大于b,就把它设置为b;
  • 如果梯度小于a,就把它设置为a;
  • 若在此区间,不做变化

梯度裁剪确保了梯度矢量的最大范数。即使在模型的损失函数不规则时,这一技巧也有助于梯度下降保持合理的行为。下面的图片展示了损失函数的陡崖。不采用裁剪,参数将会沿着梯度下降方向剧烈变化,导致其离开了最小值范围;而使用裁剪后参数变化将被限制在一个合理范围内,避免了上面的情况。

总结:

本次实验实现了梯度爆炸的复现,然后使用梯度截断的方式进行解决。梯度范数急剧变大,而后梯度范数几乎为0,这是由于梯度的急剧变化,参数数值变的较大或较小,容易落入梯度饱和区,导致梯度为0,模型很难继续训练。

参考:

NNDL 实验七 循环神经网络(2)梯度爆炸实验_HBU_David的博客-CSDN博客_什么是范数,什么是l2范数,这里为什么要打印梯度范数

什么是梯度裁剪_鸾镜朱颜暗换的博客-CSDN博客_梯度裁剪

范数(norm) 几种范数的简单介绍_Norstc的博客-CSDN博客_平方范数

NNDL 实验七 循环神经网络(2)梯度爆炸实验相关推荐

  1. NNDL 实验七 循环神经网络(1)RNN记忆能力实验

    NNDL 实验七 循环神经网络(1)RNN记忆能力实验 第6章 循环神经网络 6.1 循环神经网络的记忆能力实验 6.1.1 数据集构建 6.1.1.1 数据集的构建函数 6.1.1.2 加载数据并进 ...

  2. 深度学习 实验七 循环神经网络

    文章目录 深度学习 实验七 循环神经网络 一.问题描述 二.设计简要描述 三.程序清单 深度学习 实验七 循环神经网络 一.问题描述 之前见过的所以神经网络(比如全连接网络和卷积神经网络)都有一个主要 ...

  3. 循环神经网络中梯度爆炸的原因

    循环神经网络中梯度爆炸的原因 对于循环神经网络,要在很长时间序列的各个时刻重复应用相同的操作来构建非常深的计算图,并且模型的参数是共享的,所以使得梯度爆炸或者梯度消失的问题更加明显. 假设某个计算图中 ...

  4. NNDL 实验七 循环神经网络(4)基于双向LSTM的文本分类

    6.4 实践:基于双向LSTM模型完成文本分类任务 电影评论可以蕴含丰富的情感:比如喜欢.讨厌.等等. 情感分析(Sentiment Analysis)是为一个文本分类问题,即使用判定给定的一段文本信 ...

  5. NNDL 实验七 循环神经网络(3)LSTM的记忆能力实验

    文章目录 前言 一.6.3 LSTM的记忆能力实验 6.3.1 模型构建 6.3.1.1 LSTM层 6.3.1.2 模型汇总 6.3.2 模型训练 6.3.2.1 训练指定长度的数字预测模型 6.3 ...

  6. 1.8 循环神经网络的梯度消失-深度学习第五课《序列模型》-Stanford吴恩达教授

    ←上一篇 ↓↑ 下一篇→ 1.7 对新序列采样 回到目录 1.9 GRU 单元 循环神经网络的梯度消失 (Vanishing Gradient with RNNs) 你已经了解了RNN时如何工作的了, ...

  7. 循环神经网络——裁剪梯度(应对梯度爆炸)

    循环神经网络中比较容易出现梯度衰减或梯度爆炸,为了应对梯度爆炸,可以进行裁剪梯度.假设把所有模型参数梯度的元素拼接成一个向量g,并设裁剪的阈值是θ\thetaθ.裁剪后的梯度min(θ∣∣g∣∣,1) ...

  8. matlab 实验七 低层绘图操作,matlab实验内容答案

    实验报实验报告告说说明 明 matlab 课课程程实验实验需撰写需撰写 8 个个实验报实验报告 每个告 每个实验报实验报告内容写每次告内容写每次 实验实验内容中内容中标标号呈黑体大号字号呈黑体大号字显 ...

  9. SCAU华南农业大学-数电实验-七进制同步加法计数器-实验报告

    一.Purpose 1.利用数字电路的知识,用74LS73或74LS74(即D触发器或JK触发器)和各种逻辑门实现七进制同步加法计数器. 2.锻炼实验操作技能,使之更熟练. 二.Devices Equ ...

最新文章

  1. nodejs 开发,手把手开始第一个服务器程序(原生)
  2. 死磕java_死磕JavaScript-垃圾收集机制
  3. basename php 中文,php basename不支持中文怎么办
  4. 天寒宜早睡,梦醒闻雪声,倒计时83
  5. 【使用R语言两行语句将搜狗词库转为csv格式】
  6. javasocket编程(javasocket通信)
  7. 登录注册的业务逻辑流程梳理
  8. php加入到jpg,PHP如何将PNG转换成JPG?
  9. 2015年河南省省赛部分题题解
  10. 谷粒商城-商城业务-检索服务
  11. Pthon画皮卡丘源码
  12. FireFox浏览器的about:config
  13. 2020-03-12-脑电分析之线性与非线性变换
  14. img标签src引入svg如何修改颜色
  15. Python可视化分析疫情数据
  16. 【Android】Android App打开手机QQ、微信等应用
  17. 元宇宙链接现实与虚拟 IPFS扮演着怎样的角色?
  18. 时空幻境的体验分析:基于机制
  19. 如何更好的管理图片文件
  20. 遥感图像处理基本操作——遥感图像导入、导出、添加波段、添加删除通道、裁剪

热门文章

  1. 高桥盾react和boost_热门对比丨React pk Boost,你更偏向哪一双?
  2. 素描 山_60秒内素描镜子
  3. c语言计算两个整数的乘积
  4. Java编程那些事儿74——java.lang包介绍1
  5. 漫画:二叉树系列 第五讲(BST的删除)
  6. Java使用itextpdf生成PDF文件,用浏览器下载
  7. N9K配置Vxlan
  8. 「分布式架构」最终一致性:反熵
  9. jQuery MiniUI 快速入门:Hollo, world!(二)_nikofan-ChinaUnix博客
  10. Spark SQL数据通用保存数据_大数据培训