MS-Model【1】:nnU-Net
文章目录
- 前言
- 1. Abstract & Introduction
- 1.1. Abstract
- 1.2. Introduction
- 2. Methods
- 2.1. Network architectures
- 2.1.1. 2D U-Net
- 2.1.2. 3D U-Net
- 2.1.3. U-Net cascade
- 2.2. Dynamic adaptation of network topologies
- 2.3. Preprocessing
- 2.3.1. Cropping
- 2.3.2. Resampling
- 2.3.3. Normalization
- 2.4. Training Procedure
- 2.4.1. Loss Function
- 2.4.2. Learning Rate
- 2.4.3. Data Augmentation
- 2.4.4. Patch Sampling
- 总结
前言
本文提出的 nnU-net
(no-new U-net
),是在 2D & 3D 经典 U-net 的基础上, 稳健而又自适应的框架。nnU-net
移去了冗余的部分,着重于剩下的对模型表现和泛化能力起作用的部分
原论文链接:
nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation
论文复现参考:
MS-Train【1】:nnUNet
1. Abstract & Introduction
1.1. Abstract
U-Net
凭借其直接和成功的架构,迅速发展成为医学图像分割领域的一个常用基准。然而,U-Net
对新问题的适应性包括关于确切的架构、预处理、训练和推理的几个自由度。这些选择并不是相互独立的,而是对整体性能有很大影响。本文介绍了 nnU-Net
(no-new-Net
),它是指在二维和三维虚构 U-Net
的基础上建立的一个稳健和自适应的框架。本文论证了在许多提议的网络设计中去掉多余的 bells
和 whistles
,转而关注其余的方面,这些方面决定了一个方法的性能和可推广性。
1.2. Introduction
本文提出了 nnU-Net
(no-new-Net
)框架。它基于一组相对简单的 U-Net
模型,只包含对原始U-Net
的微小修改。本文省略了最近提出的扩展,例如使用 residual connections
、dense connections
或 attention mechanisms
。nnU-Net
可以自动将其架构适应给定的图像几何,更重要的是,它彻底定义了围绕他们的所有其他步骤。这些步骤包括:
- 预处理,比如 resampling 和 normalization
- 训练,比如损失函数、优化器的设置和数据扩充
- 推断,比如基于图像块的策略、TTA(test-time augmentation)集成和模型集成
- 后处理,比如增强单连通域
2. Methods
2.1. Network architectures
U-Net
是一个成功的 encoder-decoder
网络:
- encoder:工作原理与传统的分类 CNN 类似,以减少空间信息为代价连续聚集语义信息
- decoder:因为在分割中,语义和空间信息对网络的成功至关重要,所以需要使用 decoder 以恢复丢失的空间信息
- decoder 从
U
的底部接收语义信息,并将其与通过跳过连接直接从编码器获得的高分辨率特征图重新组合
- decoder 从
医学图像通常包括第三维,本文提出了一套包括由 2D U-Net
,3D U-Net
和 U-Net cascade
组成的基本 U-Net
架构:
- 2D 和 3D U-Net 以全分辨率生成分割
- cascade 首先生成低分辨率的分割,然后再对其进行细化
- 与原始
U-Net
的比较:- 和原始U-Net类似,在
encoder
部分,本文在池化层之间使用简单的卷积层,在decoder
部分,本文使用转置卷积 - 与原始
U-Net
不同的是,本文使用的激活函数是leaky ReLU
而不是ReLU
(-1e-2);用 instance normalization 替换更加流行的 batch normalization
- 和原始U-Net类似,在
2.1.1. 2D U-Net
当数据是各向异性的时候,传统的 3D 分割方法就会很差,所以在这里给出了 2D U-Net
的网络架构
网络特征:
- 使用全卷积神经网络
- 全卷积神经网络就是卷积取代了全连接层,全连接层必须固定图像大小而卷积不用,所以这个策略保证使用者可以输入任意尺寸的图片,而且输出也是图片,所以这是一个端到端的网络
- 左边的网络是收缩路径:使用卷积和 maxpooling.
- 右边的网络是扩张路径:使用上采样产生的特征图与左侧收缩路径对应层产生的特征图进行 concatenate 操作
- 最后再经过两次反卷积操作,生成特征图,再用两个 1 × 1 1 \times 1 1×1 的卷积做分类得到最后的两张
heatmap
- 例如第一张表示的是第一类的得分,第二张表示第二类的得分
- 然后作为 softmax 函数的输入,算出概率比较大的 softmax 类,选择它作为输入给交叉熵进行反向传播训练
更多关于 U2-Net
网络结构的讲解,可以参考我的另一篇 blog:SS-Model【6】:U2-Net
2.1.2. 3D U-Net
3D 网络的效果好,但是太占用 GPU 显存。一般情况下,可以使用小一点的图像块去训练,但是当面对比较大的图像如肝等,这种基于块的方法就会阻碍训练。这是因为受限于感受野的大小,网络结构不能收集足够的上下文信息去正确的识别大图像的特征
网络内容:
- 核心
- 训练过程只要求一部分
2D slices
,去生成密集的立体分割
- 训练过程只要求一部分
- 两种实现方法
- 在一个稀疏标注的数据集上训练并在此数据集上预测其他未标注的地方
- 在多个稀疏标注的数据集上训练,然后泛化到新的数据
网络特征:
- 网络结构的前半部分(
analysis path
)包含如下卷积操作- 每一层神经网络都包含了两个 3 × 3 × 3 3 \times 3 \times 3 3×3×3 的卷积
Batch Normalization
(为了让网络能更好的收敛)ReLU
- 下采样: 2 × 2 × 2 2 \times 2 \times 2 2×2×2 的
max_polling
,步长 stride = 2- 通过在最大池化之前将通道数量加倍来避免瓶颈
- 网络结构的合成路径(
synthesis path
)则执行下面的操作- 上采样: 2 × 2 × 2 2 \times 2 \times 2 2×2×2,步长 stride = 2
- 两个正常的卷积操作: 3 × 3 × 3 3 \times 3 \times 3 3×3×3
- Batch Normalization
- ReLU
- 把在
analysis path
上相对应的网络层的结果作为decoder
的部分输入,这样子做的原因跟 U-Net 博文 中提到的一样,是为了能采集到特征分析中保留下来的高像素特征信息,以便图像可以更好的合成 - 在最后一层, 1 × 1 × 1 1 \times 1 \times 1 1×1×1 卷积将输出通道的数量减少到 3 个标签
- 加权
softmax
损失函数将未标记像素的权重设置为零可以仅从已标记像素学习。降低了频繁出现的背景的权重,增加了内管的权重,以达到小管和背景体素对损失的均衡影响。
2.1.3. U-Net cascade
为了解决 3D U-Net
在大图像尺寸数据集上的缺陷,本文提出了级联模型:
- 第一级
3D U-Net
在下采样的图像上进行训练,然后将结果上采样到原始的体素spacing。 - 将上采样的结果作为一个额外的输入通道(one-hot 编码)送入第二级
3D U-Net
,并使用基于图像块的策略在全分辨率的图像上进行训练
2.2. Dynamic adaptation of network topologies
由于输入图像大小的不同,输入图像块大小和每个轴池化操作的数量(同样也是卷积层的数量)必须能够自适应每个数据集去考虑充足的空间信息聚合。除了自适应图像几何,还需要考虑显存的使用。
指导方针是动态平衡 batch size 和网络容量:
- 将 patch size 初始化为图像大小的中位数
- 迭代地减少 patch size,同时调整网络拓扑架构(网络深度、池化操作数量、池化操作位置、feature map 的尺寸、卷积核尺寸)
- 直到网络可以在给定 GPU 的限制下,且 batch 至少是 2 的情况下,可以被 train
网络初始配置:
- 2D U-Net
- 图像大小 = 256 × 256 256 \times 256 256×256,batch size = 42,最高层的特征图谱数量 = 30(每个下采样特征图谱数量翻倍)
- 自动将这些参数调整为每个数据集的中值平面大小(这里使用面内间距最小的平面,对应于最高的分辨率),以便网络有效地训练整个切片
- 本文将网络配置为沿每个轴池化,直到该轴的特征图谱小于8(但最多不超过6个池化操作)
- 3D U-Net
- 图像大小 = 128 × 128 × 128 128 \times 128 \times 128 128×128×128,batch size = 2,最高层的特征图谱数量 = 30
- 由于显存限制,不去增加图像大小超过 12 8 3 128^3 1283 体素,而是匹配输入图像和数据集中体素中值大小的比率
- 如果数据集的形状中值比 12 8 3 128^3 1283 小,那就使用形状的中值作为输入的图像大小并且增加 batch size(目的是将体素的数量和 KaTeX parse error: Undefined control sequence: \tiems at position 5: 128 \̲t̲i̲e̲m̲s̲ ̲128 \times 128,batch size 为 2 的体素数量相等)。沿每个轴最多池化5次直到特征图谱大小为8
2.3. Preprocessing
nnU-Net
的预处理是在没有任何用户干预的情况下执行的
2.3.1. Cropping
所有数据都被裁剪到非零值区域
2.3.2. Resampling
CNN
本身并不理解体素间距。 在医学图像中,不同的扫描仪或不同的采集协议通常会产生具有不同体素间距的数据集
为了使我们的网络能够正确学习空间语义,所有患者都被重新采样到各自数据集的中值体素间距,其中三阶样条插值用于图像数据,最近邻居插值用于相应的分割掩码
是否需要经过 U-Net cascade
模型,由以下方法确定:
- 如果重采样数据的形状中值可以作为
3D U-Net
中的输入图像(batch size = 2)的 4 倍以上,则使用U-Net cascade
模型,且数据集需要重新采样到较低的分辨率- 可以通过将体素间距增加 2 倍来完成(降低分辨率),直到满足上述标准
- 如果数据集是各向异性的,则首先对较高分辨率的轴进行下采样,直到它们与低分辨率轴匹配,然后才同时对所有轴进行下采样
2.3.3. Normalization
- 对于
CT
图像,训练集中所有 segmentation mask 中的 value 会被收集,整体的数据集会先被 clip 到 [ 0.5 , 99.5 ] [0.5, 99.5] [0.5,99.5] 百分位,然后通过收集的数据的 mean 和标准差进行z-score
正则化- 需要注意的是,如果因为裁剪减少了病例平均大小的 1/4 或更多,则标准化只在非零元素的 mask 内部进行,并且 mask 外的所有值设为 0
- 对于
MRI
图像以及其他图像,直接进行z-score
标准化
2.4. Training Procedure
所有模型都从头开始训练,并在训练集上使用五折交叉验证进行评估
2.4.1. Loss Function
结合 dice 和交叉熵损失来训练网络:
L t o t a l = L d i c e + L C E \mathcal{L}_{total} = \mathcal{L}_{dice} + \mathcal{L}_{CE} Ltotal=Ldice+LCE
- 对于在全训练集上训练的
3D U-Net
(如果不需要 cascade,则是 U-Net cascade 的第一阶段和 3D U-Net),计算 batch 里每个样本的 dice 损失,并计算 batch 中的平均值 - 对于所有其他网络,将 batch 中的样本(samples)解释为伪体积(volume),并计算批次中所有体素的 dice 损失
对于目前大多数的图像分割任务来说,使用最多评价指标就是 dice 相似系数 (Dice Similarity Coefficient) 。Dice 系数是计算两个样本之间的相似度,即考察两个样本之间重叠的范围,范围通常在0-1之间。
- 若为1,则证明两个样本完全重合
- 若为0,则证明两个样本没有相同的像素
计算方法如下:
D i c e ( P , T ) = 2 T P F P + 2 T P + F N Dice(P, T) = \frac{2TP}{FP + 2TP + FN} Dice(P,T)=FP+2TP+FN2TP
其中:
TP (True Positive)
为判定为正样本,事实上也是正样本TN (True Negative)
为判定为负样本,事实上也是负样本FP (False Positive )
为判定为正样本,事实上为负样本FN (False Negative)
为判定为负样本,事实上为正样本
分母即:
- FP + TP = 所有分类为阳性的样本
- TP + FN = 真阳 + 假阴 = 所有真的是阳性的样本
dice 的损失函数为:
L d i c e = − 2 ∣ K ∣ ∑ k ∈ K ∑ i ∈ I u i k v i k ∑ i ∈ I u i k + ∑ i ∈ I v i k \mathcal{L}_{dice} = - \frac{2}{|K|} \displaystyle\sum_{k \in K} \frac{\sum_{i \in I}u_i^k v_i^k}{\sum_{i \in I}u_i^k + \sum_{i \in I}v_i^k} Ldice=−∣K∣2k∈K∑∑i∈Iuik+∑i∈Ivik∑i∈Iuikvik
参数含义:
u
是网络的 softmax 输出v
是 ground Truth 的 one hot 编码k
为类别数u
和v
都具有形状 I × K I \times K I×K
2.4.2. Learning Rate
优化器:Adam,初始学习率 3e-4,每个 epoch 有 250 个 batch
学习率调整策略:计算训练集和验证集的指数滑动平均 loss,如果训练集的指数滑动平均 loss 在近 30 个 epoch 内减少不够 5e-3,则学习率衰减 5 倍
训练停止条件:当学习率大于 10-6 且验证集的指数滑动平均 loss 在近 60 个 epoch 内减少不到 5e-3,则终止训练
2.4.3. Data Augmentation
从有限的训练数据训练大型神经网络时,必须特别注意防止过度拟合。 本文在训练期间动态应用了如下所示的多种增强技术来解决这个问题:
- random rotations
- random scaling
- random elastic deformations
- gamma correction augmentation
- mirroring
需要注意的是,如果 3D U-Net
的输入图像块尺寸的最大边长是最短边长的两倍以上,这种情况对每个 2 维面做数据增广,然后逐个切片地将其应用于每个样本
U-Net cascade 的第二级接受前一级的输出作为输入的一部分,为了防止强 co-adaptation,应用随机形态学操作(腐蚀、膨胀、开运算、闭运算)去随机移除掉这些分割结果的连通域
2.4.4. Patch Sampling
为了增强网络训练的稳定性,强制每个 batch 中超过 1/3 的样本包含至少一个随机选择的前景
总结
本文提出了用于医疗领域的 nnU-Net
分割框架,该框架直接围绕原始 U-Net
架构构建,并动态调整自身以适应任何给定数据集的细节。 基于本文的假设,即非架构修改可能比最近提出的一些架构修改更强大,该框架的本质是自适应预处理、训练方案和推理的彻底设计。 适应新分割任务所需的所有设计选择均以全自动方式完成,无需手动交互。
参考资料
MS-Model【1】:nnU-Net相关推荐
- CSI笔记【5】:Widar2.0: Passive Human Tracking with a Single Wi-Fi Link论文阅读
CSI笔记[5]:Widar2.0: Passive Human Tracking with a Single Wi-Fi Link论文笔记 前言 Abstract 1 INTRODUCTION 2 ...
- CSI笔记【6】:Guaranteeing spoof-resilient multi-robot networks论文阅读
CSI笔记[6]:Guaranteeing spoof-resilient multi-robot networks论文阅读 Abstract 1 Introduction 1.1 Contribut ...
- CV-Model【8】:ConvNeXt
文章目录 前言 1. Abstract & Introduction 1.1. Abstract 1.2. Introduction 2. Modernizing a ConvNet: a R ...
- CV-Model【6】:Vision Transformer
系列文章目录 Transformer 系列网络(一): CV-Model[5]:Transformer Transformer 系列网络(二): CV-Model[6]:Vision Transfor ...
- CV-Model【5】:Transformer
系列文章目录 Transformer 系列网络(一): CV-Model[5]:Transformer Transformer 系列网络(二): CV-Model[6]:Vision Transfor ...
- MS-Train【2】:nnFormer
文章目录 前言 1. 安装 2. 训练与测试 2.1. 数据处理 2.1.1. 整理数据路径 2.1.2. 设置 nnFormer 读取文件的路径 2.1.3. 数据集预处理 2.2. 训练 2.2. ...
- SS-Model【6】:U2-Net
系列文章目录 U-Net语义分割系列(一): SS-Model[5]:U-Net U-Net语义分割系列(二): SS-Model[6]:U2-Net 文章目录 系列文章目录 前言 1. Abstra ...
- 机器学习入门篇【一】:以拉家常的方式讲机器学习
前言 因为对机器学习比较感兴趣,最近也可能会用得上,所以想浅浅的谈一谈机器学习,大佬就不用在这浪费时间了,不涉及公式推导.甚至该篇都称不上是什么经验贴,只能说是最近搜寻有些资料有感而发. 那么想通过这 ...
- CSI笔记【7】:Crowd Vetting: Rejecting Adversaries via Collaboration with Application to......论文阅读
CSI笔记[7]:Crowd Vetting: Rejecting Adversaries via Collaboration with Application to Multi-Robot Floc ...
- 【TensorFlow】:Eager Mode(动态图模式)
[TensorFlow]:Eager Mode(动态图模式) http://www.360doc.com/content/18/1207/16/7669533_800020620.shtml
最新文章
- 数据解析_485型风速和风向变送器数据包解析
- 服务器预装操作系统,服务器预装操作系统吗
- 开发vue底部导航栏组件
- 民主湖呀,不知道是好看还是破烂
- Html 教程 (4) <head>
- 公需科目必须学吗_税务师要继续教育吗,2019税务师怎样继续教育?
- ASP.NET 2.0中控件的简单异步回调
- Application Verifier
- 一键多功能按键识别c语言,单片机一键多功能按键识别设计
- 第一次使用pyqt5解决的几个小问题
- 安卓系统所有可声明的权限
- ASP.NET中EnableViewState
- Oracle技巧查询,很香
- 澳洲2022人口普查结果出炉--华人占比开始下降
- 透过容抗来看电容量和频率的关系
- maya(学习笔记)之Arnold渲染器二
- 没钱没资源没人脉?年入千万的她写了这本副业思维的书
- FFT的C语言实现,对照MATLIB
- 【排错必看】Windows系统安装mysql时常见问题及解决方法
- 文学院计算机报名是access吗,ACCESS综合练习范文