上篇简单的Tensorflow解决MNIST手写体数字识别可扩展性并不好。例如计算前向传播的函数需要将所有的变量都传入,当神经网络的结构变得复杂、参数更多时,程序的可读性变得非常差。而且这种方式会导致程序中有大量的冗余代码。还有就是由于没有持久化训练好的模型。当程序退出时,训练好的模型就无法再使用了,这导致得到的模型无法被重用更严重的是神经网络模型的训练时间都比较长,如果在训练程序中程序死机了,那样没有保存训练好的中间结果会浪费大量的时间和资源。所以,在训练过程中需要每隔一段时间保存一次模型训练的中间结果。

下面的代码将训练和测试分成两个独立的程序,这可以使得每一个组件更加灵活。除了将不同功能模块分开,本节还将前向传播的过程抽象成一个单独的库函数。因为神经网络的前向传播过程在训练和测试过程中都会用到,所以通过库函数的方式使用起来既可以更加方便,又可以保证训练和测试过程中使用的前向传播方法一定是一致的。

下面的代码是重构之后的程序来解决MNIST问题。重构之后的代码会拆分为3个程序。第一个是mnist_inference.py,它定义了前向传播的过程以及神经网络中的参数。第二个是mnist_train.py,它定义了神经网络的训练过程。第三个是mnist_eval.py,它定义了测试过程。

下面的代码都是由jupyter notebook生成的。

1. mnist_inference.py

# coding: utf-8
#定义了前向传播的过程和神经网络中的参数 import tensorflow as tf
#  1. 定义神经网络结构相关的参数。
INPUT_NODE = 784 # 输入层的节点数
OUTPUT_NODE = 10# 输出层的节点数
LAYER1_NODE = 500 # 隐藏层的节点数# #### 2. 通过tf.get_variable函数来获取变量。# 通过tf. get_variable函数来获取变量:在训练神经网络时会创建这些变量,在测试时会通过保存的模型保存这些变量的取值。现在更加方便的是由于可以在
# 变量加载时将滑动平均变量重命名,所以可以直接通过同样的名字在训练时使用变量本身,而在测试时使用变量的滑动平均值。在这个函数中也会将变量的
# 正则化损失加入损失函数
def get_weight_variable(shape, regularizer): # 此处的shape为[784x500]weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))  # 变量初始化函数:tf.truncated_normal_initializer# 当给出了正则化生成函数时,将当前变量的正则化损失加入名字为losses的集合,在这里使用了add_to_collection函数将一个张量加入一个集合,而这个# 集合的名称为losses。这是自定义的集合,不在tensorflow自动管理的集合列表内if regularizer != None: tf.add_to_collection('losses', regularizer(weights))return weights# #### 3. 定义神经网络的前向传播过程。
def inference(input_tensor, regularizer):# 声明第一层神经网络的变量并完成前向传播过程with tf.variable_scope('layer1'):  # 要通过tf.get_variable获取一个已经创建的变量,需要通过 tf.variable_scope函数来生成一个上下文管理器。# 这里通过 tf.get_variable和 tf.variable没有本质的区别,因为在训练或测试中没有在同一个程序中多次调用这个函数。如果在同一个程序中多次调用# 在第一次调用后需要将reuse参数设置为trueweights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer) # 权重biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0)) # 偏置layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)  # tf.nn.relu非线性激活函数# 类似的声明第二层神经网络的变量并完成前向传播过程with tf.variable_scope('layer2'):weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))layer2 = tf.matmul(layer1, weights) + biases# 返回最后前向传播的结果return layer2# 在这段代码中定义了神经网络的前向传播算法。无论是训练还是测试,都可以直接调用此函数,而不用关心具体的神经网络结构。

2. mnist_train.py

# coding: utf-8# #### 使用定义好的前向传播过程,以下是神经网络的训练程序
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 加载mnist_inference.py中定义的常量和前向传播的函数
import mnist_inference
import os# #### 1. 定义神经网络结构相关的参数。
BATCH_SIZE = 100 # 一个训练batch中的训练数据个数。个数越小越接近随机梯度下降;数字越大时,训练越接近梯度下降
LEARNING_RATE_BASE = 0.8 # 基础的学习率
LEARNING_RATE_DECAY = 0.99 # 学习率的衰减率
REGULARIZATION_RATE = 0.0001 # 描述模型复杂度的正则化项在损失函数中的系数
TRAINING_STEPS = 30000# 训练轮数
MOVING_AVERAGE_DECAY = 0.99  # 滑动平均衰减率
# 模型保存的路径和文件名
MODEL_SAVE_PATH = "/home/lilong/desktop/ckptt/"
MODEL_NAME = "model.ckpt"# #### 2. 定义训练过程。
def train(mnist):# 定义输入输出placeholder(placeholder机制用于提供输入数据,该占位符中的数据只有在运行时才指定)x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')# 这里使用L2正则化,tf.contrib.layers.l2_regularizer会返回一个函数,这个函数可以计算一个给定参数的L2正则化项的值regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)# 直接使用mnist_inference.py中定义的前向传播函数y = mnist_inference.inference(x, regularizer)  global_step = tf.Variable(0, trainable=False)# 定义损失函数、学习率、滑动平均操作以及训练过程。# 定义指数滑动平均的类,初始化给点了衰减率0.99和控制衰减率的变量global_stepvariable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)variables_averages_op = variable_averages.apply(tf.trainable_variables()) # 定义一个更新变量滑动平均的操作# 定义交叉熵损失:因为交叉熵一般和softmax回归一起使用,所以 tf.nn.sparse_softmax_cross_entropy_with_logits函数对这两个功能进行了封装。# 这里使用该函数进行加速交叉熵的计算,第一个参数是不包括softmax层的前向传播结果。第二个参数是训练数据的正确答案,这里得到的是正确答案的# 正确编号。cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))# 计算当前batch中所有样例的交叉熵平均值cross_entropy_mean = tf.reduce_mean(cross_entropy)# 总损失等于交叉熵损失和正则化损失的和loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))# 设置指数衰减的学习率learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY,staircase=True)# 这里使用指数衰减的学习率。在minimize中传入global_step将会自动更新global_step参数,从而使学习率得到相应的更新train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)# 在训练神经网络时,每过一遍数据既需要通过反向传播来更新神经神经网络的参数,又需要更新每一个参数的滑动平均值,这里的 tf.control_dependencieswith tf.control_dependencies([train_step, variables_averages_op]):train_op = tf.no_op(name='train')# 初始化TensorFlow持久化类。saver = tf.train.Saver()with tf.Session() as sess:tf.global_variables_initializer().run()# 在训练过程中不再测试模型在验证数据上的表现,验证和测试的过程将会有一个独立的程序来完成。for i in range(TRAINING_STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})# 每1000轮保存一次模型if i % 1000 == 0:# 输出当前的训练情况。这里只输出了模型在当前训练batch上的损失函数大小,通过损失函数的大小可以大概了解训练的情况。在验证数据数据# 上的正确率会有一个单独的程序来完成。print("After %d training step(s), loss on training batch is %g." % (step, loss_value))# 保存当前的模型。这里给出了global_step参数,这样可以让每个被保存模型的文件名末尾加上训练的轮数。saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)# #### 3. 主程序入口。
def main(argv=None):# "/home/lilong/desktop/MNIST_data/"# mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)mnist = input_data.read_data_sets("/home/lilong/desktop/MNIST_data/", one_hot=True)train(mnist)if __name__ == '__main__':main()

运行结果:

Extracting /home/lilong/desktop/MNIST_data/train-images-idx3-ubyte.gz
Extracting /home/lilong/desktop/MNIST_data/train-labels-idx1-ubyte.gz
Extracting /home/lilong/desktop/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting /home/lilong/desktop/MNIST_data/t10k-labels-idx1-ubyte.gz
After 1 training step(s), loss on training batch is 3.12471.
After 1001 training step(s), loss on training batch is 0.239917.
After 2001 training step(s), loss on training batch is 0.151938.
After 3001 training step(s), loss on training batch is 0.135801.
After 4001 training step(s), loss on training batch is 0.11508.
After 5001 training step(s), loss on training batch is 0.101712.
After 6001 training step(s), loss on training batch is 0.096526.
After 7001 training step(s), loss on training batch is 0.0867542.
After 8001 training step(s), loss on training batch is 0.0778042.
After 9001 training step(s), loss on training batch is 0.0693044.
After 10001 training step(s), loss on training batch is 0.0648921.
After 11001 training step(s), loss on training batch is 0.0598342.
After 12001 training step(s), loss on training batch is 0.0602573.
After 13001 training step(s), loss on training batch is 0.0580158.
After 14001 training step(s), loss on training batch is 0.0491354.
After 15001 training step(s), loss on training batch is 0.0492541.
After 16001 training step(s), loss on training batch is 0.045001.
After 17001 training step(s), loss on training batch is 0.0457389.
After 18001 training step(s), loss on training batch is 0.0468493.
After 19001 training step(s), loss on training batch is 0.0440138.
After 20001 training step(s), loss on training batch is 0.0405837.
After 21001 training step(s), loss on training batch is 0.0393501.
After 22001 training step(s), loss on training batch is 0.0451467.
After 23001 training step(s), loss on training batch is 0.0376411.
After 24001 training step(s), loss on training batch is 0.0366882.
After 25001 training step(s), loss on training batch is 0.0394025.
After 26001 training step(s), loss on training batch is 0.0351238.
After 27001 training step(s), loss on training batch is 0.0339706.
After 28001 training step(s), loss on training batch is 0.0376363.
After 29001 training step(s), loss on training batch is 0.0388179.

3. mnist_eval.py

# coding: utf-8import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 加载mnist_inference.py和mnist_train.py中定义的常量和函数
import mnist_inference
import mnist_train# #### 1. 每10秒加载一次最新的模型
# 加载的时间间隔:每10秒加载一次新的模型,并在测试数据上测试最新模型的正确率
EVAL_INTERVAL_SECS = 10def evaluate(mnist):with tf.Graph().as_default() as g:# 定义输入输出的格式x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}# 直接通过调用封装好的函数来计算前向传播结果。因为测试时不关注正则化的值,所以这里用于计算正则化损失的函数被设置为noney = mnist_inference.inference(x, None)# 使用前向传播的结果计算正确率。如果需要对未来的样例进行分类,使用tf.argmax()就可以得到输入样例的预测类别了correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))# 通过变量重命名的方式来加载模型,这样在前向传播的过程中就不需要调用求滑动平均的函数来获取平均值了。这样就可以完全共用mnist_inference.py# 中定义的前向传播过程variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# 每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化while True:with tf.Session() as sess:#  tf.train.get_checkpoint_state函数会通过checkpoint文件自找到目录中最新的文件名ckpt = tf.train.get_checkpoint_state("/home/lilong/desktop/ckptt/") # ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。if ckpt and ckpt.model_checkpoint_path:# 加载模型saver.restore(sess, ckpt.model_checkpoint_path)# 通过文件名得到模型保存时迭代的轮数(split('/')[-1].split('-')[-1]:正则表达式)global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]accuracy_score = sess.run(accuracy, feed_dict=validate_feed)print("After %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))else:print('No checkpoint file found')returntime.sleep(EVAL_INTERVAL_SECS)# ###  主程序def main(argv=None):mnist = input_data.read_data_sets("/home/lilong/desktop/MNIST_data/", one_hot=True)evaluate(mnist)if __name__ == '__main__':main()

本测试代码会每隔10秒运行一次,每次运行都是读取最新保存的模型。并在MNIST验证数据集上计算模型的正确率。注意这里如果运行完训练程序后再单独运行该测试程序会得到如下的运行结果:

Extracting /home/lilong/desktop/MNIST_data/train-images-idx3-ubyte.gz
Extracting /home/lilong/desktop/MNIST_data/train-labels-idx1-ubyte.gz
Extracting /home/lilong/desktop/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting /home/lilong/desktop/MNIST_data/t10k-labels-idx1-ubyte.gz
INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001
After 29001 training step(s), validation accuracy = 0.9846
INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001
After 29001 training step(s), validation accuracy = 0.9846
INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001
After 29001 training step(s), validation accuracy = 0.9846
INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001
After 29001 training step(s), validation accuracy = 0.9846
INFO:tensorflow:Restoring parameters from /home/lilong/desktop/ckptt/model.ckpt-29001
After 29001 training step(s), validation accuracy = 0.9846---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-15-c2f081a58572> in <module>()5 6 if __name__ == '__main__':
----> 7     main()<ipython-input-15-c2f081a58572> in main(argv)2    # mnist = input_data.read_data_sets("../../../datasets/MNIST_data", one_hot=True)3     mnist = input_data.read_data_sets("/home/lilong/desktop/MNIST_data/", one_hot=True)
----> 4     evaluate(mnist)5 6 if __name__ == '__main__':<ipython-input-14-2c48dfb7e249> in evaluate(mnist)35                     print('No checkpoint file found')36                     return
---> 37             time.sleep(EVAL_INTERVAL_SECS)KeyboardInterrupt:

从运行结果结果看出:最新模型始终是同一个,所以这里是离线的测试,要想达到在线的效果应该在运行mnist_train.py的同时也运行mnist_eval.py。但是这里必须等到产生训练模型后再开始运行测试程序,否则会输出提示:No checkpoint file found

在线运行的效果如下:

训练模型的过程:

与此同时测试过程:

本示例中最关键的就是:

# 通过变量重命名的方式来加载模型,这样在前向传播的过程中就不需要调用求滑动平均的函数来获取平均值了。# 这样就可以完全共用mnist_inference.py中定义的前向传播过程,这里是关键。       variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)

这里才是为什么可以把训练和测试分开的原因,关于变量重命名、模型保存、重载可以参考:https://blog.csdn.net/lilong117194/article/details/81742536

《Tensorflow实战Google深度学习框架》-——5.5 最佳实践样例程序

Tensorflow 改进的MNIST手写体数字识别相关推荐

  1. Tensorflow解决MNIST手写体数字识别

    这里给出的代码是来自<Tensorflow实战Google深度学习框架>,以供参考和学习. 首先这个示例应用了几个基本的方法: 使用随机梯度下降(batch) 使用Relu激活函数去线性化 ...

  2. 基于MNIST手写体数字识别--含可直接使用代码【Python+Tensorflow+CNN+Keras】

    基于MNIST手写体数字识别--[Python+Tensorflow+CNN+Keras] 1.任务 2.数据集分析 2.1 数据集总体分析 2.2 单个图片样本可视化 3. 数据处理 4. 搭建神经 ...

  3. 基于TensorFlow卷积神经网络的手写体数字识别

    一.卷积神经网络(CNN) 二.LeNet 三.代码 1.Mnist手写体训练并测试 2.可视化 四.数据集分析 五.结果分析 1.准确率 2.可视化测试 一.卷积神经网络(CNN) 参考:https ...

  4. 全连神经网络的经典实战--MNIST手写体数字识别

    mnist数据集 MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:它也包含每一张图片对应的标签,告诉我们这个是数字几.比如,上面这四张图片的标签分别是5,0,4,1. 在本章中,我们 ...

  5. 计算机视觉:mnist手写体数字识别

    一.mnist数据描述 MNIST数据集是28×28像素的灰度手写数字图片,其中数字的范围从0到9 具体如下所示(参考自Tensorflow官方文档): 二.原理   受Hubel和Wiesel对猫视 ...

  6. 随机森林算法(RandomForest)实现MNIST手写体数字识别

    一.准备: 第三方库 sklearn 二.代码: # -*- coding: utf-8 -*- # @Time : 2018/8/21 9:35 # @Author : Barry # @File ...

  7. MNIST手写体数字识别数据集

    一.总体介绍 1.1 什么是机器识别手写数字? 1.2 MNIST数据集是什么? (1)该数据集包含60,000个用于训练的示例和10,000个用于测试的示例. (2)数据集包含了0-9共10类手写数 ...

  8. 支持向量机(SVM)实现MNIST手写体数字识别

    一.SVM算法简述 支持向量机即Support Vector Machine,简称SVM.一听这个名字,就有眩晕的感觉.支持(Support).向量(Vector).机器(Machine),这三个毫无 ...

  9. keras框架下的深度学习(一)手写体数字识别

    文章目录 前言 一.keras的介绍及其操作使用 二.手写题数字识别 1.介绍 2.对数据的预处理 3.搭建网络框架 4.编译 5.循环训练 6.测试训练的网络模 7.总代码 三.附:梯度下降算法 1 ...

最新文章

  1. opencv meanStdDev
  2. linux ssh 提示 too many authentication failures for root root的身份验证失败太多 解决办法
  3. STM32库中几个重要的文件说明
  4. 体验产品一 | 悦动圈VS咕咚竞品分析报告
  5. CRM Fiori offline技术实现:js/createStores.js
  6. 贝叶斯公式理解与应用
  7. 纯php实现中秋博饼游戏(2):掷骰子并输出结果
  8. 网利友联迈入敏捷开发新时代
  9. 享有盛誉的PHP高级教程
  10. Substrate Tutorials:Start a Private Network (multi-node)
  11. 新能源汽车应该何去何从?
  12. mysql分页语句解释,mysql语句分页limit什么意思
  13. 开发落网电台windows phone 8应用的计划(6)
  14. 基于决策树的电网负荷预测
  15. 区块链游戏为何如此火?大概是因为投机者和“韭菜”太多
  16. [JZOJ4763] 【NOIP2016提高A组模拟9.7】旷野大计算
  17. mybatis-plus模板
  18. Flex布局常用的一些属性及解释
  19. PowerDesigner介绍与使用
  20. 40句让你坦露心声的经典句子(转)

热门文章

  1. Qt Creator创建一个移动应用程序
  2. C语言包含字母的2D面板中搜索给定的单词的算法(附完整源码)
  3. C++ Opengl Fog(雾)源码
  4. java监控对话框是否关闭_java – 检查是否可以安全地关闭对话框
  5. protobuf string类型_Protobuf3 使用其他消息类型
  6. 玩转NumPy——split()函数使用详解
  7. Android:WiFi连接之一
  8. 页面伪静态化 java_UrlRewrite 伪静态化页面
  9. 06_特征选择,特征选择的原因,sklearn特征选择API
  10. nginx日志切割并使用flume-ng收集日志