卷积神经网络就是含卷积层的网络。介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet [1]。这个名字来源于LeNet论文的第一作者Yann LeCun。LeNet展示了通过梯度下降训练卷积神经网络可以达到手写数字识别在当时最先进的结果。这个奠基性的工作第一次将卷积神经网络推上舞台,为世人所知。LeNet的网络结构如下图所示。

LeNet分为卷积层块和全连接层块两个部分。下面我们分别介绍这两个模块。

卷积层块里的基本单位是卷积层后接最大池化层:卷积层用来识别图像里的空间模式,如线条和物体局部,之后的最大池化层则用来降低卷积层对位置的敏感性。卷积层块由两个这样的基本单位重复堆叠构成。在卷积层块中,每个卷积层都使用$5\times 5$的窗口,并在输出上使用sigmoid激活函数。第一个卷积层输出通道数为6,第二个卷积层输出通道数则增加到16。这是因为第二个卷积层比第一个卷积层的输入的高和宽要小,所以增加输出通道使两个卷积层的参数尺寸类似。卷积层块的两个最大池化层的窗口形状均为$2\times 2$,且步幅为2。由于池化窗口与步幅形状相同,池化窗口在输入上每次滑动所覆盖的区域互不重叠。

卷积层块的输出形状为(批量大小, 通道, 高, 宽)。当卷积层块的输出传入全连接层块时,全连接层块会将小批量中每个样本变平(flatten)。也就是说,全连接层的输入形状将变成二维,其中第一维是小批量中的样本,第二维是每个样本变平后的向量表示,且向量长度为通道、高和宽的乘积。全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。

下面我们通过Sequential类来实现LeNet模型。

import time
import torch
from torch import nn, optimimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv = nn.Sequential(nn.Conv2d(1, 6, 5), # in_channels, out_channels, kernel_sizenn.Sigmoid(),nn.MaxPool2d(2, 2), # kernel_size, stridenn.Conv2d(6, 16, 5),nn.Sigmoid(),nn.MaxPool2d(2, 2))self.fc = nn.Sequential(nn.Linear(16*4*4, 120),nn.Sigmoid(),nn.Linear(120, 84),nn.Sigmoid(),nn.Linear(84, 10))def forward(self, img):feature = self.conv(img)output = self.fc(feature.view(img.shape[0], -1))return output

接下来查看每个层的形状。

net = LeNet()
print(net)

输出

LeNet((conv): Sequential((0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(1): Sigmoid()(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(4): Sigmoid()(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fc): Sequential((0): Linear(in_features=256, out_features=120, bias=True)(1): Sigmoid()(2): Linear(in_features=120, out_features=84, bias=True)(3): Sigmoid()(4): Linear(in_features=84, out_features=10, bias=True))
)

可以看到,在卷积层块中输入的高和宽在逐层减小。卷积层由于使用高和宽均为5的卷积核,从而将高和宽分别减小4,而池化层则将高和宽减半,但通道数则从1增加到16。全连接层则逐层减少输出个数,直到变成图像的类别数10。

下面我们来实验LeNet模型。实验中,我们使用Fashion-MNIST作为训练数据集。

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size=batch_size)

因为卷积神经网络计算比多层感知机要复杂,建议使用GPU来加速计算

def evaluate_accuracy(data_iter, net, device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')):acc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式else: # 自定义的模型,  不考虑GPUif('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n

确保计算使用的数据和模型同在内存或显存上。

def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()batch_count = 0for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, start = 0.0, 0.0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))

学习率采用0.001,训练算法使用Adam算法,损失函数使用交叉熵损失函数。

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0072, train acc 0.322, test acc 0.584, time 3.7 sec
epoch 2, loss 0.0037, train acc 0.649, test acc 0.699, time 1.8 sec
epoch 3, loss 0.0030, train acc 0.718, test acc 0.724, time 1.7 sec
epoch 4, loss 0.0027, train acc 0.741, test acc 0.746, time 1.6 sec
epoch 5, loss 0.0024, train acc 0.759, test acc 0.759, time 1.7 sec

小结

  • 卷积神经网络就是含卷积层的网络。
  • LeNet交替使用卷积层和最大池化层后接全连接层来进行图像分类。

PyTorch实战福利从入门到精通之七——卷积神经网络(LeNet)相关推荐

  1. PyTorch实战福利从入门到精通之四——卷积神经网络CIFAR-10图像分类

    在本教程中,我们将使用CIFAR10数据集.它有类别:"飞机"."汽车"."鸟"."猫"."鹿".& ...

  2. PyTorch实战福利从入门到精通之五——搭建ResNet

    Kaiming He的深度残差网络(ResNet)在深度学习的发展中起到了很重要的作用,ResNet不仅一举拿下了当年CV下多个比赛项目的冠军,更重要的是这一结构解决了训练极深网络时的梯度消失问题. ...

  3. PyTorch实战福利从入门到精通之三——autograd

    autograd 反向传播过程需要手动实现.这对于像线性回归等较为简单的模型来说,还可以应付,但实际使用中经常出现非常复杂的网络结构,此时如果手动实现反向传播,不仅费时费力,而且容易出错,难以检查.t ...

  4. PyTorch实战福利从入门到精通之一——PyTorch框架安装

    使用conda安装是最不容易出错的,在pytroch的官网可以选择自己需要的操作系统.python版本.cuda版本的pytorch框架. 之后复制下面的命令就可以了 安装完这个还要安个numpy p ...

  5. PyTorch实战福利从入门到精通之六——线性回归

    一元线性回归 一元线性模型非常简单,假设我们有变量 xix_ixi​ 和目标 yiy_iyi​,每个 i 对应于一个数据点,希望建立一个模型 y^i=wxi+b\hat{y}_i = w x_i + ...

  6. PyTorch实战福利从入门到精通之八——深度卷积神经网络(AlexNet)

    在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机.虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意.一方面,神经网络计算复 ...

  7. PyTorch实战福利从入门到精通之九——数据处理

    在解决深度学习问题的过程中,往往需要花费大量的精力去处理数据,包括图像.文本.语音或其它二进制数据等.数据的处理对训练神经网络来说十分重要,良好的数据处理不仅会加速模型训练,更会提高模型效果.考虑到这 ...

  8. PyTorch实战福利从入门到精通之二——Tensor

    Tensor又名张量,也是Tensorflow等框架中的重要数据结构.它可以是一个数(标量),一维数组(向量),二维数组或更高维数组.Tensor支持GPU加速. 创建Tensor 几种常见创建Ten ...

  9. 从入门到精通:卷积神经网络初学者指南

    转载自:http://www.jiqizhixin.com/article/1363?utm_source=tuicool&utm_medium=referral 这是一篇向初学者讲解卷积神经 ...

最新文章

  1. Markdown语法整理
  2. vb冒泡排序法流程图_VB算法-冒泡排序教案
  3. jQuery 表格实现
  4. Unix整理笔记——安全性——里程碑M13
  5. 算法导论读书笔记(7)
  6. [转载] Linux进程基础
  7. 7 php程序的调试方法_PHP 程序员的调试技术
  8. xsmax是大黑边?_苹果iPhone11和xsmax,8p x xr xs怎么选?干货分享!
  9. Hyper-V虚拟光纤通道
  10. linux ata4 serror,linux – 如何将kern.log错误消息中的ataX.0标识符映射到实际的/ dev / sdY设备?...
  11. UML之教学管理系统——4、Rational Rose画活动图
  12. pwntcha库的安装依赖
  13. python|爬虫|爬取豆瓣自己账号下的观影记录并可视化
  14. 奇数点偶数点fft的matlab,电子科大 数字信号处理实验2_FFT的实现
  15. 微信每日定时推送消息新闻到群聊或朋友
  16. matlab雷达目标回波仿真
  17. 一个假冒的序列号被被用来注册internetdownload manager。IDM正在退出解决办法
  18. 利用爬虫下载批量图片
  19. oracle数据库报错,ORA-01652:无法通过128(在表空间TEMP中)扩展temp段
  20. 企业在建站前需要了解的七点

热门文章

  1. SAP License:SAP顾问你算哪根葱?
  2. SAP License:SAP的公司间销售
  3. 记-ItextPDF+freemaker 生成PDF文件---导致服务宕机
  4. spring的一些概念及优点
  5. 2019 Power BI最Top50面试题,助你面试脱颖而出系列中
  6. 模式三工厂——开花结果
  7. luogu P3178 [HAOI2015]树上操作
  8. 数据结构开发(6):静态单链表的实现
  9. jhipster详解
  10. [新手学Java]使用beanUtils控制javabean