点击上方“MLNLP”,选择“星标”公众号

重磅干货,第一时间送达

作者丨CV路上一名研究僧

知乎专栏丨深度图像与视频增强

地址丨https://zhuanlan.zhihu.com/p/79046709

0. 遇到大坑

笔者在最近的项目中用到了自定义loss函数,代码一切都准备就绪后,在训练时遇到了梯度爆炸的问题,每次训练几个iterations后,梯度和loss都会变为nan。一般情况下,梯度变为nan都是出现了  ,  等情况,导致结果变为+inf,也就成了nan。

1. 问题分析

笔者需要的loss函数如下:

其中,  。

从理论上分析,这个loss函数在反向传播过程中很可能会遇到梯度爆炸,这是为什么呢?反向传播的过程是对loss链式求一阶导数的过程,那么,  的导数为:

由于  ,这个导数又可以表示为:

这样的话,出现了类似于  的表达式,也就会出现典型的$0/1$问题了。为了避免这个问题,首先进行了如下的  改变:

经过改变,在$x_i=0$时,不再是  问题了,而是转换为了一个线性函数,梯度成为了恒定的12.9,从理论上来看,避免了梯度爆炸的问题。

2. PyTorch初步实现

在实现这一过程时,依旧...遇到了大坑,下面通过示例代码来说明:

"""        loss = mse(X, gamma_inv(X))        """        def loss_function(x):        mask = (x < 0.003).float()        gamma_x = mask * 12.9 * x + (1-mask) * (x ** 0.5)        loss = torch.mean((x - gamma_x) ** 2)return lossif __name__ == '__main__':        x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)        loss = loss_function(x)print('loss:', loss)        loss.backward()print(x.grad)

改进后的  是一个分支结构,在实现时,就采用了类似于Matlab中矩阵计算的mask方式,mask定义为  ,满足条件的$x_i$在mask中对应位置的值为1,因此,  的结构只会保留  的结果,同样的道理,  就实现了上述改进后的  公式。

按理来说,此时,在反向传播过程中的梯度应该是正确的,但是,上面代码的输出结果为:

loss: tensor(0.0105, grad_fn=)tensor([    nan,  0.1416, -0.0243, -0.0167,  0.0000])

emmm....依旧为nan,问题在理论层面得到了解决,但是,在实现层面依旧没能解决.....

3. 源码调试分析

上面源码的问题依旧在  的实现,这个过程,在Python解释器解释的过程或许是这样的:

  1. 计算  ,对mask进行广播式的乘法,结果为:原本为1的位置变为了12.9,原本为0的位置依旧为0;

  2. 将1.的结果继续与x相乘,本质上仍然是与x的每个元素相乘,只是mask中不满足条件的  位置为0,表现出的结果是仅对满足条件的  进行了计算;

  3. 按照2.所述的原理,  公式的后半部分也是同样的计算过程,即,  中的每个值依旧会进行  的计算;

按照上述过程进行前向传播,在反向传播时,梯度不是从某一个分支得到的,而是两个分支的题目相加得到的,换句话说,依旧没能解决梯度变为nan的问题。

4. 源码改进及问题解决

经过第三部分的分析,知道了梯度变为nan的根本原因是当  时依旧参与了  的计算,导致在反向传播时计算出的梯度为nan。

要解决这个问题,就要保证在  时不会进行这样的计算。

新的PyTorch代码如下:

def loss_function(x):    mask = x < 0.003    gamma_x = torch.FloatTensor(x.size()).type_as(x)    gamma_x[mask] = 12.9 * x[mask]    mask = x >= 0.003    gamma_x[mask] = x[mask] ** 0.5    loss = torch.mean((x - gamma_x) ** 2)return lossif __name__ == '__main__':    x = Variable(torch.FloatTensor([0, 0.0025, 0.5, 0.8, 1]), requires_grad=True)    loss = loss_function(x)print('loss:', loss)    loss.backward()print(x.grad)

改变的地方位于`loss_function`,改变了对于  分支的处理方式,控制并保住每次计算仅有满足条件的值可以参与。此时输出为:

loss: tensor(0.0105, grad_fn=)tensor([ 0.0000,  0.1416, -0.0243, -0.0167,  0.0000])

就此,问题解决!

如有疑问,欢迎留言~

推荐阅读:

实战 | Pytorch BiLSTM + CRF做NER

如何评价Word2Vec作者提出的fastText算法?深度学习是否在文本分类等简单任务上没有优势?

从Word2Vec到Bert,聊聊词向量的前世今生(一)

bert pytorch源码_【PyTorch】梯度爆炸、loss在反向传播变为nan相关推荐

  1. Transformer-XL解读(论文 + PyTorch源码)

    前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...

  2. ELMo解读(论文 + PyTorch源码)

    ELMo的概念也是很早就出了,应该是18年初的事情了.但我仍然是后知后觉,居然还是等BERT出来很久之后,才知道有这么个东西.这两天才仔细看了下论文和源码,在这里做一些记录,如果有不详实的地方,欢迎指 ...

  3. [源码解析] PyTorch 流水线并行实现 (1)--基础知识

    [源码解析] PyTorch 流水线并行实现 (1)–基础知识 文章目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 ...

  4. 基于Pytorch源码对SGD、momentum、Nesterov学习

    目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...

  5. [源码解析] PyTorch 分布式(2) ----- DataParallel(上)

    [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 文章目录 [源码解析] PyTorch 分布式(2) ----- DataParallel(上) 0x00 摘要 ...

  6. [源码解析] PyTorch 流水线并行实现 (6)--并行计算

    [源码解析] PyTorch 流水线并行实现 (6)–并行计算 文章目录 [源码解析] PyTorch 流水线并行实现 (6)--并行计算 0x00 摘要 0x01 总体架构 1.1 使用 1.2 前 ...

  7. [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

    [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎 文章目录 [源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎 0x00 摘要 0x01 前文回顾 1.1 ...

  8. [源码解析] PyTorch分布式优化器(1)----基石篇

    [源码解析] PyTorch分布式优化器(1)----基石篇 文章目录 [源码解析] PyTorch分布式优化器(1)----基石篇 0x00 摘要 0x01 从问题出发 1.1 示例 1.2 问题点 ...

  9. pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)

    写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...

最新文章

  1. 蕨叶形生物刷新生命史,动物界至少起源于5.7亿年前
  2. 【Scratch】青少年蓝桥杯_每日一题_3.17_蹦床
  3. Android和.NET通用的AES算法 (转) 好东东 收藏一下
  4. VTK:轮廓 Glow Pass用法实战
  5. 解决虚拟机能ping通宿主机,而宿主机不能ping通虚拟机
  6. it编年史_Java的编年史和低延迟
  7. ux设计师怎样找同类产品_没有预算? 别找借口。 便宜的UX上的UX 2:让我们开始构建。...
  8. Android AsyncTask 详解及注意事项
  9. 再记AE与AO的区别与联系
  10. 给char赋超过范围的值会发生什么
  11. 逍遥模拟器安装激活面具magisk教程
  12. 新版jadx-gui导入dex会提示Bad checksum
  13. 人社部:全力支持创业和灵活就业
  14. ue编辑器c语言语法高亮文件,自己动手做 UEStudio/UltraEdit 的语法高亮文件 (*.uew)...
  15. TOOD: Task-aligned One-stage Object Detection 原理与代码解析
  16. Linux指令--traceroute,netstat,ss
  17. 如何添加51la代码及隐藏统计图标
  18. 角谷猜想(次数+过程)
  19. 第七十七篇:车辆安全-车载软件C++语言开发指南(AUTOSAR C++)
  20. [转载]【转】ArcGIS 10安装方法(对比流行的2种安装方法)||迅雷电驴下载

热门文章

  1. T-PAMI 2021 | 换个损失函数就能实现数据扩增?
  2. 深度学习检测小目标常用方法
  3. “考研3次,读博7年,英语极烂”,他却做出诺奖级成果
  4. 2008年上半年 网络工程师 上下午试卷【附带答案】
  5. Linux 之Cut命令详解
  6. Python之Scrapy爬虫的常用命令
  7. python学习 爬取亚马逊网页,失败后。修改HTTP报文头部后成功!
  8. AOI光学自动检测技术 | 基本原理与设备构成
  9. 拒绝遗忘:高效的动态规划算法
  10. 第七篇:并发-恢复机制