Tensorflow-MNIST代码解析
MNIST是深度学习的入门demo,由6万张训练图片和1万张测试图片构成(数据集下载地址:https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/),每张图片都是28*28大小,而且都是黑白两色,这些图片是采集的不同的人手写从0到9的数字。TensorFlow将这个数据集和相关操作封装到了库中,下面为训练及效果评估代码。
import sys from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import numpy as np import matplotlib.pyplot as plt from pylab import mpl #设置plt显示中文字体,避免乱码 mpl.rcParams['font.sans-serif']=['Microsoft YaHei'] mpl.rcParams['axes.unicode_minus'] = False #读取训练集 mnist = input_data.read_data_sets("d:/share/MNIST_data/", one_hot=True) # x是特征值,1X784的一维向量 x = tf.placeholder(tf.float32, [None, 784]) # w表示每一个特征值(像素点)会影响结果的权重 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) #y是预测值,包含10个元素的数组 y = tf.matmul(x, W) + b # y_是图片实际对应的值,包含10个元素的0/1数组,1代表对应的index数字 y_ = tf.placeholder(tf.float32, [None, 10] ) cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() # mnist.train 训练数据,每次提取100张,循环6000次 for _ in range(6000):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})# 取得y的最大概率对应的数组索引来和y_的数组索引对比,如果索引相同,则表示预测正确 correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
用测试集来评估准确性
print(sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))
逐个测试图片查看其预测结果,将预测不准确的结果统计输出 pre_act=[] for i in range(0, len(mnist.test.images)):result = sess.run(correct_prediction,feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})if not result:pre=sess.run(y, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})pre_list=max(pre.tolist())m=max(pre_list)pre_value=pre_list.index(m)# print('预测的值是:', pre_value)actual=sess.run(y_, feed_dict={x: np.array([mnist.test.images[i]]), y_: np.array([mnist.test.labels[i]])})actual_list=max(actual.tolist())actual_value=actual_list.index(1)#将预测值和实际值组合添加到数组保存pa='预测的值是:'+str(pre_value)+","+'实际的值是:'+str(actual_value)pre_act.append(pa)# print('实际的值是:', actual_value)display='预测的值是:'+str(pre_value)+'实际的值是:'+str(actual_value)#显示预测错误图片# one_pic_arr = np.reshape(mnist.test.images [i], (28, 28))# pic_matrix = np.matrix(one_pic_arr, dtype="float")# plt.imshow(pic_matrix)# plt.title(display)# plt.savefig('pic_matrix')# plt.show()# break #打印数组查看哪些测试图片预测错误及其真实值 print(pre_act) print("预测错误数量为:"+str(len(pre_act))+"测试数据集为:"+str(len(mnist.test.images)))
可以看到准确率只有0.9257,使用CNN准确率可以达到0.97以上。
Tensorflow-MNIST代码解析相关推荐
- Tensorflow YOLO代码解析(1)
YOLO (You Only Look Once:Unified,Real-Time Object Detection) 提出了一种实时端到端的目标检测算法,之前写过一份关于YOLO论文的解读,可供参 ...
- Graph Attention Network (GAT) 的Tensorflow版代码解析
文章目录 代码结构 参数设置 数据加载 特征预处理 模型定义 GAT核心定义:layers.py gat.py base_gattn.py 关于GAT的基本原理解析可查看另一篇博客: Graph At ...
- 代码解析深度学习系统编程模型:TensorFlow vs. CNTK
from: http://geek.csdn.net/news/detail/62429 CNTK是微软用于搭建深度神经网络的计算网络工具包,此项目已在Github上开源.因为我最近写了关于Tenso ...
- 【神经网络】(12) MobileNetV2 代码复现,网络解析,附Tensorflow完整代码
各位同学好,今天和大家分享一下如何使用 Tensorflow 复现谷歌轻量化神经网络 MobileNetV2. 在上一篇中我介绍了MobileNetV1,探讨了深度可分离卷积,感兴趣的可以看一下:ht ...
- Tensorflow 代码解析
Tensorflow 代码解析(一) Tensorflow 代码解析(二) Tensorflow 代码解析(三) Tensorflow 代码解析(四) Tensorflow 代码解析(五)
- First Steps with TensorFlow代码解析
注:本文的内容基本上都摘自tensorflow的官网,只不过官网中的这部分内容在国内访问不了,所以我只是当做一个知识的搬运工,同时梳理了一遍,方便大家查看.本文相关内容地址如下: https://de ...
- TensorFlow 概念的解析(i.e. 缘由)
TensorFlow 基本概念解析 -- TensorFlow Machine Learning Cookbook TensorFlow 中的基本概念解析 声明张量 tensorflow 中的主要数据 ...
- 【Python3】文本分类综合(rnn,cnn,word2vec,TfidfVectorizer),中文纠错代码解析(pycorrector)
文章目录 1.中文评论情感分析(keras+rnn) 1.1 需要的库 1.2 预训练词向量 1.3 词向量模型 1.4 训练语料 (数据集) 1.5 分词和tokenize 1.6 索引长度标准化 ...
- EfficientNet-V2 论文以及代码解析
参看视频:论文解析 论文地址:论文地址 源码地址:tensorflow官方源码 pytorch代码地址:pytorch代码 这里的代码解析参考的是博主噼里啪啦的源码,下面对EfficientNetV2 ...
- DeepLearning | 图注意力网络Graph Attention Network(GAT)论文、模型、代码解析
本篇博客是对论文 Velikovi, Petar, Cucurull, Guillem, Casanova, Arantxa,et al. Graph Attention Networks, 2018 ...
最新文章
- 转:Flutter Decoration背景设定(边框、圆角、阴影、形状、渐变、背景图像等)...
- 统计简单学_回归分析
- 9.逆向-函数调用约定
- 移动app部分机型无法唤起h5支付宝支付_用这段代码对App说:喂,醒醒!App,到你出场了!...
- SHD0新建屏幕变式
- Python找出某元素的索引下标
- mysql查询条件是小数 查不到6.28_28.mysql数据库之查询
- java mysql 常见框架_Java岗面试重点:Java+JVM+MySQL+框架+算法,金九银十涨薪全靠它...
- pandas按照字典格式替换dataframe的值
- bzoj 1651: [Usaco2006 Feb]Stall Reservations 专用牛棚【贪心+堆||差分】
- [数据分析工具] Pandas 不可不知的功能(一)
- Android自定义processor实现bindView功能
- 卸载java_Java面试必备——类的加载过程
- Android Binder机制介绍
- 几何图形识别 python_pygame能识别简单的几何图形吗?
- 服务器ip每天自动更换,IP地址经常更换,自动获取的IP上不了网怎么办?
- Android Beacon开发
- WPS新建文字分享微信.docx形式_高效神器:花 5 分钟输入文字,就能自动变成 PPT...
- 游戏建模控件Aspose.3D for Java最新版支持在Wavefront OBJ中添加点云支持
- 同花顺_代码解析_技术指标_T、U
热门文章
- 博学谷学习记录】超强总结,用心分享 | 架构师 JVM内核调优学习总结
- (11/∞)每日一练{1.将一张100元钞票换成等值的10元,5元,2元和1元的小钞,每次换成40张小钞,要求每一种小钞都要有,编程求出所有可能的换法总数输出并输出各换法的组合。}
- 2009 广联达软件笔试题目
- Introduction to 3D Game Programming with DirectX 12 学习笔记之 --- 第十三章:计算着色器(The Compute Shader)...
- flash开发中记录集锦
- 最省最小最牛软件推荐(每日更新)
- Xilinx Vivado和SDK安装
- LINUX下 Udev详解
- 银河麒麟下安装sshd服务(联网)
- java线程池面试题_java之线程池面试题