Inception代码解读
Inception代码解读
目录
- Inception代码解读
- 概述
- Inception网络结构图
- inception网络结构框架
- inception代码细节分析
概述
inception相比起最开始兴起的AlexNet和VGG,做了以下重要改动:
1)改变了“直通”型的网络结构,将一个大的卷积核做的事情分成了几个小的卷积核来完成;
2)这样带来的另一个好处是可以得到不同尺度的特征,并且对不同尺度大小的特征进行融合,使得提取出来的特征的语义信息更加丰富;
3)引入了1x1的卷积核,1x1的卷积核可以用来方便地改变通道数,以便于不同尺度的特征图经过通道数变换之后能够concatenate在一起。
Inception网络结构图
1)inceptionv1的朴素版本
2)inceptionv1的加1x1卷积核变换通道数的版本
3)inceptionv2的不同类型的网络结构
a)用两个3x3代替5x5的卷积核
b) n x n卷积分解成若干个n x1、1 x 1、1 x n卷积的级联
c) “展宽”结构的inception
inception网络结构框架
inception代码细节分析
from collections import namedtuple
import warnings
import torch
from torch import nn, Tensor
import torch.nn.functional as F
# from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List
from torchsummary import summary
__all__ = ['Inception3','inception_v3','InceptionOutputs','_InceptionOutputs']
# 预训练inception模型的权重
model_urls = {'inception_v3_google':'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',}
InceptionOutputs = namedtuple('InceptionOutputs',['logits','aux_logits'])
InceptionOutputs.__annotations__ = {'logits','aux_logits'}_InceptionOutputs = InceptionOutputsdef inception_v3(pretrained: bool, progress:bool,**kwargs:Any):if pretrained:if 'transform_input'not in kwargs:kwargs['transform_input'] = Trueif 'aux_logits' in kwargs:original_aux_logits = kwargs['aux_logits ']kwrags['aux_logits '] = Trueelse:original_aux_logits = True# 使用预训练模型,因此初始化参数init_weights设置为Falsekwargs['init_weights'] = Falsemodel = Inception3(**kwargs)state_dict = load_state_dict_from_url(model_urls['inception_v3_googlenet'],progress = progress)model.load_state_dict(state_dict)if not original_aux_logits:model.aux_logits = Falsemodel.AuxLogits = Nonereturn modelreturn Inception3(**kwargs)class Inception3(nn.Module):def __init__(self,num_classes:1000,aux_logits:True,transform_input:False,inception_blocks:None,init_weights:None):super(Inception3,self).__init__()# inception_blocks的不同类型if inception_blocks is None:inception_blocks = [BasicConv2d,InceptionA,InceptionB,InceptionC,InceptionD,InceptionE,InceptionAux]if init_weights is None:warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ''torchvision. If you wish to keep the old behavior (which leads to long initialization times'' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)init_weights = Trueassert len(inception_blocks)==7# inception的不同部分conv_block = inception_blocks[0]inception_a = inception_blocks[1]inception_b = inception_blocks[2]inception_c = inception_blocks[3]inception_d = inception_blocks[4]inception_e = inception_blocks[5]inception_aux = inception_blocks[6]self.aux_logits = aux_logitsself.transform_input = transform_input# 不同inception结构有不一样的卷积核大小self.Conv2d_1a_3x3 = conv_block(3,32,kernel_size = 3, stride = 2)self.Conv2d_2a_3x3 = conv_block(32,32,kernel_size = 3)self.Conv2d_2b_3x3 = conv_block(32,64,kernel_size = 3, padding = 1)self.maxpool1 = nn.MaxPool2d(kernel_size = 3, stride = 2)self.Conv2d_3b_1x1 = conv_block(64,80,kernel_size = 1)self.Conv2d_4a_3x3 = conv_block(80,192,kernel_size = 3)self.maxpool2 = nn.MaxPool2d(kernel_size = 3, stride = 2)self.Mixed_5b = inception_a(192,pool_features = 32)self.Mixed_5c = inception_a(256,pool_features = 64)self.Mixed_5d = inception_a(256,pool_features = 64)self.Mixed_6a = inception_b(288)self.Mixed_6b = inception_c(768,channels_7x7 = 128)self.Mixed_6c = inception_c(768,channels_7x7 = 160)self.Mixed_6d = inception_c(768,channels_7x7 = 160)self.Mixed_6e = inception_c(768,channels_7x7 = 192)self.Auxlogits = Noneself.Mixed_7a = inception_d(768)self.Mixed_7b = inception_e(1280)self.Mixed_7c = inception_2(2048)self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.dropout = nn.Dropout()# 分类器self.fc = nn.Linear(2048, num_classes)# 不同层的参数初始化方法if init_weights:if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):import scipy.stats as statssteddev = m.stddev if hasattr(m,'stddev') else 0.1X = stats.truncnorm(-2,2,scale = stddev)values = torch.as_tensor(X.rvs(m.weights.numel()),dtype = m.weights.dtype)values = values.view(m.weights.size())with torch.no_grad():m.weight.copy_(values)elif isinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)def _transform_input(self,x):# 对输入图片增加一维,并作中心化if self.transform_input:x_ch0 = torch.unsqueeze(x[:,0],1)*(0.229/0.5)+(0.485-0.5)/0.5x_ch1 = torch.unsqueeze(x[:,1],1)*(0.224/0.5)+(0.456-0.5)/0.5x_ch2 = torch.unsqueeze(x[:,2],1)*(0.225/0.5)+(0.406-0.5)/0.5return xdef _forward(self,x):# N x 3 x 299 x 299x = self.Conv2d_1a_3x3(x)# N x 32 x 149 x 149x = self.Conv2d_2a_3x3(x)# N x 32 x 147 x 147x = self.Conv2d_2b_3x3(x)# N x 64 x 147 x 147x = self.maxpool1(x)# N x 64 x 73 x 73x = self.Conv2d_3b_1x1(x)# N x 80 x 73 x 73x = self.Conv2d_4a_3x3(x)# N x 192 x 71 x 71x = self.maxpool2(x)# N x 192 x 35 x 35x = self.Mixed_5b(x)# N x 256 x 35 x 35x = self.Mixed_5c(x)# N x 288 x 35 x 35x = self.Mixed_5d(x)# N x 288 x 35 x 35x = self.Mixed_6a(x)# N x 768 x 17 x 17x = self.Mixed_6b(x)# N x 768 x 17 x 17x = self.Mixed_6c(x)# N x 768 x 17 x 17x = self.Mixed_6d(x)# N x 768 x 17 x 17x = self.Mixed_6e(x)# N x 768 x 17 x 17aux: Optional[Tensor] = Noneif self.AuxLogits is not None:if self.training:aux = self.AuxLogits(x)# N x 768 x 17 x 17x = self.Mixed_7a(x)# N x 1280 x 8 x 8x = self.Mixed_7b(x)# N x 2048 x 8 x 8x = self.Mixed_7c(x)# N x 2048 x 8 x 8# Adaptive average poolingx = self.avgpool(x)# N x 2048 x 1 x 1x = self.dropout(x)# N x 2048 x 1 x 1x = torch.flatten(x, 1)# N x 2048x = self.fc(x)# N x 1000 (num_classes)return x, aux# @torch.jit.unuseddef eager_outputs(self,x,aux):if self.training and self.aux_logits:return InceptionOutputs(x,aux)else:return xdef forward(self,x):x = self._transform_input(x)x,aux = self._forward(x)aux_defined = self.training and self.aux_logitsif torch.jit.is_scripting():if not aux_defined:warnings.warn("Scripted Inception3 always returns Inception3 Tuple")return InceptionOutputs(x, aux)else:return self.eager_outputs(x, aux)
class InceptionA(nn.Module):def __init__(self,in_channels,pool_features,conv_block = None):super(InceptionA,self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels,64,kernel_size = 1)self.branch5x5_1 = conv_block(in_channels,48,kernel_size = 1)self.branch5x5_2 = conv_block(48,64,kernel_size = 5,padding = 2)self.branch3x3dbl_1 = conv_block(in_channels,64,kernel_size = 1)self.branch3x3dbl_2 = conv_block(64,96,kernel_size = 1,padding = 1)self.branch3x3dbl_3 = conv_block(96,96,kernel_size = 1,padding = 1)self.branch_pool = conv_block(in_channels,pool_features,kernel_size = 1)def _forward(self,x):# 根据inceptionA的结构搭建网络branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_1(x)branch5x5 = self.branch5x5_1(branch5x5)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)branch_pool = F.avg_pool2d(x,kernel_size = 3, stride = 1, padding = 1)branch_pool = self.branch_pool(branch_pool)# 把不同尺度的输出concatenate在一起,也可以写成torch.cat((branch1x1,branch5x5,branch3x3dbl,branch_pool),axis = 1)outputs = [branch1x1,branch5x5,branch3x3dbl,branch_pool]return outputsdef forward(self,x):outputs = self._forward(x)return torch.cat(outputs,1)
class InceptionB(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionB, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionB的结构搭建网络branch3x3 = self.branch3x3(x)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)outputs = [branch3x3, branch3x3dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionC(nn.Module):def __init__(self,in_channels: int,channels_7x7: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionC, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels, 192, kernel_size=1)c7 = channels_7x7self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))self.branch_pool = conv_block(in_channels, 192, kernel_size=1)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionC的结构搭建网络branch1x1 = self.branch1x1(x)branch7x7 = self.branch7x7_1(x)branch7x7 = self.branch7x7_2(branch7x7)branch7x7 = self.branch7x7_3(branch7x7)branch7x7dbl = self.branch7x7dbl_1(x)branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionD(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionD, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionD的结构搭建网络branch3x3 = self.branch3x3_1(x)branch3x3 = self.branch3x3_2(branch3x3)branch7x7x3 = self.branch7x7x3_1(x)branch7x7x3 = self.branch7x7x3_2(branch7x7x3)branch7x7x3 = self.branch7x7x3_3(branch7x7x3)branch7x7x3 = self.branch7x7x3_4(branch7x7x3)branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)outputs = [branch3x3, branch7x7x3, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionE(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionE, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels, 320, kernel_size=1)self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))self.branch_pool = conv_block(in_channels, 192, kernel_size=1)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionE的结构搭建网络branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3_1(x)branch3x3 = [self.branch3x3_2a(branch3x3),self.branch3x3_2b(branch3x3),]branch3x3 = torch.cat(branch3x3, 1)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl),self.branch3x3dbl_3b(branch3x3dbl),]branch3x3dbl = torch.cat(branch3x3dbl, 1)branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)# inception的旁路辅助模块
class InceptionAux(nn.Module):def __init__(self, in_channels,num_classes,conv_block = None):super(InceptionAux,self).__init__()if conv_block is None:conv_block = BasicConv2dself.conv0 = conv_block(in_channels,128,kernel_size = 1)self.conv1 = conv_block(128,768,kernel_size = 5)self.conv1.stddev = 0.01self.fc = nn.Linear(768, num_classes)self.fc.stddev = 0.001def forward(self,x):# N x 768 x 17 x 17x = F.avg_pool2d(x,kernel_size = 5, stride = 3)# N x 128 x 5 x 5x = self.conv0(x)# N x 768 x 1 x 1x = self.conv1(x)# Adaptive average poolingx = F.adaptive_avg_pool2d(x,(1,1))# N x 768 x 1 x 1x = torch.flatten(x,1)# N x 768x = self.fc(x)# N x1000return x
class BasicConv2d(nn.Module):def __init__(self,in_channels,out_channels,**kwargs:Any):super(BasicConv2d,self).__init__()self.conv = nn.Conv2d(in_channels, out_channels,bias = False,**kwargs)self.bn = nn.BatchNorm2d(out_channels,eps = 0.001)def forward(self,x):x = self.conv(x)x = self.bn(x)return F.relu(x,inplace = True)
Inception代码解读相关推荐
- GoogLeNet代码解读
GoogLeNet代码解读 目录 GoogLeNet代码解读 概述 GooLeNet网络结构图 1)从输入到第一层inception 2)从第2层inception到第4层inception 3)从第 ...
- 元学习之《Matching Networks for One Shot Learning》代码解读
元学习系列文章 optimization based meta-learning <Model-Agnostic Meta-Learning for Fast Adaptation of Dee ...
- 200行代码解读TDEngine背后的定时器
作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...
- 装逼一步到位!GauGAN代码解读来了
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...
- Unet论文解读代码解读
论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...
- Lossless Codec---APE代码解读系列(二)
APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...
- RT-Thread 学习笔记(五)—— RTGUI代码解读
---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...
- vins 解读_代码解读 | VINS 视觉前端
AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...
- BERT:代码解读、实体关系抽取实战
目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...
最新文章
- 获取java hashCode分布
- 「云毕业照」刷爆朋友圈!AI人脸融合技术谁家强?
- 机器翻译引擎的基本原理 ——LSTM
- Chrome浏览器报错:Origin null is not allowed by Access-Control-Allow-Origin.
- 库克回应 iPhone 11 系列不支持 5G;哈啰 App 被下架;Flutter 1.9 稳定版发布 | 极客头条...
- 《高级着色语言HLSL入门》系列文章
- 操作系统概念第五章部分作业题答案
- Ignite 的使用过程(一)
- 开心农场违规 恐面临关停危险
- Android只播放gif动画
- 代理模式——远程代理(一)
- android 模拟器黑屏 Cordova多平台方案
- oracle minus 条件,Oracle minus用法详解及应用实例
- IP转换为long类型
- 计算机最新行情调研报告,2020年中国笔记本电脑市场调研报告
- 公众号商城开发和微信小程序商城开发有什么区别?
- 开运算和闭运算的作用
- mysql数据库学习之索引
- Class<?>和Class的区别
- access的否定形式_“肯定形式”表示“否定含义”三种形式