由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面结合网上的教程进行详细实现。

先来看一个我在网上经常看到的一个版本。

def diceCoeff(pred, gt, smooth=1, activation='sigmoid'):

r""" computational formula:

dice = (2 * (pred ∩ gt)) / (pred ∪ gt)

"""

if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()

else:

raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

intersection = (pred_flat * gt_flat).sum(1)

unionset = pred_flat.sum(1) + gt_flat.sum(1)

loss = 2 * (intersection + smooth) / (unionset + smooth)

return loss.sum() / N

整体思路就是运用dice的计算公式  2(A∩B / A∪B)。下面来分析一下可能存在的问题:

1)smooth参数是用来防止分母除0的,但是如果smooth=1的话,会使得dice的计算结果略微偏高,看下面的测试代码。

# shape = torch.Size([1, 2, 4, 4])

'''

1 0 0= bladder

0 1 0 = tumor

0 0 0= background

'''

pred = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

gt = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

dice_baldder1 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1, activation=None)

dice_baldder2 = diceCoeff(pred[:, 0:1, :], gt[:, 0:1, :], smooth=1e-5, activation=None)

print('smooth=1 : dice={:.4}'.format(dice_baldder1.item()))

print('smooth=1e-5 : dice={:.4}'.format(dice_baldder2.item()))

# 输出结果

# smooth=1 : dice=1.059

# smooth=1e-5 : dice=1.0

我们最后预测的是一个3分类的分割图,第一类是baldder, 第二类是tumor, 第三类是背景。我们先假设bladder的预测pred和gt一样,计算bladder的dice值,发现当smooth=1的时候,dice偏高, 而smooth=1e-5时dice比较合理。

2) 当预测和gt都是背景时,即图中没有要分割的部分(没有tumor或者bladder)时,会导致dice=2。如下测试

# shape = torch.Size([1, 2, 4, 4])

'''

1 0 0= bladder

0 1 0 = tumor

0 0 0= background

'''

pred = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

gt = torch.Tensor([[

[[0, 1, 1, 0],

[1, 0, 0, 1],

[1, 0, 0, 1],

[0, 1, 1, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]],

[[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0],

[0, 0, 0, 0]]]])

dice_tumor1 = diceCoeff(pred[:, 1:2, :], gt[:, 1:2, :], smooth=1, activation=None)

dice_tumor2 = diceCoeff(pred[:, 1:2, :], gt[:, 1:2, :], smooth=1e-5, activation=None)

print('smooth=1 : dice={:.4}'.format(dice_tumor1.item()))

print('smooth=1e-5 : dice={:.4}'.format(dice_tumor2.item()))

# 输出结果

# smooth=1 : dice=2.0

# smooth=1e-5 : dice=2.0

这里我们还是用1)中的数据,只是这里进行tumor的dice计算,可以看到tumor的预测图和gt中的值全是0,即没有tumor,都是背景,但dice=2,这是因为根据公式此时A∩B和A∪B都为0,结果就是 2*smooth / smooth = 2。

解决办法:我想这里应该更改代码的实现方式,用下面的计算公式替换之前的,因为之前加smooth的位置有问题。

# loss = 2 * (intersection + smooth) / (unionset + smooth) # 之前的

loss = (2 * intersection + smooth) / (unionset + smooth)

替换后的dice如下:

def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * (pred ∩ gt)) / (pred ∪ gt)

"""

if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()

else:

raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

intersection = (pred_flat * gt_flat).sum(1)

unionset = pred_flat.sum(1) + gt_flat.sum(1)

loss = (2 * intersection + smooth) / (unionset + smooth)

return loss.sum() / N

用2)中的数据做测试结果如下:dice计算正确

# smooth=1 : dice=1.0

# smooth=1e-5 : dice=1.0

dice的另一种计算方式:这里参考肾脏肿瘤挑战赛提供的dice计算方法。

def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * tp) / (2 * tp + fp + fn)

"""

if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()

else:

raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

tp = torch.sum(gt_flat * pred_flat, dim=1)

fp = torch.sum(pred_flat, dim=1) - tp

fn = torch.sum(gt_flat, dim=1) - tp

loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)

return loss.sum() / N

经过测试,也是没问题的。

整理代码

def diceCoeff(pred, gt, smooth=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * (pred ∩ gt)) / (pred ∪ gt)

"""

if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()

else:

raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

intersection = (pred_flat * gt_flat).sum(1)

unionset = pred_flat.sum(1) + gt_flat.sum(1)

loss = (2 * intersection + smooth) / (unionset + smooth)

return loss.sum() / N

def diceCoeffv2(pred, gt, eps=1e-5, activation='sigmoid'):

r""" computational formula:

dice = (2 * tp) / (2 * tp + fp + fn)

"""

if activation is None or activation == "none":

activation_fn = lambda x: x

elif activation == "sigmoid":

activation_fn = nn.Sigmoid()

elif activation == "softmax2d":

activation_fn = nn.Softmax2d()

else:

raise NotImplementedError("Activation implemented for sigmoid and softmax2d 激活函数的操作")

pred = activation_fn(pred)

N = gt.size(0)

pred_flat = pred.view(N, -1)

gt_flat = gt.view(N, -1)

tp = torch.sum(gt_flat * pred_flat, dim=1)

fp = torch.sum(pred_flat, dim=1) - tp

fn = torch.sum(gt_flat, dim=1) - tp

loss = (2 * tp + eps) / (2 * tp + fp + fn + eps)

return loss.sum() / N

class DiceLoss(nn.Module):

__name__ = 'dice_loss'

def __init__(self, activation='sigmoid'):

super(DiceLoss, self).__init__()

self.activation = activation

def forward(self, y_pr, y_gt):

return 1 - diceCoeffv2(y_pr, y_gt, activation=self.activation)

总结:上面是这几天对dice以及dice loss的一些实思考和实现,如有问题和错误,还望指出。

python中dice常见问题_【Pytorch】 Dice系数与Dice Loss损失函数实现相关推荐

  1. python随机抽签列表中的同学值日_神奇的大抽签--Python中的列表_章节测验,期末考试,慕课答案查询公众号...

    神奇的大抽签--Python中的列表_章节测验,期末考试,慕课答案查询公众号 更多相关问题 下图表示几个植物类群的进化关系.下列叙述不正确的是[ ]A.最先出现的植物类群是甲B.乙和丙都是由甲进化来的 ...

  2. 在Python中使用LSTM和PyTorch进行时间序列预测

    全文链接:http://tecdat.cn/?p=8145 顾名思义,时间序列数据是一种随时间变化的数据类型.例如,24小时内的温度,一个月内各种产品的价格,一年中特定公司的股票价格(点击文末&quo ...

  3. 正则表达式在python中的应用_学习正则表达式在python中的应用

    目的:对文本的处理,正则表达式的功能很强大,可以很巧妙的过滤.匹配.获取想要的字符串,是必须学习的技能,这里只记录常用的写法,详细文档可以参看官方帮助文档. 环境:ubuntu 16.04 pytho ...

  4. 在python中设置密码登录_在python中生成密码

    在python中生成密码 我想在python中生成一些字母数字密码. 一些可能的方法是: import string from random import sample, choice chars = ...

  5. python中 什么意思_请问python中%代表什么意思?

    婷婷同学_ 1.格式符例如:a = 'test'print 'it is a %s' %(a)打印的结果就是 it is a test2.单独看%,是一个运算符号,求余数.例如:求模运算,相当于mod ...

  6. python中计算均方误差_在python中查找线性回归的均方误差(使用scikit-learn)

    我试图在python中做一个简单的线性回归,其中x变量是单词 项目描述的计数,Y值是以天为单位的融资速度. 我有点困惑,因为测试的均方根误差(rmse)是13.77. 培训数据为13.88.首先,RM ...

  7. values在python中的意思_相当于Python的values()字典方法的Javascript

    相当于Python的values()字典方法的Javascript 这个问题已经在这里有了答案: 如何获取Javascript对象的所有属性值(不知道键)?                       ...

  8. lambda在python中的用法_在python中对lambda使用.assign()方法

    我在Python中运行以下代码:#Declaring these now for later use in the plots TOP_CAP_TITLE = 'Top 10 market capit ...

  9. python中add函数_如何使用python中的add函数?

    之前向大家介绍过python中的求和函数sum函数,numpy中的sum函数,对于数组可以指定维度进行相加.numpy中还有另一种求和运算方法,即add函数.add函数不仅作用于numpy中加法运算, ...

最新文章

  1. [BZOJ1572][Usaco2009 Open]工作安排Job
  2. 人人都能成为安全防范的高手 ——《黑客新型攻击防范:深入剖析犯罪软件》
  3. 易观与用友推出云融合产品“智能用户运营”,掀开数字营销技术新篇章
  4. 模型可解释性-贝叶斯方法
  5. Python(Windows)下安装各种库的多种方法总结--灵活使用pip
  6. Spring——AOP
  7. ie7ajax 跨域 no transport 解决办法
  8. (六)jQuery选择器
  9. 超微服务器电源短接启动图解_教你一招,让你的电脑启动速度秒杀别人
  10. 计算机简单故障时的排除方法,电脑简单故障排除解决办法大全
  11. statement的增删改查和动态的增删改查
  12. Window Operations(窗口函数的使用)
  13. 虚拟环境--virtualenv
  14. winform Combobox出现System.Data.DataRowView的解决的方法
  15. 彻底明白Java的IO系统
  16. 静态路由配置实例学习记录
  17. hashcat软件的简单实用
  18. 获取钉钉考勤机的打卡记录并且解析
  19. 华为对刷量、刷评论的惩罚是什么?有什么解决办法吗?
  20. 在线视频移动化迁徙加速,UGC待开发

热门文章

  1. Electron桌面应用
  2. FS4412开发板简介
  3. Python: SQLAlchemy 处理 PostgreSQL on conflict
  4. 我与小娜(13):LIGO是什么组织?
  5. NKOI 2495 火车运输
  6. 【chrome插件】公众号后台,固定侧边栏,自动定位菜单位置。
  7. ORACLE之ora-01722和ORA-01403的错误测试
  8. 2019ICPC上海网络赛A 边分治+线段树
  9. 学习 Java ,是看书学习快,还是看视频学习快呢 ?
  10. JavaScript获取元素