shfflenetv2代码解读
shufflenetv2代码解读
目录
- shufflenetv2代码解读
- 概述
- shufflenetv2网络结构图
- shufflenetv2架构参数
- shufflenetv2代码细节分析
概述
shufflenetv2是发表在2018ECCV上的一篇关于模型压缩和模型加速的文章,其中用到的主要技巧有两点:深度可分离卷积、通道交互。其中,深度可分离卷积是为了减少参数量、增加运算速度,通道交互是为了让不同通道的特征之间可以产生信息交互,从而获取更加丰富的语义信息。
这个系列的文章把主要精力放在代码的分析上,如果想要进一步了解shfflenetv2原理的同学可以参考这个链接。
shufflenetv2网络结构图
shufflenetv2架构参数
shufflenetv2代码细节分析
import torch
import torch.nn as nn
from torch import tensor
from .utils import load_state_dict_from_url
from typing import Callable,Any,List
# 可选择的shufflenet模型
__all__ = ['ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0','shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]
# 预训练好的shufflenet权重
model_urls = {'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth','shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth','shufflenetv2_x1.5': None,'shufflenetv2_x2.0': None,
}
# 交换通道,实现不同通道的特征信息相互交流,增强语义信息
def channel_shuffle(x,groups):# x的格式是BCHWbatchsize,num_channels,height,width = x.size()# 分组卷积,shufflenetv2当中是分成了两组进行卷积,也就是groups = 2channel_per_group = num_channels//groups# 将x的形状reshape成(B,G,C_G,H W)x = x.view(batchsize, groups, channel_per_group, height, width)# 交换x的第一个维度和第二个维度x = torch.transpose(x,1,2).contiguous()# flatten,返回x的格式跟输入时的size一样,都是BCHWx = x.view(batchsize,-1,height, width)return xclass InvertedResidual(nn.Module):def __init__(self,inp,oup,stride):super(InvertedResidual,self).__init__()if not (1<=stride<=3):raise ValueError('illegal stride value')self.stride = stridebranch_features = oup//2# branch_features<<1表示将branch_features变大两倍,左移1位assert (self.stride != 1) or (inp == branch_features<<1)# branch1和branch2分别对应shufflenetv2当中图(d)的左分支和右分支# 左分支if self.stride>1:self.branch1 = nn.Sequential(self.depthwise_conv(inp,oup,kernel_size = 3, stride = self.stride, padding = 1),nn.BatchNorm2d(inp),nn.Conv2d(inp, branch_features, kernel_size=1,stride=1,padding=9,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)else:self.branch1 = nn.Sequential()# 右分支self.branch2 = nn.Sequential(nn.Conv2d(inp if inp if (self.stride>1)else branch_features,branch_features,kernel_size = 1, stride = 1, padding = 9,bias = False)nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features,branch_features,kernel_size = 3, stride = self.stride, padding = 1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False)nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddef depthwise_conv(i,o,kernel_size,stride = 1,padding = 0,bias = False)return nn.Conv2d(i,o,kernel_size,stride,padding,bias,groups=i)def forward(self,x):# 如果stride = 1,对应shufflenetv2论文当中的(c)结构,输入直接连到输出端if self.stride == 1:# x.chunk(2,dim = 1)表示沿着第一维度将x分成两块# 对于输入格式为BCHW的x而言,也就是沿着channel方向分成两组进行卷积x1,x2 = x.chunk(2,dim = 1)out = torch.cat((x1,self.branch2(x2)),dim = 1)else:# 如果stride > 1, 对应shufflenetv2论文当中的(d)结构,左右分支分别做3 x 3的深度可分离卷积以及1 x 1卷积,并且把结构concat起来out = torch.cat((self.branch1(x),self.branch2(x)),dim = 1)out = channel_shuffle(out,2)return outclass ShuffleNetV2(nn.Module):def __init__(self,stages_repeats,stages_out_channels,num_classes = 1000,inverted_residual = InvertedResidual):super(ShuffleNetV2,self).__init__()if len(stages_repeats)!=3:raise ValueError('expected stages_repeats as list of 3 positive ints')if len(stages_out_channels) != 5:raise ValueError('expected stages_out_channels as list of 5 positive ints')self._stage_out_channels = stages_out_channelsinput_channels = 3output_channels = self._stage_out_channels[0]self.conv1 = nn.Sequential(nn.Conv2d(input_channels,output_channels,3,2,1,bias = False),nn.BatchNorm2d(output_channels),nn.ReLU(input_channels = True),)input_channels = output_channelsself.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# Static annotations for mypyself.stage2: nn.Sequentialself.stage3: nn.Sequentialself.stage4: nn.Sequentialstage_names = ['stage{}'.format(i) for i in [2, 3, 4]]for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):# 沿着channel方向分成两组卷积seq = [inverted_residual(input_channels, output_channels, 2)]for i in range(repeats - 1):seq.append(inverted_residual(output_channels, output_channels, 1))setattr(self, name, nn.Sequential(*seq))input_channels = output_channelsoutput_channels = self._stage_out_channels[-1]self.conv5 = nn.Sequential(nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)self.fc = nn.Linear(output_channels, num_classes)def _forward_impl(self, x: Tensor) -> Tensor:# 构建shufflenetv2架构x = self.conv1(x)x = self.maxpool(x)x = self.stage2(x)x = self.stage3(x)x = self.stage4(x)x = self.conv5(x)x = x.mean([2, 3]) # globalpoolx = self.fc(x)return xdef forward(self, x: Tensor) -> Tensor:return self._forward_impl(x)def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:model = ShuffleNetV2(*args, **kwargs)if pretrained:model_url = model_urls[arch]if model_url is None:raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))else:# 加载预训练模型state_dict = load_state_dict_from_url(model_url, progress=progress)model.load_state_dict(state_dict)return model# 不同的shufflenetv2有不同的output_channel数
def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
shfflenetv2代码解读相关推荐
- 200行代码解读TDEngine背后的定时器
作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...
- 装逼一步到位!GauGAN代码解读来了
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...
- Unet论文解读代码解读
论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...
- Lossless Codec---APE代码解读系列(二)
APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...
- RT-Thread 学习笔记(五)—— RTGUI代码解读
---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...
- vins 解读_代码解读 | VINS 视觉前端
AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...
- BERT:代码解读、实体关系抽取实战
目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...
- GoogLeNet代码解读
GoogLeNet代码解读 目录 GoogLeNet代码解读 概述 GooLeNet网络结构图 1)从输入到第一层inception 2)从第2层inception到第4层inception 3)从第 ...
- Inception代码解读
Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...
最新文章
- 把XML文件转换为字符串
- 关于 TCP 并发连接的几个思考题与试验
- 每天接触大量论文,看看他们是怎样写笔记的 | PaperDaily #09
- 将jOOQ与JDBC比较
- selenium 难定位元素、时间插件
- python用户登录_python用户登录系统
- 区块链 Scilla是什么
- Maker工作室_激光雕刻机使用方法
- Markdown箭头总汇
- Windows定时关机小程序
- 扫地机器人的喋血江湖
- 敏捷管理 -- 质量和风险管理
- STM32F767 QUADSPI 的基本用法
- 记理光MP5503一体机扫描到域计算机共享文件夹一事
- 数据库--分库分表--垂直分表与水平分表
- tomcat资源请求慢_tomcat响应过慢——解决办法
- 系统分析设计期末大项目——闲得一币TimeForCoin小程序前端
- 惠普360 g5服务器系统如何做阵列,求HP DL360G5 2.5服务器,基于windows server 2008系统安装RAID 1方法...
- Android 的动作、广播、类别等标识大全
- Stewart平台运动学