参考文档:https://mp.weixin.qq.com/s/1TtPWYqVkj2Gaa-3QrEG1A

这篇文章是在一个大家经常见到的数据集 MNIST 上实现一个简单的 CNN。我们会基于上一篇文章中的分类器,来讨论实现一个 CNN,需要在之前的内容上做出哪些升级。

这篇笔记的内容包含三个部分:读取 pytorch 自带的数据集并分割;实现一个 CNN 的网络结构;完成训练。这三个部分合起来完成了一个简单的浅层卷积神经网络,在 MNIST 上进行训练和测试。

1、读取自带数据集并分割

train_data = torchvision.datasets.MNIST(root='./mnist',train=True,transform=torchvision.transforms.ToTensor(),download=True
)

以这个为例,我们就可以知道如何从 pytorch 中直接使用自带的数据集。而且 pytorch 包含了很多常用的数据集(我们比较熟悉的如 MNIST,cifar家族,ImageNet 等),所以熟悉一下如何使用自带的数据集也是非常有帮助的。

导入自带的数据集,按照上面代码的读取路径为:torchvision.datasets,在这一步后面跟上选择自己想要的数据集,这里我们选择了 MNIST。

这里的几个参数也很简单,root 是读取数据的路径,如果没有就去下载;train 为 True 表示读取的是训练集,反之就是测试集;transform 表示了将图片转换为 tensor;最后的 download 参数来设置是否从网上下载数据集。

import torch.utils.data as Data
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)test_data = torchvision.datasets.MNIST(root='./mnist',train=False,transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1)/255.
test_y = test_data.targets

利用 torch 中 data 模块的 DataLoader 函数,我们可以设置一个数据加载器。设置的参数有三个,dataset 是前面读入的数据集;batch_size 就是表面意思,每批处理的数据的数量;shuffle 表示是否打乱数据。

下面的 test_data 和前面读取 train_data 是相同的方法,只不过将参数 train 调整为 False。接着 test_x 通过 unsqueeze 为数据加了一个维度,然后对所有数据除了一个 255。这里的 unsqueeze 的原因是,直接读取的数据没有通道数,所以加了这样一个维度。除以 255 是因为测试集的数据都是 0-255 之间,这样子可以将其压缩到 0-1 上。

至此我们完成了 MNIST 数据集的读取,同时分割好了训练集和测试集。而 pytorch 自带的其它数据集也可以参考这类方法进行读取。接下来我们开始学习如何在 pytorch 上定义一个 CNN 网络。

2、实现一个CNN的网络结构

class Net(torch.nn.Module):# def __init__(self, n_feature, n_hidden, n_output):#     super(Net, self).__init__()#     self.n_hidden = torch.nn.Linear(n_feature, n_hidden)#     self.out = torch.nn.Linear(n_hidden, n_output)def __init__(self, n_feature, n_hidden, n_output):super(Net, self).__init__()self.classify = torch.nn.Sequential(torch.nn.Linear(n_feature, n_hidden),torch.nn.ReLU(),torch.nn.Linear(n_hidden, n_output),)def forward(self, x_layer):# x_layer = torch.relu(self.n_hidden(x_layer))# x_layer = self.out_layer)x_layer = self.classify(x_layer)x_layer = torch.nn.functional.softmax(x_layer)return x_layer

在开始一个 CNN 开始前,我们先介绍一下 Sequential() 函数。来回想一下上篇文章的分类,我们的网络结构如上面注释的内容一样,我们需要在 init() 中定义每个网络结构,然后在 forward() 中再按照想要的顺序去一一调用。为了简化这个步骤,我们可以使用 torch.nn.Sequential() 来简化这个步骤。
可以看到在 init() 中我们定义了一个 classify,在后面的 Sequential() 函数中,我们按照原来的顺序,将我们需要的函数按序排好,在下面的 forward() 里面直接调用一个 classify() 就可以了。

现在我们明白了 Sequential() 的用法,接下来,我们看一下,如何使用 Sequential() 来定义 CNN 的网络结构。

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=2,padding=2,),nn.ReLU(),nn.MaxPool2d(2))self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2),nn.ReLU())self.out = nn.Linear(32 * 7 * 7, 10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(x.size(0), -1)output = self.out(x)return output

以第一个 conv1 为例,其中 Sequential() 包含了 3 个函数:nn.Conv2d(),nn.ReLU(),nn.MaxPool2d()。我们简单的说一下对应的内容。

nn.Conv2d() 是实现一个卷积层的函数,这也就是用 pytorch 这类框架的原因,相关的神经网络结构已经封装好,我们直接调用就可以了(但是真的不能只当个调包侠哦~~)。在这里我们给出来每个参数对应的名字:

  • in_channels 是输入的通道数,如果是 RGB 图像,这里的 in_channels 就是 3;

  • out_channels 是输出通道,也就是卷积核数量;

  • kernel_size 是卷积核的大小;

  • stride 就是字面意思,卷积过程中的步长;

  • padding 的参数表示如果做 padding 操作的话,填充的宽度。

nn.ReLU() 函数就是进行一个简单的 relu 激活层。

nn.MaxPool2d() 表示进行一个池化操作,其中的参数就是池化时的核大小。

这样就是一个标准的卷积操作了,先进行一个卷积层,再经过激活层,然后做一次池化,这三个步骤合起来我们封装为一个 conv1;conv2 也是同样的道理,只是最后去掉了池化层;然后两次卷积完了以后后面跟一个全连接层。

具体的参数计算方法我们就不做讨论了,这个是理论方面的基础要求,不是我们 pytorch 学习笔记的重点。

最后可以看到在 forward() 中就可以直接调用分别两个 conv1 和 conv2 就完成了操作,而如果没有使用 Sequential() 封装的话,我们就需要把这些激活层,池化层等等都一个一个写下来。

在最后 output 前,我们看到用了一个 x.view() 函数,这里是将第 0 维的 batch_size 保留,然后将其余的数据拉伸成 1 维数据,也就是从 (32,7,7)的维度拉伸为(32 * 7 * 7)。方便进行最后的全连接。

整个CNN的网络结构如下:

3、完成训练

cnn = CNN()optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()time_start = time.time()
for epoch in range(EPOCH):for step, (b_x, b_y) in enumerate(train_loader):output = cnn(b_x)loss = loss_func(output, b_y)optimizer.zero_grad()loss.backward()optimizer.step()# 打印训练过程if step % 50 == 0:test_output = cnn(test_x)pred_y = torch.max(test_output, 1)[1].data.numpy()accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))print('Epoch: ', epoch, '| train loss: %.4f' % loss.data, '| test accuracy: %.2f' % accuracy)

这里的训练过程和前面的分类器基本差不多,梯度优化方式选择了 Adam,损失函数还是交叉熵。在训练部分依然和上篇分类中的三板斧一样:梯度清空,误差传递,逐步更新。

唯一有区别的是,这次我们加入了一个训练过程,每过 50 步,打印一次当前训练的误差,以及当前训练的网络在测试集上的精度。输出结果为:

4、总结

本篇笔记总结了如何使用 pytorch 实现一个简单的 CNN 网络结构,并且在 pytorch 自带的数据集 MNIST 上进行训练和测试。这也是一个我们经常见到的入门数据集。

主要从三方面来帮助我们可以自己完成一个卷积神经网络:如何调用和下载 pytorch 自带的数据集;如何使用 Sequential 来实现一个 CNN 的网络的各部分;以及如何在数据集上进行训练和测试。

pytorch学习笔记(2):在MNIST上实现一个CNN相关推荐

  1. pytorch学习笔记(1):开始一个简单的分类器

    参考文档:https://mp.weixin.qq.com/s/wj8wxeaGblJijiHFZA6lXQ 回想了一下自己关于 pytorch 的学习路线,一开始找的各种资料,写下来都能跑,但是却没 ...

  2. tensorflow学习笔记五:mnist实例--卷积神经网络(CNN)

    mnist的卷积神经网络例子和上一篇博文中的神经网络例子大部分是相同的.但是CNN层数要多一些,网络模型需要自己来构建. 程序比较复杂,我就分成几个部分来叙述. 首先,下载并加载数据: import ...

  3. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  4. 深度学习入门之PyTorch学习笔记:多层全连接网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 3.1 PyTorch基础 3.2 线性模型 3.2.1 问题介绍 3.2.2 一维线性回归 3.2 ...

  5. pytorch学习笔记(4):tensorboard可视化

    参考文档:https://mp.weixin.qq.com/s/UYnBRU2b0InzM9H1xl4b4g 在之前的第二篇笔记中,我们实现了一个 CNN 网络,在 mnist 上通过两个卷积层完成分 ...

  6. PyTorch学习笔记(四):PyTorch基础实战

    PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...

  7. 莫烦pytorch学习笔记5

    莫烦pytorch学习笔记5 1 自编码器 2代码实现 1 自编码器 自编码,又称自编码器(autoencoder),是神经网络的一种,经过训练后能尝试将输入复制到输出.自编码器(autoencode ...

  8. PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call

    您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...

  9. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

最新文章

  1. linux工程师前景_小猿圈预测2019年Linux云计算发展前景
  2. 【Qt】Qt5.x移植后的环境配置(imx6)
  3. 制胜人工智能时代——企业人工智能应用现状分析(第三版)
  4. 【vs开发】向图形界面程序添加控制台
  5. Omi v1.0震撼发布 - 令人窒息的Web组件化框架
  6. js学习总结----浏览器滚动条卷去的高度scrolltop
  7. DT大数据梦工厂 第67讲
  8. Understanding Unix/Linux Programming-ls指令练习二
  9. es5 html片段拼接,es5的 reduce怎样用在拼接html字符串??? - 社区 - 妙味课堂
  10. 魔兽名字显示服务器,魔兽世界怀旧服服务器名称
  11. Excel合并单元格读取
  12. ZBrush的用途是什么
  13. 计算机系统应用是不是核心期刊,计算机系统应用
  14. 笔记本电脑键盘按键有两个功能,如何切换
  15. Docker Mirror
  16. to redirect to a secure protocol (like HTTPS) or allow insecure protocols.
  17. 把搜狗输入法词库导入Google拼音输入法
  18. 用css控制点击穿透
  19. 2021-05-30 win10 找不到wifi,提示适配器的驱动程序可能出现问题
  20. Thinkphp6如何跨域请求

热门文章

  1. 在 TMG 更新中心中使用 WSUS进行每日的定义更新
  2. 2013年5月16日星期四初始sqlserver附加数据库权限及maven和selenium
  3. 360扫地机原理大揭秘,竟还有无人驾驶技术?——浅析家用机器人SLAM方案
  4. 通过@Import注解把类注入容器的四种方式
  5. 最近一段时间遇到的费了时间的问题
  6. Spring Boot拦截器(WebMvcConfigurerAdapter)
  7. 手工编译Linux内核rpm包
  8. 用etcd实现比Redis更安全的分布式锁
  9. Laravel核心解读--Session源码解析
  10. MySQL面试题1:MySQL架构体系相关