前言

本文主要描述了如何使用现在热度和关注度比较高的Pytorch(深度学习框架)构建一个简单的卷积神经网络,并对MNIST数据集进行了训练和测试。MNIST数据集是一个28*28的手写数字图片集合,使用测试集来验证训练出的模型对手写数字的识别准确率。

PyTorch资料

PyTorch的官方文档链接:PyTorch documentation,在这里不仅有 API的说明还有一些经典的实例可供参考。

PyTorch官网论坛:vision,里面会有很大资料分享和一些热门问题的解答。

PyTorch搭建神经网络实践

在一开始导入需要导入PyTorch的两个核心库文件torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数

import 

其中值得一提的是torchvision的datasets可以很方便的自动下载数据集,这里使用的是MNIST数据集。另外的COCO,ImageNet,CIFCAR等数据集也可以很方的下载并使用,导入命令也非常简单

data_train = datasets.MNIST(root = "./data/",transform=transform,train = True,download = True)data_test = datasets.MNIST(root="./data/",transform = transform,train = False)

root指定了数据集存放的路径,transform指定导入数据集时需要进行何种变换操作,train设置为True说明导入的是训练集合,否则为测试集合。

transform里面还有很多好的方法,可以用在图片资源较少的数据集做Data Argumentation操作,这里只是做了个简单的Tensor格式转换和Batch Normalize

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])

数据下载完成后还需要做数据装载操作

data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size = 64,shuffle = True)data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size = 64,shuffle = True)

batch_size设置了每批装载的数据图片为64个,shuffle设置为True在装载过程中为随机乱序

下图为一个batch数据集(64张图片)的显示,可以看出来都为28*28的1维图片

MNIST数据集图片预览

完成数据装载后就可以构建核心程序了,这里构建的是一个包含了卷积层和全连接层的神经网络,其中卷积层使用torch.nn.Conv2d来构建,激活层使用torch.nn.ReLU来构建,池化层使用torch.nn.MaxPool2d来构建,全连接层使用torch.nn.Linear来构建

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(stride=2,kernel_size=2))self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14*14*128)x = self.dense(x)return x

其中定义了torch.nn.Dropout(p=0.5)防止模型的过拟合

forward函数定义了前向传播,其实就是正常卷积路径。首先经过self.conv1(x)卷积处理,然后进行x.view(-1, 14*14*128)压缩扁平化处理,最后通过self.dense(x)全连接进行分类

之后就是对Model对象进行调用,然后定义loss计算使用交叉熵,优化计算使用Adam自动化方式,最后就可以开始训练了

model = Model()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

在训练前可以查看神经网络架构了,print输出显示如下

Model ((conv1): Sequential ((0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU ()(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU ()(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)))(dense): Sequential ((0): Linear (25088 -> 1024)(1): ReLU ()(2): Dropout (p = 0.5)(3): Linear (1024 -> 10))
)

定义训练次数为5次,开始跑神经网络,训练完成后输入测试集合得到的结果如下

Epoch 0/5
----------
Loss is:0.0003, Train Accuracy is:99.4167%, Test Accuracy is:98.6600
Epoch 1/5
----------
Loss is:0.0002, Train Accuracy is:99.5967%, Test Accuracy is:98.9200
Epoch 2/5
----------
Loss is:0.0002, Train Accuracy is:99.6667%, Test Accuracy is:98.7700
Epoch 3/5
----------
Loss is:0.0002, Train Accuracy is:99.7133%, Test Accuracy is:98.9600
Epoch 4/5
----------
Loss is:0.0001, Train Accuracy is:99.7317%, Test Accuracy is:98.7300

从结果上看还不错,训练准确率最高达到了99.73%,测试最高准确率为98.96%。结果有轻微的过拟合迹象,如果使用更加健壮的卷积模型测试集会取得更加好的结果。

随机对几张测试集的图片进行预测,并做可视化展示

Predict Label is: [3, 4, 9, 3]
Real Label is: [3, 4, 9, 3]

训练完成后还可以保存训练得到的参数,方便下次导入后可供直接使用

torch.save(model.state_dict(), "model_parameter.pkl")

完整代码链接:JaimeTang/Pytorch-and-mnist(model_parameter.pkl文件较大未做上传)


微信公众号:PyMachine

pytorch dropout_PyTorch初探MNIST数据集相关推荐

  1. TypeError: 'module' object is not callable (pytorch在进行MNIST数据集预览时出现的错误)

    在使用pytorch在对MNIST数据集进行预览时,出现了TypeError: 'module' object is not callable的错误: 上报错信息图如下: 从图中可以看出,报错位置为第 ...

  2. 十分钟搞懂Pytorch如何读取MNIST数据集

    前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧- 正文 在阅读教程书籍<深度学习入门之Pytorch>时,文中是如此加载MNIST手写数字训练集的: ...

  3. pytorch保存准确率_初学Pytorch:MNIST数据集训练详解

    前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...

  4. python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解

    关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...

  5. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

  6. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  7. 【已补蓝奏云链接】PyTorch中MNIST数据集(附datasets.MNIST离线包)下载慢/安装慢的解决方案

    一.问题背景 在学习MNIST数据集手写数字识别demo的时候,笔者碰到了一些问题,现记录如下: 1.必须先确保torchvision已经正确安.如何安装torchvision?请参考PyTorch/ ...

  8. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  9. 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%

    基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...

最新文章

  1. 0基础学python看什么书-0基础学Python入门书籍应该看什么?
  2. ECMAScript 6 入门
  3. 个人JS体系整理(二)
  4. POJ 3253 -- Fence Repair
  5. 三丰三坐标编程基本步骤_三丰三坐标CRYSTA APEX S776
  6. Angularjs基础(三)
  7. 日记——2019-03-12
  8. Asp.net 无限级分类
  9. 不知道这十项Linux常识,就别说自己玩过Linux!
  10. 开源中国社区(OsChina.NET) 8月第3周 精彩回顾
  11. ffmpeg drawtext 背景_8款电视背景墙:电视背景墙这样装,不仅省钱还作用多!效果大不一样!...
  12. 我的自学ROS历程3-3-Vsual Studio code安装
  13. 图像处理领域术语英文对应
  14. 红米 k30 pro 刷入欧版和小米钱包/商店
  15. 【Github分享】GitHub 上值得收藏的100个精选前端项目!
  16. arm解锁 j-flash_J-Link固件烧录以及使用J-Flash向arm硬件板下载固件程序(示例代码)...
  17. Linux学习-67-日志服务器设置和日志分析工具(logwatch)安装及使用
  18. 用U盘启动WinPE全新安装原版XP系统--有关pe装系统
  19. 使用NetBox实现ASP网页封装为EXE教程
  20. 计算机主板就一亮关机了,我的电脑开机主板灯亮一下就关了,然后又亮,一直

热门文章

  1. Android 播放raw文件夹下音频文件,本地MP3文件播放,播放云端MP3文件,获取MP3文件播放时长
  2. Android百度云推送接入,附完整代码
  3. js jquery Ajax同步
  4. Android蓝牙4.0的数据通讯
  5. mysql数据库面试总结
  6. (Java)Integer类的其他常用方法
  7. middle函数C语言,C语言函数调用栈(三)
  8. 《App后台开发运维与架构实践》第2章 App后台基础技术
  9. Xcode添加pch文件
  10. 偏函数 匿名函数 高阶函数 map filter reduce