文章目录

  • 前言
  • 1. Paddle手写数字识别过程
  • 2. Paddle手写数字识别训练与推理过程实现
  • 总结
  • 参考

前言

趁着国庆尾巴,复习了Paddle框架进行深度学习实践:手写数字识别,这里分享下模型实现。


1. Paddle手写数字识别过程

这里给大家分享下手写数字识别的主要步骤:

  1. 定义数据处理过程:定义MnistDataset类,继承自paddle.io.Dataset实现模型输入数据处理,与paddle.io.DataLoader配合使用,实现数据异步加载,提高模型训练速度;
  2. 定义深度学习模型:这里使用简单的多个卷积层、ReLU激活函数,池化层来提取图像特征,使用全连接层,Softmax实现图像分类;
  3. 训练配置:使用随机梯度下降SGD来优化模型参数,使用交叉熵作为分类损失函数。
  4. 训练过程:前向计算,损失计算,模型参数更新三个过程循环进行,直到达到优化目标,即损失值足够小;
  5. 保存模型:保存上述训练模型参数,以供推理阶段加载使用。

2. Paddle手写数字识别训练与推理过程实现

# 导入飞桨和其他相关库
import paddle
from paddle.nn import Conv2D, MaxPool2D, Linear
import paddle.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import gzip
import os
import json
import random
from PIL import Image# 创建一个类MnistDataset, 继承paddle.io.Dataset,配合DataLoader实现数据异步加载
class MnistDataset(paddle.io.Dataset):def __init__(self, mode='train'):datafile = './work/mnist.json.gz'data = json.load(gzip.open(datafile))# 划分数据集为训练集、验证集和测试集train_set, val_set, test_set = data[:3]# 图片高度和宽度self.IMG_ROWS, self.IMG_COLS = 28, 28if mode == 'train':# 训练数据集imgs, labels = train_set[:2]elif mode == 'valid':imgs, labels = val_set[:2]elif mode == 'eval':imgs, labels = test_set[:2]else:raise Exception("mode can only be one of [train, valid, eval]")# 校验数据imgs_length = len(imgs)assert len(imgs) == len(labels), \"length of train_imgs({}) should be the same with train_labels({})".format(len(imgs), len(labels))self.imgs = imgsself.labels = labelsdef __getitem__(self, idx):img = np.reshape(self.imgs[idx], [1, self.IMG_ROWS, self.IMG_COLS]).astype('float32')label = np.reshape(self.labels[idx], [1]).astype('int64')return img, labeldef __len__(self):return len(self.imgs)# 定义网络结构, 多层卷积神经网络
class MNIST_CNN(paddle.nn.Layer):def __init__(self):super(MNIST_CNN, self).__init__()# 定义卷积层self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)# 定义池化层self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)# 定义全连接层,输出维度为10self.fc = Linear(in_features=980, out_features=10)# 定义前向计算过程def forward(self, inputs):x = self.conv1(inputs)x = F.relu(x)x = self.max_pool1(x)x = self.conv2(x)x = F.relu(x)x = self.max_pool2(x)x = paddle.reshape(x, [x.shape[0], 980])x = self.fc(x)return x# 定义评估函数
def evaluation(model, val_loader):model.eval()acc_set = list()for batch_id, data in enumerate(val_loader()):images, labels = data[:2]images = paddle.to_tensor(images)labels = paddle.to_tensor(labels)pred = model(images)acc = paddle.metric.accuracy(input=pred, label=labels)acc_set.extend(acc.numpy())# 计算多个batch的准确率acc_val_mean = np.array(acc_set).mean()return acc_val_mean# 定义训练函数,使用交叉熵损失函数
def train(model, train_loader):model.train()opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())ce_loss = paddle.nn.loss.CrossEntropyLoss()EPOCH_NUM = 10for epoch_id in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):# 准备数据images, labels = data[:2]images = paddle.to_tensor(images)labels = paddle.to_tensor(labels)# 前向计算过程preds = model(images)# 损失计算过程loss = ce_loss(preds, labels)if batch_id % 200 == 0:print("epoch: {}, batch: {}, loss: {}".format(epoch_id, batch_id, loss.numpy()[0]))# 反向传播过程,计算各层梯度值loss.backward()# 网络参数更新opt.step()# 清空各层梯度值opt.clear_grad()# 保存模型参数paddle.save(model.state_dict(), './mnist.pdparams')def train_process():# 声明数据加载实例,使用训练模式,train_dataset = MnistDataset(mode='train')# 调用DataLoader生成一个批次数据迭代器,异步获取train_loader = paddle.io.DataLoader(train_dataset, batch_size=100, shuffle=True, drop_last=True)# 声明数据加载实例,使用验证集val_dataset = MnistDataset(mode='valid')val_loader = paddle.io.DataLoader(val_dataset, batch_size=128, drop_last=True)# 创建模型实例model = MNIST_CNN()# 启动训练过程train(model, train_loader)# 启动评估过程acc_train_mean = evaluation(model, train_loader)acc_val_mean = evaluation(model, val_loader)print('train acc:{}, val acc:{}'.format(acc_train_mean, acc_val_mean))# 读取本地图片,转变成模型输入格式
def load_image(img_path):# 读取图片,并转换为灰度图im = Image.open(img_path).convert('L')im = im.resize((28, 28), Image.ANTIALIAS)im = np.array(im).reshape(1,1, 28, 28).astype(np.float32)# 图像归一化im = im / 255return im# 定义预测过程
def predict_process():model = MNIST_CNN()params_file_path = './mnist.pdparams'img_path = './images/0.jpg'# 加载模型参数param_dict = paddle.load(params_file_path)model.load_dict(param_dict)# 加载数据model.eval()tensor_img = load_image(img_path)# 模型返回10个分类标签对应的概率results = model(paddle.to_tensor(tensor_img))# 取概率最大的标签作为预测输出label = np.argsort(results.numpy())print('本次预测数字:', label[0][-1])if __name__ == "__main__":train_process()predict_process()

总结

这里实践,让我意识到超参数调整的重要性,让我印象深刻的就是优化函数中学习率的调整,过大或过小值都会导致损失值下降变慢,进而导致训练时间长的问题。

手写数字识别数据集集源码请参考gitee链接:
https://gitee.com/dttrcv/paddle-practice/blob/master/HandWrittenDigitRec/baseline_loss.py

参考

https://www.paddlepaddle.org.cn/tutorials/projectdetail/2310369

Paddle实践:手写数字识别相关推荐

  1. paddle实现手写数字识别终章

    要点: 资源配置 训练调试 恢复训练 模型部署 参考文档: paddle官方文档 一 资源配置 1 概述 从前几节的训练看,无论是房价预测任务还是MNIST手写字数字识别任务,训练好一个模型不会超过1 ...

  2. 2.7mnist手写数字识别之训练调试与优化精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列)

    2.7mnist手写数字识别之训练调试与优化精讲(百度架构师手把手带你零基础实践深度学习原版笔记系列) 目录 2.7mnist手写数字识别之训练调试与优化精讲(百度架构师手把手带你零基础实践深度学习原 ...

  3. 飞桨day-01手写数字识别实践(使用卷积神经网络)

    day-01高层API手写数字识别实践(使用卷积神经网络) AI Studio项目地址:https://aistudio.baidu.com/aistudio/projectdetail/150477 ...

  4. 基于Paddle的计算机视觉入门教程——第7讲 实战:手写数字识别

    B站教程地址 https://www.bilibili.com/video/BV18b4y1J7a6/ 任务介绍 手写数字识别是计算机视觉的一个经典项目,因为手写数字的随机性,使用传统的计算机视觉技术 ...

  5. 菜菜学paddle第一篇:单层网络构建手写数字识别

    前言: 1.数字识别是计算机从纸质文档.照片或其他来源接收.理解并识别可读的数字的能力,目前比较受关注的是手写数字识别.手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别.手写邮政编码 ...

  6. 深蓝学院第三章:基于卷积神经网络(CNN)的手写数字识别实践

    参看之前篇章的用全连接神经网络去做手写识别:https://blog.csdn.net/m0_37957160/article/details/114105389?spm=1001.2014.3001 ...

  7. 【项目实践】:KNN实现手写数字识别(附Python详细代码及注释)

    ↑ 点击上方[计算机视觉联盟]关注我们 本节使用KNN算法实现手写数字识别.KNN算法基本原理前边文章已经详细叙述,盟友们可以参考哦! 数据集介绍 有两个文件: (1)trainingDigits文件 ...

  8. 深度学习理论与实践第二章作业-FNN手写数字识别

    命名格式:按照课程网站中的课后作业要求 1. 根据Course02课程中对全连接神经网络的讲解,将缺失的全连接神经网络中代码块补全,并完成一次训练 需要填充的部分已经在第一部分的全连接神经网络代码中用 ...

  9. 基于飞桨实现手写数字识别2

    参考课程笔记:https://aistudio.baidu.com/aistudio/projectdetail/728143 上篇https://mp.csdn.net/console/editor ...

最新文章

  1. 折半插入排序稳定吗_C++实现经典算法--折半插入排序
  2. Observer设计模式【利用商品概念解释】
  3. 自动装配——@Autowired 构造器,参数,方法,属性都是从容器中获取参数组件的值||自定义组件想要使用Spring容器底层的一些组件 ApplicationContext,BeanFactory
  4. libevent安装总结 - jinfg2008的专栏 - 博客频道 - CSDN.NET
  5. Matlab | MATLAB编辑器:无法使用GBK编码保存文件,请改用UTF-8编码保存文件(问题解决)
  6. 前端之路(一)之W3C是什么?
  7. java弱口令生成1001无标题,教你批量生成自动发卡平台需要的卡密数据
  8. java list e 查找_源码(04) -- java.util.ListE
  9. npm 使用报错合集
  10. 不要小看日本的AI公司
  11. 集成电路制造工艺及设备
  12. 5W1h分析法分析---play框架
  13. C++中的delete——读书笔记
  14. 李宏毅机器学习课程自测练习题
  15. LintCode 1173.反转字符串
  16. Android长截图(五) - 遇到的坑
  17. java的clone你知道多少?
  18. 单细胞分析可视化工具盘点
  19. Js 几种刷新页面最快的方法
  20. LSTC LS-Opt 官方各版本下载

热门文章

  1. 关于POS系统可靠性开发的一些考虑:
  2. 【数字化】数字化转型成功的企业都发生了什么变化?
  3. java qq聊天界面代码,Java简易qq聊天,代码
  4. 关于web结合单目以及RGBD图像重建的设计(一)
  5. AVG杀毒软件是流氓软件,难以彻底干净卸载,会在系统里埋钩子(附卸载方法)
  6. delphi xe 保存图片到JPG的方法 BMP转JPG
  7. axios源码——工具函数utils.js
  8. 国内领先的30个知名b2b电子商务平台
  9. access2010版本的数据库
  10. Spring Boot 如何优雅的校验参数?