前两篇文章分别介绍了卷积层和池化层,卷积和池化是卷积神经网络必备的两大基础。本文我们将介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet[1]。LeNet名字来源于论文的第一作者Yann LeCun。1989年,LeNet使用卷积神经网络和梯度下降法,使得手写数字识别达到当时领先水平。这个奠基性的工作第一次将卷积神经网络推上历史舞台,为世人所知。由于LeNet的出色表现,在很多ATM取款机上,LeNet被用来识别数字字符。

本文基于PyTorch和TensorFlow 2的代码已经放在了我的GitHub上:https://github.com/luweizheng/machine-learning-notes/tree/master/neural-network/cnn。

网络模型结构

LeNet的网络结构如下图所示。

LeNet分为卷积层块和全连接层块两个部分。

卷积层块里的基本单位是卷积层后接最大池化层:卷积层用来识别图像里的空间模式,如线条和物体局部,之后的最大池化层则用来降低卷积层对位置的敏感性。卷积层块由卷积层加池化层两个这样的基本单位重复堆叠构成。在卷积层块中,每个卷积层都使用5×5的窗口,并在输出上使用Sigmoid激活函数。整个模型的输入是1维的黑白图像,图像尺寸为28×28。第一个卷积层输出通道数为6,第二个卷积层输出通道数则增加到16。这是因为第二个卷积层比第一个卷积层的输入的高和宽要小,所以增加输出通道使两个卷积层的参数尺寸类似。卷积层块的两个最大池化层的窗口形状均为2×2,且步幅为2。由于池化窗口与步幅形状相同,池化窗口在输入上每次滑动所覆盖的区域互不重叠。

我们通过PyTorch的Sequential类来实现LeNet模型。

class LeNet(nn.Module):    def __init__(self):        super(LeNet, self).__init__()                # 输入 1 * 28 * 28        self.conv = nn.Sequential(            # 卷积层1            # 在输入基础上增加了padding,28 * 28 -> 32 * 32            # 1 * 32 * 32 -> 6 * 28 * 28            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), nn.Sigmoid(),            # 6 * 28 * 28 -> 6 * 14 * 14            nn.MaxPool2d(kernel_size=2, stride=2), # kernel_size, stride            # 卷积层2            # 6 * 14 * 14 -> 16 * 10 * 10             nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), nn.Sigmoid(),            # 16 * 10 * 10 -> 16 * 5 * 5            nn.MaxPool2d(kernel_size=2, stride=2)        )        self.fc = nn.Sequential(            # 全连接层1            nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.Sigmoid(),            # 全连接层2            nn.Linear(in_features=120, out_features=84), nn.Sigmoid(),            nn.Linear(in_features=84, out_features=10)        )    def forward(self, img):        feature = self.conv(img)        output = self.fc(feature.view(img.shape[0], -1))        return output复制代码

我们有必要梳理一下模型各层的参数。输入形状为通道数为1的图像(1维黑白图像),尺寸为28×28,经过第一个5×5的卷积层,卷积时上下左右都使用了2个元素作为填充,输出形状为:(28 - 5 + 4 + 1) × (28 - 5 + 4 + 1) = 28 × 28。第一个卷积层输出共6个通道,输出形状为:6 × 28 × 28。最大池化层核大小2×2,步幅为2,高和宽都被折半,形状为:6 × 14 × 14。第二个卷积层的卷积核也为5 × 5,但是没有填充,所以输出形状为:(14 - 5 + 1) × (14 - 5 + 1) = 10 × 10。第二个卷积核的输出为16个通道,所以变成了 16 × 10 × 10。经过最大池化层后,高和宽折半,最终为:16 × 5 × 5。

卷积层块的输出形状为(batch_size, output_channels, height, width),在本例中是(batch_size, 16, 5, 5),其中,batch_size是可以调整大小的。当卷积层块的输出传入全连接层块时,全连接层块会将一个batch中每个样本变平(flatten)。原来是形状是:(通道数 × 高 × 宽),现在直接变成一个长向量,向量长度为通道数 × 高 × 宽。在本例中,展平后的向量长度为:16 × 5 × 5 = 400。全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。

训练模型

基于上面的网络,我们开始训练模型。我们使用Fashion-MNIST作为训练数据集,很多框架,比如PyTorc提供了Fashion-MNIST数据读取的模块,我做了一个简单的封装:

def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):    """Use torchvision.datasets module to download the fashion mnist dataset and then load into memory."""    trans = []    if resize:        trans.append(torchvision.transforms.Resize(size=resize))    trans.append(torchvision.transforms.ToTensor())        transform = torchvision.transforms.Compose(trans)    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)    if sys.platform.startswith('win'):        num_workers = 0      else:        num_workers = 4    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)    return train_iter, test_iter复制代码

load_data_fashion_mnist()方法返回训练集和测试集。

在训练过程中,我们希望看到每一轮迭代的准确度,构造一个evaluate_accuracy方法,计算当前一轮迭代的准确度(模型预测值与真实值之间的误差大小):

def evaluate_accuracy(data_iter, net, device=None):    if device is None and isinstance(net, torch.nn.Module):        device = list(net.parameters())[0].device    acc_sum, n = 0.0, 0    with torch.no_grad():        for X, y in data_iter:            if isinstance(net, torch.nn.Module):                # set the model to evaluation mode (disable dropout)                net.eval()                 # get the acc of this batch                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()                # change back to train mode                net.train()             n += y.shape[0]    return acc_sum / n复制代码

接着,我们可以构建一个train()方法,用来训练神经网络:

def try_gpu(i=0):    if torch.cuda.device_count() >= i + 1:        return torch.device(f'cuda:{i}')    return torch.device('cpu')def train(net, train_iter, test_iter, batch_size, optimizer, num_epochs, device=try_gpu()):    net = net.to(device)    print("training on", device)    loss = torch.nn.CrossEntropyLoss()    batch_count = 0    for epoch in range(num_epochs):        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0        for X, y in train_iter:            X = X.to(device)            y = y.to(device)            y_hat = net(X)            l = loss(y_hat, y)            optimizer.zero_grad()            l.backward()            optimizer.step()            train_l_sum += l.cpu().item()            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()            n += y.shape[0]            batch_count += 1        test_acc = evaluate_accuracy(test_iter, net)        if epoch % 10 == 0:            print(f'epoch {epoch + 1} : loss {train_l_sum / batch_count:.3f}, train acc {train_acc_sum / n:.3f}, test acc {test_acc:.3f}')复制代码

在整个程序的主逻辑中,设置必要的参数,读入训练和测试数据并开始训练:

def main():    batch_size = 256    lr, num_epochs = 0.9, 100    net = LeNet()    optimizer = torch.optim.SGD(net.parameters(), lr=lr)        # load data    train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)    # train    train(net, train_iter, test_iter, batch_size, optimizer, num_epochs)复制代码

小结

  1. LeNet是一个最简单的卷积神经网络,卷积神经网络包含卷积块部分和全连接层部分。
  2. 卷积块包括一个卷积层和一个池化层。

参考文献

  1. LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), 2278-2324.
  2. http://d2l.ai/chapter_convolutional-neural-networks/lenet.html
  3. https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.5_lenet

atm取款机的简单程序代码_LeNet:一个简单的卷积神经网络PyTorch实现相关推荐

  1. c语言常用的代码,初学C语言常用简单程序代码;

    <初学C语言常用简单程序代码;>由会员分享,可在线阅读,更多相关<初学C语言常用简单程序代码;(16页珍藏版)>请在人人文库网上搜索. 1.初学C语言常用简单程序代码素数的筛选 ...

  2. java 模拟电梯_请使用的Java的多线程知识来编写一个程序,实现一个简单的摩天大楼的电梯模型程序是以一座摩天大楼的多个电梯为背景,用线程、流程控制、随机函数等知识来模拟它。2、电梯的描述:...

    请使用的Java的多线程知识来编写一个程序,实现一个简单的摩天大楼的电梯模型 程序是以一座摩天大楼的多个电梯为背景,用线程.流程控制.随机函数等知识来模拟它. 2. 电梯的描述: 电梯是日常生活中经常 ...

  3. 入门攻略丨教你用低代码实现一个简单的页面跳转功能

    一.介绍 HUAWEI DevEco Studio(后文简称:IDE)自2020年9月首次发布以来,经10次迭代升级,不断为HarmonyOS应用开发增强能力.3月31日,IDE再度升级到DevEco ...

  4. keras构建卷积神经网络_通过此简单教程学习在网络上构建卷积神经网络

    keras构建卷积神经网络 by John David Chibuk 约翰·大卫·奇布克(John David Chibuk) 通过此简单教程学习在网络上构建卷积神经网络 (Learn to buil ...

  5. EEGNet:一个小型的卷积神经网络,用于基于脑电的脑机接口

    脑机接口(BCI)利用神经活动作为控制信号,可以与计算机直接通信.这种神经信号通常从各种研究充分的脑电图(EEG)信号中选择.对于给定的脑机接口(BCI)范式,特征提取器和分类器是针对其所期望的脑电图 ...

  6. atm取款机的简单程序代码

    /* *********atm取款机********** */#include<iostream> #include<stdlib.h> using namespace std ...

  7. python的简单程序代码_小白学编程?从一个简单的程序开始学习Python编程

    笔者思虑再三还是决定选择图文(因为百家的视频发布画质真不怎么样[囧]). 笔者学习编程的时间也挺长的,因为业余,因为时间不多,各种原因,自学编程的路特别难走.然后笔者发现,自己能为小白贡献一些力量,然 ...

  8. python编写程序输出诗句_Python文本处理简介:44行代码编写一个简单的隐藏诗生成器,python,入门,藏头诗...

    想必最近大家家庭群里最近都会看到这么一张图: 一惊,这什么玩意儿???后来一搜会发现里面不同的诗句来自于不同的古诗,嘛,这不是很好玩的一件事情吗?这次我们使用Github的唐诗宋词dataset:ht ...

  9. python博弈论代码_使用 40 多行的 Python 代码实现一个简单的演化过程

    Python部落(python.freelycode.com)组织翻译,禁止转载,欢迎转发. 在纳米比亚的 PyCon 会议上,我发表了一篇名为 <使用 Python 解决"升级版的剪 ...

最新文章

  1. 根据搜索来路 弹出相应广告
  2. Oracle RAC错误之--oifcfg错误案例
  3. 刻骨铭心的startActivityForResult三级跳获得第三个Activity中返回的数据
  4. 淺談Raid Cache Memory上應用的問題和實踐
  5. ERP的风险及其预防
  6. keepAliveTime和线程工厂
  7. rdf mysql持久化l_Jena 利用数据库保存,持久化本体
  8. 检查坏道右键单击盘符/属性/工具中的查错。
  9. 安卓页面布局中android:gravity与android:layout_gravity的区别
  10. python数字雨_用Python实现黑客帝国代码雨效果(3种方式)
  11. 基于单片机的音乐盒系统设计(#0435)
  12. Python中文字符串,变成英文字符串
  13. 冯诺依曼机与现代计算机的比较
  14. 一个程序猿三个月没有找到工作转去开滴滴
  15. 如何搜索自己CSDN博客中的文章
  16. 旧电脑安装黑群晖(5.1-5022)
  17. oracle中空值的替换,oracle中空值替换,精度空值(保留小数位数),时间转换
  18. Spark 运行架构与原理
  19. TCP滑动窗口模拟实战
  20. Autosar软件架构

热门文章

  1. ASP.NET AJAX - Timer控件之摆放位置的影响
  2. 真格量化——中性策略交易期权
  3. Diango博客--13.将“视图函数”类转化为“类视图”
  4. android生命周期_Android开发 View的生命周期结合代码详解
  5. 【Django】文件上传以及celery的使用
  6. 线程自动退出_C++基础 多线程笔记(一)
  7. 选择排序 冒泡排序 二分查找
  8. 计算几何——圆卡精度cf1059D
  9. 【uoj#207】共价大爷游长沙 随机化+LCT维护子树信息
  10. OpenGL学习笔记-坐标系统