1、导入库

# 说明:由于windows下运行与tensorflow相关的程序会出现“.......supports AVX2.....”的 Warnning信息十分碍眼,于是在我的查阅中,可以通过导入os库对os下的方法environ进行如下配置可以消除 Warnning

1 import tensorflow as tf
2 import os
3 os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
4 import numpy as np

2、导入训练数据集

# train训练数据集、test测试数据集
# mnist数据单元:由训练数据集的图片和标签两者组成(★)
# one_hot:又名one_hot vectors, 该参数用来将数据集中的标量转换为向量;
#       某一位是1,其余各维度数字皆为0,所以数字n将表示一个只有在第n维度数字为1的10维向量某一位是1,其余各维度数字皆为0,所以数字n将表示#       一个只有在第n维度数字为1的10维向量
#       如标签0将表示:[1,0,0,0,0,0,0,0,0,0]

1 import input_data
2 mnist = input_data.read_data_sets("MNIST_data",one_hot=True)

3、构建模型
#输入图像
## placeholder 占位符,非特定的值。借以输入任意数量minist图像
## 将一张图平展成784维的向量,并用二维浮点数张量表示图
## None 表示第一个维度可以是任何长度的
## x 二维张量,拥有多个输入

1 x = tf.placeholder("float",[None,784])

#设置张量
## Variable 可修改的张量,存在于交互性操作的图中。用于计算输入值。
## w(似权重) 用来与784维(28x28)图片向量相乘得到10维的证据值向量,每一位对应不同数字类
## b(似偏移量) 直接加在输出上

1 w = tf.Variable(tf.zeros([784,10]))
2 b = tf.Variable(tf.zeros([10]))

#实现模型

1 y = tf.nn.softmax(tf.matmul(x,w)+b)

4、训练模型

#计算交叉熵
## 成本:评价模型是坏的(cost/loss)
## cross_entropy:交叉熵 (-Σy'log(y))衡量预测描述真相的低效性
## y_ 占位符,用来计算交叉熵,输入正确值
## 结论:我们建立的模型用来训练得出真实值y_.

1 y_ = tf.placeholder("float",[None,10])
2 cross_entropy = tf.reduce_sum(y_*tf.log(y))

5、降低成本
# 图:描述各个计算单元,自动使用反向传播算法有效地确定变量如何影响想要最小化的那个成本值的
# 反向传播算法:
# 选择优化算法不断改变变量以降低成本
# 梯度下降算法:简单的学习过程
# 最小化交叉熵:算法以0.01的学习速率最小化交叉熵

1 train_step=tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

6、初始化变量

1 init = tf.global_variables_initializer()

7、启动会话

1 sess = tf.Session()
2 sess.run(init)

8、开始训练模型
# batch:分批处理,这里指数据集中的批处理数据点
# 随机训练:使用一小部分的随机数据进行训练,这里指随机梯度下降训练
# next_batch:使每一次抓取的批处理数据点都是不同的,减小开销

1 for i in range(1000):
2     batch_xs, batch_ys = mnist.train.next_batch(100)
3     sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys})

9、评估模型
#预测正确标签
## tf.argmax: 给出tensor对象在某一维上的其数据最大的索引值
## 说明:标签向量由0和1组成,1便为最大索引值,索引位置就是类别标签
## tf.argmax(y,1) 预测到的标签值;tf.argmax(y_,1) 真实标签匹配

1 correct_predication = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))

#数值转换
## 将布尔值转换成浮点数,再取平均值

1 accuracy = tf.reduce_mean(tf.cast(correct_predication,"float"))

#计算
## 此处数据被喂的对象使correct_predication,最终输出的是accuracy

1 print(sess.run(accuracy,feed_dict={x:mnist.test.images, y_:mnist.test.labels}))

转载于:https://www.cnblogs.com/Quasimodo2018-0815/p/shawshankaha_tensorflow_mnist.html

tensorflow00:windows下训练并测试MNIST数字识别详细笔记相关推荐

  1. 在MNIST数据集上训练一个手写数字识别模型

    使用Pytorch在MNIST数据集上训练一个手写数字识别模型, 代码和参数文件 可下载 1.1 数据下载 import torchvision as tvtraining_sets = tv.dat ...

  2. TensorFlow Object Detection API(Windows下训练)

    本文为作者原创,转载请注明出处(http://www.cnblogs.com/mar-q/)by 负赑屃 最近事情比较多,前面坑挖的有点久,今天终于有时间总结一下,顺便把Windows下训练跑通.Li ...

  3. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别、模型评估(99.4%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别.模型评估(99.4%) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 netwo ...

  4. DL之CNN:利用自定义DeepConvNet【7+1】算法对mnist数据集训练实现手写数字识别并预测(超过99%)

    DL之CNN:利用自定义DeepConvNet[7+1]算法对mnist数据集训练实现手写数字识别并预测(超过99%) 目录 输出结果 设计思路 核心代码 输出结果 准确度都在99%以上 1.出错记录 ...

  5. 写给初学者的深度学习教程之 MNIST 数字识别

    一般而言,MNIST 数据集测试就是机器学习和深度学习当中的"Hello World"工程,几乎是所有的教程都会把它放在最开始的地方.这是因为,这个简单的工程包含了大致的机器学习流程,通过练习这个工程 ...

  6. TensorFlow解决MNIST数字识别问题

    TensorFlow解决MNIST数字识别问题 废话 这个MNIST数字识别问题是我实现的第一个神经网络,虽然过程基本上都是对着书上的代码敲,但还是对神经网络的训练过程有了一定的了解,同时也复习了前面 ...

  7. FCN网络的训练——以燃气表数字识别为例

    原文http://blog.csdn.net/hduxiejun/article/details/54234766 FCN网络的训练--以燃气表数字识别为例 目录 用 [TOC]来生成目录: FCN网 ...

  8. 深度学习算法优化系列十八 | TensorRT Mnist数字识别使用示例

    1. 前言 上一节对TensorRT做了介绍,然后科普了TensorRT优化方式以及讲解在Windows下如何安装TensorRT6.0,最后还介绍了如何编译一个官方给出的手写数字识别例子获得一个正确 ...

  9. Windows下使用Tesseract进行OCR文字识别

    Windows下使用Tesseract进行OCR文字识别 Tesseract最初由惠普实验室支持,用于电子版文字识别,1996年被移植到Windows上,1998年进行了C++化,在2005年Tess ...

最新文章

  1. Acwing900. 整数划分[计数类dp]:完全背包解法
  2. 看看阿里的考核尺度, 阿里人工资高是有原因的
  3. python模块之collections模块
  4. AddressBookUI.Framwork应用之ABPersonViewController, ABUnknownPersonViewController,ABNewPersonViewContro
  5. vue 发送ajax请求
  6. java怎么处理ajax请求,java怎么用ajax请求?jquery ajax请求后台的简单例子
  7. c mysql 编译_MySQL编译安装之cmake
  8. css的属性是变量是怎么表达,CSS自定义属性(变量)
  9. 单细胞测序分析之小技巧之for循环批量处理数据和出图
  10. java 监听者模式有啥用,监听者模式在系统中的应用 —— 事件总线
  11. 通过servlet来实现对Mysql进行连接、插入、修改、删除操作
  12. HDU 1754 I Hate It 基础线段树
  13. 国际象棋小麦python_python图形工具turtle绘制国际象棋棋盘
  14. win2008安装mysql8.0
  15. 宝塔利用同一个ip的不同端口号架设多个网站
  16. 服务器系统详细安装步骤
  17. mybatiplus的apply_mybatis-plus入门
  18. 小知识--电脑的快捷键
  19. 行列式的定义及简单计算
  20. 破解不加微信看朋友圈

热门文章

  1. paypal国际支付的对接,使用tp5开发paypal
  2. 雪碧图 以及 渐变色
  3. zedboard如何从PL端控制DDR读写(五)
  4. 区块链行业发展势如破竹 未来区块链金融值得瞩目
  5. 时间作为执行者的用例有前置条件吗
  6. Network and Distributed System Security (NDSS) Symposium 2017
  7. 初中学历的 00 后程序员,未来怎么办?
  8. 约瑟夫环问题,n个人围成一圈,依次按1、2.....m来报数,报数值为m的人出圈,求最后出圈的人和出圈的序列
  9. python中pyecharts绘制地图
  10. 对渗透新人的几点建议