MNIST手写体识别训练和测试模型下载地址:
MNIST手写体模型下载

MNIST手写体识别,标签编码为独热(one-hot)编码

One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。
One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。

导入相关包

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np

numpy安装:
pip install numpy
matplotlib安装:
pip install matplotlib

MNIST图像读取

 mnist = input_data.read_data_sets("data/MNIST/", one_hot=True)# mnist 中每张图片共有28*28=784个像素点

变量定义

 x = tf.placeholder(tf.float32, [None, 784], name='x')# 0-9 一共十个数字-》十个类别y = tf.placeholder(tf.float32, [None, 10], name='y')# 定义变量w = tf.Variable(tf.zeros([784.10]), name='w')b = tf.Variable(tf.zeros([10]), name='b')# 使用单个神经元,进行前向计算forward = tf.matmul(x, w) + b# 使用softmax对结果集进行分类pred = tf.nn.softmax(forward)# 训练次数train_epochs = 50# 单次训练样本数(批次大小)batch_size = 10# 一轮训练有多少批次total_batch = int(mnist.train.num_examples / batch_size)learning_rate = 0.01# 显示粒度display_step = 1

定义损失函数和优化器

 # 定义交叉熵损失函数loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))# 定义优化器,梯度下降optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)

定义准确率

    # 检查预测类别tf.argmax(pred,1) 与实际类别tf.argmax(y,1)的匹配情况correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 准确率,将布尔值转化为浮点数,并计算平均值accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

定义Tensorflow会话

 sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)

模型训练

    for epoch in range(train_epochs):for batch in range(total_batch):# 读取批次数据xs, ys = mnist.train.next_batch(batch_size)# 执行批次训练sess.run(optimizer, feed_dict={x: xs, y: ys})# total_batch个批次训练完成后,使用验证数据计算误差与准确率,验证集未分批loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels})# 打印训练过程中的详细信息if (epoch + 1) % display_step == 0:print("Train Epoch:", '%02d' % (epoch + 1), 'Loss=', '{:.9f}'.format(loss), 'Accuracy=','{:.4f}'.format(acc))

图像可视化函数

def plot_images_labels_prediction(images,  # 图像列表labels,  # 标签列表prediction,  # 预测值列表index,  # 从第index个开始显示num=10):  # 缺省一次显示10幅fig = plt.gcf()  # 获取当前图标,Get Current Figurefig.set_size_inches(10, 12)  # 1英寸等于2.54cmif num > 25:num = 25  # 最多显示25个子图for i in range(0, num):ax = plt.subplot(5, 5, i + 1)  # 获取当前要处理的子图# 显示第index个图像ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')# 构建该图上要显示的titletitle = "label=" + str(np.argmax(labels[index]))if len(prediction) > 0:title += ",predict=" + str(prediction[index])# 显示图上的title信息ax.set_title(title, fontsize=10)# 不限是坐标轴ax.set_xticks([])ax.set_yticks([])index += 1plt.show()

该过程代码基于Tensorflow 1.0完成,Tensorflow 1.0安装:

  1. 通过Anaconda完成安装:
    # 创建名称为tf-1.0的conda虚拟Python环境,并指定Python版本为3.5conda create -n tf-1.0 python=3.5# 激活tf-1.0环境conda activate tf-1.0# 查找tensorflow版本号conda search tensorflow# 安装指定版本的tensorflowconda install tensorflow=1.9
  1. 通过pip安装:
 # 安装指定版本的tensorflow,默认安装tensorflow - 2.0pip install tensorflow==1.9

Tensorflow实现MNIST手写识别相关推荐

  1. tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解

    本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...

  2. mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...

    本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...

  3. Tensorflow之基于MNIST手写识别的入门介绍

    Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...

  4. 基于tensorflow的MNIST手写字识别

    一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要 ...

  5. 最终章 | TensorFlow战Kaggle“手写识别达成99%准确率

    刘颖,某互联网创业公司COO,技术出身,做产品里最懂运营的. 这是一个TensorFlow的系列文章,本文是第三篇,在这个系列中,你讲了解到机器学习的一些基本概念.TensorFlow的使用,并能实际 ...

  6. 使用mllib完成mnist手写识别任务

    使用mllib完成mnist手写识别任务 小提示,通过restart命令重启已经退出了的容器 sudo docker restart <contain id> 完成识别任务准备工作 从以下 ...

  7. Mnist手写识别项目常见问题及解决方法

    环境搭建 问题1:在Visual Studio 2019中的"扩展"管理中搜索不到"AI工具" 解决方法:因为"AI工具"插件不支持Visu ...

  8. 深度学习笔记(MNIST手写识别)

    先看了点花书,后来觉得有点枯燥去看了b站up主六二大人的pytorch深度学习实践的课,对深度学习的理解更深刻一点,顺便做点笔记,记录一些我认为重要的东西,便于以后查阅. 一. 机器学习基础 学习的定 ...

  9. TensorFlow的MNIST手写数字分类问题

    一.简介MNIST TensorFlow编程学习的入门一般都是基于MNIST手写数字数据集和Cifar(包括cifar-10和cifar-100)数据集,因为它们都比较小,一般的设备即可进行训练和测试 ...

最新文章

  1. LeetCode:Remove Nth Node From End of List
  2. SAP freelancer夫妻并不难!你也可以!
  3. redis单机版安装
  4. 光耦驱动单向可控硅_华越国际一文带路:可控硅触发设计技巧
  5. chmod命令详解使用格式和方法
  6. Java基础-hashMap原理剖析
  7. JKD16正式发布,新特新一览
  8. JAVA锁之可重入锁和递归锁及示例代码
  9. 测试环境搭建流程_软件测试流程
  10. indesign使用教程,如何将图形添加到项目?
  11. formatter function (value,row,index){} 参数的含义
  12. html网页背景图像失真,CSS实现页面背景图片模糊内容不模糊的方法
  13. python英文词频统计软件_英语词频统计软件功能介绍
  14. access数据库剔除重复项_使用Access数据库的站长看过来——如何自动去掉数据库中的重复文章...
  15. 扒一扒能加速互联网的QUIC协议
  16. Mac系统升级后导致AS不能使用SVN
  17. Phoenix FD Maya 软件插件
  18. (四)JMockit 的API:@Injectable 与 @Mocked的不同--基础篇
  19. Linux桌面系统x11原理简介
  20. 完全模拟FIFA2014世界杯 原创求顶!

热门文章

  1. 如何实现文件共享,文件共享的设置方法-镭速
  2. CTFshow web入门——文件上传
  3. 大数据背景下的数据融合
  4. 发个ZKW线段树板子测试一下代码高亮
  5. Linux驱动编程(驱动程序基石)(上)
  6. 2019年一级消防工程师报考通知
  7. 三步学会制作一个小程序
  8. C语言中const void *a是什么意思
  9. JAVA UTC时间和本地时间
  10. cad图纸怎么转换成pdf格式?