感想
首先我是首先看了一下莫凡pyhton教程中tensorflow python搭建自己的神经网络教程以及查看了官方的教程TensorFlow中文社区—MNIST进阶教程,这里面只是有简单的测试出来模型的准确率,如何使用这个模型去测试一张图片,并没有教程。做了这个呢希望可以帮助各位。
介绍一下Mnist手写字体
Mnist手写字体的官网是Yann LeCun’s website,下载后的实验数据集。具体的介绍可以参考TensorFlow中文社区和Mnist手写字体官网。我只想提一句就是数据是用行保存的。
训练cnn模型
具体的代码详解直接参考TensorFlow中文社区—MNIST进阶教程,代码多读几遍有好处。

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)import tensorflow as tfsess = tf.InteractiveSession()def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
x_image = tf.reshape(x, [-1,28,28,1])#第一层卷积层
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)#第二层卷积层
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)#第一层全连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)#输出层
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)#训练
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))sess.run(tf.initialize_all_variables())saver = tf.train.Saver()for i in range(2000):batch = mnist.train.next_batch(50)if i%100 == 0:train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_: batch[1], keep_prob: 1.0})print("step %d, training accuracy %g"%(i, train_accuracy))train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})save_path = saver.save(sess,'my_net/model.ckpt') print("test accuracy %g"%accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

使用Python将MNIST数据集转化为图片
首先将MNIST_data下的train-images-idx3-ubyte.gz和train-labels-idx1-ubyte.gz这两个东西解压。然后用python代码将其转换为图片。

import numpy as np
import struct  from PIL import Image
import os  data_file = 'somePath/train-images.idx3-ubyte' #需要修改的路径
# It's 47040016B, but we should set to 47040000B
data_file_size = 47040016
data_file_size = str(data_file_size - 16) + 'B'  data_buf = open(data_file, 'rb').read()  magic, numImages, numRows, numColumns = struct.unpack_from(  '>IIII', data_buf, 0)
datas = struct.unpack_from(  '>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(  numImages, 1, numRows, numColumns)  label_file = 'somePath/train-labels.idx1-ubyte' #需要修改的路径  # It's 60008B, but we should set to 60000B
label_file_size = 60008
label_file_size = str(label_file_size - 8) + 'B'  label_buf = open(label_file, 'rb').read()  magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from(  '>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)  datas_root = '/somePath/mnist_train' #需要修改的路径
if not os.path.exists(datas_root):  os.mkdir(datas_root)  for i in range(10):  file_name = datas_root + os.sep + str(i)  if not os.path.exists(file_name):  os.mkdir(file_name)  for ii in range(numLabels):  img = Image.fromarray(datas[ii, 0, 0:28, 0:28])  label = labels[ii]  file_name = datas_root + os.sep + str(label) + os.sep + \  'mnist_train_' + str(ii) + '.png'  img.save(file_name)  

模型测试
随便选择一张图片进行测试。


from PIL import Image, ImageFilter
import tensorflow as tf
import matplotlib.pyplot as pltdef imageprepare():"""This function returns the pixel values.The imput is a png file location."""file_name='mnist_train_7.png'#导入自己的图片地址#in terminal 'mogrify -format png *.jpg' convert jpg to pngim = Image.open(file_name).convert('L')plt.imshow(im)plt.show()tv = list(im.getdata()) #get pixel values#normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.tva = [ (255-x)*1.0/255.0 for x in tv] print(tva)return tva"""This function returns the predicted integer.The imput is the pixel values from the imageprepare() function."""# Define the model (same as when creating the model file)
result = imageprepare()def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)return tf.Variable(initial)def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],strides=[1, 2, 2, 1], padding='SAME')x = tf.placeholder("float", shape=[None, 784])
x_image = tf.reshape(x, [-1,28,28,1])#第一层卷积层
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)#第二层卷积层
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)#第一层全连接层
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)#第二层输出层
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)init = tf.initialize_all_variables()saver = tf.train.Saver()with tf.Session() as sess:sess.run(init)saver.restore(sess,'my_net/model.ckpt')prediction=tf.argmax(y_conv,1)predint=prediction.eval(feed_dict={x: [result],keep_prob: 1.0})print('recognize result:')print(predint[0])    

虽然基本上也已经完成了测试,但是发现有时测试不准确。主要原因是imageprepare()这个function并没有将图像的像素的值完全变为0或者1。在此记录一下。

参考
1.(http://blog.csdn.net/sparta_117/article/details/66965760)
2.(http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html).

用tensorflow框架和Mnist手写字体,训练cnn模型以及测试一张手写字体相关推荐

  1. Tensorflow框架:卷积神经网络实战--Cifar训练集

    Cifar-10数据集包含10类共60000张32*32的彩色图片,每类6000张图.包括50000张训练图片和 10000张测试图片 代码分为数据处理部分和卷积网络训练部分: 数据处理部分: #该文 ...

  2. 谷歌推出量子机器学习框架TFQ-TensorFlow Quantum,一个可训练量子模型的机器学习框架...

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 编辑:Sophia 计算机视觉联盟  报道  | 公众号 CVLianMeng 转载于 :专知,谷歌 AI博士笔记系 ...

  3. 对于CNN的文献阅读和识别手写数字的复现

    摘要 一.文献阅读 1.题目 2.摘要 3.引言 4.CNN模型结构 5.实验过程 6.同GS算法的对比 二.CNN识别手写数字 1.两个性质 2.图像卷积 总结 摘要 在论文方面阅读了基于CNN网络 ...

  4. 基于TensorFlow的CNN模型——猫狗分类识别器(五)之训练和评估CNN模型

    注意:这是一个完整的项目,建议您按照完整的博客顺序阅读. 目录 三.训练和优化CNN模型 1.搭建训练主循环 2.训练时间的记录 3.早期终止机制 4.训练数据的可视化 5.训练数据的保存与加载 四. ...

  5. 【记录】本科毕设:基于树莓派的智能小车设计(使用Tensorflow + Keras 搭建CNN卷积神经网络 使用端到端的学习方法训练CNN)

    0 申明 这是本人2020年的本科毕业设计,内容多为毕设论文和答辩内容中挑选.最初的灵感来自于早前看过的一些项目(抱歉时间久远,只记录了这一个,见下),才让我萌生了做个机电(小车动力与驱动)和控制(树 ...

  6. python识别手写数字字体_基于tensorflow框架对手写字体MNIST数据集的识别

    本文我们利用python语言,通过tensorflow框架对手写字体MNIST数据库进行识别. 学习每一门语言都有一个"Hello World"程序,而对数字手写体数据库MNIST ...

  7. 【TensorFlow篇】--Tensorflow框架实现SoftMax模型识别手写数字集

    一.前述 本文讲述用Tensorflow框架实现SoftMax模型识别手写数字集,来实现多分类. 同时对模型的保存和恢复做下示例. 二.具体原理 代码一:实现代码 #!/usr/bin/python ...

  8. 基于tensorflow+RNN的MNIST数据集手写数字分类

    2018年9月25日笔记 tensorflow是谷歌google的深度学习框架,tensor中文叫做张量,flow叫做流. RNN是recurrent neural network的简称,中文叫做循环 ...

  9. Python(TensorFlow框架)实现手写数字识别系统

    手写数字识别算法的设计与实现 本文使用python基于TensorFlow设计手写数字识别算法,并编程实现GUI界面,构建手写数字识别系统.这是本人的本科毕业论文课题,当然,这个也是机器学习的基本问题 ...

最新文章

  1. enterprise portal
  2. Windows下通过MinGW进行WxWidgets的动态编译与静态编译
  3. Vue + Element UI——搜索框DEMO
  4. ARM-Button-Driver-硬件图
  5. 51单片机智能小车循迹完整程序_电气与信息工程学院双创协会开展循迹小车培训...
  6. php 字符串 大括号,PHP中的字符串大括号
  7. vc通过ADO连接sql server 2000的核心代码
  8. 使用mysqldump进行逻辑备份
  9. tcp协议报文和三次握手与四次挥手
  10. android 带箭头的框,带有工具提示箭头的Android PopupWindow
  11. 无法访问 函数不正确
  12. 安卓初学者笔记(四):用白话讲明白Activity是什么
  13. 跟极限编程创始人Kent Beck学编程
  14. 恢复训练记录20210809
  15. 有哪些老鸟程序员知道而新手不知道的小技巧?
  16. 互联网下半场新征程启航,AI、大数据等前沿科技助力传统零售产业转型
  17. RecyclerView源码学习笔记(一)构造函数和setLayoutManager方法
  18. Office2013 图标显示不正常的解决办法
  19. 服务器上搭建git仓库
  20. 信息流广告的核心是什么(信息流推广的核心操作和优化思路)

热门文章

  1. 安装 SQL Server 2005 时出现性能计数器要求安装错误的解决办法
  2. 跟着例子一步步学习redux+react-redux[转载]
  3. java--tomcat
  4. Ubuntu runlevel修改
  5. solaris 关闭、释放socket端口
  6. 【elasticsearch】ES启动报错 uncaught exception in thread [main]org.elasticsearch.bootstrap.Startup
  7. 软件测试--selenium脚本编写注意点(二)
  8. 基于 jmeter 的分布式性能测试实战
  9. selenium 环境搭建
  10. python概率论_概率论中常见分布总结以及python的scipy库使用