前言

本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试。针对过程中的每个步骤都尽可能的给出了详尽的解释。

有什么问题可以评论区留言。欢迎各路大神指教。

导入包

import 

其中cv2需要安装库opencv,用于图片可视化

导入数据集

train_dataset = datasets.MNIST(root = 'data/', train = True, transform = transforms.ToTensor(), download = True)
test_dataset = datasets.MNIST(root = 'data/', train = False, transform = transforms.ToTensor(), download = True)

使用torchvision中的datasets自动下载数据集

root表示存放在当前目录下'data'文件夹中

train=True表示导入的是训练数据;train=False表示导入的是测试数据。

transform表示对每个数据进行的变化,这里是将其变为Tensor。Tensor是pytorch中存储数据的主要格式,类似于numpy,两者可相互转换。

dowload表示是否下载数据

数据装载

train_loader = DataLoader(dataset = train_dataset, batch_size = 100, shuffle = True)
test_loader = DataLoader(dataset = test_dataset, batch_size= 100, shuffle = True

使用DataLoader加载数据集。

dataset表示加载的数据集。

batch_size表示将多少个数据划分为一个batch,也就是一次性喂给模型多少个数据。

shuffle表示是否打乱数据顺序。

DataLoader还有其他参数,有兴趣可以自行搜索。

数据可视化

images

使用iter(train_loader)得到train_loader的迭代对象,next()得到迭代对象的值,并且将迭代对象指向下一个值。

make_grid将若干副图像拼接成一副,nrow表示每一行多少图像,padding表示子图像直接的距离。

拼接之后的图像第一维是channel数3,通过transpose将其变换到第三维。

使用cv2中的imshow展示图像,并等待按下任意键后图像消失。

模型定义

class Model(torch.nn.Module):def __init__(self) :super(Model, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 64, 3, 1, 1),torch.nn.ReLU(),torch.nn.Conv2d(64, 128, 3, 1, 1),torch.nn.ReLU(),torch.nn.MaxPool2d(2, 2)) self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128, 1024),torch.nn.ReLU(),torch.nn.Dropout(p = 0.5),torch.nn.Linear(1024, 10))def forward(self, x) :x = self.conv1(x)x = x.view(-1, 14*14*128)x = self.dense(x)return x

定义的模型Model从torch.nn.Module继承而来,初始化时需要初始化父类。

第一个卷积网络由两个卷积层、两个ReLU层、一个MaxPool层构成,第一个卷积层使用了64个3*3的卷积核,步长为1,填充数为1,计算得输出大小为

(
),有64个channel。同理第二层输出大小为28,有128个channel。第三层使用MaxPool函数,2*2的卷积核,步长为2,输出为
。输出总大小为

然后将第一个网络压缩到一维,输入第二个dense层。Linear是类似于Ax+b的函数,可以设置bias=True or False表示是否有偏置值b。Linear的输出大小是1024。Dropout确保没有过拟合。最后通过Linear再输出10维(0~9是十个数字)。

在forward里面定义前向传播,首先经过第一个卷积网络,然后压缩到一维,最终输入dense层获得最终结果。

损失函数与优化方法

device 

使用cuda进行训练,loss定义为交叉熵,使用Adam方法进行优化。

交叉熵(Cross Entropy)的公式为

Adam(Adaptive moment estimation)方法是RMSProp和Momentum方法的结合。

RMSProp方法是一种自适应调整学习率的方法,根据遗忘因子累加之前所有梯度平方,更新学习率。

(以下定义

Momentum是通过上一次的更新发现增强或削弱梯度。如果梯度方向与上一次相同,则加强;不同,则削减。

Adam则是将RMSProp方法中的

换做一阶矩估计,
换做二阶矩估计,并修正其偏差。

训练

if __name__ == "__main__":epochs = 5for epoch in range(epochs) :# trainsum_loss = 0.0train_correct = 0for data in train_loader:inputs, lables = datainputs, lables = Variable(inputs).cuda(), Variable(lables).cuda()optimizer.zero_grad()outputs = model(inputs)loss = cost(outputs, lables)loss.backward()optimizer.step()_, id = torch.max(outputs.data, 1)sum_loss += loss.datatrain_correct += torch.sum(id == lables.data)print('[%d,%d] loss:%.03f' % (epoch + 1, epochs, sum_loss / len(train_loader)))print('        correct:%.03f%%' % (100 * train_correct / len(train_dataset)))

训练了5次,每次计算loss和准确度。

zero_grad()将上一个batch的梯度清零,以免梯度累加造成错误。

利用backward()进行反向传播计算梯度,optimizer.step()进行梯度下降。

torch.max(a,b)对a中的固定第b维的情况下,计算最大值,返回最大值及其索引。这里是固定outputs的列,对行求最大值。outputs返回的值可以看作是归属每个类的概率,取最大概率作为最终结果。

最后累加loss并计算准确度。

测试

model.eval()test_correct = 0for data in test_loader:inputs, lables = datainputs, lables = Variable(inputs).cuda(), Variable(lables).cuda()outputs = model(inputs)_, id = torch.max(outputs.data, 1)test_correct += torch.sum(id == lables.data)print("correct:%.3f%%" % (100 * test_correct / len(test_dataset)))

测试需要用到model的eval()模式,以免将测试数据也用于训练。

最终结果

1,5] loss:0.004correct:99.000%
[2,5] loss:0.005correct:99.000%
[3,5] loss:0.004correct:99.000%
[4,5] loss:0.003correct:99.000%
[5,5] loss:0.003correct:99.000%
correct:98.000%

从结果上看还可以,准确率有98%。

可以使用数据增强等方法提高训练准确度。

除此之外,也可以将训练结果保存在本地,下次训练直接从文件中load结果。

torch.save(model.state_dict(), "parameter.pkl") #save
model.load_state_dict(torch.load('parameter.pkl')) #load

pytorch保存准确率_初学Pytorch:MNIST数据集训练详解相关推荐

  1. 【小白学习PyTorch教程】十一、基于MNIST数据集训练第一个生成性对抗网络

    「@Author:Runsen」 GAN 是使用两个神经网络模型训练的生成模型.一种模型称为生成网络模型,它学习生成新的似是而非的样本.另一个模型被称为判别网络,它学习区分生成的例子和真实的例子. 生 ...

  2. 508任务一:用pytorch简单实现LeNet5网络对MNIST数据集训练

    看了一些pytorch教学视频,结合别人的代码,按自己的喜好写出来了比较简单的实现,其实还可以把loss数据绘个表,还可以在训练时2加个循环,多训练几次. import torch import to ...

  3. 京东主图怎么保存原图_京东自营怎么做?详解京东平台操作方法

    随着2020年实体行业的受损,越来越多企业想要做出改变以打破目前的僵局.于是大家都盯上了受疫情冲击最小的电商行业.但是做电商自己完全没经验啊,该怎么办呢? 这时候就有人想到了京东自营,一个你只要往仓库 ...

  4. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  5. DL之DNN:自定义MultiLayerNet【6*100+ReLU,SGD】对MNIST数据集训练进而比较【多个超参数组合最优化】性能

    DL之DNN:自定义MultiLayerNet[6*100+ReLU,SGD]对MNIST数据集训练进而比较[多个超参数组合最优化]性能 目录 输出结果 设计思路 核心代码 输出结果 val_acc: ...

  6. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  7. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...

  8. DL之CNN可视化:利用SimpleConvNet算法【3层,im2col优化】基于mnist数据集训练并对卷积层输出进行可视化

    DL之CNN可视化:利用SimpleConvNet算法[3层,im2col优化]基于mnist数据集训练并对卷积层输出进行可视化 导读 利用SimpleConvNet算法基于mnist数据集训练并对卷 ...

  9. DL之DNN:利用MultiLayerNetExtend模型【6*100+ReLU+SGD,dropout】对Mnist数据集训练来抑制过拟合

    DL之DNN:利用MultiLayerNetExtend模型[6*100+ReLU+SGD,dropout]对Mnist数据集训练来抑制过拟合 目录 输出结果 设计思路 核心代码 更多输出 输出结果 ...

最新文章

  1. ios Develop mark
  2. 海思3559与全志a83t比较
  3. 效率系列(四) VS常用快捷键
  4. c 语言epc编码如何解开,EPC编码结构
  5. Django Cookie于Session
  6. java学习(145):file常用方法2
  7. go语言linux下开发工具,LiteIDE 开发工具指南 (Go语言开发工具)
  8. 【class2src】Decompiler
  9. Java学习笔记10(面向对象三:接口)
  10. 关于Spring3 MVC的 HttpMediaTypeNotSupportedException
  11. 服务器cp所以型号,云服务器cp
  12. 极客大学架构师训练营 听课总结 - 架构视图,设计文档 -- 第二课
  13. 生产追溯系统-IQC来料检验
  14. 进程间通讯的5种方式
  15. 【数学 博弈论】JZOJ_3339 wyl8899和法法塔的游戏
  16. 经历“海潮效应”,云图如何成为智能家居界的苹果?
  17. J-Link RTT使用
  18. KISSY基础篇乄KISSY之IO前奏
  19. pthread_cancel pthread_testcancel测试
  20. Ocean Color数据批量下载——海洋物理分布式活动档案中心PO.DAAC

热门文章

  1. mysql查询分数前三个_Mysql 单表查询各班级总分前三名
  2. Java集合查找Map,java:使用hashmap或其他一些java集合创建查找...
  3. 如何将一个字典转换为玲阶矩阵_基础渲染系列(一)图形学的基石——矩阵
  4. 网传:Vue涉及国家安全漏洞?尤雨溪亲自发文回应!
  5. 京东加淘宝,羊毛有点大
  6. 每日一皮:Bug 变 Feature !惊不惊喜,意不意外,刺不刺激!
  7. Java中如何锁文件
  8. 1024 大促书单丨神券在手,快乐我有
  9. kali Linux 屏幕旋转,MSF基础命令新手指南
  10. RandLA-Net测试