shufflenetv2代码解读

目录

    • shufflenetv2代码解读
  • 概述
  • shufflenetv2网络结构图
  • shufflenetv2架构参数
  • shufflenetv2代码细节分析

概述

shufflenetv2是发表在2018ECCV上的一篇关于模型压缩和模型加速的文章,其中用到的主要技巧有两点:深度可分离卷积、通道交互。其中,深度可分离卷积是为了减少参数量、增加运算速度,通道交互是为了让不同通道的特征之间可以产生信息交互,从而获取更加丰富的语义信息。
这个系列的文章把主要精力放在代码的分析上,如果想要进一步了解shfflenetv2原理的同学可以参考这个链接。

shufflenetv2网络结构图

shufflenetv2架构参数

shufflenetv2代码细节分析

import torch
import torch.nn as nn
from torch import tensor
from .utils import load_state_dict_from_url
from typing import Callable,Any,List
# 可选择的shufflenet模型
__all__ = ['ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0','shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
]
# 预训练好的shufflenet权重
model_urls = {'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth','shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth','shufflenetv2_x1.5': None,'shufflenetv2_x2.0': None,
}
# 交换通道,实现不同通道的特征信息相互交流,增强语义信息
def channel_shuffle(x,groups):# x的格式是BCHWbatchsize,num_channels,height,width = x.size()# 分组卷积,shufflenetv2当中是分成了两组进行卷积,也就是groups = 2channel_per_group = num_channels//groups# 将x的形状reshape成(B,G,C_G,H W)x = x.view(batchsize, groups, channel_per_group, height, width)# 交换x的第一个维度和第二个维度x = torch.transpose(x,1,2).contiguous()# flatten,返回x的格式跟输入时的size一样,都是BCHWx = x.view(batchsize,-1,height, width)return xclass InvertedResidual(nn.Module):def __init__(self,inp,oup,stride):super(InvertedResidual,self).__init__()if not (1<=stride<=3):raise ValueError('illegal stride value')self.stride = stridebranch_features = oup//2# branch_features<<1表示将branch_features变大两倍,左移1位assert (self.stride != 1) or (inp == branch_features<<1)# branch1和branch2分别对应shufflenetv2当中图(d)的左分支和右分支# 左分支if self.stride>1:self.branch1 = nn.Sequential(self.depthwise_conv(inp,oup,kernel_size = 3, stride = self.stride, padding = 1),nn.BatchNorm2d(inp),nn.Conv2d(inp, branch_features, kernel_size=1,stride=1,padding=9,bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)else:self.branch1 = nn.Sequential()# 右分支self.branch2 = nn.Sequential(nn.Conv2d(inp if inp if (self.stride>1)else branch_features,branch_features,kernel_size = 1, stride = 1, padding = 9,bias = False)nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features,branch_features,kernel_size = 3, stride = self.stride, padding = 1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False)nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),)@staticmethoddef depthwise_conv(i,o,kernel_size,stride = 1,padding = 0,bias = False)return nn.Conv2d(i,o,kernel_size,stride,padding,bias,groups=i)def forward(self,x):# 如果stride = 1,对应shufflenetv2论文当中的(c)结构,输入直接连到输出端if self.stride == 1:# x.chunk(2,dim = 1)表示沿着第一维度将x分成两块# 对于输入格式为BCHW的x而言,也就是沿着channel方向分成两组进行卷积x1,x2 = x.chunk(2,dim = 1)out = torch.cat((x1,self.branch2(x2)),dim = 1)else:# 如果stride > 1, 对应shufflenetv2论文当中的(d)结构,左右分支分别做3 x 3的深度可分离卷积以及1 x 1卷积,并且把结构concat起来out = torch.cat((self.branch1(x),self.branch2(x)),dim = 1)out = channel_shuffle(out,2)return outclass ShuffleNetV2(nn.Module):def __init__(self,stages_repeats,stages_out_channels,num_classes = 1000,inverted_residual = InvertedResidual):super(ShuffleNetV2,self).__init__()if len(stages_repeats)!=3:raise ValueError('expected stages_repeats as list of 3 positive ints')if len(stages_out_channels) != 5:raise ValueError('expected stages_out_channels as list of 5 positive ints')self._stage_out_channels = stages_out_channelsinput_channels = 3output_channels = self._stage_out_channels[0]self.conv1 = nn.Sequential(nn.Conv2d(input_channels,output_channels,3,2,1,bias = False),nn.BatchNorm2d(output_channels),nn.ReLU(input_channels = True),)input_channels = output_channelsself.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# Static annotations for mypyself.stage2: nn.Sequentialself.stage3: nn.Sequentialself.stage4: nn.Sequentialstage_names = ['stage{}'.format(i) for i in [2, 3, 4]]for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):# 沿着channel方向分成两组卷积seq = [inverted_residual(input_channels, output_channels, 2)]for i in range(repeats - 1):seq.append(inverted_residual(output_channels, output_channels, 1))setattr(self, name, nn.Sequential(*seq))input_channels = output_channelsoutput_channels = self._stage_out_channels[-1]self.conv5 = nn.Sequential(nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True),)self.fc = nn.Linear(output_channels, num_classes)def _forward_impl(self, x: Tensor) -> Tensor:# 构建shufflenetv2架构x = self.conv1(x)x = self.maxpool(x)x = self.stage2(x)x = self.stage3(x)x = self.stage4(x)x = self.conv5(x)x = x.mean([2, 3])  # globalpoolx = self.fc(x)return xdef forward(self, x: Tensor) -> Tensor:return self._forward_impl(x)def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:model = ShuffleNetV2(*args, **kwargs)if pretrained:model_url = model_urls[arch]if model_url is None:raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))else:# 加载预训练模型state_dict = load_state_dict_from_url(model_url, progress=progress)model.load_state_dict(state_dict)return model# 不同的shufflenetv2有不同的output_channel数
def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,[4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)

shfflenetv2代码解读相关推荐

  1. 200行代码解读TDEngine背后的定时器

    作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...

  2. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  3. Unet论文解读代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...

  4. Lossless Codec---APE代码解读系列(二)

    APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...

  5. RT-Thread 学习笔记(五)—— RTGUI代码解读

    ---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...

  6. vins 解读_代码解读 | VINS 视觉前端

    AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...

  7. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

  8. GoogLeNet代码解读

    GoogLeNet代码解读 目录 GoogLeNet代码解读 概述 GooLeNet网络结构图 1)从输入到第一层inception 2)从第2层inception到第4层inception 3)从第 ...

  9. Inception代码解读

    Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...

最新文章

  1. 把XML文件转换为字符串
  2. 关于 TCP 并发连接的几个思考题与试验
  3. 每天接触大量论文,看看他们是怎样写笔记的 | PaperDaily #09
  4. 将jOOQ与JDBC比较
  5. selenium 难定位元素、时间插件
  6. python用户登录_python用户登录系统
  7. 区块链 Scilla是什么
  8. Maker工作室_激光雕刻机使用方法
  9. Markdown箭头总汇
  10. Windows定时关机小程序
  11. 扫地机器人的喋血江湖
  12. 敏捷管理 -- 质量和风险管理
  13. STM32F767 QUADSPI 的基本用法
  14. 记理光MP5503一体机扫描到域计算机共享文件夹一事
  15. 数据库--分库分表--垂直分表与水平分表
  16. tomcat资源请求慢_tomcat响应过慢——解决办法
  17. 系统分析设计期末大项目——闲得一币TimeForCoin小程序前端
  18. 惠普360 g5服务器系统如何做阵列,求HP DL360G5 2.5服务器,基于windows server 2008系统安装RAID 1方法...
  19. Android 的动作、广播、类别等标识大全
  20. Stewart平台运动学

热门文章

  1. 关于什么事情能做到和不能做到的思考
  2. NodeJs 创建一个简单的服务
  3. 用掘金-Markdown 编辑器写文章
  4. Java中传值与传地址
  5. MySql解压版使用
  6. 登陆页老是提示验证码错误,validate验证控件IE下用remote方法明明返回true 但是还是报错,提示验证码错误...
  7. C# 自定义 implicit和explicit转换
  8. 7、ReadWriteLock
  9. 3月第3周中国五大顶级域名总量增5.4万 美国减31.5万
  10. 前端代码标准最佳实践:javascript篇