前言

整体步骤

在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步:
1. 把样本数据写入TFRecords二进制文件
2. 从队列中读取数据

读取TFRecords文件步骤

使用队列读取数TFRecords 文件 数据的步骤
1. 创建张量,从二进制文件读取一个样本数据
2. 创建张量,从二进制文件随机读取一个mini-batch
3. 把每一批张量传入网络作为输入点

TensorFlow使用TFRecords文件训练样本的步骤

  1. 在生成文件名的序列中,设定epoch数量
  2. 训练时,设定为无穷循环
  3. 在读取数据时,如果捕捉到错误,终止

source code

tensorflow-master\tensorflow\examples\how_tos\reading_data\fully_connected_reader.py(1.2.1)

CODE

代码与解析

解析主要在注释中,最后一个模块if __name__ == '__main__':的运行,建议参考’http://blog.csdn.net/fontthrone/article/details/76735591’

import tensorflow as tf
import os# from tensorflow.contrib.learn.python.learn.datasets import mnist
# 注意上面的这个mnist 与 example 中的 mnist 是不同的,本文件中请使用下面的那个 mnistos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import argparse
import os.path
import sys
import timefrom tensorflow.examples.tutorials.mnist import mnist# Basic model parameters as external flags.
FLAGS = None# This part of the code is added by FontTian,which comes from the source code of tensorflow.examples.tutorials.mnist
# The MNIST images are always 28x28 pixels.
# IMAGE_SIZE = 28
# IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE# Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'def read_and_decode(filename_queue):reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(serialized_example,# Defaults are not specified since both keys are required.# 必须写明faetures 中的 key 的名称features={'image_raw': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64),})# Convert from a scalar string tensor (whose single string has# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape# [mnist.IMAGE_PIXELS].# 将一个标量字符串张量(其单个字符串的长度是mnist.image像素) # 0 维的Tensor# 转换为一个带有形状mnist.图像像素的uint8张量。 # 一维的Tensorimage = tf.decode_raw(features['image_raw'], tf.uint8)# print(tf.shape(image)) # Tensor("input/Shape:0", shape=(1,), dtype=int32)image.set_shape([mnist.IMAGE_PIXELS])# print(tf.shape(image)) # Tensor("input/Shape_1:0", shape=(1,), dtype=int32)# OPTIONAL: Could reshape into a 28x28 image and apply distortions# here.  Since we are not applying any distortions in this# example, and the next step expects the image to be flattened# into a vector, we don't bother.# Convert from [0, 255] -> [-0.5, 0.5] floats.image = tf.cast(image, tf.float32) * (1. / 255) - 0.5# print(tf.shape(image)) # Tensor("input/Shape_2:0", shape=(1,), dtype=int32)# Convert label from a scalar uint8 tensor to an int32 scalar.label = tf.cast(features['label'], tf.int32)# print(tf.shape(label)) # Tensor("input/Shape_3:0", shape=(0,), dtype=int32)return image, label# 使用 tf.train.shuffle_batch 将前面生成的样本随机化,获得一个最小批次的张量
def inputs(train, batch_size, num_epochs):"""Reads input data num_epochs times.Args:train: Selects between the training (True) and validation (False) data.batch_size: Number of examples per returned batch.num_epochs: Number of times to read the input data, or 0/None totrain forever.Returns:A tuple (images, labels), where:* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]in the range [-0.5, 0.5].* labels is an int32 tensor with shape [batch_size] with the true label,a number in the range [0, mnist.NUM_CLASSES).Note that an tf.train.QueueRunner is added to the graph, whichmust be run using e.g. tf.train.start_queue_runners().输入参数:train: Selects between the training (True) and validation (False) data.batch_size: 训练的每一批有多少个样本num_epochs: 读取输入数据的次数, or 0/None 表示永远训练下去返回结果:A tuple (images, labels), where:* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]范围: [-0.5, 0.5].* labels is an int32 tensor with shape [batch_size] with the true label,范围: [0, mnist.NUM_CLASSES).注意 :  tf.train.QueueRunner 被添加进 graph, 它必须用 tf.train.start_queue_runners() 来启动线程."""if not num_epochs: num_epochs = Nonefilename = os.path.join(FLAGS.train_dir,TRAIN_FILE if train else VALIDATION_FILE)with tf.name_scope('input'):# tf.train.string_input_producer 返回一个 QueueRunner,里面有一个 FIFQueuefilename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)# 如果样本数据很大,可以分成若干文件,把文件名列表传入# Even when reading in multiple threads, share the filename queue.image, label = read_and_decode(filename_queue)# Shuffle the examples and collect them into batch_size batches.# (Internally uses a RandomShuffleQueue.)# We run this in two threads to avoid being a bottleneck.images, sparse_labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=2,capacity=1000 + 3 * batch_size,# Ensures a minimum amount of shuffling of examples.# 留下一部分队列,来保证每次有足够的数据做随机打乱min_after_dequeue=1000)return images, sparse_labelsdef run_training():"""Train MNIST for a number of steps."""# Tell TensorFlow that the model will be built into the default Graph.with tf.Graph().as_default():# Input images and labels.images, labels = inputs(train=True, batch_size=FLAGS.batch_size,num_epochs=FLAGS.num_epochs)# 构建一个从推理模型来预测数据的图logits = mnist.inference(images,FLAGS.hidden1,FLAGS.hidden2)# Add to the Graph the loss calculation.# 定义损失函数loss = mnist.loss(logits, labels)# 将模型添加到图操作中train_op = mnist.training(loss, FLAGS.learning_rate)# 初始化变量的操作init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())# Create a session for running operations in the Graph.# 在图中创建一个用于运行操作的会话sess = tf.Session()# 初始化变量,注意:string_input_product 内部创建了一个epoch计数器sess.run(init_op)# Start input enqueue threads.coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:step = 0while not coord.should_stop():start_time = time.time()# Run one step of the model.  The return values are# the activations from the `train_op` (which is# discarded) and the `loss` op.  To inspect the values# of your ops or variables, you may include them in# the list passed to sess.run() and the value tensors# will be returned in the tuple from the call._, loss_value = sess.run([train_op, loss])duration = time.time() - start_time# Print an overview fairly often.if step % 100 == 0:print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,duration))step += 1except tf.errors.OutOfRangeError:print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))finally:# 通知其他线程关闭coord.request_stop()# Wait for threads to finish.coord.join(threads)sess.close()def main(_):run_training()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--learning_rate',type=float,default=0.01,help='Initial learning rate.')parser.add_argument('--num_epochs',type=int,default=2,help='Number of epochs to run trainer.')parser.add_argument('--hidden1',type=int,default=128,help='Number of units in hidden layer 1.')parser.add_argument('--hidden2',type=int,default=32,help='Number of units in hidden layer 2.')parser.add_argument('--batch_size',type=int,default=100,help='Batch size.')parser.add_argument('--train_dir',type=str,default='/tmp/data',help='Directory with the training data.')FLAGS, unparsed = parser.parse_known_args()tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

运行结果

Step 0: loss = 2.31 (0.106 sec)
Step 100: loss = 2.14 (0.016 sec)
Step 200: loss = 1.91 (0.016 sec)
Step 300: loss = 1.69 (0.016 sec)
Step 400: loss = 1.28 (0.016 sec)
Step 500: loss = 1.02 (0.016 sec)
Step 600: loss = 0.70 (0.016 sec)
Step 700: loss = 0.71 (0.016 sec)
Step 800: loss = 0.71 (0.016 sec)
Step 900: loss = 0.49 (0.016 sec)
Step 1000: loss = 0.58 (0.016 sec)
Done training for 2 epochs, 1100 steps.

相关

  1. 把样本数据写入TFRecords二进制文件 : http://blog.csdn.net/fontthrone/article/details/76727412
  2. TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据:http://blog.csdn.net/fontthrone/article/details/76727466
  3. python中的argparse模块:http://blog.csdn.net/fontthrone/article/details/76735591

7.3 TensorFlow笔记(基础篇):加载数据之从队列中读取相关推荐

  1. 7.1 TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据

    TensorFlow加载数据 TensorFlow官方共给出三种加载数据的方式: 1. 预加载数据 2. 填充数据 预加载数据的缺点: 将数据直接嵌在数据流图中,当训练数据较大时,很消耗内存.填充的方 ...

  2. 7.2 TensorFlow笔记(基础篇): 生成TFRecords文件

    前言 在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步: 1. 把样本数据写入TFRecords二进制文件 2. 从队列 ...

  3. 6.1 Tensorflow笔记(基础篇):队列与线程

    前言 在Tensorflow的实际应用中,队列与线程是必不可少,主要应用于数据的加载等,不同的情况下使用不同的队列,主线程与其他线程异步进行数据的训练与读取,所以队列与线程的知识也是Tensorflo ...

  4. Android-入门学习笔记-使用 CursorLoader 加载数据

    3 使用这个代码片段开始练习 也可以参考 Codepath 教程 高级内容补充: 你是否在思考ArrayAdapter's 的 getView() 方法和CursorAdapter 的 newView ...

  5. 可视化 | Echarts基础异步加载数据交互组件数据集

    目录 1. ECharts 简介 2. ECharts 安装 3. ECharts 配置语法 4. ECharts 图饼 5. ECharts 样式设置 6. ECharts 异步加载数据 7. EC ...

  6. 1.1 Tensorflow笔记(基础篇): 图与会话,变量

    图与会话 import tensorflow as tf import os# 取消打印 cpu,gpu选择等的各种警告 # 设置TF_CPP_MIN_LOG_LEVEL 的等级,1.1.0以后设置2 ...

  7. Launcher3源码分析(LauncherModel加载数据)

    LauncherModel继承BroadcastReceiver,显然是一个广播接收者.在上一篇Launcher的启动中讲到桌面数据的加载工作是在LauncherModel中执行的,那么它是如何加载数 ...

  8. Cesium开发基础笔记总结(加载影像、加载地形数据、加载矢量)

    Cesium开发基础笔记总结 学习总结于GIS李胜老师博客 Cesium开发基础01加载影像数据 加载影像数据 Cesium中的影像图层类: 无论是二维地图还是三维地图,如果缺少了底图影像或电子地图, ...

  9. Google Map 开发笔记——基础篇(Javascript )

    Google Map 开发笔记--基础篇 说明: 一.使用入门: 1.在您需要显示地图的 html 页面嵌入这段 script 2.地图 DOM 元素 3.初始化地图 二.地图画点.线.面 1.标记( ...

最新文章

  1. 策略模式(封装一系列的功能,使之可以相互替换)
  2. PHP基本连接数据库
  3. Image Pro Plus测量组织平均厚度
  4. 【AWSL】之Linux源代码编译及配置yum源(tar 解包、./configure配置软件模块、make)
  5. Jmeter报告优化之New XSL stylesheet
  6. laravel如何生成swagger接口文档
  7. CSS实现垂直居中的方法
  8. asp不能做到的是什么
  9. swift-自定义Alert
  10. 配置管理工具SVN的使用
  11. UML用例图的画法详细介绍【软件工程】
  12. qq邮箱收件服务器用户名密码,iphone6/6s+设置QQ邮箱时显示用户名或密码错误的解决方法介绍...
  13. 2017年全球IDC、光器件、100G及400G数通模块市场预测
  14. Linux文件比较工具
  15. 从React专利事件看开源软件许可
  16. 常用的平方根算法详解与实现
  17. 真正从零开始,TensorFlow详细安装入门图文教程!(linux)
  18. python 经典类与新式类
  19. 360音乐搜索使用讲解说明
  20. Android及IOS微信5,Android版微信5.0今日正式发布 与iOS版略不同

热门文章

  1. Mac环境下Redis的安装与配置
  2. linux之vsftpd配置
  3. Ubuntu 14.04 hadoop单机安装
  4. SQL Server中Rollup关键字使用技巧
  5. Linux文件与目录的rwx权限
  6. 为什么 RestTemplate 那么棒,看这篇就够了!
  7. 300 行代码带你搞懂 Java 多线程!
  8. 详解 Java 的八大基本类型,写得非常好!
  9. Netty 实战:如何编写一个麻小俱全的 web 容器
  10. JAVA多线程:线程创建过程以及生命周期总结