MNIST

  MNIST是一个非常简单的机器视觉数据集。如图,它由几万字28像素×28像素的手写数字组成,这些图片只包含灰度值信息。我们的任务是对这些手写数字的图片进行分类,转成0~9一共10类。

  首先对MNIST数据进行加载,然后查看mnist这个数据集的情况。

# 输入程序
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)
# 运行结果
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784) (55000, 10)
(10000, 784) (10000, 10)
(5000, 784) (5000, 10)

  one_hot 是指一个在大多数维度上为0的向量,在一维中是1。在这种情况下,第n个数字将被表示为一个在第n维中1的向量。如数字5对应[0,0,0,0,0,1,0,0,0,0]。

  可以看到,训练集有55000个样本,测试集有10000个样本,同时验证集有5000个样本。每个样本都有对应的标签label。


实现原理

  训练数据的特征是一个55000×784的Tensor,第一个维度是图片的编号,第二个维度是图片中像素点的编号,同时训练的数据Label是一个55000×10的Tensor,使用one-hot编码。

  Softmax Regression:Softmax Regression是Logistic回归的推广,处理多分类问题。



  损失函数(交叉熵)

  其中,y是预测的概率分布,y’是真实的概率分布,通常用来判断模型对真实概率分布估计的准确程度。


程序

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data#下载并加载数据
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)#数据与标签的占位
x = tf.placeholder(tf.float32,shape = [None,784])
y_actual = tf.placeholder(tf.float32,shape=[None,10])#初始化权重和偏置
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))#softmax回归,得到预测概率
y_predict = tf.nn.softmax(tf.matmul(x,W) + b)#求交叉熵得到残差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_actual*tf.log(y_predict),reduction_indices=1))#梯度下降法使得残差最小,学习速率为0.01
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#测试阶段,测试准确度计算
correct_prediction = tf.equal(tf.argmax(y_predict,1),tf.argmax(y_actual,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,'float'))#多个批次的准确度均值init = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init)#训练,迭代1000次for i in range(1000):batch_xs,batch_ys = mnist.train.next_batch(100)#按批次训练,每批100行数据sess.run(train_step,feed_dict={x:batch_xs,y_actual:batch_ys})#执行训练if(i%100==0):#每训练100次,测试一次print("accuracy:",sess.run(accuracy,feed_dict={x: mnist.test.images, y_actual: mnist.test.labels}))

  运行结果

  


【TensorFlow】MNIST手写数字识别相关推荐

  1. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  2. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

  3. TensorFlow 2.0 mnist手写数字识别(CNN卷积神经网络)

    TensorFlow 2.0 (五) - mnist手写数字识别(CNN卷积神经网络) 源代码/数据集已上传到 Github - tensorflow-tutorial-samples 大白话讲解卷积 ...

  4. tensorflow saver_机器学习入门(6):Tensorflow项目Mnist手写数字识别-分析详解

    本文主要内容:Ubuntu下基于Tensorflow的Mnist手写数字识别的实现 训练数据和测试数据资料:http://yann.lecun.com/exdb/mnist/ 前面环境都搭建好了,直接 ...

  5. mnist手写数字识别python_Python tensorflow实现mnist手写数字识别示例【非卷积与卷积实现】...

    本文实例讲述了Python tensorflow实现mnist手写数字识别.分享给大家供大家参考,具体如下: 非卷积实现 import tensorflow as tf from tensorflow ...

  6. python cnn代码详解图解_基于TensorFlow的CNN实现Mnist手写数字识别

    本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下 一.CNN模型结构 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5* ...

  7. 基于TensorFlow深度学习框架,运用python搭建LeNet-5卷积神经网络模型和mnist手写数字识别数据集,设计一个手写数字识别软件。

    本软件是基于TensorFlow深度学习框架,运用LeNet-5卷积神经网络模型和mnist手写数字识别数据集所设计的手写数字识别软件. 具体实现如下: 1.读入数据:运用TensorFlow深度学习 ...

  8. 《深度学习之TensorFlow》reading notes(3)—— MNIST手写数字识别之二

    文章目录 模型保存 模型读取 测试模型 搭建测试模型 使用模型 模型可视化 本文是在上一篇文章 <深度学习之TensorFlow>reading notes(2)-- MNIST手写数字识 ...

  9. AI常用框架和工具丨11. 基于TensorFlow(Keras)+Flask部署MNIST手写数字识别至本地web

    代码实例,基于TensorFlow+Flask部署MNIST手写数字识别至本地web,希望对您有所帮助. 文章目录 环境说明 文件结构 模型训练 本地web创建 实现效果 环境说明 操作系统:Wind ...

  10. MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测

    Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...

最新文章

  1. 虚拟机用户配置root权限
  2. centos 安装 py pyhs2
  3. 【架构】典型的 K8s 架构图-核心概念(简化)
  4. hibernate mysql 映射_Hibernate怎么不用配置mapping就使用数据库表映射实体
  5. matlab矩阵处理实验报告,matlab实验报告一二三
  6. C语言/ 运算符的优先级以及结合方向
  7. Socket通信的安全策略问题
  8. 安装mp4,mp3等媒体解码器
  9. 各类任务的数据集大数据库
  10. jquery Ajax回调函数
  11. 第六课:计算两数的GCF(最大公因数)(基于AndroidStudio3.2)
  12. 检察机关认定河北涞源反杀案为正当防卫 决定不起诉女生父母
  13. 数据库脏读、不可重复读、幻读以及对应的隔离级别
  14. bsoj 1512 金明的预算方案(树型DP)
  15. SSH连接时候出现 REMOTE HOST IDENTIFICATION HAS CHANGED
  16. Python解析百度地图各省市经纬度(二)
  17. 金字塔原理(6)- 确定逻辑顺序
  18. KX2 101-v2
  19. 如何优雅地删除Docker镜像和容器(超详细)
  20. 【空间数据库】传统数据模型(层次、网状、关系)和空间数据模型详解

热门文章

  1. 学计算机的副部级,中国31所副部级大学排名
  2. AcWing 1058. 股票买卖 V
  3. 自动驾驶—全局定位的学习笔记
  4. 使用maven构建多模块项目
  5. Python函数的静态变量
  6. 【Qt教程】1.9 - Qt5菜单栏、工具栏、状态栏、核心窗口、浮动窗口、QMainWindow
  7. MySQL学习记录 (三) ----- SQL数据定义语句(DDL)
  8. flex 注册监听器时传值
  9. 【Linux】Ubuntu 代理配置
  10. Dao层抽取BaseDao公共方法