Pytorch搭建FCN网络

  • 前言
    • 原理
    • 代码实现

前言

FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果:

原理

为了得到更好的分割结果,论文中提到了几种网络结构FCN-32s、FCN-16s、FCN-8s,如图所示:

特征提取骨干网络是一个VGG风格的网络,由多个vgg_block组成,每个vgg_block重复应用数个卷积层+ReLU层,再加上一个池化层。池化层会使宽高减半,也就是说在pool1、pool2、pool3、pool4、pool5之后,特征图的分辨率会分别变为原图像的12\frac {1} {2}21​、14\frac {1} {4}41​、18\frac {1} {8}81​、116\frac {1} {16}161​、132\frac {1} {32}321​。然后,FCN将VGG原来的两个全连接层替换成卷积层conv6-7,此时得到的结果仍然为原图的132\frac {1} {32}321​。
对于FCN-32s,将conv7通过转置卷积上采样32倍得到;
对于FCN-16s,将pool4与conv7上采样2倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的116\frac {1} {16}161​,再上采样16倍得到。
对于FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的18\frac {1} {8}81​,再上采样8倍得到。

为什么要采取FCN-16s、FCN-8s这两种融合方式呢?这是因为CNN的浅层卷积学习到的是局部特征(边缘、纹理),深层卷积学习语义特征,提高分类性能。而语义分割对细节特征精度要求较高,对于FCN-32s,直接上采样无法恢复细节信息,因此自然想到将浅层网络学习到的特征与深层特征相融合,有了FCN-16s、FCN-8s。实验证明,确实是FCN-8s细节特征最丰富,分割效果最好。

代码实现

下面用Pytorch搭建FCN-8s的网络结构。首先是骨干网络,采用VGG16,注意将VGG原来的两个全连接层替换成卷积层conv6和conv7,并保存pool3、pool4、conv7的结果:

from torch import nn
from torchvision.models import vgg16def vgg_block(num_convs, in_channels, out_channels):blk = []for i in range(num_convs):if i == 0:blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))else:blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))blk.append(nn.ReLU(inplace=True))blk.append(nn.MaxPool2d(kernel_size=2, stride=2))  # 宽高减半return blkclass VGG16(nn.Module):def __init__(self, pretrained=True):super(VGG16, self).__init__()features = []features.extend(vgg_block(2, 3, 64))features.extend(vgg_block(2, 64, 128))features.extend(vgg_block(3, 128, 256))self.index_pool3 = len(features)features.extend(vgg_block(3, 256, 512))self.index_pool4 = len(features)features.extend(vgg_block(3, 512, 512))self.features = nn.Sequential(*features)self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)# load pretrained params from torchvision.models.vgg16(pretrained=True)if pretrained:pretrained_model = vgg16(pretrained=pretrained)pretrained_params = pretrained_model.state_dict()keys = list(pretrained_params.keys())new_dict = {}for index, key in enumerate(self.features.state_dict().keys()):new_dict[key] = pretrained_params[keys[index]]self.features.load_state_dict(new_dict)def forward(self, x):pool3 = self.features[:self.index_pool3](x)      # 1/8pool4 = self.features[self.index_pool3:self.index_pool4](pool3)  # 1/16pool5 = self.features[self.index_pool4:](pool4)  # 1/32conv6 = self.relu(self.conv6(pool5))  # 1/32conv7 = self.relu(self.conv7(conv6))  # 1/32return pool3, pool4, conv7

然后是FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,再上采样8倍。

class FCN(nn.Module):def __init__(self, num_classes, backbone='vgg'):super(FCN, self).__init__()if backbone == 'vgg':self.features = VGG16()self.scores1 = nn.Conv2d(4096, num_classes, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.scores2 = nn.Conv2d(512, num_classes, kernel_size=1)self.scores3 = nn.Conv2d(256, num_classes, kernel_size=1)self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=8, stride=8)self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=4)self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, stride=2)def forward(self, x):pool3, pool4, conv7 = self.features(x)conv7 = self.relu(self.scores1(conv7))  # 1×1卷积将通道数映射为类别数pool4 = self.relu(self.scores2(pool4))  # 1×1卷积将通道数映射为类别数pool3 = self.relu(self.scores3(pool3))  # 1×1卷积将通道数映射为类别数s = pool3 + self.upsample_2x(pool4) + self.upsample_4x(conv7)  # 相加融合out_8s = self.upsample_8x(s)  # 8倍上采样return out_8s

打印一下网络结构:

net = FCN(num_classes=21)
from torchsummary import summary
summary(net.cuda(), (3, 224, 224))
----------------------------------------------------------------Layer (type)               Output Shape         Param #
================================================================Conv2d-1         [-1, 64, 224, 224]           1,792ReLU-2         [-1, 64, 224, 224]               0Conv2d-3         [-1, 64, 224, 224]          36,928ReLU-4         [-1, 64, 224, 224]               0MaxPool2d-5         [-1, 64, 112, 112]               0Conv2d-6        [-1, 128, 112, 112]          73,856ReLU-7        [-1, 128, 112, 112]               0Conv2d-8        [-1, 128, 112, 112]         147,584ReLU-9        [-1, 128, 112, 112]               0MaxPool2d-10          [-1, 128, 56, 56]               0Conv2d-11          [-1, 256, 56, 56]         295,168ReLU-12          [-1, 256, 56, 56]               0Conv2d-13          [-1, 256, 56, 56]         590,080ReLU-14          [-1, 256, 56, 56]               0Conv2d-15          [-1, 256, 56, 56]         590,080ReLU-16          [-1, 256, 56, 56]               0MaxPool2d-17          [-1, 256, 28, 28]               0Conv2d-18          [-1, 512, 28, 28]       1,180,160ReLU-19          [-1, 512, 28, 28]               0Conv2d-20          [-1, 512, 28, 28]       2,359,808ReLU-21          [-1, 512, 28, 28]               0Conv2d-22          [-1, 512, 28, 28]       2,359,808ReLU-23          [-1, 512, 28, 28]               0MaxPool2d-24          [-1, 512, 14, 14]               0Conv2d-25          [-1, 512, 14, 14]       2,359,808ReLU-26          [-1, 512, 14, 14]               0Conv2d-27          [-1, 512, 14, 14]       2,359,808ReLU-28          [-1, 512, 14, 14]               0Conv2d-29          [-1, 512, 14, 14]       2,359,808ReLU-30          [-1, 512, 14, 14]               0MaxPool2d-31            [-1, 512, 7, 7]               0Conv2d-32           [-1, 4096, 7, 7]       2,101,248ReLU-33           [-1, 4096, 7, 7]               0Conv2d-34           [-1, 4096, 7, 7]      16,781,312ReLU-35           [-1, 4096, 7, 7]               0VGG16-36  [[-1, 256, 28, 28], [-1, 512, 14, 14], [-1, 4096, 7, 7]]               0Conv2d-37             [-1, 21, 7, 7]          86,037ReLU-38             [-1, 21, 7, 7]               0Conv2d-39           [-1, 21, 14, 14]          10,773ReLU-40           [-1, 21, 14, 14]               0Conv2d-41           [-1, 21, 28, 28]           5,397ReLU-42           [-1, 21, 28, 28]               0ConvTranspose2d-43           [-1, 21, 28, 28]           1,785ConvTranspose2d-44           [-1, 21, 28, 28]           7,077ConvTranspose2d-45         [-1, 21, 224, 224]          28,245
================================================================
Total params: 33,736,562
Trainable params: 33,736,562
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 233.14
Params size (MB): 128.69
Estimated Total Size (MB): 362.41
----------------------------------------------------------------

Pytorch搭建FCN网络相关推荐

  1. 关于使用Pytorch搭建FCN网络的笔记

    需要准备的第三方库: numpy.os.torch.cv2 一.Dataload.py的编写 该部分的主要工作是完成数据的预处理.训练集测试集的划分以及数据集的读取,即得到train_dataload ...

  2. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  3. 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

    实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...

  4. Pytorch复现FCN网络

    Pytorch复现FCN网络详解(可复现) 1.环境配置 windows10,pytorch=1.3,python=3.6 参考博客:https://github.com/wkentaro/pytor ...

  5. Pytorch搭建LeNet5网络

    本讲目标:   介绍Pytorch搭建LeNet5网络的流程. Pytorch八股法搭建LeNet5网络 1.LeNet5网络介绍 2.Pytorch搭建LeNet5网络 2.1搭建LeNet网络 2 ...

  6. 使用PyTorch搭建ResNet50网络

    ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...

  7. pytorch 搭建 VGG 网络

    目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...

  8. pytorch搭建孪生网络比较人脸相似性

    参考文献: 神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性_Bubbliiiing的博客-CSDN博客_神经网络图片相似性 Python - ...

  9. 基于pytorch的FCN网络简单实现

    参考知乎专栏实现FCN网络https://zhuanlan.zhihu.com/p/32506912 import torch from torch import nn import torch.nn ...

最新文章

  1. node sqlite 插入数据_安卓手机中的应用数据都保存在哪些文件中?
  2. Citrix通用打印服务器配置
  3. 【组队学习】【24期】Docker教程
  4. pip._vendor.urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='f 的解决办法
  5. Python Tkinter 常用控件空间位置摆放
  6. 【2021年】通过vue-cli创建electron项目
  7. 栏目图片 栏目描述_昕街拍|长期福利栏目来啦,秀街拍赢礼品!
  8. [翻译]asp.net ajax xml-script教程(二)
  9. abp vnext中swagger使用小结
  10. 寻找安全漏洞?谨慎为之
  11. PBR理论基础3.1:基于图像的光照(下)
  12. java redis源码分析,慢谈 Redis 实现分布式锁 以及 Redisson 源码解析
  13. 设计者模式之GOF23命令模式
  14. java安卓游戏源码下载_77个安卓游戏 android源码
  15. 推荐几款绿色无广告良心软件
  16. 8021什么意思_无线网络标准IEEE802.11n是什么意思
  17. B站视频下载扩展工具
  18. Django项目部署至华为云服务器
  19. 银行大数据应用场景:客户画像如何做?
  20. 1011: 【基础】空心六边形

热门文章

  1. ScarCruft不断进化,引入蓝牙收割机
  2. [开源]Fre 发布 0.5 版本,更新 diff-patch 和 proxy 方案
  3. 一篇看懂C#中的Task任务_初级篇
  4. 四、正则表达式:匹配开头与结尾
  5. 双人游戏, 双人冒险小游戏 ,双人闯关小游戏
  6. Boos直聘行业数据获取、json解析
  7. Windows找不到文件‘gpedit.msc’。请确认文件名是否正确后,再试一次。
  8. “经济型”Win8.1 4G平板电脑
  9. Win10系统图标显示不正常解决方法
  10. user mapping not found for