Inception代码解读

目录

    • Inception代码解读
  • 概述
  • Inception网络结构图
  • inception网络结构框架
  • inception代码细节分析

概述

inception相比起最开始兴起的AlexNet和VGG,做了以下重要改动:
1)改变了“直通”型的网络结构,将一个大的卷积核做的事情分成了几个小的卷积核来完成;
2)这样带来的另一个好处是可以得到不同尺度的特征,并且对不同尺度大小的特征进行融合,使得提取出来的特征的语义信息更加丰富;
3)引入了1x1的卷积核,1x1的卷积核可以用来方便地改变通道数,以便于不同尺度的特征图经过通道数变换之后能够concatenate在一起。

Inception网络结构图

1)inceptionv1的朴素版本

2)inceptionv1的加1x1卷积核变换通道数的版本

3)inceptionv2的不同类型的网络结构
a)用两个3x3代替5x5的卷积核

b) n x n卷积分解成若干个n x1、1 x 1、1 x n卷积的级联

c) “展宽”结构的inception

inception网络结构框架

inception代码细节分析

from collections import namedtuple
import warnings
import torch
from torch import nn, Tensor
import torch.nn.functional as F
# from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List
from torchsummary import summary
__all__ = ['Inception3','inception_v3','InceptionOutputs','_InceptionOutputs']
# 预训练inception模型的权重
model_urls = {'inception_v3_google':'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',}
InceptionOutputs = namedtuple('InceptionOutputs',['logits','aux_logits'])
InceptionOutputs.__annotations__ = {'logits','aux_logits'}_InceptionOutputs = InceptionOutputsdef inception_v3(pretrained: bool, progress:bool,**kwargs:Any):if pretrained:if 'transform_input'not in kwargs:kwargs['transform_input'] = Trueif 'aux_logits' in kwargs:original_aux_logits = kwargs['aux_logits ']kwrags['aux_logits '] = Trueelse:original_aux_logits  = True# 使用预训练模型,因此初始化参数init_weights设置为Falsekwargs['init_weights'] = Falsemodel = Inception3(**kwargs)state_dict = load_state_dict_from_url(model_urls['inception_v3_googlenet'],progress = progress)model.load_state_dict(state_dict)if not original_aux_logits:model.aux_logits = Falsemodel.AuxLogits = Nonereturn modelreturn Inception3(**kwargs)class Inception3(nn.Module):def __init__(self,num_classes:1000,aux_logits:True,transform_input:False,inception_blocks:None,init_weights:None):super(Inception3,self).__init__()# inception_blocks的不同类型if inception_blocks is None:inception_blocks = [BasicConv2d,InceptionA,InceptionB,InceptionC,InceptionD,InceptionE,InceptionAux]if init_weights is None:warnings.warn('The default weight initialization of inception_v3 will be changed in future releases of ''torchvision. If you wish to keep the old behavior (which leads to long initialization times'' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)init_weights = Trueassert len(inception_blocks)==7# inception的不同部分conv_block = inception_blocks[0]inception_a = inception_blocks[1]inception_b = inception_blocks[2]inception_c = inception_blocks[3]inception_d = inception_blocks[4]inception_e = inception_blocks[5]inception_aux = inception_blocks[6]self.aux_logits = aux_logitsself.transform_input = transform_input# 不同inception结构有不一样的卷积核大小self.Conv2d_1a_3x3 = conv_block(3,32,kernel_size = 3, stride = 2)self.Conv2d_2a_3x3 = conv_block(32,32,kernel_size = 3)self.Conv2d_2b_3x3 = conv_block(32,64,kernel_size = 3, padding = 1)self.maxpool1 = nn.MaxPool2d(kernel_size = 3, stride = 2)self.Conv2d_3b_1x1 = conv_block(64,80,kernel_size = 1)self.Conv2d_4a_3x3 = conv_block(80,192,kernel_size = 3)self.maxpool2 = nn.MaxPool2d(kernel_size = 3, stride = 2)self.Mixed_5b = inception_a(192,pool_features = 32)self.Mixed_5c = inception_a(256,pool_features = 64)self.Mixed_5d = inception_a(256,pool_features = 64)self.Mixed_6a = inception_b(288)self.Mixed_6b = inception_c(768,channels_7x7 = 128)self.Mixed_6c = inception_c(768,channels_7x7 = 160)self.Mixed_6d = inception_c(768,channels_7x7 = 160)self.Mixed_6e = inception_c(768,channels_7x7 = 192)self.Auxlogits = Noneself.Mixed_7a = inception_d(768)self.Mixed_7b = inception_e(1280)self.Mixed_7c = inception_2(2048)self.avgpool = nn.AdaptiveAvgPool2d((1,1))self.dropout = nn.Dropout()# 分类器self.fc = nn.Linear(2048, num_classes)# 不同层的参数初始化方法if init_weights:if isinstance(m,nn.Conv2d) or isinstance(m,nn.Linear):import scipy.stats as statssteddev = m.stddev if hasattr(m,'stddev') else 0.1X = stats.truncnorm(-2,2,scale = stddev)values = torch.as_tensor(X.rvs(m.weights.numel()),dtype = m.weights.dtype)values = values.view(m.weights.size())with torch.no_grad():m.weight.copy_(values)elif isinstance(m,nn.BatchNorm2d):nn.init.constant_(m.weight,1)nn.init.constant_(m.bias,0)def _transform_input(self,x):# 对输入图片增加一维,并作中心化if self.transform_input:x_ch0 = torch.unsqueeze(x[:,0],1)*(0.229/0.5)+(0.485-0.5)/0.5x_ch1 = torch.unsqueeze(x[:,1],1)*(0.224/0.5)+(0.456-0.5)/0.5x_ch2 = torch.unsqueeze(x[:,2],1)*(0.225/0.5)+(0.406-0.5)/0.5return xdef _forward(self,x):# N x 3 x 299 x 299x = self.Conv2d_1a_3x3(x)# N x 32 x 149 x 149x = self.Conv2d_2a_3x3(x)# N x 32 x 147 x 147x = self.Conv2d_2b_3x3(x)# N x 64 x 147 x 147x = self.maxpool1(x)# N x 64 x 73 x 73x = self.Conv2d_3b_1x1(x)# N x 80 x 73 x 73x = self.Conv2d_4a_3x3(x)# N x 192 x 71 x 71x = self.maxpool2(x)# N x 192 x 35 x 35x = self.Mixed_5b(x)# N x 256 x 35 x 35x = self.Mixed_5c(x)# N x 288 x 35 x 35x = self.Mixed_5d(x)# N x 288 x 35 x 35x = self.Mixed_6a(x)# N x 768 x 17 x 17x = self.Mixed_6b(x)# N x 768 x 17 x 17x = self.Mixed_6c(x)# N x 768 x 17 x 17x = self.Mixed_6d(x)# N x 768 x 17 x 17x = self.Mixed_6e(x)# N x 768 x 17 x 17aux: Optional[Tensor] = Noneif self.AuxLogits is not None:if self.training:aux = self.AuxLogits(x)# N x 768 x 17 x 17x = self.Mixed_7a(x)# N x 1280 x 8 x 8x = self.Mixed_7b(x)# N x 2048 x 8 x 8x = self.Mixed_7c(x)# N x 2048 x 8 x 8# Adaptive average poolingx = self.avgpool(x)# N x 2048 x 1 x 1x = self.dropout(x)# N x 2048 x 1 x 1x = torch.flatten(x, 1)# N x 2048x = self.fc(x)# N x 1000 (num_classes)return x, aux# @torch.jit.unuseddef eager_outputs(self,x,aux):if self.training and self.aux_logits:return InceptionOutputs(x,aux)else:return xdef forward(self,x):x = self._transform_input(x)x,aux = self._forward(x)aux_defined = self.training and self.aux_logitsif torch.jit.is_scripting():if not aux_defined:warnings.warn("Scripted Inception3 always returns Inception3 Tuple")return InceptionOutputs(x, aux)else:return self.eager_outputs(x, aux)
class InceptionA(nn.Module):def __init__(self,in_channels,pool_features,conv_block = None):super(InceptionA,self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels,64,kernel_size = 1)self.branch5x5_1 = conv_block(in_channels,48,kernel_size = 1)self.branch5x5_2 = conv_block(48,64,kernel_size = 5,padding = 2)self.branch3x3dbl_1 = conv_block(in_channels,64,kernel_size = 1)self.branch3x3dbl_2 = conv_block(64,96,kernel_size = 1,padding = 1)self.branch3x3dbl_3 = conv_block(96,96,kernel_size = 1,padding = 1)self.branch_pool = conv_block(in_channels,pool_features,kernel_size = 1)def _forward(self,x):# 根据inceptionA的结构搭建网络branch1x1 = self.branch1x1(x)branch5x5 = self.branch5x5_1(x)branch5x5 = self.branch5x5_1(branch5x5)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)branch_pool = F.avg_pool2d(x,kernel_size = 3, stride = 1, padding = 1)branch_pool = self.branch_pool(branch_pool)# 把不同尺度的输出concatenate在一起,也可以写成torch.cat((branch1x1,branch5x5,branch3x3dbl,branch_pool),axis = 1)outputs = [branch1x1,branch5x5,branch3x3dbl,branch_pool]return outputsdef forward(self,x):outputs = self._forward(x)return torch.cat(outputs,1)
class InceptionB(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionB, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionB的结构搭建网络branch3x3 = self.branch3x3(x)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)outputs = [branch3x3, branch3x3dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionC(nn.Module):def __init__(self,in_channels: int,channels_7x7: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionC, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels, 192, kernel_size=1)c7 = channels_7x7self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))self.branch_pool = conv_block(in_channels, 192, kernel_size=1)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionC的结构搭建网络branch1x1 = self.branch1x1(x)branch7x7 = self.branch7x7_1(x)branch7x7 = self.branch7x7_2(branch7x7)branch7x7 = self.branch7x7_3(branch7x7)branch7x7dbl = self.branch7x7dbl_1(x)branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionD(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionD, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionD的结构搭建网络branch3x3 = self.branch3x3_1(x)branch3x3 = self.branch3x3_2(branch3x3)branch7x7x3 = self.branch7x7x3_1(x)branch7x7x3 = self.branch7x7x3_2(branch7x7x3)branch7x7x3 = self.branch7x7x3_3(branch7x7x3)branch7x7x3 = self.branch7x7x3_4(branch7x7x3)branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)outputs = [branch3x3, branch7x7x3, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)class InceptionE(nn.Module):def __init__(self,in_channels: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionE, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1x1 = conv_block(in_channels, 320, kernel_size=1)self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))self.branch_pool = conv_block(in_channels, 192, kernel_size=1)def _forward(self, x: Tensor) -> List[Tensor]:# 根据inceptionE的结构搭建网络branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3_1(x)branch3x3 = [self.branch3x3_2a(branch3x3),self.branch3x3_2b(branch3x3),]branch3x3 = torch.cat(branch3x3, 1)branch3x3dbl = self.branch3x3dbl_1(x)branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)branch3x3dbl = [self.branch3x3dbl_3a(branch3x3dbl),self.branch3x3dbl_3b(branch3x3dbl),]branch3x3dbl = torch.cat(branch3x3dbl, 1)branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)branch_pool = self.branch_pool(branch_pool)outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)# inception的旁路辅助模块
class InceptionAux(nn.Module):def __init__(self, in_channels,num_classes,conv_block = None):super(InceptionAux,self).__init__()if conv_block is None:conv_block = BasicConv2dself.conv0 = conv_block(in_channels,128,kernel_size = 1)self.conv1 = conv_block(128,768,kernel_size = 5)self.conv1.stddev = 0.01self.fc = nn.Linear(768, num_classes)self.fc.stddev = 0.001def forward(self,x):# N x 768 x 17 x 17x = F.avg_pool2d(x,kernel_size = 5, stride = 3)# N x 128 x 5 x 5x = self.conv0(x)# N x 768 x 1 x 1x = self.conv1(x)# Adaptive average poolingx = F.adaptive_avg_pool2d(x,(1,1))# N x 768 x 1 x 1x = torch.flatten(x,1)# N x 768x = self.fc(x)# N x1000return x
class BasicConv2d(nn.Module):def __init__(self,in_channels,out_channels,**kwargs:Any):super(BasicConv2d,self).__init__()self.conv = nn.Conv2d(in_channels, out_channels,bias = False,**kwargs)self.bn = nn.BatchNorm2d(out_channels,eps = 0.001)def forward(self,x):x = self.conv(x)x = self.bn(x)return F.relu(x,inplace = True)

Inception代码解读相关推荐

  1. GoogLeNet代码解读

    GoogLeNet代码解读 目录 GoogLeNet代码解读 概述 GooLeNet网络结构图 1)从输入到第一层inception 2)从第2层inception到第4层inception 3)从第 ...

  2. 元学习之《Matching Networks for One Shot Learning》代码解读

    元学习系列文章 optimization based meta-learning <Model-Agnostic Meta-Learning for Fast Adaptation of Dee ...

  3. 200行代码解读TDEngine背后的定时器

    作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...

  4. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  5. Unet论文解读代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...

  6. Lossless Codec---APE代码解读系列(二)

    APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...

  7. RT-Thread 学习笔记(五)—— RTGUI代码解读

    ---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...

  8. vins 解读_代码解读 | VINS 视觉前端

    AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...

  9. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

最新文章

  1. 获取java hashCode分布
  2. 「云毕业照」刷爆朋友圈!AI人脸融合技术谁家强?
  3. 机器翻译引擎的基本原理 ——LSTM
  4. Chrome浏览器报错:Origin null is not allowed by Access-Control-Allow-Origin.
  5. 库克回应 iPhone 11 系列不支持 5G;哈啰 App 被下架;Flutter 1.9 稳定版发布 | 极客头条...
  6. 《高级着色语言HLSL入门》系列文章
  7. 操作系统概念第五章部分作业题答案
  8. Ignite 的使用过程(一)
  9. 开心农场违规 恐面临关停危险
  10. Android只播放gif动画
  11. 代理模式——远程代理(一)
  12. android 模拟器黑屏 Cordova多平台方案
  13. oracle minus 条件,Oracle minus用法详解及应用实例
  14. IP转换为long类型
  15. 计算机最新行情调研报告,2020年中国笔记本电脑市场调研报告
  16. 公众号商城开发和微信小程序商城开发有什么区别?
  17. 开运算和闭运算的作用
  18. mysql数据库学习之索引
  19. Class<?>和Class的区别
  20. access的否定形式_“肯定形式”表示“否定含义”三种形式

热门文章

  1. 计算机专业到投行的工作需要的金融财务知识
  2. 简单介绍--TOSCA自动化测试工具
  3. if else和switch的效率
  4. shell基础之if语句
  5. Java输入两个正整数m和n,求其最大公约数和最小公倍数。
  6. DI(依赖注入)简单理解 NO1
  7. bnu 4067 美丽的花环
  8. hdu-2204 Eddy's爱好 nyoj 526
  9. 定义一个接口CanFly,描述会飞的方法public void fly();
  10. 5-3如何设置文件的缓冲