为了加深对resnet的理解,参考他人代码重写了一份pytorch版的resnet50。

注意resnet50中,只有131卷积,即Basic

1. ConvBlock

卷积层进入的模块,第一层要利用步长为2进行下采样。

# 每个stage第一个卷积模块,主要进行下采样
class ConvBlock(nn.Module):  def __init__(self, in_channel, f, filters, s):# resnet50 只有 131卷积# filters传入的是一个 [filter1, filter2, filter3] 列表super(ConvBlock,self).__init__()F1, F2, F3 = filtersself.stage = nn.Sequential(nn.Conv2d(in_channel,F1,1,stride=s, padding=0, bias=False), # 1*1卷积nn.BatchNorm2d(F1),nn.ReLU(True),nn.Conv2d(F1,F2,f,stride=1, padding=True, bias=False), # 3*3 卷积nn.BatchNorm2d(F2),nn.ReLU(True),nn.Conv2d(F2,F3,1,stride=1, padding=0, bias=False), # 1*1 卷积nn.BatchNorm2d(F3),)# 短路部分,从输入in_channel直接卷积到输出F3,恒等映射self.shortcut_1 = nn.Conv2d(in_channel, F3, 1, stride=s, padding=0, bias=False)self.batch_1 = nn.BatchNorm2d(F3)self.relu_1 = nn.ReLU(True)def forward(self, X):X_shortcut = self.shortcut_1(X)X_shortcut = self.batch_1(X_shortcut)X = self.stage(X)# 残差加和X = X + X_shortcutX = self.relu_1(X)return X

2. stage内部的卷积模块

# stage内部的卷积模块
class IndentityBlock(nn.Module):def __init__(self, in_channel, f, filters):super(IndentityBlock,self).__init__()F1, F2, F3 = filtersself.stage = nn.Sequential(nn.Conv2d(in_channel,F1,1,stride=1, padding=0, bias=False),nn.BatchNorm2d(F1),nn.ReLU(True),nn.Conv2d(F1,F2,f,stride=1, padding=True, bias=False),nn.BatchNorm2d(F2),nn.ReLU(True),nn.Conv2d(F2,F3,1,stride=1, padding=0, bias=False),nn.BatchNorm2d(F3),)self.relu_1 = nn.ReLU(True)def forward(self, X):X_shortcut = XX = self.stage(X)X = X + X_shortcutX = self.relu_1(X)return X

3. ResNet核心

resnet50:包含了50个conv2d操作,共分5个stage。stage又称作残差块

stage1: 1层卷积,1层池化
stage2: 三个131卷积(1*1, 3*3, 1*1), 共9层卷积
stage3: 四个131卷积,共12层卷积
stage4: 六个131卷积,共18层卷积
stage5: 三个131卷积,9层
fc: 全连接conv2d = 1+9+12+18+9+1 = 50
class ResModel(nn.Module):'''resnet50:包含了50个conv2d操作,共分5个stagestage1: 1层卷积,1层池化stage2: 三个131卷积(1*1, 3*3, 1*1), 共9层卷积stage3: 四个131卷积,共12层卷积stage4: 六个131卷积,共18层卷积stage5: 三个131卷积,9层fc: 全连接conv2d = 1+9+12+18+9+1 = 50'''def __init__(self, n_class):'''ConvBlock相比IndentityBlock,多了使用stride下采样的环节,且Identity不进行bn所以每个stage都是通过ConvBlock进入,决定是否下采样'''super(ResModel,self).__init__()self.stage1 = nn.Sequential(nn.Conv2d(3,64,7,stride=2, padding=3, bias=False), # 3通道输入,64个7*7卷积,步长2nn.BatchNorm2d(64),nn.ReLU(True),nn.MaxPool2d(3,2,padding=1),) self.stage2 = nn.Sequential(ConvBlock(64, f=3, filters=[64, 64, 256], s=1),IndentityBlock(256, 3, [64, 64, 256]),IndentityBlock(256, 3, [64, 64, 256]),) # 3个卷积self.stage3 = nn.Sequential(ConvBlock(256, f=3, filters=[128, 128, 512], s=2),IndentityBlock(512, 3, [128, 128, 512]),IndentityBlock(512, 3, [128, 128, 512]),IndentityBlock(512, 3, [128, 128, 512]),) # 4个卷积块self.stage4 = nn.Sequential(ConvBlock(512, f=3, filters=[256, 256, 1024], s=2),IndentityBlock(1024, 3, [256, 256, 1024]),IndentityBlock(1024, 3, [256, 256, 1024]),IndentityBlock(1024, 3, [256, 256, 1024]),IndentityBlock(1024, 3, [256, 256, 1024]),IndentityBlock(1024, 3, [256, 256, 1024]),) # 6个卷积块self.stage5 = nn.Sequential(ConvBlock(1024, f=3, filters=[512, 512, 2048], s=2),IndentityBlock(2048, 3, [512, 512, 2048]),IndentityBlock(2048, 3, [512, 512, 2048]),) # 3个卷积块# 平均池化self.pool = nn.AvgPool2d(2,2,padding=1)self.fc = nn.Sequential(nn.Linear(32768,n_class)  # 输入参数等于 = 2048 * 4 * 4, 4是最后一层池化输出大小)def forward(self, X):out = self.stage1(X)out = self.stage2(out)out = self.stage3(out)out = self.stage4(out)out = self.stage5(out)out = self.pool(out)# 输出的尺寸是 [32, 2048, 4, 4]out = out.view(batches, -1)# 调整后 [32, 32768]out = self.fc(out)return out

手写ResNet50——pytorch相关推荐

  1. 全手写resnet50,分分钟识别“十二生肖“图片

    大家好啊,我是董董灿. 前天,我完全手写的算法,并手搭的神经网络,终于成功的识别出一张图片. 点击链接查看出猫现场:我出猫了,第一阶段完成! 成功地识别出来猫,意味着我搭建的整个流程跑通了. 这个流程 ...

  2. 机器学习之神经网络的公式推导与python代码(手写+pytorch)实现

    文章目录 前言 神经网络公式推导 参数定义 前向传播(forward) 反向传播(backward) 隐藏层和输出层的权重更新 输入层和隐藏层的权重更新 代码实现 python手写实现 pytorch ...

  3. PyTorch基础与简单应用:构建卷积神经网络实现MNIST手写数字分类

    文章目录 (一) 问题描述 (二) 设计简要描述 (三) 程序清单 (四) 结果分析 (五) 调试报告 (六) 实验小结 (七) 参考资料 (一) 问题描述 构建卷积神经网络实现MNIST手写数字分类 ...

  4. 使用Pytorch实现手写数字识别

    使用Pytorch实现手写数字识别 1. 思路和流程分析 流程: 准备数据,这些需要准备DataLoader 构建模型,这里可以使用torch构造一个深层的神经网络 模型的训练 模型的保存,保存模型, ...

  5. 实例:手写 CUDA 算子,让 Pytorch 提速 20 倍

    作者丨PENG Bo@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/476297195 编辑丨极市平台 本文的代码,在 win10 和 linux 均可直接编译运行: ...

  6. PyTorch基础入门五:PyTorch搭建多层全连接神经网络实现MNIST手写数字识别分类

    )全连接神经网络(FC) 全连接神经网络是一种最基本的神经网络结构,英文为Full Connection,所以一般简称FC. FC的准则很简单:神经网络中除输入层之外的每个节点都和上一层的所有节点有连 ...

  7. 深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别

    一.前言 MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一. 二.MNIST数据集 ...

  8. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

    今天我们将使用 Pytorch 来实现 LeNet-5 模型,并用它来解决 MNIST数据集的识别. 正文开始! 一.使用 LeNet-5 网络结构创建 MNIST 手写数字识别分类器 MNIST是一 ...

  9. 基于PyTorch框架的多层全连接神经网络实现MNIST手写数字分类

    多层全连接神经网络实现MNIST手写数字分类 1 简单的三层全连接神经网络 2 添加激活函数 3 添加批标准化 4 训练网络 5 结论 参考资料 先用PyTorch实现最简单的三层全连接神经网络,然后 ...

最新文章

  1. Windows Server 2012 存储 (四) SMB 对SQL 数据库和Hyper-V的支持
  2. vss6 forgot admin password
  3. 基于ArcGIS JS API 的点击查询功能
  4. BZOJ 3143 Luogu P3232 [HNOI2013]游走 (DP、高斯消元)
  5. SAP CRM里note界面默认语言的决定逻辑
  6. Php如何过360拦截,PHP常见漏洞修复文件-360漏洞修复插件
  7. 成功者都在用的“成功咒语”
  8. 虚拟机VMware下安装Linux系统,Python3.7之TensorFlow安装
  9. 基于JavaScript技术完成单击事件完成显示和隐藏
  10. display:none与visible:hidden的区别 ?
  11. 程序员面试难题,在你结婚的时候领导要求你30分钟归队,你会如何
  12. JavaScript学习笔记(六)
  13. C语言应该增加交换值的关键字或语法
  14. Atitit 架构师之道 attilax著 1.1. 认和评估系统需求, 2 1.2. 给出开发规范 2 1.3. ,搭建系统实现的核心构架, 2 1.4. 扫清主要难点的技术人员 2 1.5. 核
  15. 【数据融合】基于DS证据理论实现数据融合附matlab代码
  16. 微信公众号内推送模板消息
  17. 教你修改Win7系统的登录界面背景
  18. Edge浏览器驱动更新
  19. 华为防火墙笔记-网络地址转化NAT
  20. c语言自定义sum函数,c语言自定义函数

热门文章

  1. php开发中常用函数总结
  2. python开发微信公众号开发教程百度云_Python开发微信公众号后台(系列一)
  3. PyCharm安装教程(简单又实用)
  4. div 隐藏和显示方式
  5. Harris-Benedict等式
  6. 询盘还盘等国际贸易(转)
  7. 开发者解决当前和未来挑战?英特尔On技术创新峰会中国在线会议来了丨Intel Innovation
  8. Powerbuilder遍历treeview
  9. 解决xshell 中文乱码
  10. Android之mp3播放器开发过程