GoogLeNet代码解读

目录

    • GoogLeNet代码解读
  • 概述
  • GooLeNet网络结构图
    • 1)从输入到第一层inception
    • 2)从第2层inception到第4层inception
    • 3)从第5层inception到第7层inception
    • 4)从第8层inception到输出
  • GooLeNet架构搭建
  • 代码细节分析

概述

GooLeNet网络结构图

1)从输入到第一层inception

2)从第2层inception到第4层inception

3)从第5层inception到第7层inception

4)从第8层inception到输出

GooLeNet架构搭建

代码细节分析

from collections import namedtuple
import warnings
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from .utils import load_state_dict_from_url
from typing import Callable, Any, Optional, Tuple, List
# 可供下载的googlenet预训练模型名称
__all__ = ['GoogLeNet','googlenet','GoogLeNetOutputs','_GoogLeNetOutputs']
# 预训练权重下载
model_urls = {'googlenet':'https://download.pytorch.org/models/googlenet-1378be20.pth',}
GoogLeNetOutputs = namedtuple('GoogLeNetOutputs',['logits','aux_logits2','aux_logits1'])
GoogLeNetOutputs.__annotations__ = {'logits': Tensor, 'aux_logits2': Optional[Tensor],'aux_logits1': Optional[Tensor]}
_GoogLeNetOutputs = GoogLeNetOutputsdef googlenet(pretrained = False, progress = True, **kwargs):if pretrained:if 'transform_input' not in kwargs:kwargs['transform_input'] = Trueif 'aux_logits' not in kwargs:kwargs['aux_logits'] = Falseif kwargs['aux_logits']:warnings.warn('auxiliary heads in the pretrained googlenet model are NOT pretrained, ''so make sure to train them')orginal_aux_logits = kwargs['aux_logits']kwargs['aux_logits'] = Truekwargs['init_weights'] = Falsemodel = GoogLeNet(**kwargs)# 下载googlenet模型并加载state_dict = load_state_dict_from_url(model_urls['googlenet'],progress = progress)model.load_state_dict(state_dict)if not original_aux_logits:model.aux_logits = Falsemodel.aux1 = Nonemodel.aux2 = Nonereturn modelreturn GoogLeNet(**kwargs)class GoogLeNet(nn.Module):__constants__ = ['aux_logits','transform_input']def __init__(self,num_classes = 1000,aux_logits = True,trandform_input = False,init_weights = None,blocks = None):super(GoogLeNet,self).__init__()if blocks is None:blocks = [BasicConv2d, Inception, InceptionAux]if init_weights is None:warnings.warn('The default weight initialization of GoogleNet will be changed in future releases of ''torchvision. If you wish to keep the old behavior (which leads to long initialization times'' due to scipy/scipy#11299), please set init_weights=True.', FutureWarning)init_weights = Trueassert len(blocks)==3conv_block = blocks[0]inception_block = blocks[1]inception_aux_block = blocks[2]self.aux_logits = aux_logitsself.transform_input = transform_input# 从输入到第一层inception的卷积、池化处理self.conv1 = conv_block(3,64,kernel_size = 7, stride = 3, padding = 3)self.maxpool1 = nn.MaxPool2d(3,stride = 2, ceil_mode = True)self.conv2 = conv_block(64,64,kernel_size = 1)self.conv3 = conv_block(64,192,kernel_size = 3, padding = 1)self.maxpool2 = nn.MaxPool2d(3,stride = 2, ceil_mode = True)# 一系列的inception模块self.inception3a = inception_block(192,64,96,128,16,32,32)self.inception3b = inception_block(256, 128, 128, 192, 32, 96, 64)self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception4a = inception_block(480, 192, 96, 208, 16, 48, 64)self.inception4b = inception_block(512, 160, 112, 224, 24, 64, 64)self.inception4c = inception_block(512, 128, 128, 256, 24, 64, 64)self.inception4d = inception_block(512, 112, 144, 288, 32, 64, 64)self.inception4e = inception_block(528, 256, 160, 320, 32, 128, 128)self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)self.inception5a = inception_block(832, 256, 160, 320, 32, 128, 128)self.inception5b = inception_block(832, 384, 192, 384, 48, 128, 128)# 辅助分类模块if aux_logits:self.aux1 = inception_aux_block(512, num_classes)self.aux2 = inception_aux_block(528, num_classes)else:self.aux1 = None  # type: ignore[assignment]self.aux2 = None  # type: ignore[assignment]# 平均池化、dropout防止过拟合self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.dropout = nn.Dropout(0.2)self.fc = nn.Linear(1024, num_classes)if init_weights:self._initialize_weights()def _initialize_weights(self) -> None:# 初始化权重和偏置参数for m in self.modules():if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):import scipy.stats as statsX = stats.truncnorm(-2, 2, scale=0.01)values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)values = values.view(m.weight.size())with torch.no_grad():m.weight.copy_(values)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# 给input增加一个维度并作中心化def _transform_input(self, x: Tensor) -> Tensor:if self.transform_input:x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5x = torch.cat((x_ch0, x_ch1, x_ch2), 1)return x# 构建googlenet网络def _forward(self, x: Tensor) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:# N x 3 x 224 x 224x = self.conv1(x)# N x 64 x 112 x 112x = self.maxpool1(x)# N x 64 x 56 x 56x = self.conv2(x)# N x 64 x 56 x 56x = self.conv3(x)# N x 192 x 56 x 56x = self.maxpool2(x)# N x 192 x 28 x 28x = self.inception3a(x)# N x 256 x 28 x 28x = self.inception3b(x)# N x 480 x 28 x 28x = self.maxpool3(x)# N x 480 x 14 x 14x = self.inception4a(x)# N x 512 x 14 x 14aux1: Optional[Tensor] = Noneif self.aux1 is not None:if self.training:aux1 = self.aux1(x)x = self.inception4b(x)# N x 512 x 14 x 14x = self.inception4c(x)# N x 512 x 14 x 14x = self.inception4d(x)# N x 528 x 14 x 14aux2: Optional[Tensor] = Noneif self.aux2 is not None:if self.training:aux2 = self.aux2(x)x = self.inception4e(x)# N x 832 x 14 x 14x = self.maxpool4(x)# N x 832 x 7 x 7x = self.inception5a(x)# N x 832 x 7 x 7x = self.inception5b(x)# N x 1024 x 7 x 7x = self.avgpool(x)# N x 1024 x 1 x 1x = torch.flatten(x, 1)# N x 1024x = self.dropout(x)x = self.fc(x)# N x 1000 (num_classes)return x, aux2, aux1@torch.jit.unuseddef eager_outputs(self, x: Tensor, aux2: Tensor, aux1: Optional[Tensor]) -> GoogLeNetOutputs:if self.training and self.aux_logits:return _GoogLeNetOutputs(x, aux2, aux1)else:return x   # type: ignore[return-value]def forward(self, x: Tensor) -> GoogLeNetOutputs:x = self._transform_input(x)x, aux1, aux2 = self._forward(x)aux_defined = self.training and self.aux_logitsif torch.jit.is_scripting():if not aux_defined:warnings.warn("Scripted GoogleNet always returns GoogleNetOutputs Tuple")return GoogLeNetOutputs(x, aux2, aux1)else:return self.eager_outputs(x, aux2, aux1)# inception模块
class Inception(nn.Module):def __init__(self,in_channels: int,ch1x1: int,ch3x3red: int,ch3x3: int,ch5x5red: int,ch5x5: int,pool_proj: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(Inception, self).__init__()if conv_block is None:conv_block = BasicConv2dself.branch1 = conv_block(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(conv_block(in_channels, ch3x3red, kernel_size=1),conv_block(ch3x3red, ch3x3, kernel_size=3, padding=1))self.branch3 = nn.Sequential(conv_block(in_channels, ch5x5red, kernel_size=1),# Here, kernel_size=3 instead of kernel_size=5 is a known bug.# Please see https://github.com/pytorch/vision/issues/906 for details.conv_block(ch5x5red, ch5x5, kernel_size=3, padding=1))self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),conv_block(in_channels, pool_proj, kernel_size=1))def _forward(self, x: Tensor) -> List[Tensor]:branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return outputsdef forward(self, x: Tensor) -> Tensor:outputs = self._forward(x)return torch.cat(outputs, 1)# 辅助的inception模块,用于分类
class InceptionAux(nn.Module):def __init__(self,in_channels: int,num_classes: int,conv_block: Optional[Callable[..., nn.Module]] = None) -> None:super(InceptionAux, self).__init__()if conv_block is None:conv_block = BasicConv2dself.conv = conv_block(in_channels, 128, kernel_size=1)self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x: Tensor) -> Tensor:# aux1: N x 512 x 14 x 14, aux2: N x 528 x 14 x 14x = F.adaptive_avg_pool2d(x, (4, 4))# aux1: N x 512 x 4 x 4, aux2: N x 528 x 4 x 4x = self.conv(x)# N x 128 x 4 x 4x = torch.flatten(x, 1)# N x 2048x = F.relu(self.fc1(x), inplace=True)# N x 1024x = F.dropout(x, 0.7, training=self.training)# N x 1024x = self.fc2(x)# N x 1000 (num_classes)return x# 将卷积、bn、激活封装成一个函数,其实这里不封装也行,分成3步来写
class BasicConv2d(nn.Module):def __init__(self,in_channels: int,out_channels: int,**kwargs: Any) -> None:super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)self.bn = nn.BatchNorm2d(out_channels, eps=0.001)def forward(self, x: Tensor) -> Tensor:x = self.conv(x)x = self.bn(x)return F.relu(x, inplace=True)

GoogLeNet代码解读相关推荐

  1. 200行代码解读TDEngine背后的定时器

    作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...

  2. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  3. Unet论文解读代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...

  4. Lossless Codec---APE代码解读系列(二)

    APE file 一些概念 APE代码解读系列(一) APE代码解读系列(三) 1. 先要了解APE compression level APE主要有5level, 分别是: CompressionL ...

  5. RT-Thread 学习笔记(五)—— RTGUI代码解读

    ---恢复内容开始--- RT-Thread 版本:2.1.0 RTGUI相关代码解读,仅为自己学习记录,若有错误之处,请告知maoxudong0813@163.com,不胜感激! GUI流程: ma ...

  6. vins 解读_代码解读 | VINS 视觉前端

    AI 人工智能 代码解读 | VINS 视觉前端 本文作者是计算机视觉life公众号成员蔡量力,由于格式问题部分内容显示可能有问题,更好的阅读体验,请查看原文链接:代码解读 | VINS 视觉前端 v ...

  7. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

  8. shfflenetv2代码解读

    shufflenetv2代码解读 目录 shufflenetv2代码解读 概述 shufflenetv2网络结构图 shufflenetv2架构参数 shufflenetv2代码细节分析 概述 shu ...

  9. Inception代码解读

    Inception代码解读 目录 Inception代码解读 概述 Inception网络结构图 inception网络结构框架 inception代码细节分析 概述 inception相比起最开始兴 ...

最新文章

  1. 2014计算机三级网络技术,2014计算机三级网络技术综合题解题思路
  2. 配置SQL Server 2008 镜像
  3. eclipse svn插件安装_Visual SVN和Tortoise SVN的安装简单使用汉化
  4. 正在进行时 Order 1
  5. python123循环结构_来学Python啦,大话循环结构~
  6. PingingLab传世经典系列《CCNA完全配置宝典》-2.7 EIGRP基本配置
  7. 小白用python处理excel文件-刚入门的小白用Python操作excel表格!使工作效率提升一倍不止!...
  8. 阿里云iot事业部一面面经
  9. Unity的C#编程教程_17_Variables 挑战 3 折扣计算器
  10. php guzzle 上传文件,Guzzle 使用文档
  11. win10文件资源管理器卡死未响应的完美解决方法
  12. RAID磁盘阵列详解与维护
  13. 苹果电脑开机慢怎么办 苹果笔记本开机特别慢的处理方法
  14. 深度解析名企项目研发管理成功之路
  15. 拆解「千言数据集:文本相似度」竞赛第一背后的故事
  16. access 分组序号_二级Access数据库备考笔记之报表排序和分组
  17. iOS 第三方登录之 新浪微博登录
  18. 【背包专题】01背包
  19. 电脑远程qq怎么连接服务器未响应,win10系统打开qq提示未响应需要联机检查的还原技巧...
  20. 根据UI图设计的大小换算REM单位以及大屏页面全屏展示

热门文章

  1. cauchy problem of 1st order PDE from Partial Differential Equations
  2. 2021 第三封拒信 来自牛津大学自主智能机器和系统 Autonomous Intelligent Machines and Systems
  3. 快速了解和使用Photon Server
  4. mysql5.6使用profile工具分析sql
  5. 破天荒第一遭 安全公司因玩忽职守被客户告上法庭
  6. 7、ReadWriteLock
  7. (诊断)处理错误fatal error: Python.h: No such file or directory
  8. 深入浅出jQuery (五) 如何自定义UI-Dialog?
  9. 使用个性化Profile代替Session
  10. poj 2985(并查集+线段树求K大数)