前言:

1、数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别。手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别、手写邮政编码识别等领域,大大缩短了业务处理时间,提升了工作效率和质量。

2、MNIST是深度学习领域标准、易用的、成熟的手写数字识别模型数据集,包含50 000条训练样本和10 000条测试样本。

一、数据加载

1、使用paddle自带的数据集,加载数据非常的方便,就一行代码:

    # 加载训练集 batch_size 设为 16
train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'),batch_size=16,shuffle=True)

二、模型设计

1、新建文件MNIST.py

在该文件中定义模型和图像归一化处理函数

import paddleclass MNIST(paddle.nn.Layer):def __init__(self):super(MNIST, self).__init__()# 定义一层全连接层,输出维度是1self.fc = paddle.nn.Linear(in_features=784, out_features=1)# 定义网络结构的前向计算过程def forward(self, inputs):outputs = self.fc(inputs)return outputs# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):# 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]assert len(img.shape) == 3batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]# 归一化图像数据img = img / 255# 将图像形式reshape为[batch_size, 784]img = paddle.reshape(img, [batch_size, img_h*img_w])return img

三、模型训练

1、新建文件HandWriteNum.py

2、训练之后直接保存模型参数

#加载飞桨和相关类库
import paddle
import paddle.nn.functional as F
import numpy as np
from MNIST import MNIST
from MNIST import norm_img
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])paddle.vision.set_image_backend('cv2')
# 声明网络结构
model = MNIST()
# 启动训练模式
model.train()# 加载训练集 batch_size 设为 16
train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'),batch_size=16,shuffle=True)
# 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001
opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())EPOCH_NUM = 100
for epoch in range(EPOCH_NUM):for batch_id, data in enumerate(train_loader()):images = norm_img(data[0]).astype('float32')labels = data[1].astype('float32')#前向计算的过程predicts = model(images)# 计算损失loss = F.square_error_cost(predicts, labels)avg_loss = paddle.mean(loss)#每训练了1000批次的数据,打印下当前Loss的情况if batch_id % 1000 == 0:print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))#后向传播,更新参数的过程avg_loss.backward()opt.step()opt.clear_grad()paddle.save(model.state_dict(), './mnist.pdparams')

四、模型验证

1、新建文件HandWriteNumEvalTest.py

2、加载模型测试数据

3、加载模型训练参数

4、输出验证结果

from MNIST import MNIST
import numpy as np
import paddle# 定义预测过程
model = MNIST()
params_file_path = 'mnist.pdparams'
# 加载模型参数
param_dict = paddle.load(params_file_path)
model.load_dict(param_dict)
model.eval()# 加载测试集 batch_size 设为 1
test_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='test'))
print(len(test_loader.dataset))
success = 0
error = 0
for data in test_loader.dataset:#image = norm_img(data[0]).astype('float32')image = data[0]"""plt.imshow(image, cmap=plt.cm.binary)plt.axis('on') # 关掉坐标轴为 offplt.title('image') # 图像题目plt.show()"""image = np.array(image).reshape(1, -1).astype(np.float32)# 图像归一化,保持和数据集的数据范围一致image = 1 - image / 255label = data[1].astype('int32')[0]result = model(paddle.to_tensor(image))result = result.numpy().astype('int32')[0][0];if (label == result) :success = success + 1else:error = error + 1#  预测输出取整,即为预测的数字,打印结果
print("本次预测的正确的数量是{}, 错误的数量是{}".format(success, error))

从测试数据的结果来看,本次成功的数据是0,一万条数据,没有一个是预测成功的。

总结:

1、模型选择的不对,一切努力都白费,没法使用单层线性网络来预测数字识别。

菜菜学paddle第一篇:单层网络构建手写数字识别相关推荐

  1. TensorFlow第三步 :单层网络-Mnist手写数字识别

    一.载入数据Mnist,并检验数据 # coding=utf-8 import os os.environ["TF_CPP_MIN_LOG_LEVEL"]='2' # 只显示 wa ...

  2. 基于Python的BP网络实现手写数字识别

    资源下载地址:https://download.csdn.net/download/sheziqiong/86790047 资源下载地址:https://download.csdn.net/downl ...

  3. 【零基础】从零开始学神经网络《python神经网络编程》——手写数字识别实战

    文章目录 前言 一.机器学习是什么,深度学习是什么? 二.对NN,CNN,RNN,GNN,GAN的名词解释 三.详细介绍神经网络(NN) 1.认识神经网络 2.神经元 3.激活函数 4.权重--连接的 ...

  4. 【DCGAN】生成对抗网络,手写数字识别

    基于paddle,aistudio的DCGAN 主要用于记录自己学习经历. 1   导入必要的包 import os import random import paddle import paddle ...

  5. 单层神经网络实现手写数字识别

    Mnist手写数字识别 前言 Mnist数据集可以从官网下载,网址: http://yann.lecun.com/exdb/mnist/ 下载下来的数据集被分成两部分: 55000行的训练数据集(mn ...

  6. 深度学习第一周 tensorflow实现mnist手写数字识别

  7. 【手写数字识别】基于Lenet网络实现手写数字识别附matlab代码

    1 内容介绍 当今社会,人工智能得到快速发展,而模式识 别作为人工智能的一个重要应用领域也得到了飞 速发展,它利用计算机通过计算的方法根据样本的 特征对样本进行分类,其中的光学字符识别技术受 到广大研 ...

  8. 深度学习笔记:01快速构建一个手写数字识别系统以及张量的概念

    深度学习笔记:01快速构建一个手写数字识别系统 神经网络代码最好运行在GPU中,但是对于初学者来说运行在GPU上成本太高了,所以先运行在CPU中,就是慢一些. 一.安装keras框架 使用管理员模式打 ...

  9. 动手学PaddlePaddle(4):MNIST(手写数字识别)

    本次练习将使用 PaddlePaddle 来实现三种不同的分类器,用于识别手写数字.三种分类器所实现的模型分别为 Softmax 回归.多层感知器.卷积神经网络. 您将学会 实现一个基于Softmax ...

最新文章

  1. Bootstrap表单验证插件bootstrapValidator使用方法整理
  2. 全栈技术实践经历告诉你:开发一个商城小程序要多少钱?
  3. RQNOJ 83 魔兽世界
  4. spring_通过Spring Boot了解H2 InMemory数据库
  5. optXXX方法,optBoolean
  6. 重新探讨一下《APEX英雄》系统设计的亮点
  7. Leetcode643.Maximum Average Subarray I子数组的最大平均数1
  8. C/C++ 中变量的声明、定义、初始化的区别
  9. 字节面试官:如何实现Ajax并发请求控制
  10. 猜数(二分、线段树)
  11. RNN循环神经网络实现预测比特币价格过程详解
  12. Jenkins ssh 发布jar 时区不对
  13. 思考XSS攻击和跨站伪造请求CSRF
  14. python中response对象的属性_关于python:AttributeError:’HTTPResponse’对象没有属性’split’...
  15. 信息学竞赛 c语言 pascal,pascal信息学竞赛教程
  16. GOM引擎 mirserver服务端各文件夹注解
  17. linux nvidia显卡驱动安装教程,LINUX的NVIDIA显卡驱动安装
  18. (Cys-RGD)包被CdTe量子|3-巯基丙酸(MPA)包被近红外发光CdTe量子
  19. 消费金融及物流概念介绍
  20. 在进行原理图编译的时候提示警告:Net has no driving source

热门文章

  1. 2018年中国新零售市场研究报告——概念、模式与案例【附下载】
  2. 今年我真要去 Google 工作了
  3. elasticsearch查询搜索命令大全
  4. MySQL 8.0.15备份还原 MySQL 5.7.17
  5. 计算机专业过硬的大学,高质量文书+过硬背景,迎来美国东北大学计算机专业硕士offer...
  6. 你的applicationContext.getResources(source)为什么只拿到了一个配置文件?
  7. iOS和安卓如何打开第三方APP?
  8. 传奇GOM原版引擎支持光柱吗?
  9. 织网模板html5,织网的文言文
  10. addEventListener 事件监听方式