2020年10月5号,依然在家学习。
今天是我写的第四个 Pytorch程序, 这一次我想把之前基于PyTorch实现的简易的传统的BP全连接神经网络改写成CNN网络,想看看对比和效果差异。

这一次我设计的是一个两个卷积层,两个全连接层的网络,模型如下描述。

# 输入数据维度是[batch_size, 1, 28, 28], 批量是batch_size, 每个Img的通道是1, 图片大小是28*28# 第一层卷积核维度是[1, 10, 5 ,5], 输入通道是1, 输出通道是10, 卷积核大小是 5*5, (默认padding=0, stride = 1)
# 第一层卷积后数据维度是: [batch_size, 10, 24, 24]
# 第一层池化层: 采用的是MaxPooling, 大小是2*2.
# 第一层池化后数据维度是: [batch_size, 10, 12, 12]
# 第一层的激活函数是Relu# 第二层卷积核维度是[10, 20, 5 ,5], 输入通道是10, 输出通道是20, 卷积核大小是 5*5, (默认padding=0, stride = 1)
# 第二层卷积后数据维度是: [batch_size, 20, 8, 8]
# 第二层池化层: 采用的是MaxPooling, 大小是2*2.
# 第二层池化后数据维度是: [batch_size, 20, 4, 4]
# 第二层的激活函数是Relu# 为了连接到FC层,需要将数据重新改变维度到 [batch_size, 20*4*4=320]# 第三层FC层: [batch_size, 320, 100]
# 第三层激活函数ReLu
# 第四层FC层: [batch_size, 100, 10]

核心步骤描述如下:

1:建立的网络模型如下:

结构十分清晰,完全就是按照上述的设计组建的模型,激活函数使用ReLu。

2:MNIST数据集采用运行阶段在网上下载的方式,如果指定目录已经存在该数据集,就会忽略掉download参数,跳过下载。

这里由于用到的是CNN网络,因此需要将mnist数据的图像转成[Channel, Height, Width] 的维度,我们选择的是[1, 28, 28],因此在下载数据的时候需要特殊处理一下。

3:损失函数使用之前学习的交叉熵损失函数,梯度下降算则随机梯度下降。

4:为了我们方便观察整个训练的过程,我们在每一次迭代结束,都会现场使用模型去测试数据集上现场运行一把,看看实际的预测效果如何,分别记录下每次迭代过程中训练损失值,训练准确度,测试损失值,测试准确度,并且房补画图展示出来。

话不多说,我就直接上代码实例,代码的注释我都是用中文直接写的。

# -*- coding: utf-8 -*"""
Created on Fri Jul 27 17:47:03 2018@author: Administrator
"""
import numpy as np
import torch
from torchvision.datasets import mnist  # 导入 pytorch 内置的 mnist 数据from torch import nn
from torch.autograd import Variablefrom torch.utils.data import DataLoader
import matplotlib.pyplot as plt# Step 1:============================准备数据===================
# 定义一个对图像像素数据的标准化处理函数
# 变换到0~255的范围,在变换到0~1的范围
# 对数据进行标准化
# 对图像数据从矩阵形式变成一个 W*H的一维向量
def data_tf(img):img = np.array(img, dtype='float32') / 255img = (img - 0.5) / 0.5  # 标准化,img = img.reshape((1, 28, 28))  # 形成图像数据,也就是矩阵数据img = torch.from_numpy(img)return img# 先来准备数据
# 使用内置函数下载 mnist 数据集,并且使用自定义的标准化函数对数据进行标准化
# download 参数是表明数据是要从网上下载么?如果该目录下已经存在数据集,就不会再下载了。
train_set = mnist.MNIST('./data', train = True, transform=data_tf, download = True)
test_set = mnist.MNIST('./data', train = False, transform=data_tf, download = True)
firstImg, firstImg_label = train_set[0]  # a为训练数据第一个的图像数据,a_label为训练数据第一个的标签
# 训练数据数量是60000
print(train_set)
# 测试数据数量是10000
print(test_set)
# 打印出第一个图像和其标签的值
print(firstImg.shape)
print(firstImg_label)# DataLoader本质上就是一个iterable(跟python的内置类型list等一样),并利用多进程来加速batch data的处理,使用yield来使用有限的内存
# 使用 pytorch 自带的 DataLoader 定义一个数据迭代器,也就是将数据进行排序标号,shuffle也就是打乱数据
# DataLoader是一个高效,简洁,直观的网络输入数据结构,便于使用和扩展
# 这种方式能加快数据计算速度,减少训练时间。
train_data = DataLoader(train_set, batch_size=64, shuffle=True)  # 训练数据
test_data = DataLoader(test_set, batch_size=128, shuffle=False)  # 测试数据
# 这里展示的是一个批量处理的数据,想象成之前学习的mini-batch,每次迭代处理一个小批量的数据。
# 训练数据是64个图像为一组数据,维度是[64, 1, 28, 28]
batch, batch_label = next(iter(train_data))
# 打印出一个批次的数图像和其标签,主要为了展示维度。
print(batch.shape)
print(batch_label.shape)# Step 2:============================定义模型===================
# 定义一个类,继承自 torch.nn.Module,torch.nn.Module是callable的类
# 这里设计的CNN模型如下:# 输入数据维度是[batch_size, 1, 28, 28], 批量是batch_size, 每个Img的通道是1, 图片大小是28*28# 第一层卷积核维度是[1, 10, 5 ,5], 输入通道是1, 输出通道是10, 卷积核大小是 5*5, (默认padding=0, stride = 1)
# 第一层卷积后数据维度是: [batch_size, 10, 24, 24]
# 第一层池化层: 采用的是MaxPooling, 大小是2*2.
# 第一层池化后数据维度是: [batch_size, 10, 12, 12]
# 第一层的激活函数是Relu# 第二层卷积核维度是[10, 20, 5 ,5], 输入通道是10, 输出通道是20, 卷积核大小是 5*5, (默认padding=0, stride = 1)
# 第二层卷积后数据维度是: [batch_size, 20, 8, 8]
# 第二层池化层: 采用的是MaxPooling, 大小是2*2.
# 第二层池化后数据维度是: [batch_size, 20, 4, 4]
# 第二层的激活函数是Relu# 为了连接到FC层,需要将数据重新改变维度到 [batch_size, 20*4*4=320]# 第三层FC层: [batch_size, 320, 100]
# 第三层激活函数ReLu
# 第四层FC层: [batch_size, 100, 10]class CNNModel(torch.nn.Module):def __init__(self):# 调用父类的初始化函数,必须要的super(CNNModel, self).__init__()# 两个卷积池化层,两个全连接层self.conv1 = nn.Sequential(nn.Conv2d(1, 10, kernel_size=5), nn.MaxPool2d(2), nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(10, 20, kernel_size=5), nn.MaxPool2d(2), nn.ReLU())self.fc1 = nn.Sequential(nn.Linear(320, 100), nn.ReLU())self.fc2 = nn.Sequential(nn.Linear(100, 10))def forward(self, img):# 得到这一次运算时多少批次的,也就是保留第一个batch_size这个维度值。batch_size = img.size(0)# 卷积和池化层img = self.conv1(img)img = self.conv2(img)# 维度转化img = img.view(batch_size, -1)# 全连接层img = self.fc1(img)img = self.fc2(img)return img# 创建和实例化一个整个模型类的对象
model = CNNModel()
# 打印出整个模型
print(model)# Step 3:============================定义损失函数和优化器===================
# 定义 loss 函数,这里用的是交叉熵损失函数(Cross Entropy),这种损失函数之前博文也讲过的。
criterion = nn.CrossEntropyLoss()
# 我们优先使用随机梯度下降,lr是学习率: 0.1
optimizer = torch.optim.SGD(model.parameters(), 1e-1)# Step 4:============================开始训练网络===================
# 为了实时观测效果,我们每一次迭代完数据后都会,用模型在测试数据上跑一次,看看此时迭代中模型的效果。
# 用数组保存每一轮迭代中,训练的损失值和精确度,也是为了通过画图展示出来。
train_losses = []
train_acces = []
# 用数组保存每一轮迭代中,在测试数据上测试的损失值和精确度,也是为了通过画图展示出来。
eval_losses = []
eval_acces = []for e in range(20):# 4.1==========================训练模式==========================train_loss = 0train_acc = 0model.train()   # 将模型改为训练模式# 每次迭代都是处理一个小批量的数据,batch_size是64for im, label in train_data:im = Variable(im)label = Variable(label)# 计算前向传播,并且得到损失函数的值out = model(im)loss = criterion(out, label)# 反向传播,记得要把上一次的梯度清0,反向传播,并且step更新相应的参数。optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.item()# 计算分类的准确率_, pred = out.max(1)num_correct = (pred == label).sum().item()acc = num_correct / im.shape[0]train_acc += acctrain_losses.append(train_loss / len(train_data))train_acces.append(train_acc / len(train_data))# 4.2==========================每次进行完一个训练迭代,就去测试一把看看此时的效果==========================# 在测试集上检验效果eval_loss = 0eval_acc = 0model.eval()  # 将模型改为预测模式# 每次迭代都是处理一个小批量的数据,batch_size是128for im, label in test_data:im = Variable(im)  # torch中训练需要将其封装即Variable,此处封装像素即784label = Variable(label)  # 此处为标签out = model(im)  # 经网络输出的结果loss = criterion(out, label)  # 得到误差# 记录误差eval_loss += loss.item()# 记录准确率_, pred = out.max(1)  # 得到出现最大值的位置,也就是预测得到的数即0—9num_correct = (pred == label).sum().item()  # 判断是否预测正确acc = num_correct / im.shape[0]  # 计算准确率eval_acc += acceval_losses.append(eval_loss / len(test_data))eval_acces.append(eval_acc / len(test_data))print('epoch: {}, Train Loss: {:.6f}, Train Acc: {:.6f}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'.format(e, train_loss / len(train_data), train_acc / len(train_data),eval_loss / len(test_data), eval_acc / len(test_data)))plt.title('train loss')
plt.plot(np.arange(len(train_losses)), train_losses)
plt.plot(np.arange(len(train_acces)), train_acces)
plt.title('train acc')
plt.plot(np.arange(len(eval_losses)), eval_losses)
plt.title('test loss')
plt.plot(np.arange(len(eval_acces)), eval_acces)
plt.title('test acc')
plt.show()

这里有一些输出,我们解释下:

上图展示的是,原始图像数据中,训练数据有60000个,测试数据有10000个,其中第一个训练数据图像是[1, 28, 28]维度的向量,该图像代表的数字是5。经过DataLoader后,训练数据每一批量的数据是64个图像,是[64, 1, 28, 28]维度的矩阵。

上图展示此CNN模型的设计样貌。都是参数模型,一共有四层参数。两个卷积层,两个全连接层。

通过打印,以及结合最后看我们的画出来图像,可见此时模型的准确率已经达到了99.1%,比BP神经网络还要高一个百分点,已经将误差缩小到一个百分点内了。

《Pytorch - CNN模型》相关推荐

  1. ComeFuture英伽学院——2020年 全国大学生英语竞赛【C类初赛真题解析】(持续更新)

    视频:ComeFuture英伽学院--2019年 全国大学生英语竞赛[C类初赛真题解析]大小作文--详细解析 课件:[课件]2019年大学生英语竞赛C类初赛.pdf 视频:2020年全国大学生英语竞赛 ...

  2. ComeFuture英伽学院——2019年 全国大学生英语竞赛【C类初赛真题解析】大小作文——详细解析

    视频:ComeFuture英伽学院--2019年 全国大学生英语竞赛[C类初赛真题解析]大小作文--详细解析 课件:[课件]2019年大学生英语竞赛C类初赛.pdf 视频:2020年全国大学生英语竞赛 ...

  3. 信息学奥赛真题解析(玩具谜题)

    玩具谜题(2016年信息学奥赛提高组真题) 题目描述 小南有一套可爱的玩具小人, 它们各有不同的职业.有一天, 这些玩具小人把小南的眼镜藏了起来.小南发现玩具小人们围成了一个圈,它们有的面朝圈内,有的 ...

  4. 信息学奥赛之初赛 第1轮 讲解(01-08课)

    信息学奥赛之初赛讲解 01 计算机概述 系统基本结构 信息学奥赛之初赛讲解 01 计算机概述 系统基本结构_哔哩哔哩_bilibili 信息学奥赛之初赛讲解 02 软件系统 计算机语言 进制转换 信息 ...

  5. 信息学奥赛一本通习题答案(五)

    最近在给小学生做C++的入门培训,用的教程是信息学奥赛一本通,刷题网址 http://ybt.ssoier.cn:8088/index.php 现将部分习题的答案放在博客上,希望能给其他有需要的人带来 ...

  6. 信息学奥赛一本通习题答案(三)

    最近在给小学生做C++的入门培训,用的教程是信息学奥赛一本通,刷题网址 http://ybt.ssoier.cn:8088/index.php 现将部分习题的答案放在博客上,希望能给其他有需要的人带来 ...

  7. 信息学奥赛一本通 提高篇 第六部分 数学基础 相关的真题

    第1章   快速幂 1875:[13NOIP提高组]转圈游戏 信息学奥赛一本通(C++版)在线评测系统 第2 章  素数 第 3 章  约数 第 4 章  同余问题 第 5 章  矩阵乘法 第 6 章 ...

  8. 信息学奥赛一本通题目代码(非题库)

    为了完善自己学c++,很多人都去读相关文献,就比如<信息学奥赛一本通>,可又对题目无从下手,从今天开始,我将把书上的题目一 一的解析下来,可以做参考,如果有错,可以告诉我,将在下次解析里重 ...

  9. 信息学奥赛一本通(C++版) 刷题 记录

    总目录详见:https://blog.csdn.net/mrcrack/article/details/86501716 信息学奥赛一本通(C++版) 刷题 记录 http://ybt.ssoier. ...

  10. 最近公共祖先三种算法详解 + 模板题 建议新手收藏 例题: 信息学奥赛一本通 祖孙询问 距离

    首先什么是最近公共祖先?? 如图:红色节点的祖先为红色的1, 2, 3. 绿色节点的祖先为绿色的1, 2, 3, 4. 他们的最近公共祖先即他们最先相交的地方,如在上图中黄色的点就是他们的最近公共祖先 ...

最新文章

  1. Java 工具集 Hutool 4.0.8 发布
  2. 009_调色盘和高亮样式
  3. g2o求解BA 第10章
  4. Win7系统隐藏文件恢复的方法
  5. WORD中插入的公式与文字对不齐——公式比文字高——文字比公式低
  6. 资深面试官解答:大厂月薪过20K的测试工程师,都需要满足哪些要求?
  7. 基于DEAP库的python进化算法-7.多目标遗传算法NSGA-II
  8. 软件测试人员必备工具介绍--如何滚屏抓取图片-SnagIt篇(图)
  9. 只用一招!Python实现微信防撤回!
  10. 游戏图形引擎中 Shader Systen 的设计
  11. matlab 运动控制系统设计与实现,电力传动控制系统:运动控制系统
  12. Vtokendapp公链诠释
  13. APISpace 汉字转拼音API 方便好用
  14. 计算机主板运算放大器电路,常用运算放大器电路图 (全集)
  15. 【渝粤教育】国家开放大学2019年春季 1366英语教学理论与实践 参考试题
  16. UIPATH 结合 Python 识别 PDF 中的表格
  17. 臻图信息构建数字孪生港口船舶停靠管理系统,赋能港口创新发展
  18. MacOS Big Sur Beta 测评|使用体验|有哪些BUG?|如何安装?|实际体验如何?|WWDC2020
  19. 如何下载Office365离线安装程序包并手动安装?
  20. eclipse项目感叹号

热门文章

  1. Outlook 2013 电子邮件账户设置备份与恢复
  2. 在 java 的 bin 目录下,jdk 提供了很多使用的工具,下面学习一些监控和故障处理的工具。...
  3. 使用Telnet命令收发E-mail
  4. 《高质量c++/c编程指南》学习摘要
  5. 汇编语言---函数调用栈
  6. [转]动态加载jar文件
  7. linux运维初级课前实战随机考试题含答案(笔试+上机)
  8. EBGP邻居抖动问题
  9. Spring AMQP ActiveMQ教程
  10. 如何在CentOS上设置MariaDB Galera Cluster 10.0