PyTorch MNIST 实现

  • 概述
  • 获取数据
  • 网络模型
  • train 函数
  • test 函数
  • main 函数
  • 完整代码

概述

MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.

获取数据

def get_data():"""获取数据"""# 获取测试集train = torchvision.datasets.MNIST(root="./data", train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集# 获取测试集test = torchvision.datasets.MNIST(root="./data", train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练# 返回分割好的训练集和测试集return train_loader, test_loader

网络模型

class Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()# 卷积层self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))# Dropout层self.dropout1 = torch.nn.Dropout(0.25)self.dropout2 = torch.nn.Dropout(0.5)# 全连接层self.fc1 = torch.nn.Linear(9216, 128)self.fc2 = torch.nn.Linear(128, 10)def forward(self, x):"""前向传播"""# [b, 1, 28, 28] => [b, 32, 26, 26]out = self.conv1(x)out = F.relu(out)# [b, 32, 26, 26] => [b, 64, 24, 24]out = self.conv2(out)out = F.relu(out)# [b, 64, 24, 24] => [b, 64, 12, 12]out = F.max_pool2d(out, 2)out = self.dropout1(out)# [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]out = torch.flatten(out, 1)# [b, 9216] => [b, 128]out = self.fc1(out)out = F.relu(out)# [b, 128] => [b, 10]out = self.dropout2(out)out = self.fc2(out)output = F.log_softmax(out, dim=1)return output

train 函数

def train(model, epoch, train_loader):"""训练"""# 训练模式model.train()# 迭代for step, (x, y) in enumerate(train_loader):# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 梯度清零optimizer.zero_grad()output = model(x)# 计算损失loss = F.nll_loss(output, y)# 反向传播loss.backward()# 更新梯度optimizer.step()# 打印损失if step % 50 == 0:print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))

test 函数

def test(model, test_loader):"""测试"""# 测试模式model.eval()# 存放正确个数correct = 0with torch.no_grad():for x, y in test_loader:# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 获取结果output = model(x)# 预测结果pred = output.argmax(dim=1, keepdim=True)# 计算准确个数correct += pred.eq(y.view_as(pred)).sum().item()# 计算准确率accuracy = correct / len(test_loader.dataset) * 100# 输出准确print("Test Accuracy: {}%".format(accuracy))

main 函数

def main():# 获取数据train_loader, test_loader = get_data()# 迭代for epoch in range(iteration_num):print("\n================ epoch: {} ================".format(epoch))train(network, epoch, train_loader)test(network, test_loader)

完整代码


完整代码:

import torch
import torchvision
import torch.nn.functional as F
from torch.utils.data import DataLoaderclass Model(torch.nn.Module):def __init__(self):super(Model, self).__init__()# 卷积层self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))# Dropout层self.dropout1 = torch.nn.Dropout(0.25)self.dropout2 = torch.nn.Dropout(0.5)# 全连接层self.fc1 = torch.nn.Linear(9216, 128)self.fc2 = torch.nn.Linear(128, 10)def forward(self, x):"""前向传播"""# [b, 1, 28, 28] => [b, 32, 26, 26]out = self.conv1(x)out = F.relu(out)# [b, 32, 26, 26] => [b, 64, 24, 24]out = self.conv2(out)out = F.relu(out)# [b, 64, 24, 24] => [b, 64, 12, 12]out = F.max_pool2d(out, 2)out = self.dropout1(out)# [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216]out = torch.flatten(out, 1)# [b, 9216] => [b, 128]out = self.fc1(out)out = F.relu(out)# [b, 128] => [b, 10]out = self.dropout2(out)out = self.fc2(out)output = F.log_softmax(out, dim=1)return output# 定义超参数
batch_size = 64  # 一次训练的样本数目
learning_rate = 0.0001  # 学习率
iteration_num = 5  # 迭代次数
network = Model()  # 实例化网络
print(network)  # 调试输出网络结构
optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate)  # 优化器# GPU 加速
use_cuda = torch.cuda.is_available()
print("是否使用 GPU 加速:", use_cuda)def get_data():"""获取数据"""# 获取测试集train = torchvision.datasets.MNIST(root="./data", train=True, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))train_loader = DataLoader(train, batch_size=batch_size)  # 分割测试集# 获取测试集test = torchvision.datasets.MNIST(root="./data", train=False, download=True,transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(),  # 转换成张量torchvision.transforms.Normalize((0.1307,), (0.3081,))  # 标准化]))test_loader = DataLoader(test, batch_size=batch_size)  # 分割训练# 返回分割好的训练集和测试集return train_loader, test_loaderdef train(model, epoch, train_loader):"""训练"""# 训练模式model.train()# 迭代for step, (x, y) in enumerate(train_loader):# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 梯度清零optimizer.zero_grad()output = model(x)# 计算损失loss = F.nll_loss(output, y)# 反向传播loss.backward()# 更新梯度optimizer.step()# 打印损失if step % 50 == 0:print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))def test(model, test_loader):"""测试"""# 测试模式model.eval()# 存放正确个数correct = 0with torch.no_grad():for x, y in test_loader:# 加速if use_cuda:model = model.cuda()x, y = x.cuda(), y.cuda()# 获取结果output = model(x)# 预测结果pred = output.argmax(dim=1, keepdim=True)# 计算准确个数correct += pred.eq(y.view_as(pred)).sum().item()# 计算准确率accuracy = correct / len(test_loader.dataset) * 100# 输出准确print("Test Accuracy: {}%".format(accuracy))def main():# 获取数据train_loader, test_loader = get_data()# 迭代for epoch in range(iteration_num):print("\n================ epoch: {} ================".format(epoch))train(network, epoch, train_loader)test(network, test_loader)if __name__ == "__main__":main()

输出结果:

Model((conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))(dropout1): Dropout(p=0.25, inplace=False)(dropout2): Dropout(p=0.5, inplace=False)(fc1): Linear(in_features=9216, out_features=128, bias=True)(fc2): Linear(in_features=128, out_features=10, bias=True)
)
是否使用 GPU 加速: True================ epoch: 0 ================
Epoch: 0, Step 0, Loss: 2.3131277561187744
Epoch: 0, Step 50, Loss: 1.0419045686721802
Epoch: 0, Step 100, Loss: 0.6259541511535645
Epoch: 0, Step 150, Loss: 0.7194482684135437
Epoch: 0, Step 200, Loss: 0.4020516574382782
Epoch: 0, Step 250, Loss: 0.6890509128570557
Epoch: 0, Step 300, Loss: 0.28660136461257935
Epoch: 0, Step 350, Loss: 0.3277580738067627
Epoch: 0, Step 400, Loss: 0.2750288248062134
Epoch: 0, Step 450, Loss: 0.28428223729133606
Epoch: 0, Step 500, Loss: 0.3514065444469452
Epoch: 0, Step 550, Loss: 0.23386947810649872
Epoch: 0, Step 600, Loss: 0.25338059663772583
Epoch: 0, Step 650, Loss: 0.1743898093700409
Epoch: 0, Step 700, Loss: 0.35752204060554504
Epoch: 0, Step 750, Loss: 0.17575909197330475
Epoch: 0, Step 800, Loss: 0.20604261755943298
Epoch: 0, Step 850, Loss: 0.17389622330665588
Epoch: 0, Step 900, Loss: 0.3188241124153137
Test Accuracy: 96.56%================ epoch: 1 ================
Epoch: 1, Step 0, Loss: 0.23558208346366882
Epoch: 1, Step 50, Loss: 0.13511177897453308
Epoch: 1, Step 100, Loss: 0.18823786079883575
Epoch: 1, Step 150, Loss: 0.2644936144351959
Epoch: 1, Step 200, Loss: 0.145077645778656
Epoch: 1, Step 250, Loss: 0.30574971437454224
Epoch: 1, Step 300, Loss: 0.2386859953403473
Epoch: 1, Step 350, Loss: 0.08346735686063766
Epoch: 1, Step 400, Loss: 0.10480977594852448
Epoch: 1, Step 450, Loss: 0.07280707359313965
Epoch: 1, Step 500, Loss: 0.20928426086902618
Epoch: 1, Step 550, Loss: 0.20455852150917053
Epoch: 1, Step 600, Loss: 0.10085935145616531
Epoch: 1, Step 650, Loss: 0.13476189970970154
Epoch: 1, Step 700, Loss: 0.19087043404579163
Epoch: 1, Step 750, Loss: 0.0981522724032402
Epoch: 1, Step 800, Loss: 0.1961515098810196
Epoch: 1, Step 850, Loss: 0.041140712797641754
Epoch: 1, Step 900, Loss: 0.250461220741272
Test Accuracy: 98.03%================ epoch: 2 ================
Epoch: 2, Step 0, Loss: 0.09572553634643555
Epoch: 2, Step 50, Loss: 0.10370486229658127
Epoch: 2, Step 100, Loss: 0.17737184464931488
Epoch: 2, Step 150, Loss: 0.1570713371038437
Epoch: 2, Step 200, Loss: 0.07462178170681
Epoch: 2, Step 250, Loss: 0.18744900822639465
Epoch: 2, Step 300, Loss: 0.09910508990287781
Epoch: 2, Step 350, Loss: 0.08929706364870071
Epoch: 2, Step 400, Loss: 0.07703761011362076
Epoch: 2, Step 450, Loss: 0.10133732110261917
Epoch: 2, Step 500, Loss: 0.1314031481742859
Epoch: 2, Step 550, Loss: 0.10394387692213058
Epoch: 2, Step 600, Loss: 0.11612939089536667
Epoch: 2, Step 650, Loss: 0.17494803667068481
Epoch: 2, Step 700, Loss: 0.11065669357776642
Epoch: 2, Step 750, Loss: 0.061209067702293396
Epoch: 2, Step 800, Loss: 0.14715790748596191
Epoch: 2, Step 850, Loss: 0.03930797800421715
Epoch: 2, Step 900, Loss: 0.18030673265457153
Test Accuracy: 98.46000000000001%================ epoch: 3 ================
Epoch: 3, Step 0, Loss: 0.09266342222690582
Epoch: 3, Step 50, Loss: 0.0414913073182106
Epoch: 3, Step 100, Loss: 0.2152961939573288
Epoch: 3, Step 150, Loss: 0.12287424504756927
Epoch: 3, Step 200, Loss: 0.13468700647354126
Epoch: 3, Step 250, Loss: 0.11967387050390244
Epoch: 3, Step 300, Loss: 0.11301510035991669
Epoch: 3, Step 350, Loss: 0.037447575479745865
Epoch: 3, Step 400, Loss: 0.04699449613690376
Epoch: 3, Step 450, Loss: 0.05472381412982941
Epoch: 3, Step 500, Loss: 0.09839300811290741
Epoch: 3, Step 550, Loss: 0.07964356243610382
Epoch: 3, Step 600, Loss: 0.08182843774557114
Epoch: 3, Step 650, Loss: 0.05514759197831154
Epoch: 3, Step 700, Loss: 0.13785190880298615
Epoch: 3, Step 750, Loss: 0.062480345368385315
Epoch: 3, Step 800, Loss: 0.120387002825737
Epoch: 3, Step 850, Loss: 0.04458726942539215
Epoch: 3, Step 900, Loss: 0.17119190096855164
Test Accuracy: 98.55000000000001%================ epoch: 4 ================
Epoch: 4, Step 0, Loss: 0.08094145357608795
Epoch: 4, Step 50, Loss: 0.05615215748548508
Epoch: 4, Step 100, Loss: 0.07766406238079071
Epoch: 4, Step 150, Loss: 0.07915271818637848
Epoch: 4, Step 200, Loss: 0.1301635503768921
Epoch: 4, Step 250, Loss: 0.12118984013795853
Epoch: 4, Step 300, Loss: 0.073218435049057
Epoch: 4, Step 350, Loss: 0.04517696052789688
Epoch: 4, Step 400, Loss: 0.08493026345968246
Epoch: 4, Step 450, Loss: 0.03904269263148308
Epoch: 4, Step 500, Loss: 0.09386837482452393
Epoch: 4, Step 550, Loss: 0.12583576142787933
Epoch: 4, Step 600, Loss: 0.09053893387317657
Epoch: 4, Step 650, Loss: 0.06912104040384293
Epoch: 4, Step 700, Loss: 0.1502612829208374
Epoch: 4, Step 750, Loss: 0.07162325084209442
Epoch: 4, Step 800, Loss: 0.10512275993824005
Epoch: 4, Step 850, Loss: 0.028180215507745743
Epoch: 4, Step 900, Loss: 0.08492615073919296
Test Accuracy: 98.69%

PyTorch 手把手教你实现 MNIST 数据集相关推荐

  1. Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务

    关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!! 可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行 第一步:基本库的导入 import n ...

  2. 【Pytorch分布式训练】在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练

    文章目录 普通单卡训练-GPU 普通单卡训练-CPU 分布式训练-GPU 分布式训练-CPU 租GPU服务器相关 以下代码示例基于:在MNIST数据集上训练一个简单CNN网络,将其改成分布式训练. 普 ...

  3. 【目标检测实战】目标检测实战之一--手把手教你LMDB格式数据集制作!

    文章目录 1 目标检测简介 2 lmdb数据制作 2.1 VOC数据制作 2.2 lmdb文件生成 lmdb格式的数据是在使用caffe进行目标检测或分类时,使用的一种数据格式.这里我主要以目标检测为 ...

  4. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

  5. 【玩转华为云】手把手教你利用ModelArts实现垃圾自动分类

    本篇推文共计2000个字,阅读时间约3分钟. 华为云-华为公司倾力打造的云战略品牌,2011年成立,致力于为全球客户提供领先的公有云服务,包含弹性云服务器.云数据库.云安全等云计算服务,软件开发服务, ...

  6. 实例 :手把手教你用PyTorch快速准确地建立神经网络(附4个学习用例)

    作者:Shivam Bansal:翻译:陈之炎:校对:丁楠雅: 本文约5600字,建议阅读30+分钟. 本文中,我们将探讨PyTorch的全部内容.我们将不止学习理论,还包括编写4个不同的用例,看看P ...

  7. 手把手教你用YOLOv5训练自己的数据集(从Windows环境配置到模型部署)

    [小白CV]手把手教你用YOLOv5训练自己的数据集(从环境配置到模型部署) 本文禁止转载 前言: 1. 安装Anaconda: 2. 创建虚拟环境: 3. 安装pytorch: 4. 下载源码和安装 ...

  8. B站教学 手把手教你使用YOLOV5之口罩检测项目 最全记录详解 ( 深度学习 / 目标检测 / pytorch )

    目录 一.环境搭建 pytorch的下载 测试(cmd窗口中) pycharm下测试(要配置pycharm中的虚拟环境) 二.数据标注 下载labor image 使用labelimg进行图片标注 划 ...

  9. PyTorch 手把手搭建(MNIST)神经网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:知乎 ...

最新文章

  1. 遗留应用现代化场景:如何正确使用RESTful API
  2. eosjs v20 中文文档
  3. BZOJ 3639: Query on a tree VII LCT_set维护子树信息
  4. MySQL 索引方式
  5. canvas绘制阴影
  6. Update operation on extension field created by AET
  7. java做报表_一步一步使用POI做java报表
  8. c语言画谢宾斯基三角形
  9. 第四次作业----刘滔
  10. Android Learning:数据存储方案归纳与总结
  11. Linux Shell数值比较和字符串比较及相关
  12. python调用nmap扫描全端口_Python-通过调用Nmap来进行端口扫描
  13. iphone 扩容测试软件,拯救iPhone 12 64G!闪迪打造的扩容神器上手:轻松省钱
  14. Cherno_游戏引擎系列教程(1):1~16
  15. sony6000正在连接服务器,极速对焦+11张每秒连拍 Sony A6000评测
  16. 快搜搜:让你辞职原因有哪些?
  17. mysql 批量造假数据
  18. android dp不同高度,Android获取屏幕的宽度和高度(dp)
  19. Redis与传统sql数据库的区别
  20. 英文自我介绍(考研/校内面试/复试)

热门文章

  1. brew常用命令总结
  2. 发票OCR识别助力移动财务报销
  3. 通过爬取微博评论,发现好看的小姐姐...
  4. C#大作业——回合制游戏模拟
  5. 知其然也要知其所以然---Kernel上报电量UEvent事件流程分析
  6. Android 禁止截屏录屏
  7. 笨方法学python: ex20, 函数和文件
  8. Resin的几个常用配置
  9. 泛谈Flash文件系统
  10. 声网 X Watch Party 如何在线上一起欢快的边看电影边吐槽?