MNIST是深度学习的入门demo,由6万张训练图片和1万张测试图片构成(数据集下载地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),每张图片都是28*28大小,而且都是黑白两色,这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,下面为训练及效果评估代码。

import sys
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import matplotlib.pyplot as plt
from pylab import mpl
#设置plt显示中文字体,避免乱码
mpl.rcParams['font.sans-serif']=['Microsoft YaHei']
mpl.rcParams['axes.unicode_minus'] = False
#读取训练集
mnist = input_data.read_data_sets("d:/share/MNIST_data/", one_hot=True)
# x是特征值,1X784的一维向量
x = tf.placeholder(tf.float32, [None, 784])
# w表示每一个特征值(像素点)会影响结果的权重
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
#y是预测值,包含10个元素的数组
y = tf.matmul(x, W) + b
# y_是图片实际对应的值,包含10个元素的0/1数组,1代表对应的index数字
y_ = tf.placeholder(tf.float32, [None, 10] )
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
# mnist.train 训练数据,每次提取100张,循环6000次
for _ in range(6000):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})# 取得y的最大概率对应的数组索引来和y_的数组索引对比,如果索引相同,则表示预测正确
correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

用测试集来评估准确性

 print(sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))
逐个测试图片查看其预测结果,将预测不准确的结果统计输出
pre_act=[]
for i in range(0, len(mnist.test.images)):result = sess.run(correct_prediction,feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})if not result:pre=sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})pre_list=max(pre.tolist())m=max(pre_list)pre_value=pre_list.index(m)# print('预测的值是:', pre_value)actual=sess.run(y_, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})actual_list=max(actual.tolist())actual_value=actual_list.index(1)#将预测值和实际值组合添加到数组保存pa='预测的值是:'+str(pre_value)+","+'实际的值是:'+str(actual_value)pre_act.append(pa)# print('实际的值是:', actual_value)display='预测的值是:'+str(pre_value)+'实际的值是:'+str(actual_value)#显示预测错误图片# one_pic_arr = np.reshape(mnist.test.images [i], (28, 28))# pic_matrix = np.matrix(one_pic_arr, dtype="float")# plt.imshow(pic_matrix)# plt.title(display)# plt.savefig('pic_matrix')# plt.show()# break
#打印数组查看哪些测试图片预测错误及其真实值
print(pre_act)
print("预测错误数量为:"+str(len(pre_act))+"测试数据集为:"+str(len(mnist.test.images)))

可以看到准确率只有0.9257,使用CNN准确率可以达到0.97以上。

Tensorflow-MNIST代码解析相关推荐

  1. Tensorflow YOLO代码解析(1)

    YOLO (You Only Look Once:Unified,Real-Time Object Detection) 提出了一种实时端到端的目标检测算法,之前写过一份关于YOLO论文的解读,可供参 ...

  2. Graph Attention Network (GAT) 的Tensorflow版代码解析

    文章目录 代码结构 参数设置 数据加载 特征预处理 模型定义 GAT核心定义:layers.py gat.py base_gattn.py 关于GAT的基本原理解析可查看另一篇博客: Graph At ...

  3. 代码解析深度学习系统编程模型:TensorFlow vs. CNTK

    from: http://geek.csdn.net/news/detail/62429 CNTK是微软用于搭建深度神经网络的计算网络工具包,此项目已在Github上开源.因为我最近写了关于Tenso ...

  4. 【神经网络】(12) MobileNetV2 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 复现谷歌轻量化神经网络 MobileNetV2. 在上一篇中我介绍了MobileNetV1,探讨了深度可分离卷积,感兴趣的可以看一下:ht ...

  5. Tensorflow 代码解析

    Tensorflow 代码解析(一) Tensorflow 代码解析(二) Tensorflow 代码解析(三) Tensorflow 代码解析(四) Tensorflow 代码解析(五)

  6. First Steps with TensorFlow代码解析

    注:本文的内容基本上都摘自tensorflow的官网,只不过官网中的这部分内容在国内访问不了,所以我只是当做一个知识的搬运工,同时梳理了一遍,方便大家查看.本文相关内容地址如下: https://de ...

  7. TensorFlow 概念的解析(i.e. 缘由)

    TensorFlow 基本概念解析 -- TensorFlow Machine Learning Cookbook TensorFlow 中的基本概念解析 声明张量 tensorflow 中的主要数据 ...

  8. 【Python3】文本分类综合(rnn,cnn,word2vec,TfidfVectorizer),中文纠错代码解析(pycorrector)

    文章目录 1.中文评论情感分析(keras+rnn) 1.1 需要的库 1.2 预训练词向量 1.3 词向量模型 1.4 训练语料 (数据集) 1.5 分词和tokenize 1.6 索引长度标准化 ...

  9. EfficientNet-V2 论文以及代码解析

    参看视频:论文解析 论文地址:论文地址 源码地址:tensorflow官方源码 pytorch代码地址:pytorch代码 这里的代码解析参考的是博主噼里啪啦的源码,下面对EfficientNetV2 ...

  10. DeepLearning | 图注意力网络Graph Attention Network(GAT)论文、模型、代码解析

    本篇博客是对论文 Velikovi, Petar, Cucurull, Guillem, Casanova, Arantxa,et al. Graph Attention Networks, 2018 ...

最新文章

  1. 转:Flutter Decoration背景设定(边框、圆角、阴影、形状、渐变、背景图像等)...
  2. 统计简单学_回归分析
  3. 9.逆向-函数调用约定
  4. 移动app部分机型无法唤起h5支付宝支付_用这段代码对App说:喂,醒醒!App,到你出场了!...
  5. SHD0新建屏幕变式
  6. Python找出某元素的索引下标
  7. mysql查询条件是小数 查不到6.28_28.mysql数据库之查询
  8. java mysql 常见框架_Java岗面试重点:Java+JVM+MySQL+框架+算法,金九银十涨薪全靠它...
  9. pandas按照字典格式替换dataframe的值
  10. bzoj 1651: [Usaco2006 Feb]Stall Reservations 专用牛棚【贪心+堆||差分】
  11. [数据分析工具] Pandas 不可不知的功能(一)
  12. Android自定义processor实现bindView功能
  13. 卸载java_Java面试必备——类的加载过程
  14. Android Binder机制介绍
  15. 几何图形识别 python_pygame能识别简单的几何图形吗?
  16. 服务器ip每天自动更换,IP地址经常更换,自动获取的IP上不了网怎么办?
  17. Android Beacon开发
  18. WPS新建文字分享微信.docx形式_高效神器:花 5 分钟输入文字,就能自动变成 PPT...
  19. 游戏建模控件Aspose.3D for Java最新版支持在Wavefront OBJ中添加点云支持
  20. 同花顺_代码解析_技术指标_T、U

热门文章

  1. 博学谷学习记录】超强总结,用心分享 | 架构师 JVM内核调优学习总结
  2. (11/∞)每日一练{1.将一张100元钞票换成等值的10元,5元,2元和1元的小钞,每次换成40张小钞,要求每一种小钞都要有,编程求出所有可能的换法总数输出并输出各换法的组合。}
  3. 2009 广联达软件笔试题目
  4. Introduction to 3D Game Programming with DirectX 12 学习笔记之 --- 第十三章:计算着色器(The Compute Shader)...
  5. flash开发中记录集锦
  6. 最省最小最牛软件推荐(每日更新)
  7. Xilinx Vivado和SDK安装
  8. LINUX下 Udev详解
  9. 银河麒麟下安装sshd服务(联网)
  10. java线程池面试题_java之线程池面试题