原标题:PyTorch常见的12坑

1. nn.Module.cuda 和 Tensor.cuda 的作用效果差异

无论是对于模型还是数据,cuda函数都能实现从CPU到GPU的内存迁移,但是他们的作用效果有所不同。

对于nn.Module:

上面两句能够达到一样的效果,即对model自身进行的内存迁移。

对于Tensor:

和nn.Module不同,调用tensor.cuda只是返回这个tensor对象在GPU内存上的拷贝,而不会对自身进行改变。因此必须对tensor进行重新赋值,即tensor=tensor.cuda.

例子:

2. PyTorch 0.4 计算累积损失的不同

以广泛使用的模式total_loss += loss.data[0]为例。Python0.4.0之前,loss是一个封装了(1,)张量的Variable,但Python0.4.0的loss现在是一个零维的标量。对标量进行 索引是没有意义的(似乎会报 invalid index to scalar variable 的错误)。使用loss.item可以从标量中获取Python数字。所以改为:

如果在累加损失时未将其转换为Python数字,则可能出现程序内存使用量增加的情况。这是因为上面表达式的右侧原本是一个Python浮点数,而它现在是一个零维张量。因此,总损失累加了张量和它们的梯度历史,这可能会产生很大的autograd 图,耗费内存和计算资源。

3. PyTorch 0.4 编写不限制设备的代码

4. torch. Tensor.detach的使用

detach的官方说明如下:

假设有模型A和模型B,我们需要将A的输出作为B的输入,但训练时我们只训练模型B. 那么可以这样做:

input_B = output_A.detach

它可以使两个计算图的梯度传递断开,从而实现我们所需的功能。

5. ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm)

出现这个错误的情况是,在服务器上的docker中运行训练代码时,batch size设置得过大,shared memory不够(因为docker限制了shm).解决方法是,将Dataloader的num_workers设置为0.

6. pytorch中loss函数的参数设置

以CrossEntropyLoss为例:

若 reduce = False,那么 size_average 参数失效,直接返回向量形式的 loss,即batch中每个元素对应的loss.

若 reduce = True,那么 loss 返回的是标量:

如果 size_average = True,返回 loss.mean.

如果 size_average = False,返回 loss.sum.

weight : 输入一个1D的权值向量,为各个类别的loss加权,如下公式所示:

ignore_index : 选择要忽视的目标值,使其对输入梯度不作贡献。如果 size_average = True,那么只计算不被忽视的目标的loss的均值。

reduction : 可选的参数有:‘none’ | ‘elementwise_mean’ | ‘sum’, 正如参数的字面意思,不解释。7. pytorch的可重复性问题 8. 多GPU的处理机制

使用多GPU时,应该记住pytorch的处理逻辑是:

1.在各个GPU上初始化模型。

2.前向传播时,把batch分配到各个GPU上进行计算。

3.得到的输出在主GPU上进行汇总,计算loss并反向传播,更新主GPU上的权值。

4.把主GPU上的模型复制到其它GPU上。

9. num_batches_tracked参数

今天读取模型参数时出现了错误

大概可以看出,这个参数和训练时的归一化的计算方式有关。

因此,我们可以知道该错误是由于训练和测试所用的pytorch版本(0.4.1版本前后的差异)不一致引起的。具体的解决方案是:如果是模型参数(Orderdict格式,很容易修改)里少了num_batches_tracked变量,就加上去,如果是多了就删掉。偷懒的做法是将load_state_dict的strict参数置为False,如下所示:

还看到有人直接修改pytorch 0.4.1的源代码把num_batches_tracked参数删掉的,这就非常不建议了。

10. 训练时损失出现nan的问题

最近在训练模型时出现了损失为nan的情况,发现是个大坑。暂时先记录着。

可能导致梯度出现nan的三个原因:

1. 梯度爆炸。也就是说梯度数值超出范围变成nan. 通常可以调小学习率、加BN层或者做梯度裁剪来试试看有没有解决。

2. 损失函数或者网络设计。比方说,出现了除0,或者出现一些边界情况导致函数不可导,比方说log(0)、sqrt(0).

3. 脏数据。可以事先对输入数据进行判断看看是否存在nan.

补充一下nan数据的判断方法:

注意!像nan或者inf这样的数值不能使用 == 或者 is 来判断!为了安全起见统一使用 math.isnan 或者 numpy.isnan 吧。

例如:

raiseValueError('Expected more than 1 value per channel when training, got input size {}'.format(size))

没有什么特别好的解决办法,在训练前用 num_of_samples % batch_size 算一下会不会正好剩下一个样本。

可以考虑将`DataLoader`的`drop_last`选项设为True,这样的话,当最后一个batch凑不满时,就会舍弃掉。

12. 优化器的weight_decay项导致的隐蔽bug

我们都知道weight_decay指的是权值衰减,即在原损失的基础上加上一个L2惩罚项,使得模型趋向于选择更小的权重参数,起到正则化的效果。但是我经常会忽略掉这一项的存在,从而引发了意想不到的问题。

这次的坑是这样的,在训练一个ResNet50的时候,网络的高层部分layer4暂时没有用到,因此也并不会有梯度回传,于是我就放心地将ResNet50的所有参数都传递给Optimizer进行更新了,想着layer4应该能保持原来的权重不变才对。 但是实际上,尽管layer4没有梯度回传,但是weight_decay的作用仍然存在,它使得layer4权值越来越小,趋向于0。后面需要用到layer4的时候,发现输出异常(接近于0),才注意到这个问题的存在。

虽然这样的情况可能不容易遇到,但是还是要谨慎:暂时不需要更新的权值,一定不要传递给Optimizer,避免不必要的麻烦。

13. 优化TensorDataset的数据加载速度

` TensorDataset `提供了已经完全加载到内存中的矩阵的数据读取接口。在使用`TensorDataset`的时候,如果直接用`DataLoader`,会导致数据加载速度非常缓慢,严重拖慢训练速度,分析和解决方案详见https://huangbiubiu.github.io/2019/BEST-PRACTICE-PyTorch-TensorDataset/ 返回搜狐,查看更多

责任编辑:

pytorch统计矩阵非0的个数_PyTorch常见的12坑相关推荐

  1. pytorch统计矩阵非0的个数_矩阵的三种存储方式---三元组法 行逻辑链接法 十字链表法...

    在介绍矩阵的压缩存储前,我们需要明确一个概念:对于特殊矩阵,比如对称矩阵,稀疏矩阵,上(下)三角矩阵,在数据结构中相同的数据元素只存储一个. @[TOC] 三元组顺序表 稀疏矩阵由于其自身的稀疏特性, ...

  2. pytorch统计矩阵非0的个数_计算TensorFlow中非零元素的个数

    tf.count_nonzerocount_nonzero( input_tensor, axis=None, keep_dims=False, dtype=tf.int64, name=None, ...

  3. vba 自定义function返回值_用vba解决excel如何求前面连续为0的个数

    领导布置了任务,要求每天统计当月的发展量,并且统计有多少业务员最多连续多少天发展为0的情况,统计的表格是这样的. 用几行简单的VBA语言就能解决这个问题 首先打开vba编辑窗口,点击开发工具--vis ...

  4. 滴滴2017校园招聘编程题——阶乘末尾0的个数

    1.题目如下图所示: 2.分析:         这个题目描述的很简单,思路看似也很清晰,我们第一想到的肯定就是正常计算和统计--先计算N!阶乘的结果,然后统计结果末尾0的个数.看似这是一个很好的也很 ...

  5. Java黑皮书课后题第5章:*5.1(统计正数和负数的个数然后计算这些数的平均值)编写程序,读入未指定个数的整数,判断读入的正数有多少个、负数有多少个,然后计算输入值的总和和平均值(不记0,浮点表示)

    *5.1(统计正数和负数的个数然后计算这些数的平均值)编写程序,读入未指定个数的整数,判断读入的正数有多少个.负数有多少个,然后计算输入值的总和和平均值(不记0,平均值使用浮点表示) 题目 题目概述 ...

  6. matlab 如何统计矩阵中大于、小于或等于某一值的位置、个数

    善用 find() 和 sum() 函数! 假设存在如下矩阵 a : >> a = [1,2,3;4,5,6;7,8,9] a = 1     2     3      4     5   ...

  7. n阶非零矩阵AB,矩阵AB=0,则A和B的秩都小于n

    设A,B是两个n阶非零方阵,且AB=0,则A和B的秩都小于n 证明: ∵AB=0∵AB=0 ∵AB=0 ∴B的每一个列向量(每一列)都是方程AX=0的解∴B的每一个列向量(每一列)都是方程AX=0的解 ...

  8. python字符串大写字母个数_【python实例】统计字符串里大写字母,小写字母的个数和非字母的个数...

    """ 给定一个以下字符串:统计大写字母的个数,小写字母的个数,非字母的个数. str1 = "ajdkkKDKEK1343KFKiriromfkfKKRIOW ...

  9. Python_计算一个数的阶乘并统计尾部0的个数

    def getNum(num):# 计算阶乘result = 1for i in range(1, num+1):result *= i# 统计尾部0的个数str1 = str(result)str2 ...

最新文章

  1. python gui界面设置数据储存在哪里_我整理的一些常用Python库!让你快速记住这些库的用法!建议收藏...
  2. CvMat,Mat和IplImage之间的转化和拷贝
  3. [转载]C# ListT的并集、交集、差集
  4. http接口测试工具——RESTClient
  5. ABAP--关于重复行的处理
  6. C# 实验四 获取系统时间、点击加一秒功能
  7. dota2比分网_红黑电竞比分横空出世 LPL夏季赛火热进行
  8. OllyDBG 入门之四--破解常用断点设
  9. Oracle运行set autotrace on报错SP2-0618、SP2-0611
  10. python的基本数据结构_python学习笔记-基本数据结构
  11. AngularJs中,如何在render完成之后,执行Js脚本
  12. 用Python计算最长公共子序列和最长公共子串
  13. java工厂到接口_Java基础——接口简单工厂
  14. FFmpeg4.3.2之ffplay log输出级别(三十)
  15. mysql 主主_MySQL双主(主主)架构
  16. 卡方线性趋势检验_趋势卡方检验
  17. HTML5七夕情人节表白网页制作【花瓣图片表白】HTML+CSS+JavaScript html生日快乐祝福网页制作
  18. 解决近期Windows11更新后无法上网的问题
  19. Google earth engine 入门与简介
  20. MATLAB算法实战应用案例精讲-【数据分析】时序异常检测(补充篇)(附Java、R语言和python代码实现)

热门文章

  1. R语言rename重命名dataframe的列名实战:rename重命名dataframe的列名(写错的列名不会被重命名)
  2. R语言gganimate包创建可视化gif动图、可视化动图:ggplot2可视化静态散点图、gganimate包创建动态散点分面图(facet_wrap)动画基于transition_time函数
  3. stat_count() must not be used with a y aesthetic
  4. 使用Google Page Speed
  5. 如何用c语言打出 * * * * * * * * * * * * *?
  6. Pacbio HiFi技术原理与应用软件实例
  7. 函数重载和 函数模板
  8. sqlalchemy 对 mysql 进行增删改查
  9. python opencv 中bmp转raw格式图片并展示
  10. CNN 图像增强--DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks