pytorch dropout_PyTorch初探MNIST数据集
前言:
本文主要描述了如何使用现在热度和关注度比较高的Pytorch(深度学习框架)构建一个简单的卷积神经网络,并对MNIST数据集进行了训练和测试。MNIST数据集是一个28*28的手写数字图片集合,使用测试集来验证训练出的模型对手写数字的识别准确率。
PyTorch资料:
PyTorch的官方文档链接:PyTorch documentation,在这里不仅有 API的说明还有一些经典的实例可供参考。
PyTorch官网论坛:vision,里面会有很大资料分享和一些热门问题的解答。
PyTorch搭建神经网络实践:
在一开始导入需要导入PyTorch的两个核心库文件torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数
import
其中值得一提的是torchvision的datasets可以很方便的自动下载数据集,这里使用的是MNIST数据集。另外的COCO,ImageNet,CIFCAR等数据集也可以很方的下载并使用,导入命令也非常简单
data_train = datasets.MNIST(root = "./data/",transform=transform,train = True,download = True)data_test = datasets.MNIST(root="./data/",transform = transform,train = False)
root指定了数据集存放的路径,transform指定导入数据集时需要进行何种变换操作,train设置为True说明导入的是训练集合,否则为测试集合。
transform里面还有很多好的方法,可以用在图片资源较少的数据集做Data Argumentation操作,这里只是做了个简单的Tensor格式转换和Batch Normalize
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
数据下载完成后还需要做数据装载操作
data_loader_train = torch.utils.data.DataLoader(dataset=data_train,batch_size = 64,shuffle = True)data_loader_test = torch.utils.data.DataLoader(dataset=data_test,batch_size = 64,shuffle = True)
batch_size设置了每批装载的数据图片为64个,shuffle设置为True在装载过程中为随机乱序
下图为一个batch数据集(64张图片)的显示,可以看出来都为28*28的1维图片
完成数据装载后就可以构建核心程序了,这里构建的是一个包含了卷积层和全连接层的神经网络,其中卷积层使用torch.nn.Conv2d来构建,激活层使用torch.nn.ReLU来构建,池化层使用torch.nn.MaxPool2d来构建,全连接层使用torch.nn.Linear来构建
class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1,64,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),torch.nn.ReLU(),torch.nn.MaxPool2d(stride=2,kernel_size=2))self.dense = torch.nn.Sequential(torch.nn.Linear(14*14*128,1024),torch.nn.ReLU(),torch.nn.Dropout(p=0.5),torch.nn.Linear(1024, 10))def forward(self, x):x = self.conv1(x)x = x.view(-1, 14*14*128)x = self.dense(x)return x
其中定义了torch.nn.Dropout(p=0.5)防止模型的过拟合
forward函数定义了前向传播,其实就是正常卷积路径。首先经过self.conv1(x)卷积处理,然后进行x.view(-1, 14*14*128)压缩扁平化处理,最后通过self.dense(x)全连接进行分类
之后就是对Model对象进行调用,然后定义loss计算使用交叉熵,优化计算使用Adam自动化方式,最后就可以开始训练了
model = Model()
cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
在训练前可以查看神经网络架构了,print输出显示如下
Model ((conv1): Sequential ((0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU ()(2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU ()(4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)))(dense): Sequential ((0): Linear (25088 -> 1024)(1): ReLU ()(2): Dropout (p = 0.5)(3): Linear (1024 -> 10))
)
定义训练次数为5次,开始跑神经网络,训练完成后输入测试集合得到的结果如下
Epoch 0/5
----------
Loss is:0.0003, Train Accuracy is:99.4167%, Test Accuracy is:98.6600
Epoch 1/5
----------
Loss is:0.0002, Train Accuracy is:99.5967%, Test Accuracy is:98.9200
Epoch 2/5
----------
Loss is:0.0002, Train Accuracy is:99.6667%, Test Accuracy is:98.7700
Epoch 3/5
----------
Loss is:0.0002, Train Accuracy is:99.7133%, Test Accuracy is:98.9600
Epoch 4/5
----------
Loss is:0.0001, Train Accuracy is:99.7317%, Test Accuracy is:98.7300
从结果上看还不错,训练准确率最高达到了99.73%,测试最高准确率为98.96%。结果有轻微的过拟合迹象,如果使用更加健壮的卷积模型测试集会取得更加好的结果。
随机对几张测试集的图片进行预测,并做可视化展示
Predict Label is: [3, 4, 9, 3]
Real Label is: [3, 4, 9, 3]
训练完成后还可以保存训练得到的参数,方便下次导入后可供直接使用
torch.save(model.state_dict(), "model_parameter.pkl")
完整代码链接:JaimeTang/Pytorch-and-mnist(model_parameter.pkl文件较大未做上传)
微信公众号:PyMachine
pytorch dropout_PyTorch初探MNIST数据集相关推荐
- TypeError: 'module' object is not callable (pytorch在进行MNIST数据集预览时出现的错误)
在使用pytorch在对MNIST数据集进行预览时,出现了TypeError: 'module' object is not callable的错误: 上报错信息图如下: 从图中可以看出,报错位置为第 ...
- 十分钟搞懂Pytorch如何读取MNIST数据集
前言 本文用于记录使用pytorch读取minist数据集的过程,以及一些思考和疑惑吧- 正文 在阅读教程书籍<深度学习入门之Pytorch>时,文中是如此加载MNIST手写数字训练集的: ...
- pytorch保存准确率_初学Pytorch:MNIST数据集训练详解
前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...
- python数据集的预处理_关于Pytorch的MNIST数据集的预处理详解
关于Pytorch的MNIST数据集的预处理详解 MNIST的准确率达到99.7% 用于MNIST的卷积神经网络(CNN)的实现,具有各种技术,例如数据增强,丢失,伪随机化等. 操作系统:ubuntu ...
- pytorch训练GAN的代码(基于MNIST数据集)
论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...
- Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试
使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...
- 【已补蓝奏云链接】PyTorch中MNIST数据集(附datasets.MNIST离线包)下载慢/安装慢的解决方案
一.问题背景 在学习MNIST数据集手写数字识别demo的时候,笔者碰到了一些问题,现记录如下: 1.必须先确保torchvision已经正确安.如何安装torchvision?请参考PyTorch/ ...
- 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练
文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...
- 基于pytorch的MNIST数据集的四层CNN,测试准确率99.77%
基于pytorch的MNIST数据集的四层CNN,测试准确率99.77% MNIST数据集 环境配置 文件存储结构 代码 引入库 调用GPU 初始化变量 导入数据集并进行数据增强 导入测试集 加载测试 ...
最新文章
- 0基础学python看什么书-0基础学Python入门书籍应该看什么?
- ECMAScript 6 入门
- 个人JS体系整理(二)
- POJ 3253 -- Fence Repair
- 三丰三坐标编程基本步骤_三丰三坐标CRYSTA APEX S776
- Angularjs基础(三)
- 日记——2019-03-12
- Asp.net 无限级分类
- 不知道这十项Linux常识,就别说自己玩过Linux!
- 开源中国社区(OsChina.NET) 8月第3周 精彩回顾
- ffmpeg drawtext 背景_8款电视背景墙:电视背景墙这样装,不仅省钱还作用多!效果大不一样!...
- 我的自学ROS历程3-3-Vsual Studio code安装
- 图像处理领域术语英文对应
- 红米 k30 pro 刷入欧版和小米钱包/商店
- 【Github分享】GitHub 上值得收藏的100个精选前端项目!
- arm解锁 j-flash_J-Link固件烧录以及使用J-Flash向arm硬件板下载固件程序(示例代码)...
- Linux学习-67-日志服务器设置和日志分析工具(logwatch)安装及使用
- 用U盘启动WinPE全新安装原版XP系统--有关pe装系统
- 使用NetBox实现ASP网页封装为EXE教程
- 计算机主板就一亮关机了,我的电脑开机主板灯亮一下就关了,然后又亮,一直