一,残差网络架构

1,残差学习单元

上图左对应的是浅层网络(18层,34层),而右图对应的是深层网络(50,101,152)。

1. 左图为基本的residual block,residual mapping为两个64通道的3x3卷积,输入输出均为64通道,可直接相加。该block主要使用在相对浅层网络,比如ResNet-34;

2. 右图为针对深层网络提出的block,称为“bottleneck” block,主要目的就是为了降维。首先通过一个1x1卷积将256维通道(channel)降到64通道,最后通过一个256通道的1x1卷积恢复。

ResNet使用两种残差单元,其目的主要就是为了降低参数的数目。

2,残差学习好在哪里?

随着网络加深,梯度消失,模型准确率会先上升然后达到饱和,再持续增加深度时则会导致准确率下降。

残差跳跃式的结构,打破了传统的神经网络n-1层的输出只能给n层作为输入的惯例,使某一层的输出可以直接跨过几层作为后面某一层的输入,其意义在于为叠加多层网络而使得整个学习模型的错误率不降反升的难题提供了新的方向。至此,神经网络的层数可以超越之前的约束,达到几十层、上百层甚至千层,为高级语义特征提取和分类提供了可行性。

3,ResNet改进版本

新的残差学习单元比以前更容易训练且泛化性更强。

二,两种残差单元对应的Pytorch源码

  • a,两个卷积层的残差单元
class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(BasicBlock, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError('BasicBlock only supports groups=1 and base_width=64')if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")# Both self.conv1 and self.downsample layers downsample the input when stride != 1self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out
  • b, 有瓶颈块的三层残差单元
class Bottleneck(nn.Module):# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)# while original implementation places the stride at the first 1x1 convolution(self.conv1)# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.# This variant is also known as ResNet V1.5 and improves accuracy according to# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,base_width=64, dilation=1, norm_layer=None):super(Bottleneck, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.)) * groups# Both self.conv2 and self.downsample layers downsample the input when stride != 1self.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return out

c,整个ResNet架构

对输入图像使用一个步长为2,卷积核为7×7的卷积层(紧接着BN,Relu),紧接着是4个残差学习块,不同深度的残差网络包含的每层残差学习块的个数不同(参考图1残差网络架构中的参数),最后是一个全局池化层和一个全连接层。

class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000):self.inplanes = 64super(ResNet, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2)self.layer3 = self._make_layer(block, 256, layers[2], stride=2)self.layer4 = self._make_layer(block, 512, layers[3], stride=2)self.avgpool = nn.AvgPool2d(7, stride=1)self.fc = nn.Linear(512 * block.expansion, num_classes)

感谢解读

https://blog.csdn.net/chenyuping333/article/details/82344334

ResNet网络全部源码https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py

Pytorch ResNet源码学习相关推荐

  1. resnet50网络结构_学习笔记(一):分析resnet源码理解resnet网络结构

    最近在跑实验的过程中一直在使用resnet50和resnet34,为了弄清楚网络的结构和原理的实现,打开resnet的源码进行了学习. 残差网络学习的原理 针对神经网络过深而导致的学习准确率饱和甚至是 ...

  2. PyTorch源码学习系列 - 1.初识

    本系列文章会优先发布于微信公众号和知乎,欢迎大家关注 微信公众号:小飞怪兽屋 知乎: PyTorch源码学习系列 - 1.初识 - 知乎 (zhihu.com) 目录 本系列的目的 PyTorch是什 ...

  3. 【学习笔记】Keras库下的resnet源码分析

    Keras库下的resnet源码应用及解读 其实我也不知道这种东西有没有写下来的必要,但是跑代码的时候总摸鱼总归是不好的.虽然很简单,不过我也大概做个学习记录,写给小白看的.源码来自keras文档,大 ...

  4. Shiro源码学习之二

    接上一篇 Shiro源码学习之一 3.subject.login 进入login public void login(AuthenticationToken token) throws Authent ...

  5. Shiro源码学习之一

    一.最基本的使用 1.Maven依赖 <dependency><groupId>org.apache.shiro</groupId><artifactId&g ...

  6. mutations vuex 调用_Vuex源码学习(六)action和mutation如何被调用的(前置准备篇)...

    前言 Vuex源码系列不知不觉已经到了第六篇.前置的五篇分别如下: 长篇连载:Vuex源码学习(一)功能梳理 长篇连载:Vuex源码学习(二)脉络梳理 作为一个Web前端,你知道Vuex的instal ...

  7. vue实例没有挂载到html上,vue 源码学习 - 实例挂载

    前言 在学习vue源码之前需要先了解源码目录设计(了解各个模块的功能)丶Flow语法. src ├── compiler # 把模板解析成 ast 语法树,ast 语法树优化,代码生成等功能. ├── ...

  8. 2021-03-19Tomcat源码学习--WebAppClassLoader类加载机制

    Tomcat源码学习--WebAppClassLoader类加载机制 在WebappClassLoaderBase中重写了ClassLoader的loadClass方法,在这个实现方法中我们可以一窥t ...

  9. jQuery源码学习之Callbacks

    jQuery源码学习之Callbacks jQuery的ajax.deferred通过回调实现异步,其实现核心是Callbacks. 使用方法 使用首先要先新建一个实例对象.创建时可以传入参数flag ...

最新文章

  1. springMvc+mybatis+spring 整合 包涵整合activiti 基于maven
  2. java final关键字_终于明白 Java 为什么要加 final 关键字了!
  3. Logistic回归——二分类 —— matlab
  4. 文昌帝君 -- 《文昌帝君阴骘文》
  5. 环形链表得golang实现
  6. DataStream API及源算子
  7. linux远程控制本地用户登录,linux 本地无法登录 远程可以登陆的解决办法
  8. P1379 八数码难题
  9. html 5拜年贺卡,HTML5+CSS3实现春节贺卡
  10. 恩施机器人编程_恩施武汉机器人激光切割机
  11. python自动交易软件排名_量化投资软件排名 哪个量化交易软件最好用
  12. mars老师android开发视频教程5季+java4android视频教程
  13. 亲测可用[转]官方17ce老毛子Padavan华硕固件router插件安装方法|集成不占空间k1斐讯k2...
  14. js执行机制经典面试题(一)
  15. oracle dimension的探究(维度)
  16. 新浪微博2020界校招笔试-算法工程师
  17. 某软件测试大纲,软件测试(验收)大纲
  18. 多邻国-英语学习笔记
  19. VLC捕获网络摄像头视频(rtsp协议)
  20. Behavior Designer 干货总结

热门文章

  1. C语言函数:内存函数memmove()以及实现与使用。
  2. ABAP 基本类型 强制转换
  3. Git中tag的作用
  4. 自己动手写编译器:汤普森构造法
  5. 动态规划练习题(3)开餐馆
  6. Sams Teach Yourself SQL in 10 Minutes, Third Edition
  7. 阿里都在用的线上问题定位工具【收藏备用】
  8. 江苏高考状元在华尔街年薪1亿美元
  9. Sql Server 使用 SET NOCOUNT { ON | OFF}
  10. 我爱大自然教案计算机,我爱大自然大班教案