Tensorflow实现MNIST手写识别
MNIST手写体识别训练和测试模型下载地址:
MNIST手写体模型下载
MNIST手写体识别,标签编码为独热(one-hot)编码
One-Hot编码,又称为一位有效编码,主要是采用N位状态寄存器来对N个状态进行编码,每个状态都由他独立的寄存器位,并且在任意时候只有一位有效。
One-Hot编码是分类变量作为二进制向量的表示。这首先要求将分类值映射到整数值。然后,每个整数值被表示为二进制向量,除了整数的索引之外,它都是零值,它被标记为1。
导入相关包
import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
import matplotlib.pyplot as plt
import numpy as np
numpy安装:
pip install numpy
matplotlib安装:
pip install matplotlib
MNIST图像读取
mnist = input_data.read_data_sets("data/MNIST/", one_hot=True)# mnist 中每张图片共有28*28=784个像素点
变量定义
x = tf.placeholder(tf.float32, [None, 784], name='x')# 0-9 一共十个数字-》十个类别y = tf.placeholder(tf.float32, [None, 10], name='y')# 定义变量w = tf.Variable(tf.zeros([784.10]), name='w')b = tf.Variable(tf.zeros([10]), name='b')# 使用单个神经元,进行前向计算forward = tf.matmul(x, w) + b# 使用softmax对结果集进行分类pred = tf.nn.softmax(forward)# 训练次数train_epochs = 50# 单次训练样本数(批次大小)batch_size = 10# 一轮训练有多少批次total_batch = int(mnist.train.num_examples / batch_size)learning_rate = 0.01# 显示粒度display_step = 1
定义损失函数和优化器
# 定义交叉熵损失函数loss_function = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices=1))# 定义优化器,梯度下降optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function)
定义准确率
# 检查预测类别tf.argmax(pred,1) 与实际类别tf.argmax(y,1)的匹配情况correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))# 准确率,将布尔值转化为浮点数,并计算平均值accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
定义Tensorflow会话
sess = tf.Session()init = tf.global_variables_initializer()sess.run(init)
模型训练
for epoch in range(train_epochs):for batch in range(total_batch):# 读取批次数据xs, ys = mnist.train.next_batch(batch_size)# 执行批次训练sess.run(optimizer, feed_dict={x: xs, y: ys})# total_batch个批次训练完成后,使用验证数据计算误差与准确率,验证集未分批loss, acc = sess.run([loss_function, accuracy],feed_dict={x: mnist.validation.images, y: mnist.validation.labels})# 打印训练过程中的详细信息if (epoch + 1) % display_step == 0:print("Train Epoch:", '%02d' % (epoch + 1), 'Loss=', '{:.9f}'.format(loss), 'Accuracy=','{:.4f}'.format(acc))
图像可视化函数
def plot_images_labels_prediction(images, # 图像列表labels, # 标签列表prediction, # 预测值列表index, # 从第index个开始显示num=10): # 缺省一次显示10幅fig = plt.gcf() # 获取当前图标,Get Current Figurefig.set_size_inches(10, 12) # 1英寸等于2.54cmif num > 25:num = 25 # 最多显示25个子图for i in range(0, num):ax = plt.subplot(5, 5, i + 1) # 获取当前要处理的子图# 显示第index个图像ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')# 构建该图上要显示的titletitle = "label=" + str(np.argmax(labels[index]))if len(prediction) > 0:title += ",predict=" + str(prediction[index])# 显示图上的title信息ax.set_title(title, fontsize=10)# 不限是坐标轴ax.set_xticks([])ax.set_yticks([])index += 1plt.show()
该过程代码基于Tensorflow 1.0完成,Tensorflow 1.0安装:
- 通过Anaconda完成安装:
# 创建名称为tf-1.0的conda虚拟Python环境,并指定Python版本为3.5conda create -n tf-1.0 python=3.5# 激活tf-1.0环境conda activate tf-1.0# 查找tensorflow版本号conda search tensorflow# 安装指定版本的tensorflowconda install tensorflow=1.9
- 通过pip安装:
# 安装指定版本的tensorflow,默认安装tensorflow - 2.0pip install tensorflow==1.9
Tensorflow实现MNIST手写识别相关推荐
- tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解
本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...
- mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...
本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...
- Tensorflow之基于MNIST手写识别的入门介绍
Tensorflow是当下AI热潮下,最为受欢迎的开源框架.无论是从Github上的fork数量还是star数量,还是从支持的语音,开发资料,社区活跃度等多方面,他当之为superstar. 在前面介 ...
- 基于tensorflow的MNIST手写字识别
一.卷积神经网络模型知识要点卷积卷积 1.卷积 2.池化 3.全连接 4.梯度下降法 5.softmax 本次就是用最简单的方法给大家讲解这些概念,因为具体的各种论文网上都有,连推导都有,所以本文主要 ...
- 最终章 | TensorFlow战Kaggle“手写识别达成99%准确率
刘颖,某互联网创业公司COO,技术出身,做产品里最懂运营的. 这是一个TensorFlow的系列文章,本文是第三篇,在这个系列中,你讲了解到机器学习的一些基本概念.TensorFlow的使用,并能实际 ...
- 使用mllib完成mnist手写识别任务
使用mllib完成mnist手写识别任务 小提示,通过restart命令重启已经退出了的容器 sudo docker restart <contain id> 完成识别任务准备工作 从以下 ...
- Mnist手写识别项目常见问题及解决方法
环境搭建 问题1:在Visual Studio 2019中的"扩展"管理中搜索不到"AI工具" 解决方法:因为"AI工具"插件不支持Visu ...
- 深度学习笔记(MNIST手写识别)
先看了点花书,后来觉得有点枯燥去看了b站up主六二大人的pytorch深度学习实践的课,对深度学习的理解更深刻一点,顺便做点笔记,记录一些我认为重要的东西,便于以后查阅. 一. 机器学习基础 学习的定 ...
- TensorFlow的MNIST手写数字分类问题
一.简介MNIST TensorFlow编程学习的入门一般都是基于MNIST手写数字数据集和Cifar(包括cifar-10和cifar-100)数据集,因为它们都比较小,一般的设备即可进行训练和测试 ...
最新文章
- LeetCode:Remove Nth Node From End of List
- SAP freelancer夫妻并不难!你也可以!
- redis单机版安装
- 光耦驱动单向可控硅_华越国际一文带路:可控硅触发设计技巧
- chmod命令详解使用格式和方法
- Java基础-hashMap原理剖析
- JKD16正式发布,新特新一览
- JAVA锁之可重入锁和递归锁及示例代码
- 测试环境搭建流程_软件测试流程
- indesign使用教程,如何将图形添加到项目?
- formatter function (value,row,index){} 参数的含义
- html网页背景图像失真,CSS实现页面背景图片模糊内容不模糊的方法
- python英文词频统计软件_英语词频统计软件功能介绍
- access数据库剔除重复项_使用Access数据库的站长看过来——如何自动去掉数据库中的重复文章...
- 扒一扒能加速互联网的QUIC协议
- Mac系统升级后导致AS不能使用SVN
- Phoenix FD Maya 软件插件
- (四)JMockit 的API:@Injectable 与 @Mocked的不同--基础篇
- Linux桌面系统x11原理简介
- 完全模拟FIFA2014世界杯 原创求顶!