网络中的网络(NiN)

LeNet、AlexNet和VGG在设计上的共同之处是:先以由卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。其中,AlexNet和VGG对LeNet的改进主要在于如何对这两个模块加宽(增加通道数)和加深

网络中的网络(NiN)提出了另外一个思路,即串联多个由卷积层和“全连接”层构成的小网络来构建一个深层网络。

NiN块

卷积层的输入和输出通常是四维数组(样本,通道,高,宽),而全连接层的输入和输出则通常是二维数组(样本,特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变换为四维。NiN使用1×11\times 11×1卷积层来替代全连接层,从而使空间信息能够自然传递到后面的层中去。

NiN块是NiN中的基础块。它由一个卷积层加两个充当全连接层的1×11\times 11×1卷积层串联而成。其中第一个卷积层的超参数可以自行设置,而第二和第三个卷积层的超参数一般是固定的。

import time
import torch
from torch import nn, optimdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def nin_block(in_channels, out_channels, kernel_size, stride, padding):blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())return blk

NiN模型

NiN使用卷积窗口形状分别为11×1111\times 1111×11、5×55\times 55×5和3×33\times 33×3的卷积层,每个NiN块后接一个步幅为2、窗口形状为3×33\times 33×3的最大池化层。

NiN去掉了AlexNet最后的3个全连接层,取而代之使用了输出通道数等于标签类别数的NiN块,然后使用全局平均池化层对每个通道中所有元素求平均并直接用于分类。(这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层).

NiN的这个设计的好处是可以显著减小模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。

import torch.nn.functional as F
class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)net = nn.Sequential(nin_block(1, 96, kernel_size=11, stride=4, padding=0),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(96, 256, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(256, 384, kernel_size=3, stride=1, padding=1),nn.MaxPool2d(kernel_size=3, stride=2), nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, stride=1, padding=1),GlobalAvgPool2d(), # 将四维的输出转成二维的输出,其形状为(批量大小, 10)FlattenLayer())

简单说网络结构就是:
卷积(96个11∗11的核)(步长为4)(padding为0)→卷积(96个1∗1的核)(步长为1)→卷积(96个1∗1的核)(步长为1)→.降采样(最大池化)(3∗3的核,步长2)→.卷积(256个5∗5的核)(步长为1)(padding为2)→卷积(256个1∗1的核)(步长为1)→卷积(256个1∗1的核)(步长为1)→.降采样(最大池化)(3∗3的核,步长2)→.卷积(384个3∗3的核)(步长为1)(padding为1)→卷积(384个1∗1的核)(步长为1)→卷积(384个1∗1的核)(步长为1)→.降采样(最大池化)(3∗3的核,步长2)→.卷积(10个3∗3的核)(步长为1)(padding为1)→卷积(10个1∗1的核)(步长为1)→卷积(10个1∗1的核)(步长为1)→.降采样(平均池化)(x.size的核,步长1)\begin{matrix}卷积 \\ (96个11*11的核) \\(步长为4) \\(padding为0)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (96个1*1的核) \\(步长为1)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (96个1*1的核) \\(步长为1)\end{matrix} \rightarrow \\.\\ \begin{matrix}降采样(最大池化) \\ (3*3的核,步长2) \end{matrix}\rightarrow \\.\\ \begin{matrix}卷积 \\ (256个5*5的核) \\(步长为1)\\(padding为2)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (256个1*1的核) \\(步长为1)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (256个1*1的核) \\(步长为1)\end{matrix} \rightarrow \\.\\ \begin{matrix}降采样(最大池化) \\ (3*3的核,步长2) \end{matrix}\rightarrow \\.\\ \begin{matrix}卷积 \\ (384个3*3的核) \\(步长为1)\\(padding为1)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (384个1*1的核) \\(步长为1)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (384个1*1的核) \\(步长为1)\end{matrix} \rightarrow \\.\\ \begin{matrix}降采样(最大池化) \\ (3*3的核,步长2) \end{matrix}\rightarrow \\.\\ \begin{matrix}卷积 \\ (10个3*3的核) \\(步长为1)\\(padding为1)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (10个1*1的核) \\(步长为1)\end{matrix} \rightarrow \begin{matrix}卷积 \\ (10个1*1的核) \\(步长为1)\end{matrix} \rightarrow \\.\\ \begin{matrix}降采样(平均池化) \\ (x.size的核,步长1) \end{matrix} 卷积(96个11∗11的核)(步长为4)(padding为0)​→卷积(96个1∗1的核)(步长为1)​→卷积(96个1∗1的核)(步长为1)​→.降采样(最大池化)(3∗3的核,步长2)​→.卷积(256个5∗5的核)(步长为1)(padding为2)​→卷积(256个1∗1的核)(步长为1)​→卷积(256个1∗1的核)(步长为1)​→.降采样(最大池化)(3∗3的核,步长2)​→.卷积(384个3∗3的核)(步长为1)(padding为1)​→卷积(384个1∗1的核)(步长为1)​→卷积(384个1∗1的核)(步长为1)​→.降采样(最大池化)(3∗3的核,步长2)​→.卷积(10个3∗3的核)(步长为1)(padding为1)​→卷积(10个1∗1的核)(步长为1)​→卷积(10个1∗1的核)(步长为1)​→.降采样(平均池化)(x.size的核,步长1)​

构建一个数据样本来查看每一层的输出形状。

X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children(): X = blk(X)print(name, 'output shape: ', X.shape)
0 output shape:  torch.Size([1, 96, 54, 54])
1 output shape:  torch.Size([1, 96, 26, 26])
2 output shape:  torch.Size([1, 256, 26, 26])
3 output shape:  torch.Size([1, 256, 12, 12])
4 output shape:  torch.Size([1, 384, 12, 12])
5 output shape:  torch.Size([1, 384, 5, 5])
6 output shape:  torch.Size([1, 384, 5, 5])
7 output shape:  torch.Size([1, 10, 5, 5])
8 output shape:  torch.Size([1, 10, 1, 1])
9 output shape:  torch.Size([1, 10])

训练:

def train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
batch_size = 128
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)lr, num_epochs = 0.002, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

(pytorch-深度学习系列)网络中的网络(NiN)相关推荐

  1. PyTorch 深度学习实战 | 基于生成式对抗网络生成动漫人物

    生成式对抗网络(Generative Adversarial Network, GAN)是近些年计算机视觉领域非常常见的一类方法,其强大的从已有数据集中生成新数据的能力令人惊叹,甚至连人眼都无法进行分 ...

  2. 笔记|(b站)刘二大人:pytorch深度学习实践(代码详细笔记,适合零基础)

    pytorch深度学习实践 笔记中的代码是根据b站刘二大人的课程所做的笔记,代码每一行都有注释方便理解,可以配套刘二大人视频一同使用. 用PyTorch实现线性回归 # 1.算预测值 # 2.算los ...

  3. 【动手学深度学习PyTorch版】19 网络中的网络 NiN

    上一篇请移步[动手学深度学习PyTorch版]18 使用块的网络 VGG_水w的博客-CSDN博客 目录 一.网络中的网络 NiN 1.1 NiN ◼ 全连接层的问题 ◼ 大量的参数会带来很多问题 ◼ ...

  4. Pytorch深度学习实战教程:UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 本文的开发环境如下: 开发环境:Windows 开发语言:Python3. ...

  5. Pytorch 深度学习实战教程(二):UNet语义分割网络

    本文 GitHub https://github.com/Jack-Cherish/PythonPark 已收录,有技术干货文章,整理的学习资料,一线大厂面试经验分享等,欢迎 Star 和 完善. 一 ...

  6. Pytorch深度学习实战教程(二):UNet语义分割网络

    1 前言 本文属于Pytorch深度学习语义分割系列教程. 该系列文章的内容有: Pytorch的基本使用 语义分割算法讲解 如果不了解语义分割原理以及开发环境的搭建,请看该系列教程的上一篇文章< ...

  7. pytorch | 深度学习分割网络U-net的pytorch模型实现

    原文:https://blog.csdn.net/u014722627/article/details/60883185 pytorch | 深度学习分割网络U-net的pytorch模型实现 这个是 ...

  8. 【深度学习】快照集成等网络训练优化算法系列

    [深度学习]快照集成等网络训练优化算法系列 文章目录 1 什么是快照集成? 2 什么是余弦退火学习率? 3 权重空间中的解决方案 4 局部与全局最优解 5 特别数据增强 6 机器学习中解决数据不平衡问 ...

  9. 手把手教你搭建pytorch深度学习网络

    总有人在后台问我,如今 TensorFlow 和 PyTorch 两个深度学习框架,哪个更流行? 就这么说吧,今年面试的实习生,问到常用的深度学习框架时,他们清一色的选择了「PyTorch」. 这并不 ...

  10. 点云深度学习系列博客(二): 点云配准网络PCRNet

    目录 一. 简介 二. 基础结构 三. 项目代码 四. 实验结果 总结 Reference 今天的点云深度学习系列博客为大家介绍一个用于点云配准的深度网络:PCRNet [1].凡是对点云相关应用有些 ...

最新文章

  1. 惊艳亮相!马斯克发布自研超算 Dojo 芯片、特斯拉人形机器人
  2. .net安装_无需安装Python,就可以在.NET里调用Python库
  3. list、tuple、set、dict 四大数据结构
  4. python对象编程例子-python 面向对象编程 类和实例
  5. local_listener 与 remote_listener 参数说明
  6. 关联规则(Association Rules)笔记
  7. 9.2 mnist_with_summaries tensorboard 可视化展示
  8. mongodb同时更新一条记录_MongoDB 存储和优化系列一
  9. c语言试题c组卡片换位,蓝桥杯 卡片换位 - 李韬|aitom|机器人|SLAM - OSCHINA - 中文开源技术交流社区...
  10. OpenShift 4 - Knative教程 (5) Eventing之Source和Sink
  11. 【C语言】C语言常量和变量
  12. oracle java耗cpu_ORACLE高手请看过来,CPU使用率100% (100分)
  13. LINUX命令之stat及显示的三个时间戳
  14. [转]AndroidManifest.xml文件详解
  15. Apache启用GZIP压缩网页传输方法
  16. 【好】Paxos以及分布式一致性的学习
  17. 滑动平均_善杰告诉您初中物理学滑动变阻器的各种作用
  18. 计算可靠度编制matlab,工程结构可靠度计算的Matlab实现
  19. 2008r2 请检查名称的拼写_甲状腺素、甲状腺激素、T3、T4…这些名称你分得清吗?...
  20. Application loader:ERROR ITMS-90168: The binary you uploaded was invalid.

热门文章

  1. android 模糊读取文件名_Android 从路径中获取文件名 | 学步园
  2. 学校计算机教室的用途,录播教室有什么功能,又有哪些用途
  3. LwIP应用开发笔记之一:LwIP无操作系统基本移植
  4. [运维]---linux机器一般监控用到的概念记录
  5. 手机连接投影机的步骤_投影机安装过程详解
  6. java程序怎么都不是一个_java运行的流程-怎么运行java程序编了一个程序不知道怎么运行郁闷啊后缀文件名是 爱问知识人...
  7. java接口文档生成工具_接口文档生成
  8. java包图标是文件_关于更换.jar文件默认图标
  9. python编辑elif显示错误_Python运行的17个时新手常见错误小结
  10. http://www.od85c.com.cn/html/,OllyDbg script for unpacking Enigma 4.xx and 5.xx