7.3 TensorFlow笔记(基础篇):加载数据之从队列中读取
前言
整体步骤
在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步:
1. 把样本数据写入TFRecords二进制文件
2. 从队列中读取数据
读取TFRecords文件步骤
使用队列读取数TFRecords 文件 数据的步骤
1. 创建张量,从二进制文件读取一个样本数据
2. 创建张量,从二进制文件随机读取一个mini-batch
3. 把每一批张量传入网络作为输入点
TensorFlow使用TFRecords文件训练样本的步骤
- 在生成文件名的序列中,设定epoch数量
- 训练时,设定为无穷循环
- 在读取数据时,如果捕捉到错误,终止
source code
tensorflow-master\tensorflow\examples\how_tos\reading_data\fully_connected_reader.py(1.2.1)
CODE
代码与解析
解析主要在注释中,最后一个模块if __name__ == '__main__':
的运行,建议参考’http://blog.csdn.net/fontthrone/article/details/76735591’
import tensorflow as tf
import os# from tensorflow.contrib.learn.python.learn.datasets import mnist
# 注意上面的这个mnist 与 example 中的 mnist 是不同的,本文件中请使用下面的那个 mnistos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import argparse
import os.path
import sys
import timefrom tensorflow.examples.tutorials.mnist import mnist# Basic model parameters as external flags.
FLAGS = None# This part of the code is added by FontTian,which comes from the source code of tensorflow.examples.tutorials.mnist
# The MNIST images are always 28x28 pixels.
# IMAGE_SIZE = 28
# IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE# Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'def read_and_decode(filename_queue):reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(serialized_example,# Defaults are not specified since both keys are required.# 必须写明faetures 中的 key 的名称features={'image_raw': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64),})# Convert from a scalar string tensor (whose single string has# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape# [mnist.IMAGE_PIXELS].# 将一个标量字符串张量(其单个字符串的长度是mnist.image像素) # 0 维的Tensor# 转换为一个带有形状mnist.图像像素的uint8张量。 # 一维的Tensorimage = tf.decode_raw(features['image_raw'], tf.uint8)# print(tf.shape(image)) # Tensor("input/Shape:0", shape=(1,), dtype=int32)image.set_shape([mnist.IMAGE_PIXELS])# print(tf.shape(image)) # Tensor("input/Shape_1:0", shape=(1,), dtype=int32)# OPTIONAL: Could reshape into a 28x28 image and apply distortions# here. Since we are not applying any distortions in this# example, and the next step expects the image to be flattened# into a vector, we don't bother.# Convert from [0, 255] -> [-0.5, 0.5] floats.image = tf.cast(image, tf.float32) * (1. / 255) - 0.5# print(tf.shape(image)) # Tensor("input/Shape_2:0", shape=(1,), dtype=int32)# Convert label from a scalar uint8 tensor to an int32 scalar.label = tf.cast(features['label'], tf.int32)# print(tf.shape(label)) # Tensor("input/Shape_3:0", shape=(0,), dtype=int32)return image, label# 使用 tf.train.shuffle_batch 将前面生成的样本随机化,获得一个最小批次的张量
def inputs(train, batch_size, num_epochs):"""Reads input data num_epochs times.Args:train: Selects between the training (True) and validation (False) data.batch_size: Number of examples per returned batch.num_epochs: Number of times to read the input data, or 0/None totrain forever.Returns:A tuple (images, labels), where:* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]in the range [-0.5, 0.5].* labels is an int32 tensor with shape [batch_size] with the true label,a number in the range [0, mnist.NUM_CLASSES).Note that an tf.train.QueueRunner is added to the graph, whichmust be run using e.g. tf.train.start_queue_runners().输入参数:train: Selects between the training (True) and validation (False) data.batch_size: 训练的每一批有多少个样本num_epochs: 读取输入数据的次数, or 0/None 表示永远训练下去返回结果:A tuple (images, labels), where:* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]范围: [-0.5, 0.5].* labels is an int32 tensor with shape [batch_size] with the true label,范围: [0, mnist.NUM_CLASSES).注意 : tf.train.QueueRunner 被添加进 graph, 它必须用 tf.train.start_queue_runners() 来启动线程."""if not num_epochs: num_epochs = Nonefilename = os.path.join(FLAGS.train_dir,TRAIN_FILE if train else VALIDATION_FILE)with tf.name_scope('input'):# tf.train.string_input_producer 返回一个 QueueRunner,里面有一个 FIFQueuefilename_queue = tf.train.string_input_producer([filename], num_epochs=num_epochs)# 如果样本数据很大,可以分成若干文件,把文件名列表传入# Even when reading in multiple threads, share the filename queue.image, label = read_and_decode(filename_queue)# Shuffle the examples and collect them into batch_size batches.# (Internally uses a RandomShuffleQueue.)# We run this in two threads to avoid being a bottleneck.images, sparse_labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, num_threads=2,capacity=1000 + 3 * batch_size,# Ensures a minimum amount of shuffling of examples.# 留下一部分队列,来保证每次有足够的数据做随机打乱min_after_dequeue=1000)return images, sparse_labelsdef run_training():"""Train MNIST for a number of steps."""# Tell TensorFlow that the model will be built into the default Graph.with tf.Graph().as_default():# Input images and labels.images, labels = inputs(train=True, batch_size=FLAGS.batch_size,num_epochs=FLAGS.num_epochs)# 构建一个从推理模型来预测数据的图logits = mnist.inference(images,FLAGS.hidden1,FLAGS.hidden2)# Add to the Graph the loss calculation.# 定义损失函数loss = mnist.loss(logits, labels)# 将模型添加到图操作中train_op = mnist.training(loss, FLAGS.learning_rate)# 初始化变量的操作init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())# Create a session for running operations in the Graph.# 在图中创建一个用于运行操作的会话sess = tf.Session()# 初始化变量,注意:string_input_product 内部创建了一个epoch计数器sess.run(init_op)# Start input enqueue threads.coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:step = 0while not coord.should_stop():start_time = time.time()# Run one step of the model. The return values are# the activations from the `train_op` (which is# discarded) and the `loss` op. To inspect the values# of your ops or variables, you may include them in# the list passed to sess.run() and the value tensors# will be returned in the tuple from the call._, loss_value = sess.run([train_op, loss])duration = time.time() - start_time# Print an overview fairly often.if step % 100 == 0:print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,duration))step += 1except tf.errors.OutOfRangeError:print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))finally:# 通知其他线程关闭coord.request_stop()# Wait for threads to finish.coord.join(threads)sess.close()def main(_):run_training()if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--learning_rate',type=float,default=0.01,help='Initial learning rate.')parser.add_argument('--num_epochs',type=int,default=2,help='Number of epochs to run trainer.')parser.add_argument('--hidden1',type=int,default=128,help='Number of units in hidden layer 1.')parser.add_argument('--hidden2',type=int,default=32,help='Number of units in hidden layer 2.')parser.add_argument('--batch_size',type=int,default=100,help='Batch size.')parser.add_argument('--train_dir',type=str,default='/tmp/data',help='Directory with the training data.')FLAGS, unparsed = parser.parse_known_args()tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
运行结果
Step 0: loss = 2.31 (0.106 sec)
Step 100: loss = 2.14 (0.016 sec)
Step 200: loss = 1.91 (0.016 sec)
Step 300: loss = 1.69 (0.016 sec)
Step 400: loss = 1.28 (0.016 sec)
Step 500: loss = 1.02 (0.016 sec)
Step 600: loss = 0.70 (0.016 sec)
Step 700: loss = 0.71 (0.016 sec)
Step 800: loss = 0.71 (0.016 sec)
Step 900: loss = 0.49 (0.016 sec)
Step 1000: loss = 0.58 (0.016 sec)
Done training for 2 epochs, 1100 steps.
相关
- 把样本数据写入TFRecords二进制文件 : http://blog.csdn.net/fontthrone/article/details/76727412
- TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据:http://blog.csdn.net/fontthrone/article/details/76727466
- python中的argparse模块:http://blog.csdn.net/fontthrone/article/details/76735591
7.3 TensorFlow笔记(基础篇):加载数据之从队列中读取相关推荐
- 7.1 TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据
TensorFlow加载数据 TensorFlow官方共给出三种加载数据的方式: 1. 预加载数据 2. 填充数据 预加载数据的缺点: 将数据直接嵌在数据流图中,当训练数据较大时,很消耗内存.填充的方 ...
- 7.2 TensorFlow笔记(基础篇): 生成TFRecords文件
前言 在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步: 1. 把样本数据写入TFRecords二进制文件 2. 从队列 ...
- 6.1 Tensorflow笔记(基础篇):队列与线程
前言 在Tensorflow的实际应用中,队列与线程是必不可少,主要应用于数据的加载等,不同的情况下使用不同的队列,主线程与其他线程异步进行数据的训练与读取,所以队列与线程的知识也是Tensorflo ...
- Android-入门学习笔记-使用 CursorLoader 加载数据
3 使用这个代码片段开始练习 也可以参考 Codepath 教程 高级内容补充: 你是否在思考ArrayAdapter's 的 getView() 方法和CursorAdapter 的 newView ...
- 可视化 | Echarts基础异步加载数据交互组件数据集
目录 1. ECharts 简介 2. ECharts 安装 3. ECharts 配置语法 4. ECharts 图饼 5. ECharts 样式设置 6. ECharts 异步加载数据 7. EC ...
- 1.1 Tensorflow笔记(基础篇): 图与会话,变量
图与会话 import tensorflow as tf import os# 取消打印 cpu,gpu选择等的各种警告 # 设置TF_CPP_MIN_LOG_LEVEL 的等级,1.1.0以后设置2 ...
- Launcher3源码分析(LauncherModel加载数据)
LauncherModel继承BroadcastReceiver,显然是一个广播接收者.在上一篇Launcher的启动中讲到桌面数据的加载工作是在LauncherModel中执行的,那么它是如何加载数 ...
- Cesium开发基础笔记总结(加载影像、加载地形数据、加载矢量)
Cesium开发基础笔记总结 学习总结于GIS李胜老师博客 Cesium开发基础01加载影像数据 加载影像数据 Cesium中的影像图层类: 无论是二维地图还是三维地图,如果缺少了底图影像或电子地图, ...
- Google Map 开发笔记——基础篇(Javascript )
Google Map 开发笔记--基础篇 说明: 一.使用入门: 1.在您需要显示地图的 html 页面嵌入这段 script 2.地图 DOM 元素 3.初始化地图 二.地图画点.线.面 1.标记( ...
最新文章
- 策略模式(封装一系列的功能,使之可以相互替换)
- PHP基本连接数据库
- Image Pro Plus测量组织平均厚度
- 【AWSL】之Linux源代码编译及配置yum源(tar 解包、./configure配置软件模块、make)
- Jmeter报告优化之New XSL stylesheet
- laravel如何生成swagger接口文档
- CSS实现垂直居中的方法
- asp不能做到的是什么
- swift-自定义Alert
- 配置管理工具SVN的使用
- UML用例图的画法详细介绍【软件工程】
- qq邮箱收件服务器用户名密码,iphone6/6s+设置QQ邮箱时显示用户名或密码错误的解决方法介绍...
- 2017年全球IDC、光器件、100G及400G数通模块市场预测
- Linux文件比较工具
- 从React专利事件看开源软件许可
- 常用的平方根算法详解与实现
- 真正从零开始,TensorFlow详细安装入门图文教程!(linux)
- python 经典类与新式类
- 360音乐搜索使用讲解说明
- Android及IOS微信5,Android版微信5.0今日正式发布 与iOS版略不同