一、前言

MNIST手写数字识别程序就不过多赘述了,这个程序在深度学习中的地位跟C语言中的Hello World地位并驾齐驱,虽然很基础,但很重要,是深度学习入门必备的程序之一。

二、MNIST数据集介绍

MNIST包括6万张28*28的训练样本,1万张测试样本。

三、PyTorch实现

3.1 定义超参数

# 定义超参数
BATCH_SIZE=512 #大概需要2G的显存
EPOCHS=20 # 总共训练批次
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 让torch判断是否使用GPU,建议使用GPU环境,因为会快很多

3.2 导入训练、测试数据

# 分别导入训练、测试数据,PyTorch中已经集成了MNIST数据集,我们只需要DataLoader导入即可
train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=BATCH_SIZE, shuffle=True)

3.3 搭建深度学习网络模型

class ConvNet(nn.Module):def __init__(self):super().__init__()# batch*1*28*28(每次会送入batch个样本,输入通道数1(黑白图像),图像分辨率是28x28)# 下面的卷积层Conv2d的第一个参数指输入通道数,第二个参数指输出通道数,第三个参数指卷积核的大小self.conv1 = nn.Conv2d(1, 10, 5) # 输入通道数1,输出通道数10,核的大小5self.conv2 = nn.Conv2d(10, 20, 3) # 输入通道数10,输出通道数20,核的大小3# 下面的全连接层Linear的第一个参数指输入通道数,第二个参数指输出通道数self.fc1 = nn.Linear(20*10*10, 500) # 输入通道数是2000,输出通道数是500self.fc2 = nn.Linear(500, 10) # 输入通道数是500,输出通道数是10,即10分类def forward(self,x):in_size = x.size(0) # 在本例中in_size=512,也就是BATCH_SIZE的值。输入的x可以看成是512*1*28*28的张量。out = self.conv1(x) # batch*1*28*28 -> batch*10*24*24(28x28的图像经过一次核为5x5的卷积,输出变为24x24)out = F.relu(out) # batch*10*24*24(激活函数ReLU不改变形状))out = F.max_pool2d(out, 2, 2) # batch*10*24*24 -> batch*10*12*12(2*2的池化层会减半)out = self.conv2(out) # batch*10*12*12 -> batch*20*10*10(再卷积一次,核的大小是3)out = F.relu(out) # batch*20*10*10out = out.view(in_size, -1) # batch*20*10*10 -> batch*2000(out的第二维是-1,说明是自动推算,本例中第二维是20*10*10)out = self.fc1(out) # batch*2000 -> batch*500out = F.relu(out) # batch*500out = self.fc2(out) # batch*500 -> batch*10out = F.log_softmax(out, dim=1) # 计算log(softmax(x))return out

3.4 确定要使用的优化算法

这里使用简单粗暴的Adam

model = ConvNet().to(DEVICE)
optimizer = optim.Adam(model.parameters())

3.5 定义训练函数

def train(model, device, train_loader, optimizer, epoch):model.train()for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device), target.to(device)optimizer.zero_grad()output = model(data)loss = F.nll_loss(output, target)loss.backward()optimizer.step()if(batch_idx+1)%30 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))

3.6 定义测试函数

def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for data, target in test_loader:data, target = data.to(device), target.to(device)output = model(data)test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标correct += pred.eq(target.view_as(pred)).sum().item()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

3.7 训练并测试

for epoch in range(1, EPOCHS + 1):train(model, DEVICE, train_loader, optimizer, epoch)test(model, DEVICE, test_loader)

深度学习练手项目(一)-----利用PyTorch实现MNIST手写数字识别相关推荐

  1. 深度学习3—用三层全连接神经网络训练MNIST手写数字字符集

    上一篇文章:深度学习2-任意结点数的三层全连接神经网络 距离上篇文章过去了快四个月了,真是时光飞逝,之前因为要考博所以耽误了更新,谁知道考完博后之前落下的接近半个学期的工作是如此之多,以至于弄到现在才 ...

  2. 基于TensorFlow深度学习框架,运用python搭建LeNet-5卷积神经网络模型和mnist手写数字识别数据集,设计一个手写数字识别软件。

    本软件是基于TensorFlow深度学习框架,运用LeNet-5卷积神经网络模型和mnist手写数字识别数据集所设计的手写数字识别软件. 具体实现如下: 1.读入数据:运用TensorFlow深度学习 ...

  3. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  4. 深度学习入门项目:PyTorch实现MINST手写数字识别

    完整代码下载[github地址]:https://github.com/lmn-ning/MNIST_PyTorch.git 目录 一.MNIST数据集介绍及下载地址 二.代码结构 三.代码 data ...

  5. 深度学习数字仪表盘识别_深度学习之手写数字识别项目(Sequential方法amp;Class方法进阶版)...

    此项目使用LeNet模型针对手写数字进行分类.项目中我们分别采用了顺序式API和子类方法两种方式构建了LeNet模型训练mnist数据集,并编写了给图识物应用程序用于手写数字识别. 一.LeNet模型 ...

  6. pytorch 预测手写体数字_深度学习之PyTorch实战(3)——实战手写数字识别

    如果需要小编其他论文翻译,请移步小编的GitHub地址 传送门:请点击我 如果点击有误:https://github.com/LeBron-Jian/DeepLearningNote 上一节,我们已经 ...

  7. 深度学习 卷积神经网络-Pytorch手写数字识别

    深度学习 卷积神经网络-Pytorch手写数字识别 一.前言 二.代码实现 2.1 引入依赖库 2.2 加载数据 2.3 数据分割 2.4 构造数据 2.5 迭代训练 三.测试数据 四.参考资料 一. ...

  8. 【卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10)】

    卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10) 在上一章已经完成了卷积神经网络的结构分析,并通过各个模块理解 ...

  9. 深度学习项目实战——手写数字识别项目

    摘要 本文将介绍的有关于的paddle的实战的相关的问题,并分析相关的代码的阅读和解释.并扩展有关于的python的有关的语言.介绍了深度学习步骤: 1. 数据处理:读取数据 和 预处理操作 2. 模 ...

最新文章

  1. uninque()用法
  2. C#温故而知新学习系列之面向对象编程—构造函数(七)
  3. 零基础Java学习之接口
  4. 神经网络的输入对迭代次数的影响
  5. CentOS6.8网络接口配置文件ifcfg-eth0
  6. P5304-[GXOI/GZOI2019]旅行者【最短路】
  7. 如何快速理解读懂他人代码(下)——技巧学习篇
  8. c++thread里暂停线程_Java线程的 6 种状态
  9. 菜鸟版JAVA设计模式-从抽象与实现说桥接模式
  10. XNA 如何使用字体绘制文字,Windows Phone 游戏开发
  11. MySQL常用命令集锦
  12. 周礼键君:《建郡八音》(拼音方案---只有四调,以近音调注)
  13. 人工智能第六章——约束满足问题(CSP)
  14. iftop流量实时查看
  15. 寸 金 难 买 寸 光 阴
  16. Spring整合Kafka
  17. Oracle项目管理系统之赢得值管理
  18. React Native之ScrollView控件详解
  19. UE4 RTS 框选功能实现
  20. 使用DNSPod解析Freenom域名

热门文章

  1. DIB位图(Bitmap)的读取和保存
  2. C++Primer中文版(第4版)第四章习题答案
  3. 虚成员(virtual)
  4. 2、SharePoint安装篇——之安装Microsoft Office SharePoint Server 2007
  5. Django入门:DoesNotExist: User matching query does not exist.
  6. linux查看服务器网络状态
  7. 看google三篇论文的感触
  8. PHP是弱类型还是强类型,php弱类型比较(松散比较) | CN-SEC 中文网
  9. 【Python】直接赋值、浅拷贝和深度拷贝解析
  10. 【实操】深度学习网络万万千,到底怎么把我的数据放进去?