LeNet、AlexNet和VGG在设计上的共同之处是:先以由卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。其中,AlexNet和VGG对LeNet的改进主要在于如何对这两个模块加宽(增加通道数)和加深。本节我们介绍网络中的网络(NiN)。它提出了另外一个思路,即串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络。

1. NiN块

我们知道,卷积层的输入和输出通常是四维数组(样本,通道,高,宽),而全连接层的输入和输出则通常是二维数组(样本,特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变换为四维。1×11\times 11×1卷积层。它可以看成全连接层,其中空间维度(高和宽)上的每个元素相当于样本,通道相当于特征。因此,NiN使用1×11\times 11×1卷积层来替代全连接层,从而使空间信息能够自然传递到后面的层中去。图5.7对比了NiN同AlexNet和VGG等网络在结构上的主要区别。

NiN块是NiN中的基础块。它由一个卷积层加两个充当全连接层的1×11\times 11×1卷积层串联而成。其中第一个卷积层的超参数可以自行设置,而第二和第三个卷积层的超参数一般是固定的。

import time
import torch
from torch import nn, optimimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def nin_block(in_channels, out_channels, kernel_size, stride, padding):blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())return blk

2.NiN模型

NiN是在AlexNet问世不久后提出的。它们的卷积层设定有类似之处。NiN使用卷积窗口形状分别为11×1111\times 1111×11、5×55\times 55×5和3×33\times 33×3的卷积层,相应的输出通道数也与AlexNet中的一致。每个NiN块后接一个步幅为2、窗口形状为3×33\times 33×3的最大池化层。

除使用NiN块以外,NiN还有一个设计与AlexNet显著不同:NiN去掉了AlexNet最后的3个全连接层,取而代之地,NiN使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。

# 已保存在d2lzh_pytorch
import torch.nn.functional as F
class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])net = nn.Sequential(nin_block(1, 96, kernel_size=11, stride=4, padding=0),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(96, 256, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(256, 384, kernel_size=3, stride=1, padding=1),nn.MaxPool2d(kernel_size=3, stride=2), nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, stride=1, padding=1),GlobalAvgPool2d(), # 将四维的输出转成二维的输出,其形状为(批量大小, 10)d2l.FlattenLayer())

我们构建一个数据样本来查看每一层的输出形状。

X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children(): X = blk(X)print(name, 'output shape: ', X.shape)

输出:

0 output shape:  torch.Size([1, 96, 54, 54])
1 output shape:  torch.Size([1, 96, 26, 26])
2 output shape:  torch.Size([1, 256, 26, 26])
3 output shape:  torch.Size([1, 256, 12, 12])
4 output shape:  torch.Size([1, 384, 12, 12])
5 output shape:  torch.Size([1, 384, 5, 5])
6 output shape:  torch.Size([1, 384, 5, 5])
7 output shape:  torch.Size([1, 10, 5, 5])
8 output shape:  torch.Size([1, 10, 1, 1])
9 output shape:  torch.Size([1, 10])

5.8.3 获取数据和训练模型

我们依然使用Fashion-MNIST数据集来训练模型。NiN的训练与AlexNet和VGG的类似,但这里使用的学习率更大。

batch_size = 128
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0101, train acc 0.513, test acc 0.734, time 260.9 sec
epoch 2, loss 0.0050, train acc 0.763, test acc 0.754, time 175.1 sec
epoch 3, loss 0.0041, train acc 0.808, test acc 0.826, time 151.0 sec
epoch 4, loss 0.0037, train acc 0.828, test acc 0.827, time 151.0 sec
epoch 5, loss 0.0034, train acc 0.839, test acc 0.831, time 151.0 sec

小结

  • NiN重复使用由卷积层和代替全连接层的1×11\times 11×1卷积层构成的NiN块来构建深层网络。
  • NiN去除了容易造成过拟合的全连接输出层,而是将其替换成输出通道数等于标签类别数的NiN块和全局平均池化层。
  • NiN的以上设计思想影响了后面一系列卷积神经网络的设计。

pytorch学习笔记(二十六):NIN相关推荐

  1. JVM 学习笔记二十六、JVM监控及诊断工具-GUI篇

    二十六.JVM监控及诊断工具-GUI篇 1.工具概述 使用上一张命令行工具或组合能帮您获取目标Java应用性能相关的基础信息,但他们存在下列局限: (1)无法获取方法级别的分析数据,如方法间的调用关系 ...

  2. Jenkins 持续集成 概念(学习笔记二十六)

    持续集成:提交.测试.构建.测试.部署 前不久接触了持续集成(Continuous Integration,CI). 一.持续集成是什么 首先说说"集成"的概念.在实际的软件开发中 ...

  3. 立创eda学习笔记二十六:手把手教你使用立创eda的官方教程

    可以通过以下办法找到教程: 1,在软件界面点帮助-使用教程 2,在网站首页-帮助-教程进入 如何使用教程: 这里是一级目录,其实对新手最有用的是前面3个部分,后面的仿真先不看. 常见问题里面不光是讲的 ...

  4. [傅里叶变换及其应用学习笔记] 二十六. 高维傅里叶变换的推导

    高维意味着函数中有多个变量,典型的高维傅里叶应用为图像处理. 一个二维图像的亮度(灰度)可以用$f(x_1,x_2)$来表示,以lena为例,图像平面作为$x_1,x_2$平面,灰度作为$z$轴,形成 ...

  5. pytorch学习笔记(十六):Parameters

    文章目录 1. 访问模型参数 2. 初始化模型参数 3. 自定义初始化方法 4. 共享模型参数 小结 本节将深入讲解如何访问和初始化模型参数,以及如何在多个层之间共享同一份模型参数. 我们先定义含单隐 ...

  6. Java学习笔记二十六:Java多态中的引用类型转换

    Java多态中的引用类型转换 引用类型转换: 1.向上类型转换(隐式/自动类型转换),是小类型到大类型的转换: 2.向下类型转换(强制类型转换),是大类型到小类型的转换: 3.instanceof运算 ...

  7. OpenCV学习笔记(十六)——CamShift研究 OpenCV学习笔记(十七)——运动分析和物体跟踪Video OpenCV学习笔记(十八)——图像的各种变换(cvtColor*+)imgproc

    OpenCV学习笔记(十六)--CamShift研究 CamShitf算法,即Continuously Apative Mean-Shift算法,基本思想就是对视频图像的多帧进行MeanShift运算 ...

  8. python数据挖掘学习笔记】十六.逻辑回归LogisticRegression分析鸢尾花数据

    但是很多时候数据是非线性的,所以这篇文章主要讲述逻辑回归及Sklearn机器学习包中的LogisticRegression算法 #2018-03-28 16:57:56 March Wednesday ...

  9. python分析鸢尾花数据_python数据挖掘学习笔记】十六.逻辑回归LogisticRegression分析鸢尾花数据...

    但是很多时候数据是非线性的,所以这篇文章主要讲述逻辑回归及Sklearn机器学习包中的LogisticRegression算法 #2018-03-28 16:57:56 March Wednesday ...

  10. PyTorch学习笔记(二)——回归

    PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...

最新文章

  1. 一次公司内部的Tech Talk中涉及到的关于语言的发展问题
  2. C++11新特性之lambda表达式
  3. sdn框架的计算机网络管理,清华SDN实践--SDN 系统架构与数据中心应用
  4. #define va_arg(AP, TYPE)
  5. python 服务注册_python注册Windows服务
  6. 站在面试官角度,看求职与内卷
  7. 45 CO配置-控制-利润中心会计-维护控制范围设置
  8. c语言数据结构判断回文数,C++数据结构与算法之判断一个链表是否为回文结构的方法...
  9. VS2013 产品密钥 – 所有版本
  10. 5)Javascript设计模式:extends模式
  11. C++通过Wininet库提交POST信息登录到PHPChina中文开发者社区
  12. 关于%@ include file= %与jsp:include page=/jsp:include中的那些问题?
  13. 神经网络控制系统设计,神经网络技术及其应用
  14. c语言实验答案周信东综合程序设计,周信东主编最新版-C语言程序设计基础实验一实验报告.doc...
  15. 如何用微pe+msdn进行纯净重装Windows系统
  16. 创建一个员工类(Employee),其中包括:1) 4个私有属性:员工姓名(name)、员工年龄(age)、员工职位(position)、工资(salary)
  17. WINDOWS中hosts文件位置
  18. vue 手写签名_与众不同的手写签批
  19. 团体程序设计天梯赛-练习集 L2-015 互评成绩 (25分)
  20. 终极WordPress安全指南-分步指南(2020)

热门文章

  1. MVC html 控件扩展【转载】
  2. [原创] 共享两个有用的网页布局表格
  3. Vim-latex 插件 的安装
  4. group_concat 排序并取前三个
  5. python网络数据采集(伴奏曲)
  6. ES6--Decorator修饰器
  7. java_web学习(六) request对象中的get和post差异
  8. iOS开发-retain/assign/strong/weak/copy/mutablecopy/autorelease区别
  9. 衔接UI线程和管理后台工作线程的类(多线程、异步调用)[转]
  10. 【数据库系统设计】关系数据库标准语言SQL(3)