前言

在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed

错误信息类似下面打印信息:

/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):File "tutorial.py", line 100, in <module>model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)File "tutorial.py", line 80, in train_modelloss = criterion(outputs, labels)File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__result = self.forward(*input, **kwargs)File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forwardself.weight, self.size_average)File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropyreturn nll_loss(log_softmax(input), target, weight, size_average)File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_lossreturn f(input, target)File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forwardoutput, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83

通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes条件。

t >= 0 && t < n_classes条件

在分类任务中,需要调用torch.nn.functional.cross_entropy()函数求交叉熵,从PyTorch官网可以看到该函数定义:

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

可以注意到有一个key-value是ignore_index=-100。这是在交叉熵计算时被跳过的部分。通常是在数据增强中的填充值。

而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。

代码示例

数据增强部分

import torchvision.transforms.functional as tftf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)

求交叉熵部分

import torch
import torch.nn.functional as F
import torch.nn as nndef cross_entropy2d(input, target, weight=None, reduction='none'):n, c, h, w = input.size()nt, ht, wt = target.size()if h != ht or w != wt:input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)target = target.view(-1)loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)return loss

分析

可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。

其他

在PyTorch框架下,使用无用label值进行填充和处理时,要注意在使用scatter_函数时也需要注意对无用label进行提前处理,否则在使用data.scatter_()时同样也会报类似类别index错误。

labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)

参考资料

[1] torch.nn.functional — PyTorch 1.8.0 documentation
[2] Pytorch里的CrossEntropyLoss详解 - marsggbo - 博客园
[3] RuntimeError: cuda runtime error (59) : device-side assert triggered when running transfer_learning_tutorial · Issue #1204 · pytorch/pytorch
[4] PyTorch 中,nn 与 nn.functional 有什么区别? - 知乎
[5] FaceParsing.PyTorch/augmentations.py at master · TracelessLe/FaceParsing.PyTorch

PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录相关推荐

  1. RuntimeError: Assertion cur_target 0 cur_target n_classes failed

    问题描述 使用pytorch的函数 torch.nn.CrossEntropyLoss()计算Loss时报错: RuntimeError: Assertion `cur_target >= 0 ...

  2. MaskRCNN-Benchmark框架Assertion 't ** 0 t ** n_classes' failed可能的原因

    Mask R-CNN Benchmark是一个完全由PyTorch 1.0写成,快速.模块化的Faster R-CNN和Mask R-CNN组件.该项目旨在让用户更容易地创建一种模块,实现对图片中物品 ...

  3. RuntimeError Assertion cur_target = 0 cur_target n_classes failed

    RuntimeError: Assertion `cur_target >= 0 && cur_target < n_classes' failed criterion = ...

  4. PyTorch中F.cross_entropy()函数

    对PyTorch中F.cross_entropy()的理解 PyTorch提供了求交叉熵的两个常用函数: 一个是F.cross_entropy(), 另一个是F.nll_entropy(), 是对F. ...

  5. Maven 新版本 3.8.1 打包报错 maven-default-http-blocker (http://0.0.0.0/): Blocked mirror for repositories

    Maven 新版本 3.8.1 打包报错 maven-default-http-blocker (http://0.0.0.0/): Blocked mirror for repositories [ ...

  6. easy-mock本地部署成功,访问报错:EADDRNOTAVAIL 0.0.0.0:7300 解决方案

    easy-mock本地部署成功,访问报错:EADDRNOTAVAIL 0.0.0.0:7300 解决方案 参考文章: (1)easy-mock本地部署成功,访问报错:EADDRNOTAVAIL 0.0 ...

  7. clion pycharm 报错 bash: line 0: cd: /xxx: No such file or directory

    现象 : clion 报错 bash: line 0: cd: /tmp/src/cmake-build-debug: No such file or directory pycharm 报错 bas ...

  8. Hadoop报错:All specified directories are failed to load.

    Hadoop报错:All specified directories are failed to load. 先将所有的Hadoop进程停掉,然后清空将所有节点的Hadoop的安装目录下的data目录 ...

  9. linux 修改网卡报错xe,centos修改端口出现Failed to start OpenSSH server daemon 启动报错和-xe报错的解决方法...

    修改SSH端口: # vi /etc/ssh/sshd_config 里面找port 22,在前面加上#,去掉注释,然后把22改成你想要的端口. #port 22 建议去掉注释之前,先增加你需要的端口 ...

最新文章

  1. R语言构建xgboost模型:使用xgboost构建泊松回归(poisson regression)模型
  2. matlab的输出(命令窗口、fprint函数、sprintf函数、disp函数)
  3. 运行维护:UPS电源并列运行分析及维护应用
  4. 1.3.3 系统调用(执行过程、访管指令、库函数与系统调用)
  5. mysql 类型解释_MySQL 数据类型说明解释
  6. java读取excel中的数据存到数据库
  7. java boolean 多线程_JAVA多线程两个实用的辅助类(CountDownLatch和AtomicBoolean)
  8. Python学习 Day 2-数据类型和变量
  9. 19岁少女辍学就业,却遭身价2.5亿创业公司解雇
  10. CentOS 7安装之后的七个事
  11. Java:Spi 小实战
  12. 模电数电c语言笔试题,模电数电题面试题集锦
  13. Python:用类与对象写一元二次方程计算器中遇到的错误
  14. 单页面动画 html5,9款惊艳的HTML5/CSS3动画应用赏析
  15. 九爷带你了解 Tomcat 优化
  16. RK3399 4K 带宽不足[drm:vop_isr] ERROR POST_BUF_EMPTY irq err
  17. AUTOCAD——偏移命令、移动命令
  18. adb连接夜神模拟器提示:adb unable to connect to 127.0.0.162001 cannot connect to 127.0.0.16200 由于目标 计算机积极拒绝
  19. 学校计算机刷卡机,海口学校食堂系统,食堂就餐刷卡机
  20. ext核心API详解

热门文章

  1. C# Nullable 类型转换报错问题
  2. sublime text 一些常用插件
  3. (五)OpenCV | 斑点中心检测(图像矩)
  4. uniapp小程序毛玻璃效果白边去除
  5. p​o​s​t​m​a​r​k​使​用
  6. html监听页面滚动高度,jquery如何监听滚动条事件?
  7. html canvas生成图片,html5 canvas画板涂鸦生成图片代码
  8. mac tar解压错误
  9. 我的世界 RPG 服务器物品系统 - 原版物品重分类 (SpigotVanilla-2)
  10. 牛轧糖Android 7.1系统,小米5C吃上“牛轧糖”推送安卓7.1:系统更流畅,联通信号更稳定!...