1.摘要

医学图像分割是开发医疗保健系统,特别是疾病诊断和治疗计划的必要前提。在各种医学图像分割任务中,U形架构(也称为U-Net)已成为事实上的标准,并取得了巨大的成功。然而,由于卷积运算的内在局部性,U-Net在显式建模长程依赖性方面通常表现出局限性。Transformer是为序列间预测而设计的,它已经成为具有天生的全局自我关注机制的替代架构,但由于低级细节不足,定位能力有限。

在本文中,我们提出TransUNet作为医学图像分割的有力替代方案,它既有Transformers的优点,也有U-Net的优点。一方面,Transformer将来自卷积神经网络(CNN)特征图的标记化图像块编码为用于提取全局上下文的输入序列。另一方面,解码器对编码特征进行上采样,然后将其与高分辨率CNN特征图相结合,以实现精确定位。

我们认为Transformer可以作为医学图像分割任务的强大编码器,与U-Net相结合,通过恢复局部空间信息来增强更精细的细节。TransUNet在不同的医疗应用(包括多器官分割和心脏分割)上实现了优于各种竞争方法的性能。

2.介绍

尽管基于CNN的方法具有非凡的表示能力,但由于卷积运算的内在局部性,其在建模显式长程关系时通常表现出局限性。因此,这些结构通常产生较弱的性能,特别是对于患者之间在纹理、形状和大小方面存在较大差异的目标结构。为了克服这一限制,现有研究建议基于CNN特征建立自我注意机制。另一方面,专为序列预测而设计的Transformers已成为替代架构,它完全使用分配卷积算子,并且仅依赖于注意机制而不是Transformer。与先前基于CNN的方法不同,Transformers不仅在建模全局上下文方面具有强大的能力,而且在大规模预训练下对下游任务表现出卓越的可转移性。这一成功在机器翻译和自然语言处理(NLP)领域得到了广泛证实。最近,各种图像识别任务的尝试也达到甚至超过了最先进的水平。

在这篇论文中,提出了第一个研究,它探索了Transformer在医学图像分割背景下的潜力。然而,有趣的是,我们发现单纯的使用(即使用一个Transformer对标记化的图像补丁进行编码,然后直接对隐藏特征表示进行上采样,使其成为一个全分辨率的密集输出)不能产生令人满意的结果。

这是由于transformer将输入视为1D序列,专注于各个阶段的全局上下文建模,因此导致低分辨率的特征缺乏详细的定位信息。而这些信息不能通过直接上采样到全分辨率有效恢复,因此导致分割结果粗糙。另一方面,CNN体系结构(例如U-Net)提供了一种提取低级视觉线索的途径,可以很好地弥补这些细微的空间细节。

为此,我们提出了首个医学图像分割框架TransUNet,该框架从序列到序列预测的角度建立了自注意机制。为了弥补transformer带来的特征分辨率损失,TransUNet采用了一种混合的CNN- transformer架构,利用来自CNN特征的详细高分辨率空间信息和transformer编码的全局上下文。受u型建筑设计的启发,Transformer Encoder的自关注特征图被上采样,与编码路径上shortcut的不同高分辨率CNN特征相结合,实现精确定位。我们证明,这样的设计允许我们的框架保持变压器的优点,也有利于医学图像分割。经验结果表明,与之前的基于cnn的自注意方法相比,我们基于transformer的体系结构提供了更好的利用自注意的方法。此外,我们观察到更密集的低层特征合并通常会导致更好的分割精度。大量的实验证明了该方法在各种医学图像分割任务中的优越性。

给定图像H×W ×C,空间分辨率为H×W,通道数为C。我们的目标是预测相应的大小为H × W的像素级标签映射。最常见的方法是直接训练CNN(例如UNet),首先将图像编码为高级特征表示,然后将其解码回全空间分辨率。与现有的方法不同,我们的方法通过使用变压器在编码器设计中引入了自我注意机制。

3.网络结构

整个TransUNet的网络结构如图上所示,其实挺容易理解的。

1)首先输入一个CT图像,如果图像是单通道的就通过repeat函数复制两次,将通道扩充为三通道的。

2)将三通道图像[B,3,224,224]]通过ResNetV2进行下采样,将图像编码为高级特征表示。然后创建一个feature列表将每个下采样后的特征图保存下来。最后经过ResNetV2网络输出[B,16C,14,14]。C=64 x widthfactor。feature列表包含三个尺寸大小的特征图,分别是[B,C,112,112],[B,4C,56,56],[B,8C,28,28]

具体ResNetV2的结构:先使用卷积核7x7,s=2,p=3的卷积操作进行root下采样为[B,C,112,112],然后通过maxpooling层下采样为[B,C,56,56]。最后就通过三个block下采样,输出特征图。

3)先对下采样后的特征图[B,16C,14,14]进行embedding

利用卷积核大小为patch_size x patch_size,s=patch_size将图像切成一个个patch,得到[B,hidden_channel,num_patches],num_patches=(14/patch_size)**2。然后reshape为[B,n_p,h_c]。加入position embedding=[1,n_p,h_c]。最终得到输出特征图[B,n_p,h_c]。

4)Transformer Encoder。这步没什么说的,就是大家非常熟悉的Transformer,计算自注意力。输出[B,n_p,h_c]

5)Decoder。先将[B,n_p,h_c]reshape为[B,h_c,根号下n_p,根号下n_p]。然后Conv2d为[B,512,14,14]。然后通过四个decoder最后输出[B,32,224,224]

decoder为先通过双线性上采样来使特征图尺寸扩大一倍,然后与之前卷积后shortcut的feature进行concat,最后通过两个Conv2d将特征图映射到低维空间。

6)SegmentationHead。将网络输出的特征图进行分割,通过Conv2d输出[B,n_classes,224,224]

4.TransUNet网络全部代码

下方代码为除了ResNetV2之外的全部代码。

# coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport copy
import logging
import mathfrom os.path import join as pjoinimport torch
import torch.nn as nn
import numpy as npfrom torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from . import vit_seg_configs as configs
from .vit_seg_modeling_resnet_skip import ResNetV2logger = logging.getLogger(__name__)ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"def np2th(weights, conv=False):"""Possibly convert HWIO to OIHW."""if conv:weights = weights.transpose([3, 2, 0, 1])return torch.from_numpy(weights)# sigmoid激活函数
def swish(x):return x * torch.sigmoid(x)# 激活函数的选择
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}class Attention(nn.Module):def __init__(self, config, vis):super(Attention, self).__init__()self.vis = visself.num_attention_heads = config.transformer["num_heads"]self.attention_head_size = int(config.hidden_size / self.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_size# 线性层得到q、k、v向量self.query = Linear(config.hidden_size, self.all_head_size)self.key = Linear(config.hidden_size, self.all_head_size)self.value = Linear(config.hidden_size, self.all_head_size)self.out = Linear(config.hidden_size, config.hidden_size)self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])self.softmax = Softmax(dim=-1)def transpose_for_scores(self, x):new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)x = x.view(*new_x_shape)return x.permute(0, 2, 1, 3)def forward(self, hidden_states):# 线性层得到q、k、v矩阵mixed_query_layer = self.query(hidden_states)mixed_key_layer = self.key(hidden_states)mixed_value_layer = self.value(hidden_states)query_layer = self.transpose_for_scores(mixed_query_layer)key_layer = self.transpose_for_scores(mixed_key_layer)value_layer = self.transpose_for_scores(mixed_value_layer)# q乘k的转置attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))# 除以根号下dattention_scores = attention_scores / math.sqrt(self.attention_head_size)# 进行softmaxattention_probs = self.softmax(attention_scores)weights = attention_probs if self.vis else Noneattention_probs = self.attn_dropout(attention_probs)# 乘以v矩阵context_layer = torch.matmul(attention_probs, value_layer)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)# 映射层attention_output = self.out(context_layer)attention_output = self.proj_dropout(attention_output)return attention_output, weightsclass Mlp(nn.Module):def __init__(self, config):super(Mlp, self).__init__()self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)self.act_fn = ACT2FN["gelu"]self.dropout = Dropout(config.transformer["dropout_rate"])self._init_weights()def _init_weights(self):nn.init.xavier_uniform_(self.fc1.weight)nn.init.xavier_uniform_(self.fc2.weight)nn.init.normal_(self.fc1.bias, std=1e-6)nn.init.normal_(self.fc2.bias, std=1e-6)def forward(self, x):x = self.fc1(x)x = self.act_fn(x)x = self.dropout(x)x = self.fc2(x)x = self.dropout(x)return xclass Embeddings(nn.Module):"""Construct the embeddings from patch, position embeddings."""def __init__(self, config, img_size, in_channels=3):super(Embeddings, self).__init__()self.hybrid = Noneself.config = configimg_size = _pair(img_size)if config.patches.get("grid") is not None:   # ResNetgrid_size = config.patches["grid"]patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])  self.hybrid = Trueelse:patch_size = _pair(config.patches["size"])n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])self.hybrid = Falseif self.hybrid:self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)in_channels = self.hybrid_model.width * 16self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=patch_size,stride=patch_size)self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))self.dropout = Dropout(config.transformer["dropout_rate"])def forward(self, x):if self.hybrid:x, features = self.hybrid_model(x)else:features = None# [B,16C,14,14]x = self.patch_embeddings(x)  # (B, hidden. n_patches^(1/2), n_patches^(1/2))x = x.flatten(2)  # (B, hidden, n_patches)x = x.transpose(-1, -2)  # (B, n_patches, hidden)embeddings = x + self.position_embeddingsembeddings = self.dropout(embeddings)return embeddings, featuresclass Block(nn.Module):def __init__(self, config, vis):super(Block, self).__init__()self.hidden_size = config.hidden_sizeself.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)self.ffn = Mlp(config)self.attn = Attention(config, vis)def forward(self, x):h = xx = self.attention_norm(x)x, weights = self.attn(x)x = x + hh = xx = self.ffn_norm(x)x = self.ffn(x)x = x + hreturn x, weightsdef load_from(self, weights, n_block):ROOT = f"Transformer/encoderblock_{n_block}"with torch.no_grad():query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size, self.hidden_size).t()key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size, self.hidden_size).t()out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size, self.hidden_size).t()query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)self.attn.query.weight.copy_(query_weight)self.attn.key.weight.copy_(key_weight)self.attn.value.weight.copy_(value_weight)self.attn.out.weight.copy_(out_weight)self.attn.query.bias.copy_(query_bias)self.attn.key.bias.copy_(key_bias)self.attn.value.bias.copy_(value_bias)self.attn.out.bias.copy_(out_bias)mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()self.ffn.fc1.weight.copy_(mlp_weight_0)self.ffn.fc2.weight.copy_(mlp_weight_1)self.ffn.fc1.bias.copy_(mlp_bias_0)self.ffn.fc2.bias.copy_(mlp_bias_1)self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))class Encoder(nn.Module):def __init__(self, config, vis):super(Encoder, self).__init__()self.vis = visself.layer = nn.ModuleList()self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)for _ in range(config.transformer["num_layers"]):layer = Block(config, vis)self.layer.append(copy.deepcopy(layer))def forward(self, hidden_states):attn_weights = []# 依次经过几个Blockfor layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weightsclass Transformer(nn.Module):def __init__(self, config, img_size, vis):super(Transformer, self).__init__()self.embeddings = Embeddings(config, img_size=img_size)self.encoder = Encoder(config, vis)def forward(self, input_ids):embedding_output, features = self.embeddings(input_ids)encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)return encoded, attn_weights, featuresclass Conv2dReLU(nn.Sequential):def __init__(self,in_channels,out_channels,kernel_size,padding=0,stride=1,use_batchnorm=True,):conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride=stride,padding=padding,bias=not (use_batchnorm),)relu = nn.ReLU(inplace=True)bn = nn.BatchNorm2d(out_channels)super(Conv2dReLU, self).__init__(conv, bn, relu)class DecoderBlock(nn.Module):def __init__(self,in_channels,out_channels,skip_channels=0,use_batchnorm=True,):super().__init__()self.conv1 = Conv2dReLU(in_channels + skip_channels,out_channels,kernel_size=3,padding=1,use_batchnorm=use_batchnorm,)self.conv2 = Conv2dReLU(out_channels,out_channels,kernel_size=3,padding=1,use_batchnorm=use_batchnorm,)self.up = nn.UpsamplingBilinear2d(scale_factor=2)def forward(self, x, skip=None):x = self.up(x)if skip is not None:x = torch.cat([x, skip], dim=1)x = self.conv1(x)x = self.conv2(x)return xclass SegmentationHead(nn.Sequential):def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()super().__init__(conv2d, upsampling)class DecoderCup(nn.Module):def __init__(self, config):super().__init__()self.config = confighead_channels = 512self.conv_more = Conv2dReLU(config.hidden_size,head_channels,kernel_size=3,padding=1,use_batchnorm=True,)decoder_channels = config.decoder_channelsin_channels = [head_channels] + list(decoder_channels[:-1])out_channels = decoder_channelsif self.config.n_skip != 0:skip_channels = self.config.skip_channelsfor i in range(4-self.config.n_skip):  # re-select the skip channels according to n_skipskip_channels[3-i]=0else:skip_channels=[0,0,0,0]blocks = [DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)]self.blocks = nn.ModuleList(blocks)def forward(self, hidden_states, features=None):B, n_patch, hidden = hidden_states.size()  # reshape from (B, n_patch, hidden) to (B, h, w, hidden)h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))x = hidden_states.permute(0, 2, 1)x = x.contiguous().view(B, hidden, h, w)x = self.conv_more(x)for i, decoder_block in enumerate(self.blocks):if features is not None:skip = features[i] if (i < self.config.n_skip) else Noneelse:skip = Nonex = decoder_block(x, skip=skip)return xclass VisionTransformer(nn.Module):def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):super(VisionTransformer, self).__init__()self.num_classes = num_classesself.zero_head = zero_headself.classifier = config.classifierself.transformer = Transformer(config, img_size, vis)self.decoder = DecoderCup(config)self.segmentation_head = SegmentationHead(in_channels=config['decoder_channels'][-1],out_channels=config['n_classes'],kernel_size=3,)self.config = configdef forward(self, x):# [B,C,H,W],当图片为单通道[B,1,H,W]时,将通道数复制三次得到[B,3,H,W]if x.size()[1] == 1:# repeat的参数是对应维度的复制个数,参数为(2,2)时,0维复制两次,1维复制两次x = x.repeat(1,3,1,1)x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)x = self.decoder(x, features) # [B,16,H,W]logits = self.segmentation_head(x)return logitsdef load_from(self, weights):with torch.no_grad():res_weight = weightsself.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])posemb_new = self.transformer.embeddings.position_embeddingsif posemb.size() == posemb_new.size():self.transformer.embeddings.position_embeddings.copy_(posemb)elif posemb.size()[1]-1 == posemb_new.size()[1]:posemb = posemb[:, 1:]self.transformer.embeddings.position_embeddings.copy_(posemb)else:logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))ntok_new = posemb_new.size(1)if self.classifier == "seg":_, posemb_grid = posemb[:, :1], posemb[0, 1:]gs_old = int(np.sqrt(len(posemb_grid)))gs_new = int(np.sqrt(ntok_new))print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)zoom = (gs_new / gs_old, gs_new / gs_old, 1)posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2npposemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)posemb = posemb_gridself.transformer.embeddings.position_embeddings.copy_(np2th(posemb))# Encoder wholefor bname, block in self.transformer.encoder.named_children():for uname, unit in block.named_children():unit.load_from(weights, n_block=uname)if self.transformer.embeddings.hybrid:self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(np2th(res_weight["conv_root/kernel"], conv=True))gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():for uname, unit in block.named_children():unit.load_from(res_weight, n_block=bname, n_unit=uname)CONFIGS = {'ViT-B_16': configs.get_b16_config(),'ViT-B_32': configs.get_b32_config(),'ViT-L_16': configs.get_l16_config(),'ViT-L_32': configs.get_l32_config(),'ViT-H_14': configs.get_h14_config(),'R50-ViT-B_16': configs.get_r50_b16_config(),'R50-ViT-L_16': configs.get_r50_l16_config(),'testing': configs.get_testing(),
}

下方代码为ResNetV2的网络结构

import mathfrom os.path import join as pjoin
from collections import OrderedDictimport torch
import torch.nn as nn
import torch.nn.functional as Fdef np2th(weights, conv=False):"""Possibly convert HWIO to OIHW."""if conv:weights = weights.transpose([3, 2, 0, 1])return torch.from_numpy(weights)class StdConv2d(nn.Conv2d):def forward(self, x):w = self.weightv, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)w = (w - m) / torch.sqrt(v + 1e-5)return F.conv2d(x, w, self.bias, self.stride, self.padding,self.dilation, self.groups)def conv3x3(cin, cout, stride=1, groups=1, bias=False):return StdConv2d(cin, cout, kernel_size=3, stride=stride,padding=1, bias=bias, groups=groups)def conv1x1(cin, cout, stride=1, bias=False):return StdConv2d(cin, cout, kernel_size=1, stride=stride,padding=0, bias=bias)class PreActBottleneck(nn.Module):"""Pre-activation (v2) bottleneck block."""def __init__(self, cin, cout=None, cmid=None, stride=1):super().__init__()cout = cout or cincmid = cmid or cout//4self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)self.conv1 = conv1x1(cin, cmid, bias=False)self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)self.conv3 = conv1x1(cmid, cout, bias=False)self.relu = nn.ReLU(inplace=True)if (stride != 1 or cin != cout):# Projection also with pre-activation according to paper.self.downsample = conv1x1(cin, cout, stride, bias=False)self.gn_proj = nn.GroupNorm(cout, cout)def forward(self, x):# Residual branchresidual = x# 是否有下采样模块if hasattr(self, 'downsample'):residual = self.downsample(x)residual = self.gn_proj(residual)# Unit's branchy = self.relu(self.gn1(self.conv1(x)))y = self.relu(self.gn2(self.conv2(y)))y = self.gn3(self.conv3(y))y = self.relu(residual + y)return ydef load_from(self, weights, n_block, n_unit):conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])self.conv1.weight.copy_(conv1_weight)self.conv2.weight.copy_(conv2_weight)self.conv3.weight.copy_(conv3_weight)self.gn1.weight.copy_(gn1_weight.view(-1))self.gn1.bias.copy_(gn1_bias.view(-1))self.gn2.weight.copy_(gn2_weight.view(-1))self.gn2.bias.copy_(gn2_bias.view(-1))self.gn3.weight.copy_(gn3_weight.view(-1))self.gn3.bias.copy_(gn3_bias.view(-1))if hasattr(self, 'downsample'):proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])self.downsample.weight.copy_(proj_conv_weight)self.gn_proj.weight.copy_(proj_gn_weight.view(-1))self.gn_proj.bias.copy_(proj_gn_bias.view(-1))class ResNetV2(nn.Module):"""Implementation of Pre-activation (v2) ResNet mode."""def __init__(self, block_units, width_factor):super().__init__()width = int(64 * width_factor)self.width = widthself.root = nn.Sequential(OrderedDict([('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),('gn', nn.GroupNorm(32, width, eps=1e-6)),('relu', nn.ReLU(inplace=True)),# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))]))# [B,C,55,55]self.body = nn.Sequential(OrderedDict([('block1', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +[(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],))),('block2', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +[(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],))),('block3', nn.Sequential(OrderedDict([('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +[(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],))),]))def forward(self, x):# [B,3,224,224]features = []b, c, in_size, _ = x.size()x = self.root(x)  # [B,C,112,112]features.append(x)x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)  # [B,C,55,55]for i in range(len(self.body)-1):x = self.body[i](x)right_size = int(in_size / 4 / (i+1))if x.size()[2] != right_size:pad = right_size - x.size()[2]assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]else:feat = xfeatures.append(feat)x = self.body[-1](x)  # [B,16C,14,14]return x, features[::-1]

TransUNet: Transformers Make StrongEncoders for Medical Image Segmentation文章详解(结合代码)相关推荐

  1. 学习 TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation 医学分割

    Fusing Transformers and CNNs for Medical Image Segmentation 文章目录 Fusing Transformers and CNNs for Me ...

  2. [Transformer]TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation

    TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation Abstract Section I Introducti ...

  3. [论文阅读] TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation

    论文地址:https://arxiv.org/abs/2102.08005 发表于:MICCAI'21 Abstract 医学图像分割,这一众多临床需求的先决条件–已经被卷积神经网络(CNN)的最新进 ...

  4. [医学图像Transformer]TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

    TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation 论文地址 开源代码 摘要 传统U-Net通常在显 ...

  5. TransUNet:Transformers Make Strong Encoders for Medical Image Segmentation用于医疗图像分割的transformers编码器详解

    原文地址:https://arxiv.org/pdf/2102.04306.pdf 收录:CVPR 2021 代码: https://github.com/Beckschen/TransUNet 首篇 ...

  6. [深度学习论文笔记]UNETR: Transformers for 3D Medical Image Segmentation

    UNETR: Transformers for 3D Medical Image Segmentation UNETR:用于三维医学图像分割的Transformer Published: Oct 20 ...

  7. HiFormer Hierarchical Multi-scale Representations Using Transformers for Medical Image Segmentation

    [WACV2023] HiFormer: Hierarchical Multi-scale Representations Using Transformers for Medical Image S ...

  8. 深度学习论文: U-Net and its variants for Medical Image Segmentation : A short review

    深度学习论文: U-Net and its variants for Medical Image Segmentation : A short review U-Net and its variant ...

  9. 论文解析[3] U-Net and its variants for Medical Image Segmentation : A short review

    发表年份:2022 论文地址:https://arxiv.org/abs/2204.08470v1 文章目录 论文概要 1. 介绍 2. 方法 2.1 传统分割 2.2 U-Net 2015 2.3 ...

最新文章

  1. 基于大数据的Uber数据实时监控(Prat 3:使用Vert.x的实时仪表盘)
  2. Android滚动页面位置指示器:CircleIndicator
  3. 伪装qizhi software数字签名的下载者分析报告
  4. Kafka:Kafka核心概念
  5. UIRefreshControl使用总结
  6. android 7.1 apk的systemuid [2]
  7. 用四位led数码管作显示器的篮球比赛24秒计时器求c语言代码,单片机编程控制LED七段数码管作显示的篮球赛计时计分系统...
  8. mysql 存储过程 sql变量_SQL基础-变量 存储过程和函数
  9. C++基本数据类型列表
  10. eclipse中如何修改编码格式
  11. Okio 1.9简单入门
  12. javascript动画系列 —— 切换图片(原生)
  13. 非合作博弈篇——纳什均衡(Nash Equilibrium)
  14. mov转换成mp4,mov转mp4方法
  15. 电商平台商品详情API调用,获取SKU各类信息
  16. 湖南大学校园网登录地址
  17. 清华大学计算机系成立量子软件研究中心,应明生受聘为主任
  18. 插头DP 概率DP / 期望DP
  19. 上海公积金销户问题--程序员
  20. 工作站和台式机的区别

热门文章

  1. 计算机网络——计算机网络体系结构-计算机网络概述(课程笔记)
  2. 041_CSS及案例-网站主页模板
  3. flash与服务端的交互方法
  4. cocos 中每个节点的visit与draw函数
  5. Beam failure Recovery
  6. 普通用户 linux用tar解压文件 无法 open: 没有那个文件或目录
  7. Scala详细文本教学04
  8. 2021年起重机械指挥复审模拟考试及起重机械指挥考试试题
  9. A2-02-24.DML- Inserting Data into A Table Using MySQL INSERT Statement
  10. Openwrt Lede koolshare固件下屏蔽固定MAC地址以及屏蔽某些网站