文章目录

  • 准备工作
  • BasicBlock块
  • ResNet-18、34网络结构
  • 完整代码:
  • 小总结

准备工作

昨天把论文读完了,CNN基础论文 精读+复现---- ResNet(一)
,今天用pytorch复现一下。

之前论文中提到过ResNet有很多种,这里复现一下ResNet-18和ResNet34吧,这俩基本一样。

这两种残差块,左边是 18 和34层的,50,101,152用右边的残差快。

ResNet-18,只需要左边的残差块,这俩残差块都实现一下,整体网络实现ResNet-18。

BasicBlock块

按照上面左边的图, 结构很清晰: 卷积 -> BN -> Relu -> 卷积 -> BN。
这几层写出来先:

nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(outchannel),
nn.ReLU(inplace=True),
nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(outchannel)

然后定义一下右路的恒等映射,这个其实就是个空就行了,里面什么层都不放。

self.shortcut = nn.Sequential()

之后还有个很重要的东西,就是之前说过的 残差F(X)与自身输入x维度必须一致。

这里直接通过1 * 1 卷积核进行卷积升降维就ok,也记得要加BN。

if stride != 1 or inchannel != outchannel:#shortcut,这里为了跟2个卷积层的维度结构一致。self.shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(outchannel))

最后在汇总的时候是 两路相加然后再Relu:

out = self.left(x)
out = out + self.shortcut(x)
out = nn.relu(out)

放到一起就有了BasicBlock块:


import torch
import torch.nn as nn
#残差块ResBlock
class ResBlock(nn.Module):def __init__(self, inchannel, outchannel, stride=1):super(ResBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),nn.BatchNorm2d(outchannel))self.shortcut = nn.Sequential()if stride != 1 or inchannel != outchannel:#shortcut,这里为了跟2个卷积层的维度结构一致。self.shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(outchannel))def forward(self, x):out = self.left(x)out = out + self.shortcut(x)out = nn.relu(out)return out

另外一个 BottleNeck块 就不写了,跟这个差不多 就是都了一层而已。

ResNet-18、34网络结构

根据论文中的这张图:

ResNet-18和34 一共6个部分: 开头的卷积层,然后中间4块残差块,最后一个全连接层。

先将各层堆叠起来:

     # 一开始的卷积+池化层
self.pre = nn.Sequential(nn.Conv2d(3, 64, 7, 2, 3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, 1))#各层的残差块
self.layer1 = self._make_layer(64, 64, blocks[0])
self.layer2 = self._make_layer(64, 128, blocks[1], stride=2)
self.layer3 = self._make_layer(128, 256, blocks[2], stride=2)
self.layer4 = self._make_layer(256, 512, blocks[3], stride=2)# 最后的全连接层
self.fc = nn.Linear(512, num_classes)

将重复的残差块放到一起去:

def _make_layer(self, inchannel, outchannel, block_num, stride=1):# 重复的残差块shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU())layers = []layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))for i in range(1, block_num):layers.append(ResidualBlock(outchannel, outchannel))return nn.Sequential(*layers)

改写一下上面实现的BasicBlock残差块让他适应 18和34层的ResNet。

class ResidualBlock(nn.Module):def __init__(self, inchannel, outchannel, stride=1, shortcut=None):super(ResidualBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),nn.BatchNorm2d(outchannel))self.right = shortcutdef forward(self, x):out = self.left(x)residual = x if self.right is None else self.right(x)out += residualreturn F.relu(out)

上面这段代码没事好说的,就是稍微改动了一下 最开始实现的残差块的参数。

整体forward一下:

def forward(self, x):x = self.pre(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = F.avg_pool2d(x, 7)x = x.view(x.size(0), -1)return self.fc(x)

完整代码:

将上面的汇总到一起看一下完整的 ResNet-18、34代码。
这里没有加数据集。

torchsummary 也是pytorch里的可视化方法。

import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torchsummary import summaryclass ResidualBlock(nn.Module):"""实现子module: Residual Block"""def __init__(self, inchannel, outchannel, stride=1, shortcut=None):super(ResidualBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, 3, stride, 1, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, 3, 1, 1, bias=False),nn.BatchNorm2d(outchannel))self.right = shortcutdef forward(self, x):out = self.left(x)residual = x if self.right is None else self.right(x)out += residualreturn F.relu(out)class ResNet(nn.Module):"""实现主module:ResNet34ResNet34包含多个layer,每个layer又包含多个Residual block用子module来实现Residual block,用_make_layer函数来实现layer"""def __init__(self, blocks, num_classes=1000):super(ResNet, self).__init__()self.model_name = 'resnet34'# 前几层: 图像转换self.pre = nn.Sequential(nn.Conv2d(3, 64, 7, 2, 3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, 1))# 重复的layer,分别有3,4,6,3个residual blockself.layer1 = self._make_layer(64, 64, blocks[0])self.layer2 = self._make_layer(64, 128, blocks[1], stride=2)self.layer3 = self._make_layer(128, 256, blocks[2], stride=2)self.layer4 = self._make_layer(256, 512, blocks[3], stride=2)# 分类用的全连接self.fc = nn.Linear(512, num_classes)def _make_layer(self, inchannel, outchannel, block_num, stride=1):"""构建layer,包含多个residual block"""shortcut = nn.Sequential(nn.Conv2d(inchannel, outchannel, 1, stride, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU())layers = []layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))for i in range(1, block_num):layers.append(ResidualBlock(outchannel, outchannel))return nn.Sequential(*layers)def forward(self, x):x = self.pre(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = F.avg_pool2d(x, 7)x = x.view(x.size(0), -1)return self.fc(x)def ResNet18():return ResNet([2, 2, 2, 2])def ResNet34():return ResNet([3, 4, 6, 3])if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = ResNet34()model.to(device)summary(model, (3, 224, 224))

输出各层的信息:

小总结

其实使用的话完全没必要自己写,可以直接 用API调用ResNet的预训练模型,然后修改最后的全连接层做迁移学习就行了,像下面这样。

import torchvision
model = torchvision.models.resnet18(pretrained=True)

总的来说ResNet代码复现起来非常潦草,可以看到我都不像之前那样加入数据集和可视化结果了,代码写一半,感觉自己太菜了,差好多东西, 我得继续去学习了,后面论文先不看了。完整代码我放到Github上了一份:https://github.com/shitbro6/paper

CNN基础论文 精读+复现---- ResNet(二)相关推荐

  1. CNN基础论文 精读+复现----LeNet5 (二)

    文章目录 复现准备 数据部分 搭建网络结构 C1层: S2层: C3层: S4层: C5层: F6层: Output层: 损失函数与优化器: 复现准备 论文开头的一些概念和思想已经分析完.没看过的可以 ...

  2. CNN基础论文 精读+复现----VGG(一)

    文章目录 前言 第1页 第2-3页 第四页 第五页 前言 原文Github地址:https://github.com/shitbro6/paper/blob/main/VGG.pdf 原文arxiv地 ...

  3. CNN基础论文 精读+复现----GoogleNet InceptionV1 (一)

    文章目录 前言 第1页 摘要与引言 第2页 文献综述 第3-4页 第4-5页 inception模块细节 第5-7页 GoogLeNet 第8页 训练细节 第8-10页 ILSVRC 2014 inc ...

  4. 【推荐系统论文精读系列】(二)--Factorization Machines

    文章目录 一.摘要 二.介绍 三.稀疏性下预测 四.分解机(FM) A. Factorization Machine Model B. Factorization Machines as Predic ...

  5. 进阶必备:CNN经典论文代码复现 | 附下载链接

    经常会看到类似的广告<面试算法岗,你被要求复现论文了吗?>不好意思,我真的被问过这个问题.当然也不是所有面试官都会问,究其原因,其实也很好理解.企业肯定是希望自己的产品是有竞争力,有卖点的 ...

  6. 李沐论文精读: ResNet 《Deep Residual Learning for Image Recognition》 by Kaiming He

    目录 1 摘要 主要内容 主要图表 2 导论 2.1为什么提出残差结构 2.2 实验验证 3 实验部分 3.1 不同配置的ResNet结构 3.2 残差结构效果对比 3.3 残差结构中,输入输出维度不 ...

  7. 【论文精读】resnet精读

    跟李沐学AI的b站视频视频-论文精读笔记第二期 包含resnet论文精读第一遍和论文精读第二遍

  8. 【推荐系统论文精读系列】(八)--Deep Crossing:Web-Scale Modeling without Manually Crafted Combinatorial Features

    文章目录 一.摘要 二.介绍 三.相关工作 四.搜索广告 五.特征表示 5.1 独立特征 5.2 组合特征 六.模型架构 6.1 Embedding层 6.2 Stacking层 6.3 Residu ...

  9. 【推荐系统论文精读系列】(五)--Neural Collaborative Filtering

    文章目录 一.摘要 二.介绍 三.准备知识 3.1 从隐式数据中进行学习 3.2 矩阵分解 四.神经协同过滤 4.1 总体框架 4.1.1 学习NCF 4.2 广义矩阵分解(GMF) 4.3 多层感知 ...

最新文章

  1. c语言表达逻辑量的方法,c语言中用什么表示逻辑量为真
  2. iframe嵌套网页
  3. AttributeError:module tensorflow no attribute app解决办法
  4. 【每日算法】基数排序算法
  5. unity ppr_智能自动PPR更改事件策略
  6. Golang 实现tcp转发代理
  7. 海康相机SDK+halcon17(64位)+MFC+VS(64位)联合开发遇到的问题(在使用GenImage3Extern将RGB数据转换为halcon图像时出现异常情况处理)
  8. gcc离线安装 ubuntu 不用编译_「ubuntu安装gcc」ubuntu18.04安装gcc详细步骤(附问题集) - seo实验室...
  9. HDU杭电操作系统实验报告-操作系统课程设计-咸鱼的自留地
  10. 用 js判断 一个数是否是素数(质数)_js 基础算法题(二)
  11. 真解决EasyUi的 select 使用 class=“easyui-combobox“ 样式绑定onSelect/onChange事件
  12. 常见的网络摄像机方案
  13. 块截断编码图像压缩技术
  14. [原创] Python3.6+request+beautiful 半次元Top100 爬虫实战,将小姐姐的cos美图获得
  15. PTA L1-049 天梯赛座位分配(20分)(python)
  16. PX90---Lags Backs
  17. HCNP——RIPv1和RIPv2概况
  18. Hacked by 1BYTE
  19. scratch3.0加载自己的作品最新版
  20. JavaEE 之 Mybatis

热门文章

  1. 雪淇MM最经典的10句话
  2. SYN480R 解码
  3. 电脑上的软件和硬件怎么区分?
  4. ASP.NET印刷行业印务管理系统,源码免费分享
  5. linux xenserver教程,XenServer 6.5安装图文教程
  6. 一篇搞懂ddt数据驱动测试
  7. C语言程序设计-现代方法 第二版 第3.2.3小节 分数相加
  8. leetcode-1109. 航班预订统计(C++|差分)
  9. BroadLink智能遥控器
  10. 用单片机解码红外遥控器