文章目录

  • 前言
  • 一、UNETR网络结构
  • 二、代码
    • 1.引入库
    • 2.辅助函数和自定义keras层
    • 3.构建Vision Transformer
    • 4.构建完整UNETR
    • 5.简单测试

前言

  现在在尝试各种网络做医学图像分割,这算是我第一次开始尝试Transformer-CNN的图像分割方法。首先想试试这个用完整Vision Transformer(ViT)做编码器的UNETR,可惜这次网上甚至找不到公开的Tensorflow版代码了,无奈只能自己动手试(缝)试(合),日常抛砖引玉。

  这次实现的是2D图像分割版本。不过要转成3D的也很简单,ViT本身对图片维度也不敏感,CNN部分把二维卷积换三维就完事儿了。

文献:UNETR: Transformers for 3D Medical Image Segmentation
参考代码:1. Keras官方示例:Transformer图像分类
2. GitHub用户tamasino52的非官方Pytorch实现

一、UNETR网络结构


  UNETR的完整结构如上,感觉和U-Net相比,最主要的变化就是编码器换成了类似Vision Transformer 16 Base的结构,其它的改变基本都是因此而生的。比如说,由于ViT的输出是固定的patches大小([H/p, W/p, D/p], p=patch_size),导致UNETR中类似“跳跃连接”的部分必须使用连续的反卷积恢复特征图分辨率(图中蓝色块),随后再传输到传统U-Net的解码器层(图中黄色块)。
  完整的ViT架构和额外添加的诸多卷积块导致UNETR的参数量骤增,达到92M(UNETR论文中表示nn-UNet为19M),但效果确实是好的,现在已经成了很多3D医学图像分割任务的对比方法了。

二、代码

  1. 本人Tensorflow版本2.8.0,建议2.5.0以上使用。
  2. 如果没有安装tensorflow_addons,注释掉相关语句就可以

1.引入库

代码如下:

import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import keras
import keras.backend as K
from keras.layers import (Layer, BatchNormalization, LayerNormalization, Conv2D, Conv2DTranspose, Embedding, Activation, Dense, Dropout, MultiHeadAttention, add, Input, concatenate, GlobalAveragePooling1D)
from keras.models import Model

2.辅助函数和自定义keras层

mlp, Patches 和 PatchEncoder 代码来自keras的Code Example

def mlp(x, hidden_units, dropout_rate):if not isinstance(hidden_units, list): hidden_units = [hidden_units]for units in hidden_units:x = Dense(units, activation=tf.nn.gelu)(x)x = Dropout(dropout_rate)(x)return xclass Patches(Layer):'''提取图像块并序列化[B, H, W, C] -> [B, H/patch_size, W/patch_size, C*(patch_size^2)] -> [B, H*W/(patch_size^2), C*(patch_size^2)]'''def __init__(self, patch_size):super(Patches, self).__init__()self.patch_size = patch_sizedef call(self, images):batch_size = tf.shape(images)[0]patches = tf.image.extract_patches(images = images,sizes=[1, self.patch_size, self.patch_size, 1],strides=[1, self.patch_size, self.patch_size, 1],rates=[1, 1, 1, 1],padding='VALID',)patch_dims = patches.shape[-1]patches = tf.reshape(patches, [batch_size, -1, patch_dims])return patchesclass PatchEncoder(Layer):'''将图块线性投影到projection_dim并且为图块引入一个可学习的位置嵌入'''def __init__(self, num_patches, projection_dim):super(PatchEncoder, self).__init__()self.num_patches = num_patchesself.projection = Dense(units=projection_dim)self.position_embeding = Embedding(input_dim=num_patches, output_dim=projection_dim)def call(self, patch):positions = tf.range(start=0, limit=self.num_patches, delta=1)encoded = self.projection(patch) + self.position_embeding(positions)return encodeddef normalization(input_tensor, normalization, name=None):if normalization=='batch':return(BatchNormalization(name=None if name is None else name + '_batchnorm')(input_tensor))elif normalization=='layer':return(LayerNormalization(epsilon=1e-6, name=None if name is None else name + '_layernorm')(input_tensor))elif normalization=='group':return(tfa.layers.GroupNormalization(groups=8, name=None if name is None else name + '_groupnorm')(input_tensor))elif normalization == None:return input_tensorelse:raise ValueError('Invalid normalization')def conv_norm_act(input_tensor, filters, kernel_size , norm_type='batch', act_type='relu', dilation=1):'''Conv2d + Normalization(norm_type:str) + Activation(act_type:str)'''output_tensor = Conv2D(filters, kernel_size, padding='same', dilation_rate=(dilation, dilation), use_bias=False if norm_type is not None else True, kernel_initializer='he_normal')(input_tensor)output_tensor = normalization(output_tensor, normalization=norm_type)if act_type is not None: output_tensor = Activation(act_type)(output_tensor)return output_tensordef conv2d_block(input_tensor, filters, kernel_size, norm_type, use_residual, act_type='relu',double_features = False, dilation=[1, 1], name=None):x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[0], use_bias=False, kernel_initializer='he_normal', name=None if name is None else name + '_conv2d_0')(input_tensor)x = normalization(x, norm_type, name=None if name is None else name + '_0')x = Activation(act_type, name=None if name is None else name + act_type + '_0')(x)if double_features:filters *= 2x = Conv2D(filters, kernel_size, padding='same', dilation_rate=dilation[1], use_bias=False, kernel_initializer='he_normal', name=None if name is None else name + '_conv2d_1')(x)x = normalization(x, norm_type, name=None if name is None else name + '_1')if use_residual:if K.int_shape(input_tensor)[-1] != K.int_shape(x)[-1]:shortcut = Conv2D(filters, kernel_size=1, padding='same', use_bias=False, kernel_initializer='he_normal', name=None if name is None else name + '_shortcut_conv2d')(input_tensor)shortcut = normalization(shortcut, norm_type, name=None if name is None else name + '_shortcut')x = add([x, shortcut])else:x = add([x, input_tensor])x = Activation(act_type, name=None if name is None else name + act_type + '_0')(x)return xdef deconv_conv_block(x,filters_list: list,kernel_size,norm_type,act_type,):'''对应UNETR结构图中的蓝色块'''for filts in filters_list:x = Conv2DTranspose(filts, 2, (2, 2), kernel_initializer='he_normal')(x)x = conv_norm_act(x, filts, kernel_size, norm_type, act_type)return xdef conv_deconv_block(x,filters,kernel_size,norm_type,use_residual,act_type,):'''对应UNETR结构图中的黄色+绿色块'''x = conv2d_block(x, filters, kernel_size, norm_type, use_residual, act_type)x = Conv2DTranspose(filters // 2, 2, (2, 2), kernel_initializer='he_normal')(x)return x

3.构建Vision Transformer

这一部分也是来自keras的Code Example,主要是去掉了分类头,并且增添了要在特定层输出“跳跃连接”的部分,源码中的注释我基本都保留了。

def create_vit(x,patch_size,num_patches,projection_dim,num_heads,transformer_units,transformer_layers,dropout_rate,extract_layers,):skip_connections = []# Create patches.patches = Patches(patch_size)(x)# Encode patches.encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)# Create multiple layers of the Transformer block.for layer in range(transformer_layers):# Layer normalization 1.x1 = LayerNormalization(epsilon=1e-6)(encoded_patches)# Create a multi-head attention layer.attention_output = MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim//num_heads, dropout=dropout_rate)(x1, x1)# Skip connection 1.x2 = add([attention_output, encoded_patches])# Layer normalization 2.x3 = LayerNormalization(epsilon=1e-6)(x2)# MLP.x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=dropout_rate)# Skip connection 2.encoded_patches = add([x3, x2])if layer + 1 in extract_layers:skip_connections.append(encoded_patches)return skip_connections

4.构建完整UNETR

def build_model(# ↓ Base argumentsinput_shape = (256, 256, 3),class_nums = 5,# ↓ ViT argumentspatch_size = 16,projection_dim = 768,num_heads = 12,transformer_units = [2048, 768],transformer_layers = 12,extract_layers = [3, 6, 9, 12],dropout_rate = 0.1,# ↓ Conv argumentskernel_size = 3,conv_norm = 'batch',conv_act = 'relu',use_residual = False,# ↓ Other argumentsshow_summary = True,output_act = 'auto',):'''input_shape: tuple, (height, width, channel) 注意这是2D分割class_nums: int, 输出通道数patch_size: int, 图像分块尺寸projection_dim: int, ViT中的投影维度num_heads: int, 多头注意力的头数transformer_units: list, ViT中MLP模块的隐藏层数,注意是列表形式transformer_layers: int, Transformer的堆叠层数extract_layers: list, 决定ViT中哪些层要加入"跳跃连接"中,默认是[3, 6, 9, 12]dropout_rate: float, ViT部分的dropout比率kernel_size: int, 卷积核尺寸conv_norm: str, 卷积层的normalization方式, 'batch'或'layer'或'group'conv_act: str, 卷积层的激活函数use_residual: bool, 是否使用残差连接show_summary: bool, 是否显示模型概况output_act: str, 输出层的激活函数, 'auto'时会根据class_nums决定, 也可以自己指定'softmax'或'sigmoid''''z4_de_filts = 512z3_de_filts_list = [512]z2_de_filts_list = [512, 256]z1_de_filts_list = [512, 256, 128]z34_conv_filts = 512z23_conv_filts = 256z12_conv_filts = 128z01_conv_filts = 64if output_act == 'auto': output_act = 'sigmoid' if class_nums == 1 else 'softmax'assert input_shape[0] == input_shape[1] and input_shape[0] // patch_sizenum_patches = (input_shape[0] * input_shape[1]) // (patch_size ** 2)inputs = Input(input_shape)z0 = inputsz1, z2, z3, z4 = create_vit(z0, patch_size,num_patches,projection_dim,num_heads,transformer_units,transformer_layers,dropout_rate,extract_layers)z1 = tf.reshape(z1, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim)) # [B, H/16, W/16, projection_dim]z2 = tf.reshape(z2, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim))z3 = tf.reshape(z3, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim))z4 = tf.reshape(z4, (-1, input_shape[0] // patch_size, input_shape[1] // patch_size, projection_dim))z4 = Conv2DTranspose(z4_de_filts, 2, (2, 2), kernel_initializer='he_normal')(z4)z3 = deconv_conv_block(z3, z3_de_filts_list, kernel_size, conv_norm, conv_act)z3 = concatenate([z3, z4])z3 = conv_deconv_block(z3, z34_conv_filts, kernel_size, conv_norm, use_residual, conv_act)z2 = deconv_conv_block(z2, z2_de_filts_list, kernel_size, conv_norm, conv_act)z2 = concatenate([z2, z3])z2 = conv_deconv_block(z2, z23_conv_filts, kernel_size, conv_norm, use_residual, conv_act)z1 = deconv_conv_block(z1, z1_de_filts_list, kernel_size, conv_norm, conv_act)z1 = concatenate([z1, z2])z1 = conv_deconv_block(z1, z12_conv_filts, kernel_size, conv_norm, use_residual, conv_act)z0 = conv2d_block(z0, z01_conv_filts, kernel_size, conv_norm, use_residual, conv_act)z0 = concatenate([z0, z1])z0 = conv2d_block(z0, z01_conv_filts, kernel_size, conv_norm, use_residual, conv_act)outputs = Conv2D(class_nums, 1, activation=output_act)(z0)model = Model(inputs=inputs, outputs=outputs)if show_summary: model.summary()return model

5.简单测试

如果以上代码全部放在同一个python脚本中,可以添加下面的代码并运行脚本,尝试构建网络:

if __name__ == '__main__':x = np.random.uniform(size=(1, 256, 256, 3))model = build_model(# ↓ Base argumentsinput_shape = (256, 256, 3),class_nums = 5,# ↓ ViT argumentspatch_size = 16,projection_dim = 768,num_heads = 12,transformer_units = [2048, 768],transformer_layers = 12,extract_layers = [3, 6, 9, 12],dropout_rate = 0.1,# ↓ Conv argumentskernel_size = 3,conv_norm = 'batch',conv_act = 'relu',use_residual = False,# ↓ Other argumentsshow_summary = True,output_act = 'auto',)y = model(x)print(x.shape, y.shape)

唉。

UNETR 医学图像分割架构 2D版 (Tensorflow2 Keras 实现UNETR)相关推荐

  1. (新SOTA)UNETR++:轻量级的、高效、准确的共享权重的3D医学图像分割

    (新SOTA)UNETR++:轻量级的.高效.准确的共享权重的3D医学图像分割 0 Abstract 由于Transformer模型的成功,最近的工作研究了它们在3D医学分割任务中的适用性.在Tran ...

  2. tensorflow版使用uNet进行医学图像分割(Skin数据集)

    tensorflow版使用uNet进行医学图像分割(Skin数据集) 深度学习.计算机视觉学习笔记.医学图像分割.uNet.Skin皮肤数据集 tensorflow版使用uNet进行医学图像分割(Sk ...

  3. (脑肿瘤阅读笔记:四十六)KIU-Net用于医学图像分割和体积分割的过完备卷积网络架构

    目录 Title:KiU-Net: Overcomplete Convolutional Architectures for Biomedical Image and Volumetric Segme ...

  4. Tensorflow2.0 医学图像分割(大脑肿瘤识别)

    医学图像分割是医学图像处理与分析领域的复杂而关键的步骤,其目的是将医学图像中具有某些特殊含义的部分分割出来,并提取相关特征,为临床诊疗和病理学研究提供可靠的依据,辅助医生作出更为准确的诊断.本次实验使 ...

  5. 用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割

    用NVIDIA Tensor Cores和TensorFlow 2加速医学图像分割 Accelerating Medical Image Segmentation with NVIDIA Tensor ...

  6. nnUNet原创团队全新力作!MedNeXt:医学图像分割新SOTA

    Title:MedNeXt: Transformer-driven Scaling of ConvNets for Medical Image Segmentation MedNeXt:用于医学图像分 ...

  7. 医学图像分割综述:U-Net系列

    文章目录 Medical Image Segmentation Review:The Success of U-Net 摘要 引言 分类法 2D Unet 3D U-Net U-Ne的临床意义和疗效 ...

  8. 基于深度学习的自然图像和医学图像分割:网络结构设计

    来源:知乎.极市平台.深度学习爱好者作者丨李慕清@知乎 https://zhuanlan.zhihu.com/p/104854615 本文约5100字,建议阅读10分钟 本文首先介绍一些经典的语义分割 ...

  9. 基于深度学习的医学图像分割综述

    转载:https://blog.csdn.net/weixin_41783077/article/details/80894466 摘要:医学图像分割是决定医学图像在临床诊疗中能否提供可靠依据的关键问 ...

最新文章

  1. linux access函数_构建一个即时消息应用(九):Conversation 页面 | Linux 中国
  2. python要和什么一起学_跟哥一起学Python(1) - python简介
  3. 阿里巴巴的云原生应用开源探索与实践
  4. installshield 4075 错误
  5. python结束线程_2018-01-02 如何优雅地终止python线程
  6. 忽略已检查的异常,所有出色的开发人员都在这样做–基于600,000个Java项目
  7. java定义计算机类并模拟其操作
  8. 干货:结合Scikit-learn介绍几种常用的特征选择方法
  9. CSS3 背景起始位置 background-origin属性
  10. 快速找出Linux服务器上不该存在恶意或后门文件
  11. Android开发之使用GridView+仿微信图片上传功能(附源代码)
  12. Asp.Net MVC 控制器
  13. Windows server 2008R2 如何成功离线安装.net 4.7.2
  14. 学霸如何使用计算机,学霸必备软件!超强的灵格斯词霸怎么用?
  15. comsol matlab安装教程,comsol和matlab接口
  16. 实例详解——编译器命令#pragma section作用于函数时作用域是否覆盖到其子函数...
  17. 程序员的日常工作是怎样的?
  18. opencv学习笔记(三)分离颜色通道多通道颜色混合
  19. 结构光扫描仪(维基百科全翻译版)
  20. BIGEMAP谷歌卫星地图下载器

热门文章

  1. QT学习笔记(六)——QT弹出对话框并在主窗口调用对话框的信息
  2. def read()
  3. [译] 什么是即时通信(Instant Messaging)
  4. zabbix触发器通过钉钉发送警报
  5. R中常用统计指标含义
  6. 视频信号指标与测试方法
  7. 免费拥有个人云主机——AWS免费EC2套餐内容
  8. 小米air2se耳机只有一边有声音怎么办_小米真无线蓝牙耳机Air2 SE体验:花小钱也能办大事...
  9. 支持DoH的DNS服务器,Win11 支持私密 DNS-over-HTTPS(DoH) 附启用教程
  10. Hadoop 2.X的安装与配置