resnet中最中重要的就是代码残差块,

def forward(self,x)identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsampleout += identityout = self.relu(out)return out

残差块结构:

左边的就叫做BasicBlock,右边就叫bottleneck

BasicBlock

class BasicBlock(nn.Module):expansion =1def __init__(self, inplanes, planes, stride=1, downsample=None):super(BasicBlock,self).__init__()self.conv1 = conv3x3(inplanes,planes,stride)self.bn1 = nn.BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes,planes)self.bn2 = nn.BatchNorm2d(planes)self.downsample = downsampleself.stride =stride
def forward(self,x)identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsampleout += identityout = self.relu(out)return out

bottleneck

注意Res18、Res34用的是BasicBlock,其余用的是Bottleneck

resnet18: ResNet(BasicBlock, [2, 2, 2, 2])

resnet34: ResNet(BasicBlock, [3, 4, 6, 3])

resnet50:ResNet(Bottleneck, [3, 4, 6, 3])

resnet101:ResNet(Bottleneck, [3, 4, 23, 3])

resnet152:ResNet(Bottleneck, [3, 8, 36, 3])

expansion = 4,因为Bottleneck中每个残差结构输出维度都是输入维度的4倍

class Bottleneck(nn.Module):
    expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = conv1x1(inplanes, planes)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = conv1x1(planes, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

def forward(self, x):
        identity = x

out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

out = self.conv3(out)
        out = self.bn3(out)

if self.downsample is not None:
            identity = self.downsample(x)

out += identity
        out = self.relu(out)

return out

ResNet类

几个关键点:

1.在残差结构之前,先对原始224 x 224的图片处理,在经过7 x 7的大卷积核、BN、ReLU、最大池化之后得到56 x 56 x 64的feature map
2.从layer1、layer2、layer3、layer4的定义可以看出,第一个stage不会减小feature map,其余都会在stage的第一层用步长2的3 x 3卷积进行feature map长和宽减半
3._make_layer函数中downsample对残差结构的输入进行升维,直接1 x 1卷积再加上BN即可,后面BasicBlock类和Bottleneck类用得到
4.最后的池化层使用的是自适应平均池化,而非论文中的全局平均池化

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000, zero_init_residual=False):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion),
            )

layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

return nn.Sequential(*layers)

def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

return x

resent代码详解相关推荐

  1. 【CV】Pytorch一小时入门教程-代码详解

    目录 一.关键部分代码分解 1.定义网络 2.损失函数(代价函数) 3.更新权值 二.训练完整的分类器 1.数据处理 2. 训练模型(代码详解) CPU训练 GPU训练 CPU版本与GPU版本代码区别 ...

  2. html5代码转换为视频,HTML5中的视频代码详解

    摘要 腾兴网为您分享:HTML5中的视频代码详解,智学网,云闪付,易推广,小红书等软件知识,以及360win10,流量魔盒,fitbit,上港商城,安卓2.3.7,全民惠,五年级下册英语单词表图片,t ...

  3. js php base64,JavaScript实现Base64编码与解码的代码详解

    本篇文章给大家分享的是jJavaScript实现Base64编码与解码的代码详解,内容挺不错的,希望可以帮助到有需要的朋友 一.加密解密方法使用//1.加密 var str = '124中文内容'; ...

  4. yii mysql 事务处理_Yii2中事务的使用实例代码详解

    前言 一般我们做业务逻辑,都不会仅仅关联一个数据表,所以,会面临事务问题. 数据库事务(Database Transaction) ,是指作为单个逻辑工作单元执行的一系列操作,要么完全地执行,要么完全 ...

  5. 代码详解|tensorflow实现 聊天AI--PigPig养成记(1)

    Chapter1.代码详解 完整代码github链接,Untitled.ipynb文件内. [里面的测试是还没训练完的时候测试的,今晚会更新训练完成后的测试结果] 修复了网上一些代码的bug,解决了由 ...

  6. vue build text html,Vue中v-text / v-HTML使用实例代码详解_放手_前端开发者

    废话少说,代码如下所述: /p> 显示123 /p> 补充:vuejs {{}},v-text 和 v-html的区别 {{message}} let app = new Vue({ el ...

  7. sift计算描述子代码详解_代码详解——如何计算横向误差?

    在路径跟踪控制的论文中,我们常会看到判断精确性的指标,即横向误差和航向误差,那么横向误差和航向误差如何获得? 在前几期代码详解中,参考路径和实际轨迹均由To Workspace模块导出,如图所示: 那 ...

  8. 委托与事件代码详解与(Object sender,EventArgs e)详解

    委托与事件代码详解 using System; using System.Collections.Generic; using System.Text; namespace @Delegate //自 ...

  9. python怎么画条形图-python绘制条形图方法代码详解

    1.首先要绘制一个简单的条形图 import numpy as np import matplotlib.pyplot as plt from matplotlib import mlab from ...

  10. python代码大全表解释-python操作列表的函数使用代码详解

    python的列表很重要,学习到后面你会发现使用的地方真的太多了.最近在写一些小项目时经常用到列表,有时其中的方法还会忘哎! 所以为了复习写下了这篇博客,大家也可以来学习一下,应该比较全面和详细了 列 ...

最新文章

  1. PL/SQL 操作数据库常见脚本
  2. 利用iframe与Response.Flush实现进度展示效果
  3. 基于visual Studio2013解决C语言竞赛题之0502最小数替换
  4. 稀疏矩阵十字链表类java_稀疏矩阵的十字链表存储表示
  5. 最全的B端产品经理干货知识(2)
  6. jQuery 基础事件
  7. Ubantu install jdk
  8. MATLAB R2016a 简单介绍
  9. 49个Excel常用技巧
  10. Android——ViewHolder的作用与用法
  11. python opencv将图片转为灰度图
  12. 第二节 物料清单(BOM)
  13. Android 7.0 ----- Direct Boot模式(AppClock)
  14. 计算机中线性结构定义,数据结构基本概念
  15. iPhone设置整点报时提醒
  16. Postgre SQL 中的时间格式
  17. 如何将html内容解码,3.5.3 对HTML进行编码和解码
  18. JS实现Web网页打印功能(IE)
  19. python写入文件乱码\u559c\u6b22\u4e00\u4e2a\u4eba
  20. 常用各个手机屏幕分辨率归纳。iphone5/iphone7/iphone7 plus/iphoneX/Android 分辨率大小归纳

热门文章

  1. Accurate, Large Minibatch SGD
  2. 数据--第23课 - 队列的优化实现
  3. 负载均衡常见问题之会话保持-粘滞会话(Sticky Sessions)
  4. 【持久化框架】SpringMVC+Spring4+Mybatis3集成,开发简单Web项目+源码下载 【转】...
  5. linux date 得到指定 datemonth 月的 开始一天 结束一天
  6. 编译HG255D的openwrt固件
  7. 更轻松的获取APK文件安装时间
  8. 笔记二:云上传与调用获取openid
  9. 【题解】CF#713 E-Sonya Partymaker
  10. 第5课 混合编程和芯片手册阅读