源码地址:https://github.com/HiLab-git/SSL4MIS

目的

本文主要通过对github上源码的分析,学习半监督语义分割的思想,并通过代码提供的数据对比各个半监督方法的效果。

介绍

在语义分割领域,标注往往是比较困难的。因为掩膜标注要求和目标边缘紧密贴合,否则会带来边界上的额外损失。如下图:

相反的,未标注的数据量一般要远远多于标注的数据量。基于此,半监督方法的研究就至关重要了。
机器学习按数据标注情况可分为三种:监督学习,无监督学习和半监督学习。

  • 监督学习:在有标记的情况下,对数据进行分类或回归。比如随机森林、SVM和目前流行的全卷积网络、循环神经网络都归为这一类;
  • 无监督学习:没有给定事先标记过的范例,自动对输入的资料进行分类或分群。此类有很多机器学习算法,比如k-means、meanshift和PCA,一般通过核函数划分超空间;
  • 半监督学习:在有部分标记的情况下,使用所有提供的数据,对输入进行分类和回归的方法。

实验

数据来源

本文实验采用19%标注数据,81%未标注数据进行训练与测试。数据来源为ACDC-Segmentation,该数据集为第戎大学采集的心脏核磁共振影像,标注类型为:背景区域,右心室腔,心肌层和左心室腔。我们使用后三类作为分割结果,使用dice和hd95作为评价指标进行实验。两个指标中,Dice对mask的内部填充比较敏感,而hausdorff distance 对分割出的边界比较敏感。

测量指标

Dice

dice是评价两个目标相关性的指标,又叫F1-score。平衡了召回率和精度的影响,是一个综合性指标。
Dice=1recall−1+precision−1=2TP2TP+FP+FNDice=\frac {1} {recall^{-1}+precision^{-1}}=\frac {2TP} {2TP+FP+FN} Dice=recall−1+precision−11​=2TP+FP+FN2TP​

hausdorff distance

hausdorff distance是测量点集X的到另外一个集和Y最近点的最大距离。

结合下图直观的说,就是比较两点的距离,取更大值。

hd95(95% hausdorff distance)类似HD,但只取距离排序后的中间的95%距离,其目的是减轻特殊野点的影响。


监督学习(Baseline)

监督学习实验中使用Unet作为分割网络。下文对比方法中,除非指明,否则默认也使用Unet网络进行对照。监督学习流程如下:

监督损失包括Dice loss 和 Cross Entropy loss。
losssupervised=lossCE(X,Y)+lossDice(X,Y),lossCE(X,Y)=∑(−Ylog⁡(X)+(1−Y)log⁡(1−X)),lossDice(X,Y)=1−2∣X⋂Y∣X+Yloss_{supervised}=loss_{CE}(X,Y)+loss_{Dice}(X,Y), \\ loss_{CE}(X, Y) = \sum (-Y \log (X)+(1-Y) \log (1-X)), \\ loss_{Dice}(X, Y)=1-2 \frac {\left | X \bigcap Y\right | } {X+Y} losssupervised​=lossCE​(X,Y)+lossDice​(X,Y),lossCE​(X,Y)=∑(−Ylog(X)+(1−Y)log(1−X)),lossDice​(X,Y)=1−2X+Y∣X⋂Y∣​
其中,X表示网络预测输出,Y表示标注。
监督学习只使用19%的标注数据作为输入源,训练10000次,得到的结果如下:

其中,编号1、2、3分别代表三个分类。监督学习训练过程中Dice最好为81%,这个结果将作为基准。

下面介绍半监督方法。


半监督学习

相对于监督学习,半监督学习增加了一致性损失,用于测量未标注数据的分割结果并使其靠近某一种约束。

下面是具体的方案,除非特殊指明,否则参数条件和监督学习一致。

mean teacher (论文链接)

mean teacher 的一致性损失为:
lossConsistency=1n∑(f(X)−fema(Xema))2loss_{Consistency}= \frac {1} {n} \sum (f(X)-f_{ema}(X_{ema}))^{2} lossConsistency​=n1​∑(f(X)−fema​(Xema​))2
其中femaf_{ema}fema​是无梯度分割网络,XemaX_{ema}Xema​是加噪原始数据。该约束可描述为:使用源图像得到的分割结果,和加噪图像得到的分割结果应该是一致的。

流程如下:

下面是代码分析:

volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
...
#加噪
noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
ema_inputs = unlabeled_volume_batch + noise
...
#正常结果
outputs = model(volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)
#加噪结果
with torch.no_grad():ema_output = ema_model(ema_inputs)ema_output_soft = torch.softmax(ema_output, dim=1)
...
#一致性损失
consistency_loss = torch.mean((outputs_soft[args.labeled_bs:]-ema_output_soft)**2)

该方法训练效果如下:

mean teacher 得到的最优Dice为82.7%。

uncertainty aware mean teacher(论文链接)

该方法相对于mean teacher,增加了一个不确定性掩膜mask。公式如下:
lossConsistency=lossmeanteacher∗maskuncertainty=−∑(Ylog⁡Y)mask={1uncertainty<th0elseloss_{Consistency}=loss_{mean teacher}*mask \\ uncertainty=- \sum {(Y\log{Y})} \\ mask=\left\{\begin{matrix} 1 \qquad uncertainty<th \\ 0 \qquad else \end{matrix}\right. lossConsistency​=lossmeanteacher​∗maskuncertainty=−∑(YlogY)mask={1uncertainty<th0else​

其中,uncertaintyuncertaintyuncertainty由Xlog(X)Xlog(X)Xlog(X)函数构成,在X=0.36附近最大,两端最小。也就是说,预测图越接近0.36,不确定性越大,不确定性超过一定阈值就置零。

preds = torch.zeros([stride * T, num_classes, w, h]).cuda()
for i in range(T//2):#带噪声输入ema_inputs = volume_batch_r + \torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)with torch.no_grad():preds[2 * stride * i:2 * stride *(i + 1)] = ema_model(ema_inputs)
preds = F.softmax(preds, dim=1)
preds = preds.reshape(T, stride, num_classes, w, h)
preds = torch.mean(preds, dim=0)
#不确定性掩膜计算
uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6),\dim=1, keepdim=True)

结果如下:

最优Dice=84%。

interpolation consistency(论文链接)

该方法使用两张图像内插作为输入,如下:
lossunsupervised=Mean(outputmixed,f(inputmixed))loss_{unsupervised}=Mean(output_{mixed},f(input_{mixed})) lossunsupervised​=Mean(outputmixed​,f(inputmixed​))
其中,f是不带梯度的分割网络。

#混合输入 input_{mixed}
batch_ux_mixed = unlabeled_volume_batch_0 * \(1.0 - ict_mix_factors) + \unlabeled_volume_batch_1 * ict_mix_factors
#混合输入2
input_volume_batch = torch.cat([labeled_volume_batch, batch_ux_mixed], dim=0)
outputs = model(input_volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)
with torch.no_grad():ema_output_ux0 = torch.softmax(ema_model(unlabeled_volume_batch_0), dim=1)ema_output_ux1 = torch.softmax(ema_model(unlabeled_volume_batch_1), dim=1)#混合输出batch_pred_mixed = ema_output_ux0 * \(1.0 - ict_mix_factors) + ema_output_ux1 * ict_mix_factors
#混合输入和输出计算一致性损失
consistency_weight = get_current_consistency_weight(iter_num//150)consistency_loss = torch.mean((outputs_soft[args.labeled_bs:] - batch_pred_mixed) ** 2)

效果如下:

最优Dice=82%。

最小熵约束(论文链接)

entropy minimization,损失如下:
y=−∑(plog⁡p)log⁡(C)y=- \frac { \sum (p \log p)}{\log (C)} y=−log(C)∑(plogp)​
该公式对输出概率进行最小熵约束,使得p接近0或1时损失较小。其中,C为常数。公式和上一个方法里mask的公式差不多,效果应该也比不上。

outputs = model(volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)
...
# 对所有数据进行熵计算
consistency_loss = losses.entropy_loss(outputs_soft, C=4)

结果如下:

最优Dice=80%。

dv (论文没找到)

这个方法还不能用unet跑,作者说正在完善代码。所以本实验使用的是unet_dv网络。直接看结果:

最优Dice=81%。

对抗网络(论文链接)

该方法应用了对抗网络的思想,设计了一个鉴别器网络DAN。它的核心思想是对鉴别器训练。损失部分由一致性损失LcL_{c}Lc​和鉴别器损失LdL_{d}Ld​组成。假设现有数据集X=(X1,X2)X=(X_{1},X_{2})X=(X1​,X2​),其中已标注子集为X1X_{1}X1​,未标注子集为X2X_{2}X2​,它们由分割模型预测的结果分别为X1X_{1}X1​,Y2Y_{2}Y2​,那么一致性损失和鉴别器损失分别为:
Ld=f(D(X1,Y1),1)+f(D(X2,Y2),0)Lc=f(D(X2,Y2),1)L_{d}=f(D(X_{1},Y_{1}),1)+f(D(X_{2},Y_{2}),0) \\ L_{c}=f(D(X_{2},Y_{2}),1) Ld​=f(D(X1​,Y1​),1)+f(D(X2​,Y2​),0)Lc​=f(D(X2​,Y2​),1)

#假设标注的数据都为真,未标注数据为假
DAN_target = torch.tensor([0] * args.batch_size).cuda()
DAN_target[:args.labeled_bs] = 1
...
# 未标注数据为真,形成对抗,注意这里使用的是DAN_target[:args.labeled_bs]
# 而不是DAN_target[args.labeled_bs:]
DAN_outputs = DAN(outputs_soft[args.labeled_bs:], volume_batch[args.labeled_bs:])
consistency_loss = F.cross_entropy(DAN_outputs, (DAN_target[:args.labeled_bs]).long())
...
#鉴别器损失计算,假设标注的数据都为真,未标注数据为假
DAN_outputs = DAN(outputs_soft, volume_batch)
DAN_loss = F.cross_entropy(DAN_outputs, DAN_target.long())

效果如下:

最优Dice=83%。

总结

总的来说,半监督学习方法最终结果都要比基准好。效果最好的是uncertainty aware mean teacher,Dice=84%,原因可能是它用了比较多的约束条件;最少迭代的是对抗网络,它在大概4000次迭代的时候就已经达到最好效果,但具体是不是训练最快的,还要等进一步验证。

半监督语义分割方法汇总(附代码分析)相关推荐

  1. CVPR 2022 | 商汤/上交/港中文提出U2PL:使用不可靠伪标签的半监督语义分割

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Pascal  |  已授权转载(源:知乎)编辑:CVer https://zhuanlan.zhih ...

  2. Semi-supervised Semantic Segmentation with Error Localization Network(基于误差定位网络的半监督语义分割 )

    Semi-supervised Semantic Segmentation with Error Localization Network(基于误差定位网络的半监督语义分割 ) Abstract 本文 ...

  3. CVPR 2021 | 北大MSRA提出CPS:基于交叉伪监督的半监督语义分割

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Charles  |  源:知乎 https://zhuanlan.zhihu.com/p/37812 ...

  4. 用于半监督语义分割的基于掩码的数据增强

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 小白导读 论文是学术研究的精华和未来发展的明灯.小白决心每天为大家 ...

  5. CVPR 2022|U2PL:使用不可靠伪标签的半监督语义分割

    本文转自商汤学术 导读 半监督任务的关键在于充分利用无标签数据,商汤科技联合上海交通大学.香港中文大学,基于「 Every Pixel Matters」的理念,有效利用了包括不可靠样本在内的全部无标签 ...

  6. 对抗学习的半监督语义分割

    <font color="red">GAN生成对抗网络:</font>由两个子网络组成,generator和discriminator,在训练过程中,这两个 ...

  7. 半监督语义分割_paper reading part2

    03 SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers Time:2021.05 我 ...

  8. YOLOV5 的小目标检测网络结构优化方法汇总(附代码)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨南山 来源丨 AI约读社 YOLOv5是一种非常受欢迎的单阶段目标检测,以其性能和速度著称,其结 ...

  9. 三种基于自监督深度估计的语义分割方法(arXiv 2021)

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨泡泡机器人 来源丨泡泡机器人SLAM 标题: Three Ways to Improve Sem ...

  10. 【论文汇总】 ECCV 2020 语义分割paper汇总

    语义分割 segmentation paper@ECCV 2020 ECCV 2020语义分割文章总结,文章下载链接. 文章目录 语义分割 segmentation paper@ECCV 2020 前 ...

最新文章

  1. mysql localhost无法登陆_MySQL 'root'@'localhost'无法登录
  2. 删了手机里的一个html文件,手机太卡,哪些内容可以毫不犹豫的删除?
  3. [转]LoadRunner 各个指标分析
  4. 【Android工具】免费二次元追番神器,各种字幕组新番旧番良心资源,重要的事说三遍:没有广告!没有广告!没有广告...
  5. Eclipse——导出可执行jar包
  6. 使用ANTLR在5分钟内用Java解析任何语言:例如Python
  7. 干货来袭!java核心技术卷一pdf
  8. script-百度换肤效果
  9. 第一期:浙大版《JAVA语言程序设计教程》(第二版)翁凯等 主编 ——小白的入门之路(上)(一)
  10. 怎么把mp3格式的音频文件转为文字?
  11. 图解侧方停车技巧2015高清版
  12. ContextCapture系列教程(三):大疆精灵4RTK版无人机POS数据提取、处理(处理后勉强达到免相控要求)
  13. java和python哪个运行速度快_python和java学哪个比较简单点
  14. 咸鱼软件应用—Cura3D切片
  15. 让你的微信小程序对用户更加友好:上拉加载和下拉刷新就是关键
  16. (转)Q格式的转换问题与移位
  17. 丹尼斯·里奇-c语言之父,Unix之父
  18. python 自动运维架构师_运维架构师-Python 自动化运维开发-031
  19. Android实训-家庭财务管理系统
  20. LeaRun.Framework━ .NET快速开发框架 ━ 工作流程组件介绍

热门文章

  1. 三对角、五对角追赶法求解线性方程组
  2. 灵格斯怎么屏幕取词_灵格斯词霸怎么用?灵格斯词霸使用手册
  3. [裴礼文数学分析中的典型问题与方法习题参考解答]4.4.9
  4. 转速开环恒压频比异步电动机调速系统仿真
  5. java中UUID类生成32位随机数(附加 6 位随机数)
  6. 主流数据库对比,主流数据库性能、选型对比
  7. 大华linux密码,大华wifi摄像头的初始化和读取视频流
  8. 数字电路基础知识——锁存器与触发器的建立时间和保存时间(一)
  9. 使用QEMU搭建ARM64实验环境
  10. java开发文档怎么写?教你写java技术文档