Kaiming He的深度残差网络(ResNet)在深度学习的发展中起到了很重要的作用,ResNet不仅一举拿下了当年CV下多个比赛项目的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题。

首先来看看ResNet的网络结构,这里选取的是ResNet的一个变种:ResNet34。ResNet的网络结构如图所示,可见除了最开始的卷积池化和最后的池化全连接之外,网络中有很多结构相似的单元,这些重复单元的共同点就是有个跨层直连的shortcut。ResNet中将一个跨层直连的单元称为Residual block,其结构如图所示,左边部分是普通的卷积网络结构,右边是直连,但如果输入和输出的通道数不一致,或其步长不为1,那么就需要有一个专门的单元将二者转成一致,使其可以相加。

另外我们可以发现Residual block的大小也是有规律的,在最开始的pool之后有连续的几个一模一样的Residual block单元,这些单元的通道数一样,在这里我们将这几个拥有多个Residual block单元的结构称之为layer,注意和之前讲的layer区分开来,这里的layer是几个层的集合。

考虑到Residual block和layer出现了多次,我们可以把它们实现为一个子Module或函数。这里我们将Residual block实现为一个子moduke,而将layer实现为一个函数。下面是实现代码,规律总结如下:

  • 对于模型中的重复部分,实现为子module或用函数生成相应的modulemake_layer
  • nn.Module和nn.Functional结合使用
  • 尽量使用nn.Seqential

from torch import  nn
import torch as t
from torch.nn import  functional as F
class ResidualBlock(nn.Module):'''实现子module: Residual Block'''def __init__(self, inchannel, outchannel, stride=1, shortcut=None):super(ResidualBlock, self).__init__()self.left = nn.Sequential(nn.Conv2d(inchannel,outchannel,3,stride, 1,bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),nn.BatchNorm2d(outchannel) )self.right = shortcutdef forward(self, x):out = self.left(x)residual = x if self.right is None else self.right(x)out += residualreturn F.relu(out)class ResNet(nn.Module):'''实现主module:ResNet34ResNet34 包含多个layer,每个layer又包含多个residual block用子module来实现residual block,用_make_layer函数来实现layer'''def __init__(self, num_classes=1000):super(ResNet, self).__init__()# 前几层图像转换self.pre = nn.Sequential(nn.Conv2d(3, 64, 7, 2, 3, bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(3, 2, 1))# 重复的layer,分别有3,4,6,3个residual blockself.layer1 = self._make_layer( 64, 64, 3)self.layer2 = self._make_layer( 64, 128, 4, stride=2)self.layer3 = self._make_layer( 128, 256, 6, stride=2)self.layer4 = self._make_layer( 256, 512, 3, stride=2)#分类用的全连接self.fc = nn.Linear(512, num_classes)def _make_layer(self,  inchannel, outchannel, block_num, stride=1):'''构建layer,包含多个residual block'''shortcut = nn.Sequential(nn.Conv2d(inchannel,outchannel,1,stride, bias=False),nn.BatchNorm2d(outchannel))layers = []layers.append(ResidualBlock(inchannel, outchannel, stride, shortcut))for i in range(1, block_num):layers.append(ResidualBlock(outchannel, outchannel))return nn.Sequential(*layers)def forward(self, x):x = self.pre(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = F.avg_pool2d(x, 7)x = x.view(x.size(0), -1)return self.fc(x)
model = ResNet()
input  = t.autograd.Variable(t.randn(1, 3, 224, 224))
o = model(input)
print(o)

与PyTorch配套的图像工具包torchvision已经实现了深度学习中大多数经典的模型,其中就包括ResNet34,读者可以通过下面两行代码使用:

from torchvision import models
model = models.resnet34()

本例中ResNet34的实现就是参考了torchvision中的实现并做了简化

PyTorch实战福利从入门到精通之五——搭建ResNet相关推荐

  1. PyTorch实战福利从入门到精通之三——autograd

    autograd 反向传播过程需要手动实现.这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查.t ...

  2. PyTorch实战福利从入门到精通之一——PyTorch框架安装

    使用conda安装是最不容易出错的,在pytroch的官网可以选择自己需要的操作系统.python版本.cuda版本的pytorch框架. 之后复制下面的命令就可以了 安装完这个还要安个numpy p ...

  3. PyTorch实战福利从入门到精通之六——线性回归

    一元线性回归 一元线性模型非常简单,假设我们有变量 xix_ixi​ 和目标 yiy_iyi​,每个 i 对应于一个数据点,希望建立一个模型 y^i=wxi+b\hat{y}_i = w x_i + ...

  4. PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类

    在本教程中,我们将使用CIFAR10数据集.它有类别:"飞机"."汽车"."鸟"."猫"."鹿".& ...

  5. PyTorch实战福利从入门到精通之八——深度卷积神经网络(AlexNet)

    在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机.虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意.一方面,神经网络计算复 ...

  6. PyTorch实战福利从入门到精通之七——卷积神经网络(LeNet)

    卷积神经网络就是含卷积层的网络.介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet [1].这个名字来源于LeNet论文的第一作者Yann LeCun.LeNet展示了通过梯度下降训练卷积神经 ...

  7. PyTorch实战福利从入门到精通之九——数据处理

    在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像.文本.语音或其它二进制数据等.数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果.考虑到这 ...

  8. PyTorch实战福利从入门到精通之二——Tensor

    Tensor又名张量,也是Tensorflow等框架中的重要数据结构.它可以是一个数(标量),一维数组(向量),二维数组或更高维数组.Tensor支持GPU加速. 创建Tensor 几种常见创建Ten ...

  9. 【Python】Python实战从入门到精通之五 -- 教你使用文件写入

    本文是<Python实战从入门到精通>系列之第5篇 [Python]Python实战从入门到精通之一 -- 教你深入理解Python中的变量和数据类型 [Python]Python实战从入 ...

最新文章

  1. Linux那些事儿 之 戏说USB(28)设备的生命线(十一)
  2. 让你完全理解base64是怎么回事
  3. 【转】iOS开发24:使用SQLite3存储和读取数据
  4. 1.1 编程语言介绍
  5. 【20161108】总结
  6. L2与L1正则化理解
  7. 户外演出系统服务器,演艺灯光系统
  8. GPS Programming Tips for Windows Mobile
  9. 通达信公式改写成python代码
  10. Android 仿百合网超火爆社交app首页滑动效果
  11. Summary——CrowdPose: Efficient Crowded Scenes Pose Estimation and A New Benchmark
  12. 看完这篇文章APP关键词覆盖增加70000|互联网行业公会
  13. 25.JavaScript的Symbol类型、隐藏属性、全局注册表
  14. 【单片机笔记】基于2G、4G通信的物联网数据方案及扫码支付方案
  15. 区间估计Bootstraping/Jackknife
  16. Redux以及Flux介绍
  17. 5.编写程序,建立一个含有5名学生成绩的文件:stu1.txt, 解释说明:为了避免测试代码时,反复从屏幕输入数据样例,我这里将数据存在f2.txt文件中,使用freopen()函数访问并读出数据
  18. rg1 蓝光危害rg0_LED蓝光危害评价的最新标准及测试方案介绍
  19. MongoDB年终大会转移至线上进行 | 周五参会指南
  20. Elastic与阿里云助力汽车及出行产业数字化转型

热门文章

  1. 风控五大模型、三大风险指的是什么--几大模型PD、LGD、评分模型都有哪些细节点
  2. fastclick源码简析
  3. 获取客户端的IP地址
  4. kafka自定义序列化器
  5. 字符串之String类
  6. LeetCode Online Judge 题目C# 练习 - Search in Rotated Sorted Array II
  7. ionic3 生命周期
  8. 2019/2/17 Python今日收获
  9. Mybatis笔记 – Po映射类型
  10. 酒精测试仪检定设备设计与验证