半监督3D医学图像分割(三):URPC
Efficient Semi-supervised Gross Target Volume of Nasopharyngeal Carcinoma Segmentation via Uncertainty Rectified Pyramid Consistency
深度学习归根结底是数据驱动的,模型训练的好坏依赖于数据集。在医学图像分割领域,即使是像nn-UNet那样强大的训练框架,也受限于数据集的大小。相比自然图像,医学图像的标注代价更加昂贵,相反,无标注的图像有很多。半监督学习的目的就是将无标注的数据利用起来,达到比单独用有标注数据集更好的效果。
前两篇博客介绍的方法,都是student-teacher双路模型,URPC是单路模型。TS模型是在学生网络和教师网络之间做一致性损失,URPC则是多级特征内部做一致性损失,相比TS架构的计算量和显存大大降低。
网络结构
Overview of the proposed Uncertainty Rectified Pyramid Consistency framework
- 以U-Net为基础(Backbone),E是编码器(Encoder),D是解码器(Decoder)
- p0,p1,p2,ps是解码器不同层的预测结果,通过上采样统一尺寸
- 绿色和红色箭头分别代表分割损失(有标注)和一致性损失(无标注)
论文在U-Net的解码器增加了一个金字塔预测结构,叫做PPNet。PPNet输出多尺度的预测结果,和标签计算分割损失,与Rectified后的特征做一致性损失。同时引入了深监督,一致性损失,不确定性抑制的概念。
[p0′,p1′,p2′,ps′]=[f(x∣D0),f(x∣D1),f(x∣D2),f(x∣D3)][p_0',p_1',p_2',p_s']=[f(x|D_0),f(x|D_1),f(x|D_2),f(x|D_3)] [p0′,p1′,p2′,ps′]=[f(x∣D0),f(x∣D1),f(x∣D2),f(x∣D3)]
- p’是解码器不同层的输出,不同特征的分辨率和通道数是不一样的
[p0,p1,p2,ps]=[g(p0′),g(p1′),g(p2′),g(ps′)][p_0,p_1,p_2,p_s]=[g(p_0'),g(p_1'),g(p_2'),g(p_s')] [p0,p1,p2,ps]=[g(p0′),g(p1′),g(p2′),g(ps′)]
- g由上采样模块、 1x1x1的卷积层和softmax层组成
- p是概率图(C x H x W x D),此时分辨率和通道数都是一样的
损失函数
1.分割损失
医学图像分割任务中常用的交叉熵和dice损失,s个预测结果分别与标签计算损失,然后取平均
Lsup=1S∑s=0S−1Ldice(ps,yi)+Lce(ps,yi)2L_{sup}=\frac{1}{S}\sum_{s=0}^{S-1}{\frac{L_{dice}(p_s,y_i)+L_{ce}(p_s,y_i)}{2}} Lsup=S1s=0∑S−12Ldice(ps,yi)+Lce(ps,yi)
2.无监督损失
为了有效利用无标注数据,URPC利用多尺度特征计算一致性损失,从而引入了正则化。具体而言,设计了金字塔一致性损失最小化不同尺度预测之间的差异。首先,对这些预测结果求平均:
pc=1S∑s=0S−1psp_c=\frac{1}{S}\sum_{s=0}^{S-1}{p_s} pc=S1s=0∑S−1ps
金字塔一致性损失定义为:
Lpyc=1S∑s=0S−1∣∣ps−pc∣∣2L_{pyc}=\frac{1}{S}\sum_{s=0}^{S-1}||p_s-p_c||_2 Lpyc=S1s=0∑S−1∣∣ps−pc∣∣2
ps是不同尺度的预测结果,pc是均值,ps与pc计算MSE
3.不确定性修正
计算不确定度
不同尺度的特征分辨率不同,如果输入的原图是 H x W x D,上采样四次,则p0~p3的分辨率为 H x W x D,H/2 x W/2 x D/2,H/4 x W/4 x D/4,H/8 x W/8 x D/8。在U-Net网络中,分辨率越低的特征图,通道数越多,语义特征越高级,捕获的低频信息越多。反之,分辨率越高的特征,包含的高频信息越多。由于不同特征的频率信息不同,直接上采样后计算一致性可能存在问题,比如细节信息的丢失。
与上一篇博客提到的UA-MT不同,URPC只需要一次前向传播,能够高效的计算不确定度,用的是KL散度计算预测结果和平均预测结果之间的差异。
Ds≈∑j=0Cpsj⋅logpsjpcjD_s\approx\sum_{j=0}^{C}{p_s^j \cdot log\frac{p_s^j}{p_c^j}} Ds≈j=0∑Cpsj⋅logpcjpsj
- C是分割类别,psj是ps的第j个通道,pcj是pc的第j个通道
- Ds是ps和pc之间的KL散度,用来表示不确定度,形状是 C x H x W x D
- pc可以认为是多尺度预测结果的中心,Ds越大代表离中心越远,不确定性越高
不确定性修正
根据不确定度对上文提到的金字塔一致性损失做了修正,
ωsv=e−Dsv\omega_s^v=e^{-D_s^v} ωsv=e−Dsv
- psv和pcv分别表示ps和pc在体素v处的概率向量
- 上式第一项是修正后的金字塔一致性损失,第二项有点像正则化,用来降低不确定性
- wsv和Dsv是当前体素的权重和不确定度,根据公式,不确定度越高的区域,分配的权重越低
4.损失函数
URPC在有标注图像上的分割损失和无标注图像上的一致性损失,用一个公式表示
Ltotal=Lsup+λ⋅LunsupL_{total}=L_{sup}+\lambda\cdot L_{unsup} Ltotal=Lsup+λ⋅Lunsup
λ(t)=ωmax⋅e−5(1−ttmax)2\lambda(t) = \omega_{max} \cdot e^{-5(1-\frac{t}{t_{max}})^2} λ(t)=ωmax⋅e−5(1−tmaxt)2
- Lsup和Lunsup对应有监督损失和无监督损失
- λ是无监督损失的权重,在训练过程中逐渐增加,防止网络训练前期被无意义的目标影响
代码解读
网络就是深监督的U-Net,另外添加4个Dropout
层增加了多尺度特征的随机性
class unet_3D_dv_semi(nn.Module):def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):super(unet_3D_dv_semi, self).__init__()self.is_deconv = is_deconvself.in_channels = in_channelsself.is_batchnorm = is_batchnormself.feature_scale = feature_scalefilters = [64, 128, 256, 512, 1024]filters = [int(x / self.feature_scale) for x in filters]# downsamplingself.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1))self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1))self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1))self.maxpool3 = nn.MaxPool3d(kernel_size=(2, 2, 2))self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1))self.maxpool4 = nn.MaxPool3d(kernel_size=(2, 2, 2))self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3, 3, 3), padding_size=(1, 1, 1))# upsamplingself.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)# deep supervisionself.dsv4 = UnetDsv3(in_size=filters[3], out_size=n_classes, scale_factor=8)self.dsv3 = UnetDsv3(in_size=filters[2], out_size=n_classes, scale_factor=4)self.dsv2 = UnetDsv3(in_size=filters[1], out_size=n_classes, scale_factor=2)self.dsv1 = nn.Conv3d(in_channels=filters[0], out_channels=n_classes, kernel_size=1)self.dropout1 = nn.Dropout3d(p=0.5)self.dropout2 = nn.Dropout3d(p=0.3)self.dropout3 = nn.Dropout3d(p=0.2)self.dropout4 = nn.Dropout3d(p=0.1)def forward(self, inputs):conv1 = self.conv1(inputs)maxpool1 = self.maxpool1(conv1)conv2 = self.conv2(maxpool1)maxpool2 = self.maxpool2(conv2)conv3 = self.conv3(maxpool2)maxpool3 = self.maxpool3(conv3)conv4 = self.conv4(maxpool3)maxpool4 = self.maxpool4(conv4)center = self.center(maxpool4)up4 = self.up_concat4(conv4, center)up4 = self.dropout1(up4)up3 = self.up_concat3(conv3, up4)up3 = self.dropout2(up3)up2 = self.up_concat2(conv2, up3)up2 = self.dropout3(up2)up1 = self.up_concat1(conv1, up2)up1 = self.dropout4(up1)# Deep Supervisiondsv4 = self.dsv4(up4)dsv3 = self.dsv3(up3)dsv2 = self.dsv2(up2)dsv1 = self.dsv1(up1)return dsv1, dsv2, dsv3, dsv4
dsv1,dsv2,dsv3,dsv4对应网络图中的p0,p1,p2,p3dsv1, dsv2, dsv3, dsv4对应网络图中的p_0, p_1,p_2,p_3 dsv1,dsv2,dsv3,dsv4对应网络图中的p0,p1,p2,p3
for epoch_num in tqdm(range(max_epoch), ncols=70):for i_batch, sampled_batch in enumerate(trainloader):volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']volume_batch, label_batch = volume_batch.cuda(), label_batch.cuda()unlabeled_volume_batch = volume_batch[labeled_bs:]outputs_aux1, outputs_aux2, outputs_aux3, outputs_aux4, = model(volume_batch)outputs_aux1_soft = torch.softmax(outputs_aux1, dim=1)outputs_aux2_soft = torch.softmax(outputs_aux2, dim=1)outputs_aux3_soft = torch.softmax(outputs_aux3, dim=1)outputs_aux4_soft = torch.softmax(outputs_aux4, dim=1)
- outputs_aux1, outputs_aux2, outputs_aux3, outputs_aux4是多尺度的预测结果
loss_ce_aux1 = ce_loss(outputs_aux1[:args.labeled_bs], label_batch[:args.labeled_bs])loss_ce_aux2 = ce_loss(outputs_aux2[:args.labeled_bs], label_batch[:args.labeled_bs])loss_ce_aux3 = ce_loss(outputs_aux3[:args.labeled_bs], label_batch[:args.labeled_bs])loss_ce_aux4 = ce_loss(outputs_aux4[:args.labeled_bs], label_batch[:args.labeled_bs])loss_dice_aux1 = dice_loss(outputs_aux1_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))loss_dice_aux2 = dice_loss(outputs_aux2_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))loss_dice_aux3 = dice_loss(outputs_aux3_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))loss_dice_aux4 = dice_loss(outputs_aux4_soft[:args.labeled_bs], label_batch[:args.labeled_bs].unsqueeze(1))supervised_loss = (loss_ce_aux1+loss_ce_aux2+loss_ce_aux3+loss_ce_aux4 +loss_dice_aux1+loss_dice_aux2+loss_dice_aux3+loss_dice_aux4)/8
- 对有标注的图像计算分割损失,即常用的交叉熵和dice损失
preds = (outputs_aux1_soft + outputs_aux2_soft+outputs_aux3_soft+outputs_aux4_soft)/4variance_aux1 = torch.sum(kl_distance(torch.log(outputs_aux1_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)exp_variance_aux1 = torch.exp(-variance_aux1)variance_aux2 = torch.sum(kl_distance(torch.log(outputs_aux2_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)exp_variance_aux2 = torch.exp(-variance_aux2)variance_aux3 = torch.sum(kl_distance(torch.log(outputs_aux3_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)exp_variance_aux3 = torch.exp(-variance_aux3)variance_aux4 = torch.sum(kl_distance(torch.log(outputs_aux4_soft[args.labeled_bs:]), preds[args.labeled_bs:]), dim=1, keepdim=True)exp_variance_aux4 = torch.exp(-variance_aux4)consistency_dist_aux1 = (preds[args.labeled_bs:] - outputs_aux1_soft[args.labeled_bs:]) ** 2consistency_loss_aux1 = torch.mean(consistency_dist_aux1 * exp_variance_aux1) / (torch.mean(exp_variance_aux1) + 1e-8) + torch.mean(variance_aux1)consistency_dist_aux2 = (preds[args.labeled_bs:] - outputs_aux2_soft[args.labeled_bs:]) ** 2consistency_loss_aux2 = torch.mean(consistency_dist_aux2 * exp_variance_aux2) / (torch.mean(exp_variance_aux2) + 1e-8) + torch.mean(variance_aux2)consistency_dist_aux3 = (preds[args.labeled_bs:] - outputs_aux3_soft[args.labeled_bs:]) ** 2consistency_loss_aux3 = torch.mean(consistency_dist_aux3 * exp_variance_aux3) / (torch.mean(exp_variance_aux3) + 1e-8) + torch.mean(variance_aux3)consistency_dist_aux4 = (preds[args.labeled_bs:] - outputs_aux4_soft[args.labeled_bs:]) ** 2consistency_loss_aux4 = torch.mean(consistency_dist_aux4 * exp_variance_aux4) / (torch.mean(exp_variance_aux4) + 1e-8) + torch.mean(variance_aux4)consistency_loss = (consistency_loss_aux1 + consistency_loss_aux2 + consistency_loss_aux3 + consistency_loss_aux4) / 4
- preds对应pc
pc=1S∑s=0S−1psp_c=\frac{1}{S}\sum_{s=0}^{S-1}{p_s} pc=S1s=0∑S−1ps
- variance_aux对应Ds
Ds≈∑j=0Cpsj⋅logpsjpcjD_s\approx\sum_{j=0}^{C}{p_s^j \cdot log\frac{p_s^j}{p_c^j}} Ds≈j=0∑Cpsj⋅logpcjpsj
- exp_variance_aux对应wsv
ωsv=e−Dsv\omega_s^v=e^{-D_s^v} ωsv=e−Dsv
- consistency_loss对应无监督损失
consistency_weight = get_current_consistency_weight(iter_num//150)loss = supervised_loss + consistency_weight * consistency_loss
- consistency_weight对应λ,w_max=0.1,随iteration逐渐增加到0.1
λ(t)=ωmax⋅e−5(1−ttmax)2\lambda(t) = \omega_{max} \cdot e^{-5(1-\frac{t}{t_{max}})^2} λ(t)=ωmax⋅e−5(1−tmaxt)2
其余代码细节见LASeg: 2018 Left Atrium Segmentation (MRI)中的train_URPC.py
实验结果
论文实验
原论文是在鼻咽癌核磁数据集上做的实验
NPC数据集消融实验(18例有标签,162例无标签,10%)
- S是多尺度特征的数量,GTVnx和GTVnd是分割的不同区域,DSC是dice系数,ASD是平均表面距离
- S=4时,分割精度最高。S=3时效果也不错,看的出来,继续增加S提升不大,反而会增加网络参数和计算量
- UR(uncertainty rectification)是不确定性修正,UM(uncertainty minimization)是不确定度抑制
与其他方法的对比结果
- 同样是10%的标签率,SL是监督学习
- 所有的半监督网络都以3D U-Net为backbone,URPC的表现是最好的
可视化结果
- 图a中,不确定的区域主要集中在分割目标的边界区域
- 图b中,随着标签率从10%提高到50%,URPC的dice指标一直是比DAN高的
我的实验
我在左心房数据集(LAHeart2018)上做的实验,一共154例数据,123例当做训练集,31例当做测试集。
Loss变化曲线
- 我这里的对比并不严谨,mean teacher网络里面的backbone是V-Net,URPC的backbone是U-Net
- 根据我自己做的实验,URPC效果在左心房核磁数据集上的表现,比MT没有多大提高
分割结果重建图:红色是金标签,蓝色是模型预测结果
- 相比只使用标注数据集的全监督方法,半监督方法不管是在评价指标或者可视化分割结果上,都是有显著提高的。
URPC的优点在于不修改网络结构,把U-Net网络中的多级特征利用起来,相互之间做一致性损失,并引入不确定度修正预测结果。
相比student-teacher网络,训练起来更加简单。
参考资料:
Luo X, Liao W, Chen J, et al. Efficient semi-supervised gross target volume of nasopharyngeal carcinoma segmentation via uncertainty rectified pyramid consistency[C]//International Conference on Medical Image Computing and Computer-Assisted Intervention. Springer, Cham, 2021: 318-329.
HiLab-git/SSL4MIS: Semi Supervised Learning for Medical Image Segmentation, a collection of literature reviews and code implementations.
项目地址:
LASeg: 2018 Left Atrium Segmentation (MRI)
如有问题,欢迎联系 ‘ice_rain123@foxmail.com’
半监督3D医学图像分割(三):URPC相关推荐
- 半监督3D医学图像分割(四):SASSNet
形状感知半监督医学图像分割 Shape-aware Semi-supervised 3D Semantic Segmentation for Medical Images 研究背景 随着人工智能技术在 ...
- (新SOTA)UNETR++:轻量级的、高效、准确的共享权重的3D医学图像分割
(新SOTA)UNETR++:轻量级的.高效.准确的共享权重的3D医学图像分割 0 Abstract 由于Transformer模型的成功,最近的工作研究了它们在3D医学分割任务中的适用性.在Tran ...
- 深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割
深度学习实战(十):使用 PyTorch 进行 3D 医学图像分割 1. 项目简介 2. 3D医学图像分割的需求 3. 医学图像和MRI 4. 三维医学图像表示 5. 3D-Unet模型 5.1损失函 ...
- 《弱监督/半监督的DCNN图像分割》笔记
<Weakly- and Semi-Supervised Learning of a Deep Convolutional Network for Semantic Image Segmenta ...
- 半监督目标检测(三)
目录 ISMT 动机 1. Overview 2. Pseudo Labels Fusion 3. Interactive Self-Training 4. Mean Teacher Unbiased ...
- CVPR 2020 | 利用强化学习进行交互式3D医学图像分割
点击上方"视学算法",选择"星标" 快速获得最新干货 本文转载自:机器之心 如何提高交互式图像分割算法的效率?上海交大和华师大的研究者提出了一种基于多智能体深度 ...
- 【深度学习】V-Net 3D医学图像分割 Dice loss 损失
论文:https://arxiv.org/abs/1606.04797 论文 本文引入Dice coefficient 去处理医学3D图像里面 前景和背景体素数量严重不平衡的情况. 网络用于处理3D图 ...
- 基于弱监督深度学习的医学图像分割方法综述
基于弱监督深度学习的医学图像分割方法综述 摘要:基于深度学习的医学影像分割尽管精度在不断的提升,但是离不开大规模的高质量标注数据的训练,被称为弱监督学习的深度学习的一个分支正在帮助医生通过减少对完整和 ...
- 基于深度学习的自然图像和医学图像分割:网络结构设计
来源:知乎.极市平台.深度学习爱好者作者丨李慕清@知乎 https://zhuanlan.zhihu.com/p/104854615 本文约5100字,建议阅读10分钟 本文首先介绍一些经典的语义分割 ...
最新文章
- POJ - 3694 Network tanjar割边+lca
- linux网络编程——webserver服务器编写
- 用户可以改变计算机功能键吗,电脑键盘快捷键怎么更改
- mysql内连接部门平均值_详解MySql基本查询、连接查询、子查询、正则表达查询_MySQL...
- sql server版本 性能_迁移到高版本 SQL 数据库后,性能变差了
- mysql2个字段还会map_通过注解实现MyBatis将sql查询结果的两个字段分别作为map的key,value...
- CF584D 【Dima and Lisa】题解
- idea的Database导出导入表操作
- 虚拟机实验Windows10备份和还原
- 【高等数学】上册知识点复习
- 链表(c语言),c语言链表(c语言链表详解)
- c# 角度和弧度的转换
- pc登录2个微信客户端
- H3C路由器静态NAT_不同网段的两个路由器如何互通?
- 【python】启动客户端报错:OSError: [WinError 740] 请求的操作需要提升。
- 中介者模式的实际应用
- 数据分析系列:绩效(效率)评价与python实现(层析分析、topsis、DEA)
- Ajax洗洁精的特性,洗涤剂中常用表面活性剂的特点
- matlab中plot函数如何在图像上标记某些点?
- bilibili视频下载神器[无广告]