GoogLeNet模型

  • 1. GoogLeNet介绍
    • 1.1 背景
    • 1.2 GoogLeNet网络结构
  • 2. PyTorch实现
    • 2.1 导入相应的包
    • 2.2 定义Inception块结构
    • 2.3 定义GoogLeNet网络
    • 2.4 训练

1. GoogLeNet介绍

1.1 背景

GoogLeNet是谷歌(Google)研究出来的深度网络结构,为什么不叫“GoogleNet”,而叫“GoogLeNet”,据说是为了向“LeNet”致敬,因此取名为“GoogLeNet”。

2014年,GoogLeNetVGG是当年ImageNet挑战赛(ILSVRC14)的双雄,GoogLeNet获得了第一名、VGG获得了第二名,这两类模型结构的共同特点是层次更深了。VGG继承了LeNet以及AlexNet的一些框架结构,而GoogLeNet则做了更加大胆的网络结构尝试,虽然深度只有22层,但大小却比AlexNet和VGG小很多,GoogleNet参数为500万个,AlexNet参数个数是GoogleNet的12倍,VGGNet参数又是AlexNet的3倍,因此在内存或计算资源有限时,GoogleNet是比较好的选择;从模型结果来看,GoogLeNet的性能却更加优越。

1.2 GoogLeNet网络结构

谷歌提出了最原始Inception的基本结构,既能增加神经网络表现,又能保证计算资源的使用效率:

该结构将CNN中常用的卷积(1x1,3x3,5x5)、池化操作(3x3)堆叠在一起(卷积、池化后的尺寸相同,将通道相加),一方面增加了网络的宽度,另一方面也增加了网络对尺度的适应性。

网络卷积层中的网络能够提取输入的每一个细节信息,同时5x5的滤波器也能够覆盖大部分接受层的的输入。还可以进行一个池化操作,以减少空间大小,降低过度拟合。在这些层之上,在每一个卷积层后都要做一个ReLU操作,以增加网络的非线性特征

然而这个Inception原始版本,所有的卷积核都在上一层的所有输出上来做,而那个5x5的卷积核所需的计算量就太大了,造成了特征图的厚度很大,为了避免这种情况,在3x3前、5x5前、max pooling后分别加上了1x1的卷积核,以起到了降低特征图厚度的作用,这也就形成了Inception最初的的网络结构,如下图所示:

其中1x1的卷积核的作用:

  • 减少维度,并用于修正线性激活(ReLU):
    比如,上一层的输出为100x100x128,经过具有256个通道的5x5卷积层之后(stride=1,pad=2),输出数据为100x100x256,其中,卷积层的参数为128x5x5x256= 819200。而假如上一层输出先经过具有32个通道的1x1卷积层,再经过具有256个输出的5x5卷积层,那么输出数据仍为为100x100x256,但卷积参数量已经减少为128x1x1x32 + 32x5x5x256= 204800,大约减少了4倍。

GoogLeNet 由Inception基础块组成。 Inception块相当于⼀个有4条线路的⼦⽹络。它通过不同窗口形状的卷积层和最⼤池化层来并⾏抽取信息,并使⽤1×1卷积层减少通道数从而降低模型复杂度。 可以⾃定义的超参数是每个层的输出通道数,我们以此来控制模型复杂度。

GoogLeNet模型完整模型结构为:

2. PyTorch实现

2.1 导入相应的包
import time
import torch
from torch import nn, optim
import torchvision
import numpy as np
import sys
sys.path.append("/home/kesci/input/")
import d2lzh1981 as d2l
import os
import torch.nn.functional as F
2.2 定义Inception块结构
class Inception(nn.Module):# c1 - c4为每条线路里的层的输出通道数def __init__(self, in_c, c1, c2, c3, c4):super(Inception, self).__init__()# 线路1,单1 x 1卷积层self.p1_1 = nn.Conv2d(in_c, c1, kernel_size=1)# 线路2,1 x 1卷积层后接3 x 3卷积层self.p2_1 = nn.Conv2d(in_c, c2[0], kernel_size=1)self.p2_2 = nn.Conv2d(c2[0], c2[1], kernel_size=3, padding=1)# 线路3,1 x 1卷积层后接5 x 5卷积层self.p3_1 = nn.Conv2d(in_c, c3[0], kernel_size=1)self.p3_2 = nn.Conv2d(c3[0], c3[1], kernel_size=5, padding=2)# 线路4,3 x 3最大池化层后接1 x 1卷积层self.p4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)self.p4_2 = nn.Conv2d(in_c, c4, kernel_size=1)# 每一层都要经过ReLU激活函数def forward(self, x):p1 = F.relu(self.p1_1(x))p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))p4 = F.relu(self.p4_2(self.p4_1(x)))return torch.cat((p1, p2, p3, p4), dim=1)  # 在通道维上连结输出
2.3 定义GoogLeNet网络
# b1是7x7卷积层 + 3x3最大池化层
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# b2是1x1卷积层 + 3x3卷积层 + 3x3最大池化层
b2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1),nn.Conv2d(64, 192, kernel_size=3, padding=1),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# b3是Inception块x2 + 3x3最大池化层
b3 = nn.Sequential(Inception(192, 64, (96, 128), (16, 32), 32),Inception(256, 128, (128, 192), (32, 96), 64),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# b4是Inception块x5 + 3x3最大池化层
b4 = nn.Sequential(Inception(480, 192, (96, 208), (16, 48), 64),Inception(512, 160, (112, 224), (24, 64), 64),Inception(512, 128, (128, 256), (24, 64), 64),Inception(512, 112, (144, 288), (32, 64), 64),Inception(528, 256, (160, 320), (32, 128), 128),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
# b5是Inception块x2 + 全局平均池化层
b5 = nn.Sequential(Inception(832, 256, (160, 320), (32, 128), 128),Inception(832, 384, (192, 384), (48, 128), 128),d2l.GlobalAvgPool2d())
# nn.Sequential()连接以上结构并添加一个全连接输出层
net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))
2.4 训练
net = nn.Sequential(b1, b2, b3, b4, b5, d2l.FlattenLayer(), nn.Linear(1024, 10))#batchsize=128
batch_size = 16
# 如出现“out of memory”的报错信息,可减小batch_size或resize
#train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

动手学深度学习(PyTorch实现)(十一)--GoogLeNet模型相关推荐

  1. 动手学深度学习(PyTorch实现)(八)--AlexNet模型

    AlexNet模型 1. AlexNet模型介绍 1.1 AlexNet的特点 1.2 AlexNet的结构 1.3 AlexNet参数数量 2. AlexNet的PyTorch实现 2.1 导入相应 ...

  2. 动手学深度学习(PyTorch实现)(七)--LeNet模型

    LeNet模型 1. LeNet模型 2. PyTorch实现 2.1 模型实现 2.2 获取数据与训练 1. LeNet模型 LeNet分为卷积层块和全连接层块两个部分.下面我们分别介绍这两个模块. ...

  3. 动手学深度学习(PyTorch实现)(十三)--ResNet模型

    ResNet模型 1. ResNet介绍 2. ResNet结构 3. ResNet的PyTorch实现 3.1 导入所需要的包 3.2 构建ResNet网络 3.3 开始训练 注:本文部分内容参考博 ...

  4. 动手学深度学习(PyTorch实现)(九)--VGGNet模型

    VGGNet模型 1. VGGNet模型介绍 1.1 VGGNet的结构 1.2 VGGNet结构举例 2. VGGNet的PyTorch实现 2.1 导入相应的包 2.2 基本网络单元block 2 ...

  5. 动手学深度学习(PyTorch实现)(十)--NiN模型

    NiN模型 1. NiN模型介绍 1.1 NiN模型结构 1.2 NiN结构与VGG结构的对比 2. PyTorch实现 2.1 导入相应的包 2.2 定义NiN block 2.3 全局最大池化层 ...

  6. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  7. 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...

  8. 动手学深度学习Pytorch Task01

    深度学习目前以及未来都有良好的发展前景.正值疫情期间,报名参加了动手学深度学习pytorch版的公开课,希望在以后的学习生活中能够灵活运用学到的这些知识. 第一次课主要包含三个部分:线性回归.soft ...

  9. 伯禹公益AI《动手学深度学习PyTorch版》Task 07 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 07 学习笔记 Task 07:优化算法进阶:word2vec:词嵌入进阶 微信昵称:WarmIce 优化算法进阶 emmmm,讲实 ...

最新文章

  1. Kubernetes通过containerd访问registry的30443端口
  2. Python基础教程:字符串中split与rsplit的方法原理教程
  3. Spring注解方式配置切面类
  4. 指定端口传输_高速数字传输链路测试 - 高速数字电路仿真设计与测试技术发展趋势综述(二)...
  5. c++输出小数点后几位_2.1 怎么在屏幕上输出各种类型的数据
  6. FCN Caffe:可视化featureMaps和Weights(C++)、获取FCN结果
  7. 3 View - 状态保持 session
  8. Vue 组件开发 - 数据输入框组件
  9. 5-6pooling层
  10. java实现省市区的联动,chosen实现省市区三级联动
  11. 实录:VCS双机使用DiskReservation资源导致多路径失效
  12. App测试Android的闪退总结
  13. Android 2048游戏开发
  14. C# 调用C/C++动态链接库,结构体中的char*类型
  15. Oracle11g在Windows和Linux下imp导入表,exp导出表,sqluldr2导出表,sqlldr导入表
  16. 《编译原理》陈火旺——词法分析程序c语言实现完整版
  17. Linux部署单体架构,从单体式架构迁移到微服务架构:三个策略叙述
  18. 计算机语言怎么学,教你如何学习计算机编程语言
  19. 算法——0~1之间浮点实数的二进制表示
  20. 对Android apk 签名 --apksigner

热门文章

  1. mysql where substr_mysql – 在WHERE子句中使用substr的SELECT语句
  2. mysql前一天_mysql查询当天,前一天,一周,一个月
  3. java window的对象方法,[Java教程]如何真正重写window对象的方法_星空网
  4. Linux系统日常维护命令
  5. 信安教程第二版-第6章认证技术原理与应用
  6. 2小时c++与ros教学
  7. 用js实现千位分隔符
  8. laydate日期插件使用
  9. Android 控制ScrollView滚动到底部
  10. python多维数组初始化后赋值的问题