Pytorch搭建FCN网络
Pytorch搭建FCN网络
- 前言
- 原理
- 代码实现
前言
FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果:
原理
为了得到更好的分割结果,论文中提到了几种网络结构FCN-32s、FCN-16s、FCN-8s,如图所示:
特征提取骨干网络是一个VGG风格的网络,由多个vgg_block组成,每个vgg_block重复应用数个卷积层+ReLU层,再加上一个池化层。池化层会使宽高减半,也就是说在pool1、pool2、pool3、pool4、pool5之后,特征图的分辨率会分别变为原图像的12\frac {1} {2}21、14\frac {1} {4}41、18\frac {1} {8}81、116\frac {1} {16}161、132\frac {1} {32}321。然后,FCN将VGG原来的两个全连接层替换成卷积层conv6-7,此时得到的结果仍然为原图的132\frac {1} {32}321。
对于FCN-32s,将conv7通过转置卷积上采样32倍得到;
对于FCN-16s,将pool4与conv7上采样2倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的116\frac {1} {16}161,再上采样16倍得到。
对于FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,此时得到的特征图尺寸为原图的18\frac {1} {8}81,再上采样8倍得到。
为什么要采取FCN-16s、FCN-8s这两种融合方式呢?这是因为CNN的浅层卷积学习到的是局部特征(边缘、纹理),深层卷积学习语义特征,提高分类性能。而语义分割对细节特征精度要求较高,对于FCN-32s,直接上采样无法恢复细节信息,因此自然想到将浅层网络学习到的特征与深层特征相融合,有了FCN-16s、FCN-8s。实验证明,确实是FCN-8s细节特征最丰富,分割效果最好。
代码实现
下面用Pytorch搭建FCN-8s的网络结构。首先是骨干网络,采用VGG16,注意将VGG原来的两个全连接层替换成卷积层conv6和conv7,并保存pool3、pool4、conv7的结果:
from torch import nn
from torchvision.models import vgg16def vgg_block(num_convs, in_channels, out_channels):blk = []for i in range(num_convs):if i == 0:blk.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))else:blk.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))blk.append(nn.ReLU(inplace=True))blk.append(nn.MaxPool2d(kernel_size=2, stride=2)) # 宽高减半return blkclass VGG16(nn.Module):def __init__(self, pretrained=True):super(VGG16, self).__init__()features = []features.extend(vgg_block(2, 3, 64))features.extend(vgg_block(2, 64, 128))features.extend(vgg_block(3, 128, 256))self.index_pool3 = len(features)features.extend(vgg_block(3, 256, 512))self.index_pool4 = len(features)features.extend(vgg_block(3, 512, 512))self.features = nn.Sequential(*features)self.conv6 = nn.Conv2d(512, 4096, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.conv7 = nn.Conv2d(4096, 4096, kernel_size=1)# load pretrained params from torchvision.models.vgg16(pretrained=True)if pretrained:pretrained_model = vgg16(pretrained=pretrained)pretrained_params = pretrained_model.state_dict()keys = list(pretrained_params.keys())new_dict = {}for index, key in enumerate(self.features.state_dict().keys()):new_dict[key] = pretrained_params[keys[index]]self.features.load_state_dict(new_dict)def forward(self, x):pool3 = self.features[:self.index_pool3](x) # 1/8pool4 = self.features[self.index_pool3:self.index_pool4](pool3) # 1/16pool5 = self.features[self.index_pool4:](pool4) # 1/32conv6 = self.relu(self.conv6(pool5)) # 1/32conv7 = self.relu(self.conv7(conv6)) # 1/32return pool3, pool4, conv7
然后是FCN-8s,将pool3、pool4上采样2倍得到的特征图、conv7上采样4倍得到的特征图进行相加融合,再上采样8倍。
class FCN(nn.Module):def __init__(self, num_classes, backbone='vgg'):super(FCN, self).__init__()if backbone == 'vgg':self.features = VGG16()self.scores1 = nn.Conv2d(4096, num_classes, kernel_size=1)self.relu = nn.ReLU(inplace=True)self.scores2 = nn.Conv2d(512, num_classes, kernel_size=1)self.scores3 = nn.Conv2d(256, num_classes, kernel_size=1)self.upsample_8x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=8, stride=8)self.upsample_4x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=4, stride=4)self.upsample_2x = nn.ConvTranspose2d(num_classes, num_classes, kernel_size=2, stride=2)def forward(self, x):pool3, pool4, conv7 = self.features(x)conv7 = self.relu(self.scores1(conv7)) # 1×1卷积将通道数映射为类别数pool4 = self.relu(self.scores2(pool4)) # 1×1卷积将通道数映射为类别数pool3 = self.relu(self.scores3(pool3)) # 1×1卷积将通道数映射为类别数s = pool3 + self.upsample_2x(pool4) + self.upsample_4x(conv7) # 相加融合out_8s = self.upsample_8x(s) # 8倍上采样return out_8s
打印一下网络结构:
net = FCN(num_classes=21)
from torchsummary import summary
summary(net.cuda(), (3, 224, 224))
----------------------------------------------------------------Layer (type) Output Shape Param #
================================================================Conv2d-1 [-1, 64, 224, 224] 1,792ReLU-2 [-1, 64, 224, 224] 0Conv2d-3 [-1, 64, 224, 224] 36,928ReLU-4 [-1, 64, 224, 224] 0MaxPool2d-5 [-1, 64, 112, 112] 0Conv2d-6 [-1, 128, 112, 112] 73,856ReLU-7 [-1, 128, 112, 112] 0Conv2d-8 [-1, 128, 112, 112] 147,584ReLU-9 [-1, 128, 112, 112] 0MaxPool2d-10 [-1, 128, 56, 56] 0Conv2d-11 [-1, 256, 56, 56] 295,168ReLU-12 [-1, 256, 56, 56] 0Conv2d-13 [-1, 256, 56, 56] 590,080ReLU-14 [-1, 256, 56, 56] 0Conv2d-15 [-1, 256, 56, 56] 590,080ReLU-16 [-1, 256, 56, 56] 0MaxPool2d-17 [-1, 256, 28, 28] 0Conv2d-18 [-1, 512, 28, 28] 1,180,160ReLU-19 [-1, 512, 28, 28] 0Conv2d-20 [-1, 512, 28, 28] 2,359,808ReLU-21 [-1, 512, 28, 28] 0Conv2d-22 [-1, 512, 28, 28] 2,359,808ReLU-23 [-1, 512, 28, 28] 0MaxPool2d-24 [-1, 512, 14, 14] 0Conv2d-25 [-1, 512, 14, 14] 2,359,808ReLU-26 [-1, 512, 14, 14] 0Conv2d-27 [-1, 512, 14, 14] 2,359,808ReLU-28 [-1, 512, 14, 14] 0Conv2d-29 [-1, 512, 14, 14] 2,359,808ReLU-30 [-1, 512, 14, 14] 0MaxPool2d-31 [-1, 512, 7, 7] 0Conv2d-32 [-1, 4096, 7, 7] 2,101,248ReLU-33 [-1, 4096, 7, 7] 0Conv2d-34 [-1, 4096, 7, 7] 16,781,312ReLU-35 [-1, 4096, 7, 7] 0VGG16-36 [[-1, 256, 28, 28], [-1, 512, 14, 14], [-1, 4096, 7, 7]] 0Conv2d-37 [-1, 21, 7, 7] 86,037ReLU-38 [-1, 21, 7, 7] 0Conv2d-39 [-1, 21, 14, 14] 10,773ReLU-40 [-1, 21, 14, 14] 0Conv2d-41 [-1, 21, 28, 28] 5,397ReLU-42 [-1, 21, 28, 28] 0ConvTranspose2d-43 [-1, 21, 28, 28] 1,785ConvTranspose2d-44 [-1, 21, 28, 28] 7,077ConvTranspose2d-45 [-1, 21, 224, 224] 28,245
================================================================
Total params: 33,736,562
Trainable params: 33,736,562
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 233.14
Params size (MB): 128.69
Estimated Total Size (MB): 362.41
----------------------------------------------------------------
Pytorch搭建FCN网络相关推荐
- 关于使用Pytorch搭建FCN网络的笔记
需要准备的第三方库: numpy.os.torch.cv2 一.Dataload.py的编写 该部分的主要工作是完成数据的预处理.训练集测试集的划分以及数据集的读取,即得到train_dataload ...
- 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记
使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...
- 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)
实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...
- Pytorch复现FCN网络
Pytorch复现FCN网络详解(可复现) 1.环境配置 windows10,pytorch=1.3,python=3.6 参考博客:https://github.com/wkentaro/pytor ...
- Pytorch搭建LeNet5网络
本讲目标: 介绍Pytorch搭建LeNet5网络的流程. Pytorch八股法搭建LeNet5网络 1.LeNet5网络介绍 2.Pytorch搭建LeNet5网络 2.1搭建LeNet网络 2 ...
- 使用PyTorch搭建ResNet50网络
ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...
- pytorch 搭建 VGG 网络
目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...
- pytorch搭建孪生网络比较人脸相似性
参考文献: 神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性_Bubbliiiing的博客-CSDN博客_神经网络图片相似性 Python - ...
- 基于pytorch的FCN网络简单实现
参考知乎专栏实现FCN网络https://zhuanlan.zhihu.com/p/32506912 import torch from torch import nn import torch.nn ...
最新文章
- node sqlite 插入数据_安卓手机中的应用数据都保存在哪些文件中?
- Citrix通用打印服务器配置
- 【组队学习】【24期】Docker教程
- pip._vendor.urllib3.exceptions.ReadTimeoutError: HTTPSConnectionPool(host='f 的解决办法
- Python Tkinter 常用控件空间位置摆放
- 【2021年】通过vue-cli创建electron项目
- 栏目图片 栏目描述_昕街拍|长期福利栏目来啦,秀街拍赢礼品!
- [翻译]asp.net ajax xml-script教程(二)
- abp vnext中swagger使用小结
- 寻找安全漏洞?谨慎为之
- PBR理论基础3.1:基于图像的光照(下)
- java redis源码分析,慢谈 Redis 实现分布式锁 以及 Redisson 源码解析
- 设计者模式之GOF23命令模式
- java安卓游戏源码下载_77个安卓游戏 android源码
- 推荐几款绿色无广告良心软件
- 8021什么意思_无线网络标准IEEE802.11n是什么意思
- B站视频下载扩展工具
- Django项目部署至华为云服务器
- 银行大数据应用场景:客户画像如何做?
- 1011: 【基础】空心六边形
热门文章
- ScarCruft不断进化,引入蓝牙收割机
- [开源]Fre 发布 0.5 版本,更新 diff-patch 和 proxy 方案
- 一篇看懂C#中的Task任务_初级篇
- 四、正则表达式:匹配开头与结尾
- 双人游戏, 双人冒险小游戏 ,双人闯关小游戏
- Boos直聘行业数据获取、json解析
- Windows找不到文件‘gpedit.msc’。请确认文件名是否正确后,再试一次。
- “经济型”Win8.1 4G平板电脑
- Win10系统图标显示不正常解决方法
- user mapping not found for