1. MNIST 数据集加载

MNIST 数据集可以从MNIST官网下载。也可以通过 Tensorflow 提供的 input_data.py进行载入。

由于上述方法下载数据集比较慢,我已经把下载好的数据集上传到CSDN资源中,可以直接下载。

将下载好的数据集放到目录C:/Users/Administrator/.spyder-py3/MNIST_data/下。目录可以根据自己的喜好变换,只是代码中随之改变即可。

通过运行Tensorflow 提供的代码加载数据集:

MNIST数据集包含55000样本的训练集,5000样本的验证集,10000样本的测试集。 input_data.py 已经将下载好的数据集解压、重构图片和标签数据来组成新的数据集对象。

图像是28像素x28像素大小的灰度图片。空白部分全部为0,有笔迹的地方根据颜色深浅有0~1的取值,因此,每个样本有28x28=784维的特征,相当于展开为1维。

所以,训练集的特征是一个 55000x784 的 Tensor,第一纬度是图片编号,第二维度是图像像素点编号。而训练集的 Label(图片代表的是0~9中哪个数)是一个 55000x10 的 Tensor,10是10个种类的意思,进行 one-hot 编码 即只有一个值为1,其余为0,如数字0,对于 label 为[1,0,0,0,0,0,0,0,0,0]。

2. Softmax Regression 算法

数字都是0~9之间的,一共有10个类别,当对图片进行预测时,Softmax Regression 会对每一种类别估算一个概率,并将概率最大的那个数字作为结果输出。

Softmax Regression 将可以判定为某类的特征相加,然后将这些特征转化为判定是这一个类的概率。我们对图片的所以像素求一个加权和。如某个像素的灰度值大代表很有可能是数字n,这个像素权重就很大,反之,这个权重很有可能为负值。

特征公式:

bi" role="presentation" style=" box-sizing: border-box; outline: 0px; display: inline; line-height: normal; word-spacing: normal; overflow-wrap: break-word; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border-width: 0px; border-style: initial; border-color: initial; ">bibi 为偏置值,就是这个数据本身的一些倾向。

然后用 softmax 函数把这些特征转换成概率

对所有特征计算 softmax,并进行标准化(所有类别输出的概率值和为1):

判定为第 i 类的概率为:

Softmax Regression 流程如下:

转换为矩阵乘法:

写成公式如下:

3.实现模型

import tensorflow as tfsess = tf.InteractiveSession()x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)
  • 1

  • 2

  • 3

  • 4

  • 5

  • 6

首先载入 Tensorflow 库,并创建一个新的 InteractiveSession ,之后的运算默认在这个 session 中。

  • placeholder:输入数据的地方,None 代表不限条数的输入,每条是784维的向量

  • Variable:存储模型参数,持久化的

4.训练模型

我们定义一个 loss 函数来描述模型对问题的分类精度。 Loss 越小,模型越精确。这里采用交叉熵:


其中,y 是我们预测的概率分布, y’ 是实际的分布。

y_ = tf.placeholder(tf.float32, [None,10])cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
  • 1

  • 2

定义一个 placeholder 用于输入正确值,并计算交叉熵。

接着采用随机梯度下降法,步长为0.5进行训练。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
  • 1

训练模型,让模型循环训练1000次,每次随机从训练集去100条样本,以提高收敛速度。

for i in range(1000):  batch_xs, batch_ys = mnist.train.next_batch(100)  train_step.run({x: batch_xs, y_: batch_ys})
  • 1

  • 2

  • 3

5.评估模型

我们通过判断实际值和预测值是否相同来评估模型,并计算准确率,准确率越高,分类越精确。

6.总结

实现的整个流程:

  1. 定义算法公式,也就是神经网络前向传播时的计算。

  2. 定义 loss ,选定优化器,并指定优化器优化 loss。

  3. 迭代地对数据进行训练。

  4. 在测试集或验证集上对准确率进行评测。

7.全部代码

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data

# 获取数据mnist = input_data.read_data_sets("C:/Users/Administrator/.spyder-py3/MNIST_data/", one_hot=True)

print('训练集信息:')print(mnist.train.images.shape,mnist.train.labels.shape)print('测试集信息:')print(mnist.test.images.shape,mnist.test.labels.shape)print('验证集信息:')print(mnist.validation.images.shape,mnist.validation.labels.shape)

# 构建图sess = tf.InteractiveSession()x = tf.placeholder(tf.float32, [None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))

y = tf.nn.softmax(tf.matmul(x,W) + b)

y_ = tf.placeholder(tf.float32, [None,10])cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

# 进行训练tf.global_variables_initializer().run()

for i in range(1000):  batch_xs, batch_ys = mnist.train.next_batch(100)  train_step.run({x: batch_xs, y_: batch_ys})

# 模型评估correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print('MNIST手写图片准确率:')print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))

运用cnn实现手写体(mnist)数字识别_实现 MNIST 手写数字识别相关推荐

  1. python开源数字识别_[转]:手写数字识别系统之数字提取

    引言 所谓数字分割就是指将经过二值化后的图像中的单个数字区域进行提取的过程.数字分割在数字识别中是一个必不可少的关键步骤,只有能够将数字进行准确的提取,才能将其一一识别. 数字分割的方法 数字分割的方 ...

  2. 【手写数字识别】RBM神经网络手写数字识别【含GUI Matlab源码 1109期】

    ⛄一.手写数字识别技术简介 1 案例背景 手写体数字识别是图像识别学科下的一个分支,是图像处理和模式识别研究领域的重要应用之一,并且具有很强的通用性.由于手写体数字的随意性很大,如笔画粗细.字体大小. ...

  3. 【手写数字识别】Fisher分类手写数字识别 【含Matlab源码 505期】

    ⛄一.Fisher分类手写数字识别简介 1引言 手写体数字识别在过去的几十年里一直是模式识别领域的研究热点,在手写较多的领域如邮政编码.统计报表.财务报表.支票的数字识别等方面有广泛应用.专家.学者提 ...

  4. python手写字体程序_深度学习---手写字体识别程序分析(python)

    我想大部分程序员的第一个程序应该都是"hello world",在深度学习领域,这个"hello world"程序就是手写字体识别程序. 这次我们详细的分析下手 ...

  5. 使用ps完成手写数字图片(用于验证手写数字模型或制作数据集)

    首先,我们要新建一个28*28的画板. 然后给画板重命名为写的数字. 新建一个28*28的矩形,颜色填充为黑色,让其填满整个画板. 选中矩形,然后选择画笔工具,点击矩形,栅格化. 选中画板,按住Ctr ...

  6. matlab基于SVM的手写字体识别,机器学习SVM--基于手写字体识别

    每一行代表一个手写字体图像,最大值为16,大小64,然后最后一列为该图片的标签值. import numpy as np from sklearn import svm import matplotl ...

  7. 使用线性回归识别sklearn中的手写数字digit

    从昨天晚上,到今天上午12点半左右吧,一直在调这个代码.最开始训练的时候,老是说loss:nan 查了资料,因为是如果损失函数使用交叉熵,如果预测值为0或负数,求log的时候会出错.需要对预测结果进行 ...

  8. MNIST——手写数字识别数据集

    MNIST数据集由Yann LeCun搜集,是一个大型的手写体数字数据库,通常用于训练各种图像处理系统,也被广泛用于机器学习领域的训练和测试.MNIST数字文字识别数据集数据量不会太多,而且是单色的图 ...

  9. PYQT5+CNN(TensorFlow-keras)做一个简单的手写数字识别PC端图形化小程序

    目录 前言 一.功能介绍 1.画板识别 2.图片识别 二.UI设计 1.整体设计思想 2.颜色设计 3.Logo 设计 4.按钮设计 三.算法介绍 1.图片预处理 2.数字分割和显示 3.识别算法 4 ...

  10. 基于tensorflow2.0利用CNN与线性回归两种方法实现手写数字识别

    CNN实现手写数字识别 导入模块和数据集 import os import tensorflow as tf from tensorflow import keras from tensorflow. ...

最新文章

  1. 【Android Protobuf 序列化】Protobuf 使用 ( protobuf-gradle-plugin 插件简介 | Android Studio 中配置插件 | AS 中编译源文件 )
  2. mysql select union_MySQL SELECT语法(四)UNION语法详解
  3. springmvc基础学习3---注解简单理解
  4. Codeforces Round #667 (Div. 3)
  5. 在众多编程语言中,你可知哪种语言的安全性更高,安全漏洞最少?
  6. 安卓开发_自定义控件_界面的简单侧滑
  7. 《消息队列》函数讲解
  8. 雷军:小米12 Pro全球首发索尼IMX707
  9. html代码id,浅谈html中id和name的区别实例代码
  10. java和C#的相同之处笔记
  11. Qt QSsh 使用 windows Qt实现ssh客户端
  12. 2022爱分析·虚拟化活动实践报告
  13. 建行网银盾无法识别怎么办
  14. Java获取本机外网ip地址的方法
  15. html2pdf无法导出图片解决方案(2020版)
  16. linux 自启动 快捷键,linux自定义快捷键、文件打开方式、文件快捷方式、启动器及开机启动...
  17. 罗技M590优联无法使用的问题解决
  18. CAD 偏移和复制、移动的区别
  19. Eclipse中配置python环境
  20. (java桌面应用程序)淘金者游戏及玩法介绍

热门文章

  1. wordpress插件列表
  2. Q102:光线追踪场景(2)——PLYs(多种模型汇集)
  3. 问题二十五:为什么有时候XnView无法显示PPM图片?
  4. java mvc接收 时间_Springmvc 如何接收java8的时间localDateTime。
  5. Error Could not open client transport with JDBC Uri jdbchive2hadoop10210000 Failed to open new sessi
  6. java线程池测试,Java线程池【测试Markdown样式】
  7. 机器学习 - [集成学习]Bagging算法的编程实现
  8. Java编程基础 - 泛型
  9. 计算机报警声 一高一低,有报警声电脑问题怎么处理 有报警声电脑问题处理方法【介绍】...
  10. java 订阅 kafka_尝试从kafka(0.10版本)访问kafka(0.90版本)时订阅方法抛出错误...