MINIS手写体识别相当于机器学习中的Hello world,作为深度学习领域的入门任务是初学者的首选。

数据集下载 download.py

1. MINIST数据集

MINIST数据集是一个图片压缩包,包含大量手写数字图片;
其中一个数据样本有两部分构成:手写图片和label;

数据集每个样本都是(1,784)维度的图片,为了方便展示就像下图一样画成二维矩阵,实际上后面要将向量转换为28*28的二维矩阵;

2. 读取数据

TensorFlow为了教学MNIST而提前编制了程序,所以只需两行代码就可自动下载数据集文件

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# 从MNIST_data/中读取MNIST数据,这条语句在数据不存在时,会自动执行下载

运行后可以在相应目录下查看到如下数据压缩包

3. 打印数据集基本信息

数据集可分为:

训练集 测试集 验证集
minist.train minist.test minist.validation
55,000 10,000 5,000
训练模型 多次使用调参 评估模型

将MINIST保存为图片 save_pic.py

1. 创建文件夹保存原始图片

save_dir = 'MNIST_data/raw/'
if os.path.exists(save_dir) is False:  # 判断save_dir文件是否存在os.makedirs(save_dir) # os.makedirs()方法用于递归创建目录

2. reshape图片

 # 注意:mnist.train.images[i, :]就表示第i张图片(序号从0开始)image_array = mnist.train.images[i, :]# TensorFlow中的MNIST图片是一个784维的向量,我们重新把它还原为28x28维的图像image_array = image_array.reshape(28, 28)

3. 文件命名

# 保存文件的格式为 mnist_train_0.jpg, mnist_train_1.jpg, ... ,mnist_train_19.jpgfilename = save_dir + 'mnist_train_%d.jpg' % i

4.转换为图片存储

    # 将image_array保存为图片:array - > image#scipy.misc.toimage(image_array, cmin=0.0, cmax=1.0).save(filename)# AttributeError: module 'scipy.misc' has no attribute 'toimage'Image.fromarray((image_array * 255).astype('uint8'), mode='L').convert('RGB').save(filename)

运行后可以在相应目录下看到存储下来的如下图片:

打印MINIST数据集图片的标签 label.py

1. 独热编码

one - hot读热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0;独热编码常用于表示拥有有限个可能值的字符串或标识符;

2. argmax(array, axis)

numpy.argmax(array, axis) 函数:
array是矩阵,axis是0或者1;axis默认为0
其中,0表示的是按行比较返回最大值的索引,1表示按列比较返回最大值的索引
eg: one_dim_array = np.array([1, 4, 5, 3, 7, 2, 6])
print(np.argmax(one_dim_array)) # 4

3. 打印label

MINIST数据集每个样本的label是(1*10)维的数组,采用独热编码表示,形如(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)

# 看前20张训练图片的label
for i in range(20):one_hot_label = mnist.train.labels[i, :]# 通过np.argmax我们可以直接获得原始的labellabel = np.argmax(one_hot_label)print(f'mnist_train_{i}.jpg label: {label}')# mnist_train_18.jpg label: 6

Softmax回Softmax_regression.py

1. Softmax回归模型

Softmax回归是一个线性的多类分类模型,我们期望对输入的图片计算出它属于某个类别的概率,比如90%的概率是2,5%的概率是5,那么最终会输出这张图片的标签为5;因此,一张图片对于每一个数字的吻合度可以被softmax函数转换成为一个概率值。softmax函数可以定义为:

由下图所示,对于输入Xi加权求和,再分别加上一个偏置量,最后输入到Softmax函数中:

其矩阵表现形式如下:


下面举一个例子理解:
假设我们现在需要预估房价。输出y是房子价格,输入Xi就是多种影响因素,可能是大小,距离,教育资源等等,类比MINIST问题就是输入图片的像素值;W为各影响因素的权,b为偏移量;

2. 设置占位符

  1. tf.placeholder( dtype, shape=None, name )
    dtype是数据类型;shape是数据形状,默认一维,也可二维;name可有可无
  2. palceholder只暂时存储变量,传值过程在sess.run()中进行
x = tf.placeholder(tf.float32, [None, 784])
# 用于得到传递进来的待识别的训练图片
W = tf.Variable(tf.zeros([784, 10]))
# 相当于一层神经网络上的参数
b = tf.Variable(tf.zeros([10]))
# 偏置向量

2. 预测函数

y是模型的输出,y_是实际的图像标签,用one-hot表示

# y=softmax(Wx + b),y表示模型的输出
y = tf.nn.softmax(tf.matmul(x, W) + b)
# tf.matmual()表示两个矩阵相乘# y_是实际的图像标签,同样以占位符表示
y_ = tf.placeholder(tf.float32, [None, 10])

3. 损失函数

根据y和y_来构造交叉熵损失函数

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))

4. 梯度下降

使用梯度下降的方式去迭代使得损失函数达到最小;
GradientDescentOptimizer(0.01)是梯度下降的封装函数,设置0.01的初始学习速率;
minimize(cross_entropy)表示最小化损失函数;

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

5. 定义会话Session

上面所有的步骤只是定义了一个框架,Tensorflow中要求所有的数据计算都要在定义的会话中进行;

with tf.Session() as sess:tf.global_variables_initializer().run()# 初始化所有变量for _ in range(1000):# 迭代1000次batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})  # feed_dict给使用placeholder创建出来的tensor赋值

6. 查看准确率

  1. equal()函数:逐元素比较是否相等,返回结一个比较结果(Ture/False)

  2. reduce_mean()函数:计算张量的各个维度上的元素的平均值

  3. cast()函数:张量数据类型转换,将布尔型转换为float32
    例如:
    [True,False,True,True]可以用[1,0,1,1]表示,精度 为0.75
    0.75 = 3/4

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
# 用测试集的数据去得到准确率

源代码

github源代码

基于Tensorflow的MINIST手写体识别相关推荐

  1. 基于CNn的MINIST手写体识别

    深度学习的上机作业: 基于CNN卷积神经网络的MINIST手写体识别 版本:python-3.9,tensorflow-2.9 目录 MINIST数据集 训练CNN卷积神经网络 使用训练好的模型进行预 ...

  2. 基于tensorflow、CNN网络识别花卉的种类(图像识别)

    基于tensorflow.CNN网络识别花卉的种类 这是一个图像识别项目,基于 tensorflow,现有的 CNN 网络可以识别四种花的种类.适合新手对使用 tensorflow进行一个完整的图像识 ...

  3. 猫狗大战——基于TensorFlow的猫狗识别(2)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 上篇文章我们说了关于猫狗大战这个项目的一些准备工作,接下来,我们看看具体的代码详解. 猫狗大战--基于TensorFlow的猫狗识别(1) 文件 ...

  4. 基于tensorflow的minst手写体数字识别

    引言 TensorFlow 是一个采用数据流图,用于数值计算的开源软件库.它是一个不严格的"神经网络"库,可以利用它提供的模块搭建大多数类型的神经网络.它可以基于CPU或GPU运行 ...

  5. python神经网络库识别验证码_基于TensorFlow 使用卷积神经网络识别字符型图片验证码...

    本项目使用卷积神经网络识别字符型图片验证码,其基于TensorFlow 框架.它封装了非常通用的校验.训练.验证.识别和调用 API,极大地减低了识别字符型验证码花费的时间和精力. 项目地址:http ...

  6. 基于Tensorflow实现声纹识别

    前言 本章介绍如何使用Tensorflow实现简单的声纹识别模型,首先你需要熟悉音频分类,没有了解的可以查看这篇文章<基于Tensorflow实现声音分类>.基于这个知识基础之上,我们训练 ...

  7. 基于TensorFlow的简单验证码识别

    TensorFlow 可以用来实现验证码识别的过程,这里识别的验证码是图形验证码,首先用标注好的数据来训练一个模型,然后再用模型来实现这个验证码的识别. 生成验证码 首先生成验证码,这里使用 Pyth ...

  8. 基于keras的mnist手写体识别程序

    大家好 我是来自河北大学 心电组的一名研一的学生,本篇文章是我对mnist识别学习的认识和分享. 本文主要用来给想要用keras搭建网络识别mnist的同学一个引导. 有错误的地方请大家指正 我会虚心 ...

  9. 基于TensorFlow Lite的人声识别在端上的实现

    通过TensorFlow Lite,移动终端.IoT设备可以在端上实现声音识别,这可以应用在安防.医疗监护等领域.来自阿里巴巴闲鱼技术互动组仝辉和上叶通过TensorFlow Lite实现了一套完整的 ...

  10. 猫狗大战——基于TensorFlow的猫狗识别(1)

    微信公众号:龙跃十二 我是小玉,一个平平无奇的小天才! 简介: 关于猫狗识别是机器学习和深度学习的一个经典实例,下来小玉把自己做的基于CNN卷积神经网络利用Tensorflow框架进行猫狗的识别的程序 ...

最新文章

  1. [ZZ]知名互联网公司Python的16道经典面试题及答案
  2. Google App Engine给我们带来了什么?
  3. img summernote 加类_控制好情绪 的动态 - SegmentFault 思否
  4. [云炬ThinkPython阅读笔记]2.3 表达式和语句
  5. 移动开发架构之MVVM模式
  6. 【HDU - 1281 】棋盘游戏 (经典的二分图匹配,匈牙利算法,枚举删除顶点,必须边,关建边)
  7. Centos7下更改docker镜像和容器的默认路径
  8. Code Access Security (CAS)
  9. mysql 时间序列可视化工具_mysql – 从from到to条目创建时间序列
  10. esxi存储(外部共享存储)- Open FIle
  11. LXC源码编译测试(五)
  12. Linux下安装gcc和g++
  13. ecshop验证码无法显示
  14. vue如何加载html字符串_VUE渲染后端返回含有script标签的html字符串示例
  15. Ubuntu下逻辑坏道解决方案
  16. 道德经和译文_道德经全文和译文
  17. O2O口号容易运营难
  18. CES 2019上芯片巨头们的争夺焦点:光线追踪、“永远”在线PC、汽车
  19. 选择器和字体的设置7.22
  20. dz论坛附件在服务器中的位置,Discuz! 远程附件设置图文说明

热门文章

  1. 安装tomcat时出错:failed to install tomcat6 service问题的解决方法
  2. 希尔伯特变换产生负频率解决方法
  3. 高颜值:Redis官方可视化工具,功能强大!
  4. win10系统64位安装与配置java环境,安装使用citespace经验
  5. Spring代码实例系列-绪论
  6. Apache Tomcat7.0 Tomcat7启动不了的解决问题
  7. linux系统下在ubuntu20.04安装matlab2017总结
  8. qa 芯片测试_关于半导体设备测试,看这一篇就够了
  9. Unable to start LiveReload server
  10. 会议会展活动管理软件可实现哪些功能