0.介绍

Squeezenet网址
torchvision.model.squeeze官方文档
主要思想:堆叠Fire模块,每个Fire模块,分别采用1x1和3x3两个分支,最后做拼;,每个Fire的尺寸不变,channel数不变或增加;每个stage的Fire模块之间用nn.MaxPool2d进行下采样;使用卷积层代替FC层,channel数为类别数

1.源码

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.hub import load_state_dict_from_url__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth','squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}class Fire(nn.Module): #Fire模块def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):super(Fire, self).__init__()self.inplanes = inplanesself.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)self.squeeze_activation = nn.ReLU(inplace=True)self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)self.expand1x1_activation = nn.ReLU(inplace=True)self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)self.expand3x3_activation = nn.ReLU(inplace=True)def forward(self, x):x = self.squeeze(x)x = self.squeeze_activation(x)return torch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))], 1)class SqueezeNet(nn.Module):def __init__(self, version='1.0', num_classes=1000):super(SqueezeNet, self).__init__()self.num_classes = num_classesif version == '1_0':self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(96, 16, 64, 64),Fire(128, 16, 64, 64),Fire(128, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 32, 128, 128),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),nn.MaxPool2d(kernel_size=3, stride=2 ,ceil_mode=True),Fire(512, 64, 256, 256),)elif version == '1_1':self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(64, 16, 64, 64),Fire(128, 16, 64, 64),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(128, 32, 128, 128),Fire(256, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),Fire(512, 64, 256, 256),)else:raise ValueError("Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected".format(version=version))#使用卷积代替全连接层final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)self.classifier = nn.Sequential(nn.Dropout(0.5),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1,1)))for m in self.modules():if isinstance(m, nn.Conv2d):if m is final_conv:init.normal_(m.weight, mean=0.0, std=0.01)else:init.kaiming_uniform_(m.weight)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):x = self.features(x)x = self.classifier(x)return x.view(x.size(0), self.num_classes)def _squeezenet(version, pretrained, progress, **kwargs):model = SqueezeNet(version, **kwargs)if pretrained:arch = 'squeezenet' + versionstate_dict = load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)return modeldef squeezenet1_0(pretrained=False, progress=True, **kwargs):return _squeezenet('1_0', pretrained, progress, **kwargs)def squeezenet1_1(pretrained=False, progress=True, **kwargs):return _squeezenet('1_1', pretrained, progress, **kwargs)

2.一些用法

2.1 torch.cat

torch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))], 1)
#按照第一个维度(channel维度)对[]内的Tensor进行拼接

2.2 nn.MaxPool2d()

nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1,return_indices=False, ceil_mode=False)
# kernel_size(int or tuple) - max pooling的窗口大小
#stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size
#padding(int or tuple, optional) - 输入的每一条边补充0的层数
#dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数
#return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助
#ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作

2.3 使用全卷积代替全连接层

#使用全卷积代替FC层
self.classifier = nn.Sequential(nn.Dropout(0.5),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1,1)))
def forward(self, x):x = self.features(x)x = self.classifier(x)return x.view(x.size(0), self.num_classes)
#即先采用AdaptiveAvgPool2D,将size变为1x1,channel数=num_classes,再做resize

Pytorch源码学习之四:torchvision.models.squeezenet相关推荐

  1. PyTorch源码解读之torchvision.models

    PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets.torchvision.models.torchvisi ...

  2. PyTorch源码学习系列 - 1.初识

    本系列文章会优先发布于微信公众号和知乎,欢迎大家关注 微信公众号:小飞怪兽屋 知乎: PyTorch源码学习系列 - 1.初识 - 知乎 (zhihu.com) 目录 本系列的目的 PyTorch是什 ...

  3. OpenFire源码学习之四:openfire的启动流程

    openfire启动 ServerStarter 启动流程图: 启动的总入口在ServerStarter的main方法中.通过上图首先它会先加载它所需要的jar文件.最后通过java反射机制将XMPP ...

  4. Pytorch ResNet源码学习

    一,残差网络架构 1,残差学习单元 上图左对应的是浅层网络(18层,34层),而右图对应的是深层网络(50,101,152). 1. 左图为基本的residual block,residual map ...

  5. 基于Pytorch源码对SGD、momentum、Nesterov学习

    目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...

  6. 小白学习pytorch源码(二):setup.py最详细解读

    小白学习pytorch源码(二) pytorch setup.py最全解析 setup.py与setuptools setup.py最详细解读 setup.py 环境检查 setup.py setup ...

  7. PyTorch 源码解读之 torch.serialization torch.hub

    作者 | 123456 来源 | OpenMMLab 编辑 | 极市平台 导读 本文解读基于PyTorch 1.7版本,对torch.serialization.torch.save和torch.hu ...

  8. Transformer-XL解读(论文 + PyTorch源码)

    前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...

  9. pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)

    写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...

最新文章

  1. 项目管理过程中应注意的问题
  2. 分享5个免费的在线 SQL 数据库环境,简直太方便了!
  3. node 使用 download-git-repo 下载 github 代码
  4. 必须要改变这样的生活
  5. NIPS 2016 | Best Paper, Dual Learning, Review Network, VQA 等论文选介
  6. 1 第一次画PCB总结
  7. 对渠道流量异常情况的分析
  8. Wordcounter,使用Lambdas和Fork / Join计算Java中的单词数
  9. [导入]C#好书盘点【月儿原创】
  10. python和arduino串口通信_利用串行通信实现python与arduino的同步
  11. 全流分析取证:高级威胁哪里跑?!
  12. 期望E==>加权均值(每个元素×它们各自的概率)
  13. (48)VHDL实现8位奇偶校验电路(process语句语句)
  14. 【转载】如何把Mysql5.5数据库的数据导入到MSSql 数据库中【mysql-connector-odbc-3.51.28-win32】...
  15. 学习V神的手把手教你写脚本引擎 一
  16. 64qam星座图matlab,基于MATLAB的QAM 眼图和星座图
  17. python 小世界网络
  18. office二级笔记
  19. CSS实现导航条图片的翻转菜单
  20. Oracle默认内置账户介绍,SYS与SYSTEM两个账户的区别

热门文章

  1. 回溯_leetcode.17.电话号码的字母组合
  2. 前端 JavaScript 原型和原型链
  3. 你可知vivo手机5大黑科技?如果连这都不知道的话,那可太浪费了
  4. 博士申请 | 美国佛罗里达州立大学计算机系王广老师招收人工智能全奖博士生...
  5. SQLServer2008服务无法启动
  6. 线性规划两阶段求解方法
  7. 盲盒交友变现系统/脱单盲盒/一元交友/存取小纸条盲盒/分销功能/盲盒交友小程序
  8. 剪辑过程中的思考与总结(持续更新ing)
  9. 锐龙intel服务器性能,CPU选购 详解锐龙和英特尔性能对比
  10. 小米10至尊纪念版与OPPO Find X2 Pro哪个好