由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面结合网上的教程进行详细实现。


def diceCoeff(pred, gt, smooth=1, activation='sigmoid'):

r""" computational formula:

dice = (2 * (pred ∩ gt)) / (pred ∪ gt)


if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()


raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

intersection = (pred_flat * gt_flat).sum(1)

unionset = pred_flat.sum(1) + gt_flat.sum(1)

loss = 2 * (intersection + smooth) / (unionset + smooth)

return loss.sum() / N

整体思路就是运用dice的计算公式  2(A∩B / A∪B)。下面来分析一下可能存在的问题:


# shape = torch.Size([1, 2, 4, 4])


1 0 0= bladder

0 1 0 = tumor

0 0 0= background


pred = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

gt = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)

dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)

print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))

print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))

# 输出结果

# smooth=1 : dice=1.059

# smooth=1e-5 : dice=1.0

我们最后预测的是一个3分类的分割图,第一类是baldder, 第二类是tumor, 第三类是背景。我们先假设bladder的预测pred和gt一样,计算bladder的dice值,发现当smooth=1的时候,dice偏高, 而smooth=1e-5时dice比较合理。

2) 当预测和gt都是背景时,即图中没有要分割的部分(没有tumor或者bladder)时,会导致dice=2。如下测试

# shape = torch.Size([1, 2, 4, 4])


1 0 0= bladder

0 1 0 = tumor

0 0 0= background


pred = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

gt = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

dice_tumor1 = diceCoeff(pred[:, 1:2, :], gt[:, 1:2, :], smooth=1, activation=None)

dice_tumor2 = diceCoeff(pred[:, 1:2, :], gt[:, 1:2, :], smooth=1e-5, activation=None)

print('smooth=1 : dice={:.4}'.format(dice_tumor1.item()))

print('smooth=1e-5 : dice={:.4}'.format(dice_tumor2.item()))

# 输出结果

# smooth=1 : dice=2.0

# smooth=1e-5 : dice=2.0

这里我们还是用1)中的数据,只是这里进行tumor的dice计算,可以看到tumor的预测图和gt中的值全是0,即没有tumor,都是背景,但dice=2,这是因为根据公式此时A∩B和A∪B都为0,结果就是 2*smooth / smooth = 2。


# loss = 2 * (intersection + smooth) / (unionset + smooth) # 之前的

loss = (2 * intersection + smooth) / (unionset + smooth)


def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * (pred ∩ gt)) / (pred ∪ gt)


if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()


raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

intersection = (pred_flat * gt_flat).sum(1)

unionset = pred_flat.sum(1) + gt_flat.sum(1)

loss = (2 * intersection + smooth) / (unionset + smooth)

return loss.sum() / N


# smooth=1 : dice=1.0

# smooth=1e-5 : dice=1.0


def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * tp) / (2 * tp + fp + fn)


if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()


raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

tp = torch.sum(gt_flat * pred_flat, dim=1)

fp = torch.sum(pred_flat, dim=1) - tp

fn = torch.sum(gt_flat, dim=1) - tp

loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)

return loss.sum() / N



def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * (pred ∩ gt)) / (pred ∪ gt)


if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()


raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

intersection = (pred_flat * gt_flat).sum(1)

unionset = pred_flat.sum(1) + gt_flat.sum(1)

loss = (2 * intersection + smooth) / (unionset + smooth)

return loss.sum() / N

def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * tp) / (2 * tp + fp + fn)


if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()


raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

tp = torch.sum(gt_flat * pred_flat, dim=1)

fp = torch.sum(pred_flat, dim=1) - tp

fn = torch.sum(gt_flat, dim=1) - tp

loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)

return loss.sum() / N

class DiceLoss(nn.Module):

__name__ = 'dice_loss'

def __init__(self, activation='sigmoid'):

super(DiceLoss, self).__init__()

self.activation = activation

def forward(self, y_pr, y_gt):

return 1 - diceCoeffv2(y_pr, y_gt, activation=self.activation)

总结:上面是这几天对dice以及dice loss的一些实思考和实现,如有问题和错误,还望指出。

