之前构建的模型在MNIST上只有91%的正确率,有点低,我们尝试一下使用卷积神经网络来改善效果。如果您不是很清楚什么是卷积神经网络的话,可以参考我的这篇文章:链接。

权重初始化

在创建模型之前,我们先来创建权重和偏置。一般来说,初始化时应加入轻微噪声,来打破对称性,防止零梯度的问题。因为我们用的是ReLu激活函数,所以用稍微大于0的值来初始化偏置能够避免节点输出恒为0的问题(dead neurons)。为了不在建立模型的时候反复做初始化操作,我们定义两个函数用于初始化。

def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1) # 从截断的正态分布中输出随机值return tf.Variable(initial)
def bias_variable(shape):initial = tf.constant(0.1, shape=shape )return tf.Variable(initial)

卷积和池化

Tensorflow在卷积和池化上有很强的灵活性。我们如何处理边界,步长设置多大什么的。在这里我们的卷积使用步长(stride size)为1,边距(padding size)为0的模板,保证输出和输入是同一个大小。我们的池化用简单传统的2*2大小的模板做max pooling。为了代码更简洁,我们把这部分抽象成一个函数。

def conv2d(x, w):return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding="SAME")
def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2,1], padding="SAME")

strides里面的每一量对应Wx上的移动步长,比如strides = [1, 2, 3, 4]批,每次移动batch的个数是1;每次移动in_height的数目是2;每次移动in_width的数目是3;每次移动in_channels的数目是4。当然,每次只应该移动一个量。注意,batch和in_channels一般每次只会移动1。所以一般形式是strides = [1, stride, stride, 1]。得到的结果,包括四个维度[batch, in_height, in_width, in_channels]ksize指对x的四个维度做池化时的大小。如ksize=[1, 2, 2, 1],池化的模板的每次一个batch,一个channel,长为2,宽为2。

padding可以用SAMEVALID两种方式:对于VALID,输出的形状计算如下:

对于SAME,输出的形状计算如下:

现在我们可以开始实现第一层了。它由一个卷积核接一个max pooling完成。卷积在每个5*5的patch中算出32个特征。权重是一个[5, 5, 1, 32]的张量,前面两个维度代表的是patch的大小,接着是输入的通道数,最后输出的是通道数目。输出对应一个同样大小的偏置向量。

w_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])

为了用这一层,我们把我们的输入图片x,变成一个4d向量,第2,3维对应图片的宽高,最后一维代表颜色通道,-1表示,它的大小信息由其它几组值确定。

x_image = tf.reshape(x, [-1, 28, 28, 1])

我们把x_image和权值向量进行卷积相乘,加上偏置,使用ReLu激活函数,最后max pooling。

h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(

为了构建一个更深的网络,我们会把几个类似的层堆叠起来,第二层卷积中,我们采用5*5的卷积核,希望得到64个特征。

w_conv2 = weight_variable([5, 5, 32, 64])
b_conv1 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv1)
h_pool2 = max_pool_2x2(h_conv2)

密集连接层

28/2/2=7,现在图片降维到7x7,我们加入一个有1024个神经元的全连接层,用于处理整个图片。我们把池化层输出的张量reshape成一些向量,乘上权重矩阵,加上偏置,使用ReLu激活。

w_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

Dropout

为了减少过拟合,我们在输出层之前加入dropout。我们用一个placeholder来代表一个神经元在dropout中被保留的概率。这样我们可以在训练过程中启用dropout,在测试过程中关闭dropout。Tensorflow的tf.nn.dropout操作会自动处理神经元输出值的scale。所以用dropout的时候不用考虑scale。

keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

最后我们在输出层添加softmax函数。

w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

接下来我们需要对其进行训练和测试:

cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
correct_prediction = tf.equal(tf.arg_max(y_conv, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())
for i in range(20000):batch = mnist.train.next_batch(50)if i % 100 == 0:train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0})print("step %d, train_accuracy %g" % (i, train_accuracy))train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5})
print("test_accuracy %g" % accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

全部代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1) # 从截断的正态分布中输出随机值return tf.Variable(initial)
def bias_variable(shape):initial = tf.constant(0.1, shape=shape )return tf.Variable(initial)
def conv2d(x, w):return tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding="SAME")
def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2,1], padding="SAME")
w_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 28, 28, 1])
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)w_conv2 = weight_variable([5, 5, 32, 64])
b_conv1 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv1)
h_pool2 = max_pool_2x2(h_conv2)w_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
correct_prediction = tf.equal(tf.arg_max(y_conv, 1), tf.arg_max(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.global_variables_initializer())、
Saver = tf.train.Saver()
try:Saver.restore(sess, tf.train.latest_checkpoint("E://code_of_ocr/MNIST_CNN_TENSORFLOW/network_model"))print('success add the model')
except:sess.run(tf.global_variables_initializer())print('error of add the model')
for i in range(20000):batch = mnist.train.next_batch(50)if i % 100 == 0:train_accuracy = accuracy.eval(feed_dict={x:batch[0], y_:batch[1], keep_prob:1.0})print("step %d, train_accuracy %g" % (i, train_accuracy))train_step.run(feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5})
Saver.save(sess, "E://code_of_ocr/MNIST_CNN_TENSORFLOW/network_model/crack_capcha.model")
print("test_accuracy %g" % accuracy.eval(feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

输出结果如下:

参考:

https://blog.csdn.net/CY_TEC/article/details/52082647

https://blog.csdn.net/wuzqChom/article/details/74785643

我的微信公众号名称:深度学习与先进智能决策
微信公众号ID:MultiAgent1024
公众号介绍:主要研究强化学习、计算机视觉、深度学习、机器学习等相关内容,分享学习过程中的学习笔记和心得!期待您的关注,欢迎一起学习交流进步!

Tensorflow官方文档学习理解 (五)-卷积MNIST相关推荐

  1. 每天一小时python官方文档学习(五)————数据结构之元组、集合与字典

    昨天介绍完了最常用的列表,之后就是次常用的元组.集合与字典了. 5.3. 元组和序列 元组和之前讲过的列表有很多共同特性,例如索引和切片操作.实际上,他们是 序列 数据类型(list, tuple, ...

  2. tensorflow学习笔记十7:tensorflow官方文档学习 How to Retrain Inception's Final Layer for New Categories

    现代物体识别模型有数以百万计的参数,可能需要数周才能完全训练.学习迁移是一个捷径,很多这样的工作,以充分的训练模式的一组类ImageNet技术,并从现有的权重进行新课.在这个例子中,我们将从头再训练最 ...

  3. TensorFlow官方文档中的sub 和mul中的函数已经在API中改名了

    照着tensorflow 官方文档学习tensorflow时,出现问题: 第一,执行程序 #进入一个交互式Tensorflow会话 import tensorflow as tf sess = tf. ...

  4. TensorFlow 官方文档中文版发布啦(持续维护)

    TensorFlow 是 Google 研发的第二代人工智能学习系统,是 Google 为了帮助全球开发者们更加方便和高效地开发机器学习 (Machine Learning)和人工智能 (AI) 应用 ...

  5. Spring Boot 官方文档学习(一)入门及使用

    Spring Boot 官方文档学习(一)入门及使用 个人说明:本文内容都是从为知笔记上复制过来的,样式难免走样,以后再修改吧.另外,本文可以看作官方文档的选择性的翻译(大部分),以及个人使用经验及问 ...

  6. tensorflow官方文档_开源分享:最好的TensorFlow入门教程

    如果一门技术的学习曲线过于陡峭,那么我们在入门时的场景往往是,一鼓作气,没入门,再而衰,三而竭.演绎一出从入门到放弃的败走麦城. 今天发现一个入门TensorFlow的宝藏,迫不及待的分享给大家.这个 ...

  7. TensorFlow 官方文档中文版发布啦(持续维护) 1

    TensorFlow 是 Google 研发的第二代人工智能学习系统,是 Google 为了帮助全球开发者们更加方便和高效地开发机器学习 (Machine Learning)和人工智能 (AI) 应用 ...

  8. ZooKeeper官方文档学习笔记03-程序员指南03

    我的每一篇这种正经文章,都是我努力克制玩心的成果,我可太难了,和自己做斗争. ZooKeeper官方文档学习笔记04-程序员指南03 绑定 Java绑定 客户端配置参数 C绑定 陷阱: 常见问题及故障 ...

  9. ZooKeeper官方文档学习笔记01-zookeeper概述

    纠结了很久,我决定用官方文档学习 ZooKeeper概述 学习文档 学习计划 ZooKeeper:分布式应用程序的分布式协调服务 设计目标 数据模型和分层名称空间 节点和短命节点 有条件的更新和监视 ...

  10. R语言reshape2包-官方文档学习

    R语言reshape2包-官方文档学习 简介 核心函数 长数据与宽数据 宽数据 长数据 melt函数 meltarray meltdataframe meltdefault meltlist cast ...

最新文章

  1. LeetCode Text Justification(贪心)
  2. input python2.7_python 中的input
  3. 用户注册,用邮箱来验证用户是否存在
  4. dijkstra最短路径算法视频_java实现Dijkstra算法求最短路径
  5. python socket清空接受区_原始Python服务器
  6. 推荐系统的封闭和禁锢问题
  7. 生成 excel 直接用 httpServletResponse 输出
  8. python的属性访问,python:如何访问函数的属性
  9. C#编写程序操作数据库如何防止SQL注入漏洞的发生
  10. 虚拟机 linux 盘分小了,增加虚拟机硬盘分区大小
  11. JLU数据结构第七次上机实验解题报告
  12. 简单详细叙述FpGrowth算法思想(附python源码实现)
  13. 用 ANSYS/LS-DYNA 进行显式动力学仿真计算 (转帖,有修改)
  14. 机器视觉检测技术之颜色视觉工具应用
  15. NVIDIA GPU 运算能力列表
  16. 利用线性回归进行销售预测
  17. c c python的区别_python版本的区别 Cpython Jython pypy ?
  18. ubuntu16.04+Titan Xp的驱动官网上找不到
  19. win7 共享打印机后,客户端连接提示:打印机已删除(0x00000709)
  20. 几乎每个人都听说过三皇五帝,那么三皇五帝是否存在?又是谁呢?

热门文章

  1. mysql在安全模式下备份_win10安全模式下怎么备份数据?
  2. 二十三、K8s集群强化1-认证
  3. VRRP实现AC双机备份原理详解与配置实例
  4. [C语言循环应用]--打印字符金字塔
  5. 实现二叉树的遍历(递归与非递归)
  6. NYOJ--12--喷水装置(二)
  7. Solidity陷阱:以太坊的随机数生成
  8. 漫画 | Redis常见面试问题
  9. eclipse(Kepler Service Release 2)问题记录
  10. redis 条件查询