Pytorch实现ResNet50网络结构,包含ResNet18,ResNet34,ResNet50,ResNet101,ResNet152
创建各版本的ResNet模型,ResNet18,ResNet34,ResNet50,ResNet101,ResNet152
原文地址: https://arxiv.org/pdf/1512.03385.pdf
论文就不解读了,大部分解读都是翻译,看的似懂非懂,自己搞懂就行了。
最近想着实现一下经典的网络结构,看了原文之后,根据原文代码结构开始实现。
起初去搜了下各种版本的实现,发现很多博客都是错误百出,有些博文都发布几年了,错误还是没人发现,评论区几十号人不知道是真懂还是装懂,颇有些无奈啊。
因此打算自己手动实现网络结构,锻炼下自己的代码能力,也加深对网络结构的理解。
写完之后也很欣慰,毕竟一直认为自己是个菜鸡,最近竟然接连不断的发现很多博文的错误之处,而且很多人看后都没发现的,想想自己似乎还有点小水平。
最后在一套代码里,实现了各版本ResNet,为了方便。
其实最后还是觉得应该每个网络分开写比较好。因为不同版本的网络内部操作是有很大差异的,本文下面的代码是将ResidualBlock和 BottleNeckBlock分开写的,但是在维度的变换上差异还是很复杂,一方面想提高代码的复用性,另一方面也受制于复杂度。所以最后写出的算不上高复用性的精简代码。勉强能用。关于ResNet的结构,除各版本分开写之外,重复的block其实也可以分开写,因为BottleNeckBlock的维度变换太复杂,参数变换多,能分开就分开,复杂度小的地方可以复用。
以下是网络结构和实现代码,检验后都是对的;水平有限,如发现有错误,欢迎评论告知!
1 残差结构图
2 VGG-19与ResNet34结构比较
3 ResNet各版本的结构
4 代码实现各版本
import torch.nn as nn
from torch.nn import functional as Fclass ResNetModel(nn.Module):"""实现通用的ResNet模块,可根据需要定义"""def __init__(self, num_classes=1000, layer_num=[],bottleneck = False):super(ResNetModel, self).__init__()#conv1self.pre = nn.Sequential(#in 224*224*3nn.Conv2d(3,64,7,2,3,bias=False), #输入通道3,输出通道64,卷积核7*7*64,步长2,根据以上计算出padding=3#out 112*112*64nn.BatchNorm2d(64), #输入通道C = 64nn.ReLU(inplace=True), #inplace=True, 进行覆盖操作# out 112*112*64nn.MaxPool2d(3,2,1), #池化核3*3,步长2,计算得出padding=1;# out 56*56*64)if bottleneck: #resnet50以上使用BottleNeckBlockself.residualBlocks1 = self.add_layers(64, 256, layer_num[0], 64, bottleneck=bottleneck)self.residualBlocks2 = self.add_layers(128, 512, layer_num[1], 256, 2,bottleneck)self.residualBlocks3 = self.add_layers(256, 1024, layer_num[2], 512, 2,bottleneck)self.residualBlocks4 = self.add_layers(512, 2048, layer_num[3], 1024, 2,bottleneck)self.fc = nn.Linear(2048, num_classes)else: #resnet34使用普通ResidualBlockself.residualBlocks1 = self.add_layers(64,64,layer_num[0])self.residualBlocks2 = self.add_layers(64,128,layer_num[1])self.residualBlocks3 = self.add_layers(128,256,layer_num[2])self.residualBlocks4 = self.add_layers(256,512,layer_num[3])self.fc = nn.Linear(512, num_classes)def add_layers(self, inchannel, outchannel, nums, pre_channel=64, stride=1, bottleneck=False):layers = []if bottleneck is False:#添加大模块首层, 首层需要判断inchannel == outchannel ?#跨维度需要stride=2,shortcut也需要1*1卷积扩维layers.append(ResidualBlock(inchannel,outchannel))#添加剩余nums-1层for i in range(1,nums):layers.append(ResidualBlock(outchannel,outchannel))return nn.Sequential(*layers)else: #resnet50使用bottleneck#传递每个block的shortcut,shortcut可以根据是否传递pre_channel进行推断#添加首层,首层需要传递上一批blocks的channellayers.append(BottleNeckBlock(inchannel,outchannel,pre_channel,stride))for i in range(1,nums): #添加n-1个剩余blocks,正常通道转换,不传递pre_channellayers.append(BottleNeckBlock(inchannel,outchannel))return nn.Sequential(*layers)def forward(self, x):x = self.pre(x)x = self.residualBlocks1(x)x = self.residualBlocks2(x)x = self.residualBlocks3(x)x = self.residualBlocks4(x)x = F.avg_pool2d(x, 7)x = x.view(x.size(0), -1)return self.fc(x)class ResidualBlock(nn.Module):'''定义普通残差模块resnet34为普通残差块,resnet50为瓶颈结构'''def __init__(self, inchannel, outchannel, stride=1, padding=1, shortcut=None):super(ResidualBlock, self).__init__()#resblock的首层,首层如果跨维度,卷积stride=2,shortcut需要1*1卷积扩维if inchannel != outchannel:stride= 2shortcut=nn.Sequential(nn.Conv2d(inchannel,outchannel,1,stride,bias=False),nn.BatchNorm2d(outchannel))# 定义残差块的左部分self.left = nn.Sequential(nn.Conv2d(inchannel, outchannel, 3, stride, padding, bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel, outchannel, 3, 1, padding, bias=False),nn.BatchNorm2d(outchannel),)#定义右部分self.right = shortcutdef forward(self, x):out = self.left(x)residual = x if self.right is None else self.right(x)out = out + residualreturn F.relu(out)class BottleNeckBlock(nn.Module):'''定义resnet50的瓶颈结构'''def __init__(self,inchannel,outchannel, pre_channel=None, stride=1,shortcut=None):super(BottleNeckBlock, self).__init__()#首个bottleneck需要承接上一批blocks的输出channelif pre_channel is None: #为空则表示不是首个bottleneck,pre_channel = outchannel #正常通道转换else: # 传递了pre_channel,表示为首个block,需要shortcutshortcut = nn.Sequential(nn.Conv2d(pre_channel,outchannel,1,stride,0,bias=False),nn.BatchNorm2d(outchannel))self.left = nn.Sequential(#1*1,inchannelnn.Conv2d(pre_channel, inchannel, 1, stride, 0, bias=False),nn.BatchNorm2d(inchannel),nn.ReLU(inplace=True),#3*3,inchannelnn.Conv2d(inchannel,inchannel,3,1,1,bias=False),nn.BatchNorm2d(inchannel),nn.ReLU(inplace=True),#1*1,outchannelnn.Conv2d(inchannel,outchannel,1,1,0,bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),)self.right = shortcutdef forward(self,x):out = self.left(x)residual = x if self.right is None else self.right(x)return F.relu(out+residual)if __name__ == '__main__':# channel_nums = [64,128,256,512,1024,2048]num_classes = 6#layers = 18, 34, 50, 101, 152layer_nums = [[2,2,2,2],[3,4,6,3],[3,4,6,3],[3,4,23,3],[3,8,36,3]]#选择resnet版本,# resnet18 ——0;resnet34——1,resnet-50——2,resnet-101——3,resnet-152——4i = 3;bottleneck = i >= 2 #i<2, false,使用普通的ResidualBlock; i>=2,true,使用BottleNeckBlockmodel = ResNetModel(num_classes,layer_nums[i],bottleneck)print(model)
Pytorch实现ResNet50网络结构,包含ResNet18,ResNet34,ResNet50,ResNet101,ResNet152相关推荐
- 【pytorch】ResNet18、ResNet20、ResNet34、ResNet50网络结构与实现
文章目录 ResNet主体 BasicBlock ResNet18 ResNet34 ResNet20 Bottleneck Block ResNet50 ResNet到底解决了什么问题 选取经典的早 ...
- 通过和resnet18和resnet50理解PyTorch的ResNet模块
文章目录 模型介绍 resnet18模型流程 总结 resnet50 总结 resnet和resnext的框架基本相同的,这里先学习下resnet的构建,感觉高度模块化,很方便.本文算是对 PyTor ...
- 小白入门计算机视觉系列——ReID(二):baseline构建:基于PyTorch的全局特征提取网络(Finetune ResNet50+tricks)
ReID(二):baseline构建:基于PyTorch的全局特征提取网络(Finetune ResNet50+tricks) 本次带来的是计算机视觉中比较热门的重点的一块,行人重识别(也叫Perso ...
- resnet50网络结构_Resnet50详解与实践(基于mindspore)
1. 简述 Resnet是残差网络(Residual Network)的缩写,该系列网络广泛用于目标分类等领域以及作为计算机视觉任务主干经典神经网络的一部分,典型的网络有resnet50, resne ...
- Resnet-50网络结构详解
解决的问题: 梯度消失,深层网络难训练. 因为梯度反向传播到前面的层,重复相乘可能使梯度无穷小.结果就是,随着网络的层数更深,其性能趋于饱和,甚至迅速下降. 关于为什么残差结构(即多了一条跳跃连接线后 ...
- RNN知识+LSTM知识+encoder-decoder+ctc+基于pytorch的crnn网络结构
一.基础知识: 下图是一个循环神经网络实现语言模型的示例,可以看出其是基于当前的输入与过去的输入序列,预测序列的下一个字符. 序列特点就是某一步的输出不仅依赖于这一步的输入,还依赖于其他步的输入或输出 ...
- ResNet50 网络结构搭建(PyTorch)
ResNet50是一个经典的特征提取网络结构,虽然Pytorch已有官方实现,但为了加深对网络结构的理解,还是自己动手敲敲代码搭建一下.需要特别说明的是,笔者是以熟悉网络各层输出维度变化为目的的,只对 ...
- resnet50网络结构_学习笔记(一):分析resnet源码理解resnet网络结构
最近在跑实验的过程中一直在使用resnet50和resnet34,为了弄清楚网络的结构和原理的实现,打开resnet的源码进行了学习. 残差网络学习的原理 针对神经网络过深而导致的学习准确率饱和甚至是 ...
- resnet50网络结构_AAAI2020 | 利用网络结构关系加速NAS+Layer
这是我在海康威视研究院实习的工作,被AAAI-2020接收为Spotlight. 论文地址:http://xxx.itp.ac.cn/pdf/2002.12580v1 引子 上一个阶段的网络结构搜索( ...
- 【深度学习】resnet-50网络结构
最近许多目标检测网络的backbone都有用到resnet-50的部分结构,于是找到原论文,看了一下网络结构,在这里做一个备份,需要的时候再来看看. 整体结构 layer0 首先是layer0,这部分 ...
最新文章
- 修改windows系統下xampp中apache端口被其他程式占用的問題
- go语言学习笔记(2)命令源码文件
- 关于主键的设计、primary key
- jdk8 string::_JDK 12的String :: transform方法的简要但复杂的历史
- java逸出_Java并发编程 - 对象的共享
- MySQL主键自增长报duplicate_MySQL使用on duplicate key update引起主键不连续自增
- GANs最新综述论文: 生成式对抗网络及其变种如何有用【附pdf下载】
- python中的pyinstaller库_Python(00):PyInstaller库,打包成exe基本介绍
- android:fillviewport=true 不起作用,无法在android模拟器中滚动
- 用python实现的的手写数字识别器
- 蓝桥杯 ALGO-71 算法训练 比较字符串
- 深入浅出 Javascript API(二)--地图显示与基本操作
- c# 匿名用戶登錄以後的事件處理
- MDK5如何生成bin文件
- 我的爷爷(知识渊博的下乡知青)
- PPT中表格的插入与结构调整
- ctc decoder
- python修改表格居中_python修改表格居中_CSS样式更改——列表、表格和轮廓
- js日期时间格式化yyyy-mm-dd hh:ii:ss
- python爬取高清动图