NiN模型

  • 1. NiN模型介绍
    • 1.1 NiN模型结构
    • 1.2 NiN结构与VGG结构的对比
  • 2. PyTorch实现
    • 2.1 导入相应的包
    • 2.2 定义NiN block
    • 2.3 全局最大池化层
    • 2.4 训练网络

1. NiN模型介绍

1.1 NiN模型结构

NiN模型即Network in Network模型,最早是由论文Network In Network(Min Lin, ICLR2014).提出的。这篇文章有两个很重要的观点:

  • 1×1卷积的使用
    文中提出使用mlpconv网络层替代传统的convolution层。mlp层实际上是卷积加传统的mlp(多层感知机),因为convolution是线性的,而mlp是非线性的,后者能够得到更高的抽象,泛化能力更强。在跨通道(cross channel,cross feature map)情况下,mlpconv等价于卷积层+1×1卷积层,所以此时mlpconv层也叫cccp层(cascaded cross channel parametric pooling)。
  • CNN网络中不使用FC层(全连接层)
    文中提出使用Global Average Pooling取代最后的全连接层,因为全连接层参数多且易过拟合。做法即移除全连接层,在最后一层(文中使用mlpconv)层,后面加一层Average Pooling层。
    以上两点,之所以重要,在于,其在较大程度上减少了参数个数,确能够得到一个较好的结果。而参数规模的减少,不仅有利用网络层数的加深(由于参数过多,网络规模过大,GPU显存等不够用而限制网络层数的增加,从而限制模型的泛化能力),而且在训练时间上也得到改进。
    线性卷积层和mlpconv层的区别如图所示:

    下图是NiN网络结构:

    第一个卷积核是11x11x3x96,因此在一个patch块上卷积的输出是1x1x96的feature map(一个96维的向量)。 在其后又接了一个MLP层,输出仍然是96。 因此这个MLP层就等价于一个1 x 1 的卷积层, 这样工程上任然按照之前的方式实现,不需要额外工作。
1.2 NiN结构与VGG结构的对比

LeNet、AlexNet和VGG:

  • 先以由卷积层构成的模块充分抽取 空间特征,再以由全连接层构成的模块来输出分类结果。

NiN:

  • 串联多个由卷积层和“全连接”层构成的小⽹络来构建⼀个深层⽹络。
    ⽤了输出通道数等于标签类别数的NiN块,然后使⽤全局平均池化层对每个通道中所有元素求平均并直接⽤于分类。

1×1卷积核作用

  • 放缩通道数:通过控制卷积核的数量达到通道数的放缩。
  • 增加非线性。1×1卷积核的卷积过程相当于全连接层的计算过程,并且还加入了非线性激活函数,从而可以增加网络的非线性。
  • 计算参数少

NiN网络的特点

  • NiN重复使⽤由卷积层和代替全连接层的1×1卷积层构成的NiN块来构建深层⽹络。
  • NiN去除了容易造成过拟合的全连接输出层,而是将其替换成输出通道数等于标签类别数 的NiN块和全局平均池化层。
  • NiN的以上设计思想影响了后⾯⼀系列卷积神经⽹络的设计。

2. PyTorch实现

2.1 导入相应的包
import time
import torch
from torch import nn, optim
import torchvision
import numpy as np
import sys
sys.path.append("/home/kesci/input/")
import d2lzh1981 as d2l
import os
import torch.nn.functional as F
2.2 定义NiN block
'''
参数:
in_channels: 输入通道数
out_channels:输出通道数
kernel_size: 卷积核尺寸
stride: 步幅
padding: 填充
'''
def nin_block(in_channels, out_channels, kernel_size, stride, padding):blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())return blk
2.3 全局最大池化层
# 已保存在d2lzh_pytorch
class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])
# 构建NiN网络
net = nn.Sequential(nin_block(1, 96, kernel_size=11, stride=4, padding=0),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(96, 256, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(256, 384, kernel_size=3, stride=1, padding=1),nn.MaxPool2d(kernel_size=3, stride=2), nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, stride=1, padding=1),GlobalAvgPool2d(), # 将四维的输出转成二维的输出,其形状为(批量大小, 10)d2l.FlattenLayer())

生成随机输入X,观察网络的结构:

X = torch.rand(1, 1, 224, 224)
for name, blk in net.named_children(): X = blk(X)print(name, 'output shape: ', X.shape)

输出结果为:

2.4 训练网络
batch_size = 128
# 如出现“out of memory”的报错信息,可减小batch_size或resize
#train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
# 定义学习率和训练周期
lr, num_epochs = 0.002, 5
# 优化函数为Adam
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

动手学深度学习(PyTorch实现)(十)--NiN模型相关推荐

  1. 动手学深度学习(PyTorch实现)(八)--AlexNet模型

    AlexNet模型 1. AlexNet模型介绍 1.1 AlexNet的特点 1.2 AlexNet的结构 1.3 AlexNet参数数量 2. AlexNet的PyTorch实现 2.1 导入相应 ...

  2. 动手学深度学习(PyTorch实现)(七)--LeNet模型

    LeNet模型 1. LeNet模型 2. PyTorch实现 2.1 模型实现 2.2 获取数据与训练 1. LeNet模型 LeNet分为卷积层块和全连接层块两个部分.下面我们分别介绍这两个模块. ...

  3. 动手学深度学习(PyTorch实现)(十二)--批量归一化(BatchNormalization)

    批量归一化-BatchNormalization 1. 前言 2. 批量归一化的优势 3. BN算法介绍 4. PyTorch实现 4.1 导入相应的包 4.2 定义BN函数 4.3 定义BN类 5. ...

  4. 动手学深度学习(PyTorch实现)(十三)--ResNet模型

    ResNet模型 1. ResNet介绍 2. ResNet结构 3. ResNet的PyTorch实现 3.1 导入所需要的包 3.2 构建ResNet网络 3.3 开始训练 注:本文部分内容参考博 ...

  5. 动手学深度学习(PyTorch实现)(十一)--GoogLeNet模型

    GoogLeNet模型 1. GoogLeNet介绍 1.1 背景 1.2 GoogLeNet网络结构 2. PyTorch实现 2.1 导入相应的包 2.2 定义Inception块结构 2.3 定 ...

  6. 动手学深度学习(PyTorch实现)(九)--VGGNet模型

    VGGNet模型 1. VGGNet模型介绍 1.1 VGGNet的结构 1.2 VGGNet结构举例 2. VGGNet的PyTorch实现 2.1 导入相应的包 2.2 基本网络单元block 2 ...

  7. 【动手学深度学习PyTorch版】6 权重衰退

    上一篇移步[动手学深度学习PyTorch版]5 模型选择 + 过拟合和欠拟合_水w的博客-CSDN博客 目录 一.权重衰退 1.1 权重衰退 weight decay:处理过拟合的最常见方法(L2_p ...

  8. 【动手学深度学习PyTorch版】19 网络中的网络 NiN

    上一篇请移步[动手学深度学习PyTorch版]18 使用块的网络 VGG_水w的博客-CSDN博客 目录 一.网络中的网络 NiN 1.1 NiN ◼ 全连接层的问题 ◼ 大量的参数会带来很多问题 ◼ ...

  9. 伯禹公益AI《动手学深度学习PyTorch版》Task 05 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 05 学习笔记 Task 05:卷积神经网络基础:LeNet:卷积神经网络进阶 微信昵称:WarmIce 昨天打了一天的<大革 ...

最新文章

  1. 我都陪你坐了一天了,你好歹说句话啊!吖的,谱也忒大了。。。
  2. Android 保存图片到系统及相关问题的解决方案
  3. 在DevExpress中使用CameraControl控件进行摄像头图像采集
  4. python语音播报-用Python写一个语音播放软件
  5. 为EasyUI 的Tab 标签添加右键菜单
  6. 前置交换机数据交换_我们的数据科学交换所
  7. query登录linux命令,在Linux系统中使用sqlcmd命令连接与查询SQL Server
  8. SQL.H 通过此文件寻找sqlAPI编程的一种捷径
  9. spring-boot-starter-parent的主要作用
  10. Linux下mysql主从复制配置(CentOS7)
  11. 【SELinux】vendor_file_contexts没有被编译到vendor/etc/selinux/路径下
  12. 橱柜衣柜 sketchup草图大师设计全屋定制家具意义?谈单拆单生产一起做了?-有屋软件
  13. 消息中间件TongLinkQ(TLQ)使用总结——记那几天趟过的坑
  14. 守望先锋中的netcode_如何跟踪守望先锋中的化妆品和事件物品
  15. 音频系统POP音的原理和解决方法
  16. Lattice系列FPGA
  17. 解决sysman.mgmt_task_qtable ORA-600 kdsgrp1错误
  18. 国外部分音乐人工智能/音乐科技研究机构科研项目简介
  19. 了解数据的发展历程--大数据简史
  20. XMind商业思维导图——市场营销!

热门文章

  1. SpringBoot+EHcache实现缓存
  2. java基本语法心得_Java学习笔记(一)——基础语法(上)
  3. 2021年下半年网络工程师下午真题及答案解析
  4. HTML5---新标签与特性
  5. Cookie、token、session的区别是什么?
  6. 猴子选大王--约瑟夫问题浅析
  7. java jar包与配置文件的写法
  8. 【2】最简单的Laravel5.1程序分析
  9. DEDE留言板调用导航的方法
  10. [转]Objective-C 语言特性