
在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed


/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):File "tutorial.py", line 100, in <module>model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)File "tutorial.py", line 80, in train_modelloss = criterion(outputs, labels)File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__result = self.forward(*input, **kwargs)File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forwardself.weight, self.size_average)File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropyreturn nll_loss(log_softmax(input), target, weight, size_average)File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_lossreturn f(input, target)File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forwardoutput, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83

通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes条件。

t >= 0 && t < n_classes条件


torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')


而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。



import torchvision.transforms.functional as tftf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)


import torch
import torch.nn.functional as F
import torch.nn as nndef cross_entropy2d(input, target, weight=None, reduction='none'):n, c, h, w = input.size()nt, ht, wt = target.size()if h != ht or w != wt:input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)target = target.view(-1)loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)return loss


可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。



labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)


