Pytorch源码学习之四:torchvision.models.squeezenet
0.介绍
Squeezenet网址
torchvision.model.squeeze官方文档
主要思想:堆叠Fire模块,每个Fire模块,分别采用1x1和3x3两个分支,最后做拼;,每个Fire的尺寸不变,channel数不变或增加;每个stage的Fire模块之间用nn.MaxPool2d进行下采样;使用卷积层代替FC层,channel数为类别数
1.源码
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.hub import load_state_dict_from_url__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']
model_urls = {'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth','squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
}class Fire(nn.Module): #Fire模块def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes):super(Fire, self).__init__()self.inplanes = inplanesself.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)self.squeeze_activation = nn.ReLU(inplace=True)self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)self.expand1x1_activation = nn.ReLU(inplace=True)self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=1)self.expand3x3_activation = nn.ReLU(inplace=True)def forward(self, x):x = self.squeeze(x)x = self.squeeze_activation(x)return torch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))], 1)class SqueezeNet(nn.Module):def __init__(self, version='1.0', num_classes=1000):super(SqueezeNet, self).__init__()self.num_classes = num_classesif version == '1_0':self.features = nn.Sequential(nn.Conv2d(3, 96, kernel_size=7, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(96, 16, 64, 64),Fire(128, 16, 64, 64),Fire(128, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 32, 128, 128),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),nn.MaxPool2d(kernel_size=3, stride=2 ,ceil_mode=True),Fire(512, 64, 256, 256),)elif version == '1_1':self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=3, stride=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(64, 16, 64, 64),Fire(128, 16, 64, 64),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(128, 32, 128, 128),Fire(256, 32, 128, 128),nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),Fire(256, 48, 192, 192),Fire(384, 48, 192, 192),Fire(384, 64, 256, 256),Fire(512, 64, 256, 256),)else:raise ValueError("Unsupported SqueezeNet version {version}: 1_0 or 1_1 expected".format(version=version))#使用卷积代替全连接层final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)self.classifier = nn.Sequential(nn.Dropout(0.5),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1,1)))for m in self.modules():if isinstance(m, nn.Conv2d):if m is final_conv:init.normal_(m.weight, mean=0.0, std=0.01)else:init.kaiming_uniform_(m.weight)if m.bias is not None:init.constant_(m.bias, 0)def forward(self, x):x = self.features(x)x = self.classifier(x)return x.view(x.size(0), self.num_classes)def _squeezenet(version, pretrained, progress, **kwargs):model = SqueezeNet(version, **kwargs)if pretrained:arch = 'squeezenet' + versionstate_dict = load_state_dict_from_url(model_urls[arch],progress=progress)model.load_state_dict(state_dict)return modeldef squeezenet1_0(pretrained=False, progress=True, **kwargs):return _squeezenet('1_0', pretrained, progress, **kwargs)def squeezenet1_1(pretrained=False, progress=True, **kwargs):return _squeezenet('1_1', pretrained, progress, **kwargs)
2.一些用法
2.1 torch.cat
torch.cat([self.expand1x1_activation(self.expand1x1(x)),self.expand3x3_activation(self.expand3x3(x))], 1)
#按照第一个维度(channel维度)对[]内的Tensor进行拼接
2.2 nn.MaxPool2d()
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
class torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1,return_indices=False, ceil_mode=False)
# kernel_size(int or tuple) - max pooling的窗口大小
#stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size
#padding(int or tuple, optional) - 输入的每一条边补充0的层数
#dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数
#return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助
#ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作
2.3 使用全卷积代替全连接层
#使用全卷积代替FC层
self.classifier = nn.Sequential(nn.Dropout(0.5),final_conv,nn.ReLU(inplace=True),nn.AdaptiveAvgPool2d((1,1)))
def forward(self, x):x = self.features(x)x = self.classifier(x)return x.view(x.size(0), self.num_classes)
#即先采用AdaptiveAvgPool2D,将size变为1x1,channel数=num_classes,再做resize
Pytorch源码学习之四:torchvision.models.squeezenet相关推荐
- PyTorch源码解读之torchvision.models
PyTorch框架中有一个非常重要且好用的包:torchvision,该包主要由3个子包组成,分别是:torchvision.datasets.torchvision.models.torchvisi ...
- PyTorch源码学习系列 - 1.初识
本系列文章会优先发布于微信公众号和知乎,欢迎大家关注 微信公众号:小飞怪兽屋 知乎: PyTorch源码学习系列 - 1.初识 - 知乎 (zhihu.com) 目录 本系列的目的 PyTorch是什 ...
- OpenFire源码学习之四:openfire的启动流程
openfire启动 ServerStarter 启动流程图: 启动的总入口在ServerStarter的main方法中.通过上图首先它会先加载它所需要的jar文件.最后通过java反射机制将XMPP ...
- Pytorch ResNet源码学习
一,残差网络架构 1,残差学习单元 上图左对应的是浅层网络(18层,34层),而右图对应的是深层网络(50,101,152). 1. 左图为基本的residual block,residual map ...
- 基于Pytorch源码对SGD、momentum、Nesterov学习
目前神经网络的监督学习过程通常为: 数据加载(load)进神经网络 经过网络参数对数据的计算,得出预测值(predict) 根据预测值与标注值(label)之间的差距,产生损失(loss) 通过反向传 ...
- 小白学习pytorch源码(二):setup.py最详细解读
小白学习pytorch源码(二) pytorch setup.py最全解析 setup.py与setuptools setup.py最详细解读 setup.py 环境检查 setup.py setup ...
- PyTorch 源码解读之 torch.serialization torch.hub
作者 | 123456 来源 | OpenMMLab 编辑 | 极市平台 导读 本文解读基于PyTorch 1.7版本,对torch.serialization.torch.save和torch.hu ...
- Transformer-XL解读(论文 + PyTorch源码)
前言 目前在NLP领域中,处理语言建模问题有两种最先进的架构:RNN和Transformer.RNN按照序列顺序逐个学习输入的单词或字符之间的关系,而Transformer则接收一整段序列,然后使用s ...
- pytorch 测试每一类_DeepFM全方面解析(附pytorch源码)
写在前面 最近看了DeepFM这个模型.把我学习的思路和总结放上来给大家和未来的自己做个参考和借鉴.文章主要希望能串起学习DeepFM的各个环节,梳理整个学习思路.以"我"的角度浅 ...
最新文章
- 项目管理过程中应注意的问题
- 分享5个免费的在线 SQL 数据库环境,简直太方便了!
- node 使用 download-git-repo 下载 github 代码
- 必须要改变这样的生活
- NIPS 2016 | Best Paper, Dual Learning, Review Network, VQA 等论文选介
- 1 第一次画PCB总结
- 对渠道流量异常情况的分析
- Wordcounter,使用Lambdas和Fork / Join计算Java中的单词数
- [导入]C#好书盘点【月儿原创】
- python和arduino串口通信_利用串行通信实现python与arduino的同步
- 全流分析取证:高级威胁哪里跑?!
- 期望E==>加权均值(每个元素×它们各自的概率)
- (48)VHDL实现8位奇偶校验电路(process语句语句)
- 【转载】如何把Mysql5.5数据库的数据导入到MSSql 数据库中【mysql-connector-odbc-3.51.28-win32】...
- 学习V神的手把手教你写脚本引擎 一
- 64qam星座图matlab,基于MATLAB的QAM 眼图和星座图
- python 小世界网络
- office二级笔记
- CSS实现导航条图片的翻转菜单
- Oracle默认内置账户介绍,SYS与SYSTEM两个账户的区别
热门文章
- 回溯_leetcode.17.电话号码的字母组合
- 前端 JavaScript 原型和原型链
- 你可知vivo手机5大黑科技?如果连这都不知道的话,那可太浪费了
- 博士申请 | 美国佛罗里达州立大学计算机系王广老师招收人工智能全奖博士生...
- SQLServer2008服务无法启动
- 线性规划两阶段求解方法
- 盲盒交友变现系统/脱单盲盒/一元交友/存取小纸条盲盒/分销功能/盲盒交友小程序
- 剪辑过程中的思考与总结(持续更新ing)
- 锐龙intel服务器性能,CPU选购 详解锐龙和英特尔性能对比
- 小米10至尊纪念版与OPPO Find X2 Pro哪个好