对原始unet的改进主要有两方面,第一是卷积块的改进,第二是unet模型结构的改进。
对于卷积块的改进,可以把原来的卷积块换成残差块。
对于模型结构的改进可以在模型结构上多加U。
残差块有5种,resnet18、resnet34、resnet50、resnet101、resnet152,resnet后跟的数字表示卷积层数,前两种残差块类型是basic block,后三种是bottleneck block。

图1.图片来自[1]
如图1所示,左边是一个basic block,右边是一个bottleneck block。basic block 由两组通道数相同、kernel size都是33的卷积和一个绕过这两组卷积的short cut构成。bottleneck block 由kenel_size分别为11、33、11,前两个通道数相同、最后一个通道数是前两个通道数4倍的3组卷积和绕过这3组卷积的short cut构成。

图2.图片来自[1]
图2为5种残差块内部kernel size和通道数的设置。

unet结构的改进方法是unet嵌套,如图3所示:

图3.图片来自[2]
原来的unet下采样n次后上采样n次。unet++在原有的基础上,在第n-1次下采样后接着上采样n-1次,以此类推,直到第1次下采样后接着上采样1次。同时,每次新增加上采样后对应的skip也加上。
unet++主要有4种,如图4所示:

图4.图片来自[2]

resnet结合unet++L2的代码如下:


import torch
import torch.nn as nn
import torch.nn.functional as Fclass BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))self.shortcut = nn.Sequential()if stride != 1 or in_channels != BasicBlock.expansion * out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels * BasicBlock.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class BottleNeck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.residual_function = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True),nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels * BottleNeck.expansion),)self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels * BottleNeck.expansion:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels * BottleNeck.expansion))def forward(self, x):return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))class ResNet(nn.Module):def __init__(self, in_chans, block, num_block, num_classes=100):super().__init__()self.block = blockself.in_channels = 64self.conv1 = nn.Sequential(nn.Conv2d(in_chans, 64, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True))self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.conv2_x = self._make_layer(block, 64, num_block[0], 1)self.conv3_x = self._make_layer(block, 128, num_block[1], 2)self.conv4_x = self._make_layer(block, 256, num_block[2], 2)self.conv5_x = self._make_layer(block, 512, num_block[3], 2)self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):strides = [stride] + [1] * (num_blocks - 1)layers = []for stride in strides:layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channels * block.expansionreturn nn.Sequential(*layers)def forward(self, x):f1 = self.conv1(x)f2 = self.conv2_x(self.pool(f1))f3 = self.conv3_x(f2)f4 = self.conv4_x(f3)f5 = self.conv5_x(f4)output = self.avg_pool(f5)output = output.view(output.size(0), -1)output = self.fc(output)return f1,f2,f3,outputdef resnet18(in_chans):return ResNet(in_chans, BasicBlock, [2, 2, 2, 2])def resnet34(in_chans):return ResNet(in_chans, BasicBlock, [3, 4, 6, 3])def resnet50(in_chans):return ResNet(in_chans, BottleNeck, [3, 4, 6, 3])def resnet101(in_chans):return ResNet(in_chans, BottleNeck, [3, 4, 23, 3])def resnet152(in_chans):return ResNet(in_chans, BottleNeck, [3, 8, 36, 3])"""### ResNet_UNetpp"""class ConvBlock(nn.Module):def __init__(self, in_chans, out_chans, stride):super(ConvBlock, self).__init__()self.conv1 = nn.Conv2d(in_chans, out_chans, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_chans)self.relu1 = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_chans, out_chans, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(out_chans)self.relu2 = nn.ReLU(inplace=True)def forward(self, x):x = self.relu1(self.bn1(self.conv1(x)))out = self.relu2(self.bn2(self.conv2(x)))return outclass UpConvBlock(nn.Module):def __init__(self, in_chans, bridge_chans_list, out_chans):super(UpConvBlock, self).__init__()self.up = nn.ConvTranspose2d(in_chans, out_chans, kernel_size=2, stride=2)self.conv_block = BasicBlock(out_chans + sum(bridge_chans_list), out_chans, 1)def forward(self, x, bridge_list):x = self.up(x)x = torch.cat([x] + bridge_list, dim=1)out = self.conv_block(x)return outclass ResNet_UNetpp(nn.Module):def __init__(self, in_chans=1, n_classes=2, backbone=resnet18):super(ResNet_UNetpp, self).__init__()'''
兼容resnet18/34/50/101/152'''#backbone.block.expansion feat_chans = [64*backbone(in_chans).block.expansion, 128*backbone(in_chans).block.expansion, 256*backbone(in_chans).block.expansion]self.conv_x00 = backbone(in_chans).block(in_chans, feat_chans[0]//(backbone(in_chans).block.expansion), 1)self.conv_x10 = backbone(in_chans).block(feat_chans[0], feat_chans[1]//(backbone(in_chans).block.expansion), 2)self.conv_x20 = backbone(in_chans).block(feat_chans[1], feat_chans[2]//(backbone(in_chans).block.expansion), 2)self.conv_x01 = UpConvBlock(feat_chans[1], [feat_chans[0]], feat_chans[0])self.conv_x11 = UpConvBlock(feat_chans[2], [feat_chans[1]], feat_chans[1])self.conv_x02 = UpConvBlock(feat_chans[1], [feat_chans[0], feat_chans[0]], feat_chans[0])self.cls_conv_x01 = nn.Conv2d(feat_chans[0], 2, kernel_size=1)self.cls_conv_x02 = nn.Conv2d(feat_chans[0], 2, kernel_size=1)def forward(self, x):x00 = self.conv_x00(x)x10 = self.conv_x10(x00)x20 = self.conv_x20(x10)x01 = self.conv_x01(x10, [x00])x11 = self.conv_x11(x20, [x10])x02 = self.conv_x02(x11, [x00, x01])out01 = self.cls_conv_x01(x01)out02 = self.cls_conv_x02(x02)print('x00', x00.shape)print('x10', x10.shape)print('x20', x20.shape)print('x01', x01.shape)print('x11', x11.shape)print('x02', x02.shape)print('out01', out01.shape)print('out02', out02.shape)return out01, out02x = torch.randn((2, 1, 224, 224), dtype=torch.float32)
model = ResNet_UNetpp(in_chans=1, backbone=resnet50)
y1, y2 = model(x)

结果:

Refferences:
[1]He K, Zhang X, Ren S, et al. Deep residual learning for image recognition[C]//Proceedings of the IEEE conference on computer vision and pattern recognition. 2016: 770-778.
[2]Zhou Z, Siddiquee M M R, Tajbakhsh N, et al. Unet++: A nested u-net architecture for medical image segmentation[M]//Deep learning in medical image analysis and multimodal learning for clinical decision support. Springer, Cham, 2018: 3-11.

resnet_unetpp相关推荐

最新文章

  1. WPF指南之一(WPF的结构)
  2. [异常特工]android常见bug跟踪
  3. python pandas加速包
  4. Snmp扫描-snmpwalk、snmpcheck
  5. 转:微服务架构:BFF和网关是如何演化出来的?(这篇文章相当棒)
  6. java后台工具类-通过交易码获得方法名
  7. 汇编比较两个数大小_计算机是怎样跑起来的 -- 体验一次汇编过程
  8. idea中Terminal输入命令git log后如何退出
  9. STM32F405 标准库 SHT20温湿度传感器
  10. 大数据之Oozie——源码分析(一)程序入口
  11. android 用LruCache读取大图片并缓存(转)
  12. noip模拟赛 czy的后宫
  13. 安装python3.7的步骤_如何在Debian 9上安装Python 3.7?
  14. (转载)New poker 2总算放出新固件了!
  15. 乞丐的一句话,感动中国13亿人。
  16. 海南旅游自由行攻略怎么玩
  17. 吴恩达深度学习笔记——改善深层神经网络:超参数调整,正则化,最优化(Hyperparameter Tuning)
  18. 计算机系统建模_包图
  19. Unity导入模型贴贴图一面有贴图另一面透明的解决方法
  20. 关于PS新建(PS如何新建)

热门文章

  1. HTML表格自动排序
  2. C++三目运算符(简述)
  3. MYSQL 获取当前日期及日期格式,和常用时间转换函数
  4. [nginx代理配置][nginx proxy_pass][nginx从一台服务器代理到另外一台服务器,浏览器地址不改变]
  5. Java list.toArray()和list.toArray(T[] a)
  6. 任务栏中间的活动窗口图标不见了怎么办
  7. 基于工业路由器的智慧医疗远程监控系统
  8. 事件驱动架构(EDA/SEDA/DEDA/ESB/CQRS/EventSourcing)
  9. ajax异步请求案例
  10. 倒立摆系统分析及控制