半监督语义分割方法汇总(附代码分析)
源码地址:https://github.com/HiLab-git/SSL4MIS
目的
本文主要通过对github上源码的分析,学习半监督语义分割的思想,并通过代码提供的数据对比各个半监督方法的效果。
介绍
在语义分割领域,标注往往是比较困难的。因为掩膜标注要求和目标边缘紧密贴合,否则会带来边界上的额外损失。如下图:
相反的,未标注的数据量一般要远远多于标注的数据量。基于此,半监督方法的研究就至关重要了。
机器学习按数据标注情况可分为三种:监督学习,无监督学习和半监督学习。
- 监督学习:在有标记的情况下,对数据进行分类或回归。比如随机森林、SVM和目前流行的全卷积网络、循环神经网络都归为这一类;
- 无监督学习:没有给定事先标记过的范例,自动对输入的资料进行分类或分群。此类有很多机器学习算法,比如k-means、meanshift和PCA,一般通过核函数划分超空间;
- 半监督学习:在有部分标记的情况下,使用所有提供的数据,对输入进行分类和回归的方法。
实验
数据来源
本文实验采用19%标注数据,81%未标注数据进行训练与测试。数据来源为ACDC-Segmentation,该数据集为第戎大学采集的心脏核磁共振影像,标注类型为:背景区域,右心室腔,心肌层和左心室腔。我们使用后三类作为分割结果,使用dice和hd95作为评价指标进行实验。两个指标中,Dice对mask的内部填充比较敏感,而hausdorff distance 对分割出的边界比较敏感。
测量指标
Dice
dice是评价两个目标相关性的指标,又叫F1-score。平衡了召回率和精度的影响,是一个综合性指标。
Dice=1recall−1+precision−1=2TP2TP+FP+FNDice=\frac {1} {recall^{-1}+precision^{-1}}=\frac {2TP} {2TP+FP+FN} Dice=recall−1+precision−11=2TP+FP+FN2TP
hausdorff distance
hausdorff distance是测量点集X的到另外一个集和Y最近点的最大距离。
结合下图直观的说,就是比较两点的距离,取更大值。
hd95(95% hausdorff distance)类似HD,但只取距离排序后的中间的95%距离,其目的是减轻特殊野点的影响。
监督学习(Baseline)
监督学习实验中使用Unet作为分割网络。下文对比方法中,除非指明,否则默认也使用Unet网络进行对照。监督学习流程如下:
监督损失包括Dice loss 和 Cross Entropy loss。
losssupervised=lossCE(X,Y)+lossDice(X,Y),lossCE(X,Y)=∑(−Ylog(X)+(1−Y)log(1−X)),lossDice(X,Y)=1−2∣X⋂Y∣X+Yloss_{supervised}=loss_{CE}(X,Y)+loss_{Dice}(X,Y), \\ loss_{CE}(X, Y) = \sum (-Y \log (X)+(1-Y) \log (1-X)), \\ loss_{Dice}(X, Y)=1-2 \frac {\left | X \bigcap Y\right | } {X+Y} losssupervised=lossCE(X,Y)+lossDice(X,Y),lossCE(X,Y)=∑(−Ylog(X)+(1−Y)log(1−X)),lossDice(X,Y)=1−2X+Y∣X⋂Y∣
其中,X表示网络预测输出,Y表示标注。
监督学习只使用19%的标注数据作为输入源,训练10000次,得到的结果如下:
其中,编号1、2、3分别代表三个分类。监督学习训练过程中Dice最好为81%,这个结果将作为基准。
下面介绍半监督方法。
半监督学习
相对于监督学习,半监督学习增加了一致性损失,用于测量未标注数据的分割结果并使其靠近某一种约束。
下面是具体的方案,除非特殊指明,否则参数条件和监督学习一致。
mean teacher (论文链接)
volume_batch, label_batch = sampled_batch['image'], sampled_batch['label']
...
#加噪
noise = torch.clamp(torch.randn_like(unlabeled_volume_batch) * 0.1, -0.2, 0.2)
ema_inputs = unlabeled_volume_batch + noise
...
#正常结果
outputs = model(volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)
#加噪结果
with torch.no_grad():ema_output = ema_model(ema_inputs)ema_output_soft = torch.softmax(ema_output, dim=1)
...
#一致性损失
consistency_loss = torch.mean((outputs_soft[args.labeled_bs:]-ema_output_soft)**2)
该方法训练效果如下:
mean teacher 得到的最优Dice为82.7%。
uncertainty aware mean teacher(论文链接)
preds = torch.zeros([stride * T, num_classes, w, h]).cuda()
for i in range(T//2):#带噪声输入ema_inputs = volume_batch_r + \torch.clamp(torch.randn_like(volume_batch_r) * 0.1, -0.2, 0.2)with torch.no_grad():preds[2 * stride * i:2 * stride *(i + 1)] = ema_model(ema_inputs)
preds = F.softmax(preds, dim=1)
preds = preds.reshape(T, stride, num_classes, w, h)
preds = torch.mean(preds, dim=0)
#不确定性掩膜计算
uncertainty = -1.0 * torch.sum(preds*torch.log(preds + 1e-6),\dim=1, keepdim=True)
interpolation consistency(论文链接)
#混合输入 input_{mixed}
batch_ux_mixed = unlabeled_volume_batch_0 * \(1.0 - ict_mix_factors) + \unlabeled_volume_batch_1 * ict_mix_factors
#混合输入2
input_volume_batch = torch.cat([labeled_volume_batch, batch_ux_mixed], dim=0)
outputs = model(input_volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)
with torch.no_grad():ema_output_ux0 = torch.softmax(ema_model(unlabeled_volume_batch_0), dim=1)ema_output_ux1 = torch.softmax(ema_model(unlabeled_volume_batch_1), dim=1)#混合输出batch_pred_mixed = ema_output_ux0 * \(1.0 - ict_mix_factors) + ema_output_ux1 * ict_mix_factors
#混合输入和输出计算一致性损失
consistency_weight = get_current_consistency_weight(iter_num//150)consistency_loss = torch.mean((outputs_soft[args.labeled_bs:] - batch_pred_mixed) ** 2)
最小熵约束(论文链接)
outputs = model(volume_batch)
outputs_soft = torch.softmax(outputs, dim=1)
...
# 对所有数据进行熵计算
consistency_loss = losses.entropy_loss(outputs_soft, C=4)
dv (论文没找到)
这个方法还不能用unet跑,作者说正在完善代码。所以本实验使用的是unet_dv网络。直接看结果:
最优Dice=81%。
对抗网络(论文链接)
#假设标注的数据都为真,未标注数据为假
DAN_target = torch.tensor([0] * args.batch_size).cuda()
DAN_target[:args.labeled_bs] = 1
...
# 未标注数据为真,形成对抗,注意这里使用的是DAN_target[:args.labeled_bs]
# 而不是DAN_target[args.labeled_bs:]
DAN_outputs = DAN(outputs_soft[args.labeled_bs:], volume_batch[args.labeled_bs:])
consistency_loss = F.cross_entropy(DAN_outputs, (DAN_target[:args.labeled_bs]).long())
...
#鉴别器损失计算,假设标注的数据都为真,未标注数据为假
DAN_outputs = DAN(outputs_soft, volume_batch)
DAN_loss = F.cross_entropy(DAN_outputs, DAN_target.long())
总结
总的来说,半监督学习方法最终结果都要比基准好。效果最好的是uncertainty aware mean teacher,Dice=84%,原因可能是它用了比较多的约束条件;最少迭代的是对抗网络,它在大概4000次迭代的时候就已经达到最好效果,但具体是不是训练最快的,还要等进一步验证。
半监督语义分割方法汇总(附代码分析)相关推荐
- CVPR 2022 | 商汤/上交/港中文提出U2PL:使用不可靠伪标签的半监督语义分割
点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Pascal | 已授权转载(源:知乎)编辑:CVer https://zhuanlan.zhih ...
- Semi-supervised Semantic Segmentation with Error Localization Network(基于误差定位网络的半监督语义分割 )
Semi-supervised Semantic Segmentation with Error Localization Network(基于误差定位网络的半监督语义分割 ) Abstract 本文 ...
- CVPR 2021 | 北大MSRA提出CPS:基于交叉伪监督的半监督语义分割
点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 作者:Charles | 源:知乎 https://zhuanlan.zhihu.com/p/37812 ...
- 用于半监督语义分割的基于掩码的数据增强
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 小白导读 论文是学术研究的精华和未来发展的明灯.小白决心每天为大家 ...
- CVPR 2022|U2PL:使用不可靠伪标签的半监督语义分割
本文转自商汤学术 导读 半监督任务的关键在于充分利用无标签数据,商汤科技联合上海交通大学.香港中文大学,基于「 Every Pixel Matters」的理念,有效利用了包括不可靠样本在内的全部无标签 ...
- 对抗学习的半监督语义分割
<font color="red">GAN生成对抗网络:</font>由两个子网络组成,generator和discriminator,在训练过程中,这两个 ...
- 半监督语义分割_paper reading part2
03 SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers Time:2021.05 我 ...
- YOLOV5 的小目标检测网络结构优化方法汇总(附代码)
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨南山 来源丨 AI约读社 YOLOv5是一种非常受欢迎的单阶段目标检测,以其性能和速度著称,其结 ...
- 三种基于自监督深度估计的语义分割方法(arXiv 2021)
点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨泡泡机器人 来源丨泡泡机器人SLAM 标题: Three Ways to Improve Sem ...
- 【论文汇总】 ECCV 2020 语义分割paper汇总
语义分割 segmentation paper@ECCV 2020 ECCV 2020语义分割文章总结,文章下载链接. 文章目录 语义分割 segmentation paper@ECCV 2020 前 ...
最新文章
- mysql localhost无法登陆_MySQL 'root'@'localhost'无法登录
- 删了手机里的一个html文件,手机太卡,哪些内容可以毫不犹豫的删除?
- [转]LoadRunner 各个指标分析
- 【Android工具】免费二次元追番神器,各种字幕组新番旧番良心资源,重要的事说三遍:没有广告!没有广告!没有广告...
- Eclipse——导出可执行jar包
- 使用ANTLR在5分钟内用Java解析任何语言:例如Python
- 干货来袭!java核心技术卷一pdf
- script-百度换肤效果
- 第一期:浙大版《JAVA语言程序设计教程》(第二版)翁凯等 主编 ——小白的入门之路(上)(一)
- 怎么把mp3格式的音频文件转为文字?
- 图解侧方停车技巧2015高清版
- ContextCapture系列教程(三):大疆精灵4RTK版无人机POS数据提取、处理(处理后勉强达到免相控要求)
- java和python哪个运行速度快_python和java学哪个比较简单点
- 咸鱼软件应用—Cura3D切片
- 让你的微信小程序对用户更加友好:上拉加载和下拉刷新就是关键
- (转)Q格式的转换问题与移位
- 丹尼斯·里奇-c语言之父,Unix之父
- python 自动运维架构师_运维架构师-Python 自动化运维开发-031
- Android实训-家庭财务管理系统
- LeaRun.Framework━ .NET快速开发框架 ━ 工作流程组件介绍