Tensor Flow官方网站上提供三种读取数据的方法

1. 预加载数据:在Tensor Flow图中定义常量或变量来保存所有数据,将数据直接嵌到数据图中,当训练数据较大时,很消耗内存。

x1=tf.constant([0,1])

x2=tf.constant([1,0])

y=tf.add(x1,x2)

2.填充数据:使用sess.run()的feed_dict参数,将Python产生的数据填充到后端,之前的MNIST数据集就是通过这种方法。也有消耗内存,数据类型转换耗时的缺点。

3. 从文件读取数据:从文件中直接读取,让队列管理器从文件中读取数据。分为两步

先把样本数据写入TFRecords二进制文件

再从队列中读取

TFRecord是TensorFlow提供的一种统一存储数据的二进制文件,能更好的利用内存,更方便的复制和移动,并且不需要单独的标记文件。下面通过代码来将MNIST转换成TFRecord的数据格式,其他数据集也类似。

#生成整数型的属性

def _int64_feature(value):

return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串型的属性

def _bytes_feature(value):

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(data_set,name):

'''

将数据填入到tf.train.Example的协议缓冲区(protocol buffer)中,将协议缓冲区序列

化为一个字符串,通过tf.python_io.TFRecordWriter写入TFRecords文件

'''

images=data_set.images

labels=data_set.labels

num_examples=data_set.num_examples

if images.shape[0]!=num_examples:

raise ValueError ('Imagessize %d does not match label size %d.'\

%(images.shape[0],num_examples))

rows=images.shape[1] #28

cols=images.shape[2] #28

depth=images.shape[3] #1 是黑白图像

filename = os.path.join(FLAGS.directory, name + '.tfrecords')

#使用下面语句就会将三个文件存储为一个TFRecord文件,当数据量较大时,最好将数据写入多个文件

#filename="C:/Users/dbsdz/Desktop/TF练习/TFRecord"

print('Writing',filename)

writer=tf.python_io.TFRecordWriter(filename)

for index in range(num_examples):

image_raw=images[index].tostring() #将图像矩阵化为一个字符串

#写入协议缓冲区,height、width、depth、label编码成int 64类型,image——raw编码成二进制

example=tf.train.Example(features=tf.train.Features(feature={

'height':_int64_feature(rows),

'width':_int64_feature(cols),

'depth':_int64_feature(depth),

'label':_int64_feature(int(labels[index])),

'image_raw':_bytes_feature(image_raw)}))

writer.write(example.SerializeToString()) #序列化字符串

writer.close()

上面程序可以将MNIST数据集中所有的训练数据存储到三个TFRecord文件中。结果如下图

从队列中TFRecord文件,过程分三步

1. 创建张量,从二进制文件中读取一个样本

2. 创建张量,从二进制文件中随机读取一个mini-batch

3. 把每一批张量传入网络作为输入节点

具体代码如下

def read_and_decode(filename_queue): #输入文件名队列

reader=tf.TFRecordReader()

_,serialized_example=reader.read(filename_queue)

#解析一个example,如果需要解析多个样例,使用parse_example函数

features=tf.parse_single_example(

serialized_example,

#必须写明feature里面的key的名称

features={

#TensorFlow提供两种不同的属性解析方法,一种方法是tf.FixedLenFeature,

#这种方法解析的结果为一个Tensor。另一个方法是tf.VarLenFeature,

#这种方法得到的解析结果为SparseTensor,用于处理稀疏数据。

#这里解析数据的格式需要和上面程序写入数据的格式一致

'image_raw':tf.FixedLenFeature([],tf.string),#图片是string类型

'label':tf.FixedLenFeature([],tf.int64), #标记是int64类型

})

#对于BytesList,要重新进行编码,把string类型的0维Tensor变成uint8类型的一维Tensor

image = tf.decode_raw(features['image_raw'], tf.uint8)

image.set_shape([IMAGE_PIXELS])

#tensor("input/DecodeRaw:0",shape=(784,),dtype=uint8)

#image张量的形状为:tensor("input/sub:0",shape=(784,),dtype=float32)

image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

#把标记从uint8类型转换为int32类性

#label张量的形状为tensor(“input/cast_1:0",shape=(),dtype=int32)

label = tf.cast(features['label'], tf.int32)

return image,label

def inputs(train,batch_size,num_epochs):

#输入参数:

#train:选择输入训练数据/验证数据

#batch_size:训练的每一批有多少个样本

#num_epochs:过几遍数据,设置为0/None表示永远训练下去

'''

返回结果: A tuple (images,labels)

*images:类型为float,形状为【batch_size,mnist.IMAGE_PIXELS],范围【-0.5,0.5】。

*label:类型为int32,形状为【batch_size],范围【0,mnist.NUM_CLASSES]

注意tf.train.QueueRunner必须用tf.train.start_queue_runners()来启动线程

'''

if not num_epochs:num_epochs=None

#获取文件路径,即./MNIST_data/train.tfrecords,./MNIST_data/validation.records

filename=os.path.join(FLAGS.train_dir,TRAIN_FILE if train else VALIDATION_FILE)

with tf.name_scope('input'):

#tf.train.string_input_producer返回一个QueueRunner,里面有一个FIFOQueue

filename_queue=tf.train.string_input_producer(#如果样本量很大,可以分成若干文件,把文件名列表传入

[filename],num_epochs=num_epochs)

image,label=read_and_decode(filename_queue)

#随机化example,并把它们整合成batch_size大小

#tf.train.shuffle_batch生成了RandomShuffleQueue,并开启两个线程

images,sparse_labels=tf.train.shuffle_batch(

[image,label],batch_size=batch_size,num_threads=2,

capacity=1000+3*batch_size,

min_after_dequeue=1000) #留下一部分队列,来保证每次有足够的数据做随机打乱

return images,sparse_labels

最后,构建一个三层的神经网络,包含两层卷积层以及一层使用SoftMax层,附上完整代码如下

# -*- coding: utf-8 -*-

"""

Created on Sun Apr 8 11:06:16 2018

@author: dbsdz

https://blog.csdn.net/xy2953396112/article/details/54929073

"""

import tensorflow as tf

import os

import time

import math

from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

# Basic model parameters as external flags.

flags = tf.app.flags

flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')

flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')

flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')

flags.DEFINE_integer('batch_size', 100, 'Batch size. '

'Must divide evenly into the dataset sizes.')

flags.DEFINE_string('train_dir', 'Mnist_data/', 'Directory to put the training data.')

flags.DEFINE_string('directory', './MNIST_data',

'Directory to download data files and write the '

'converted result')

flags.DEFINE_integer('validation_size', 5000,

'Number of examples to separate from the training '

'data for the validation set.')

flags.DEFINE_integer('num_epochs',10,'num_epochs set')

FLAGS = tf.app.flags.FLAGS

IMAGE_SIZE = 28

IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE #图片像素728

TRAIN_FILE = "train.tfrecords"

VALIDATION_FILE="validation.tfrecords"

#生成整数型的属性

def _int64_feature(value):

return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字符串型的属性

def _bytes_feature(value):

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def convert_to(data_set,name):

'''

将数据填入到tf.train.Example的协议缓冲区(protocol buffer)中,将协议缓冲区序列

化为一个字符串,通过tf.python_io.TFRecordWriter写入TFRecords文件

'''

images=data_set.images

labels=data_set.labels

num_examples=data_set.num_examples

if images.shape[0]!=num_examples:

raise ValueError ('Imagessize %d does not match label size %d.'\

%(images.shape[0],num_examples))

rows=images.shape[1] #28

cols=images.shape[2] #28

depth=images.shape[3] #1 是黑白图像

filename = os.path.join(FLAGS.directory, name + '.tfrecords')

#使用下面语句就会将三个文件存储为一个TFRecord文件,当数据量较大时,最好将数据写入多个文件

#filename="C:/Users/dbsdz/Desktop/TF练习/TFRecord"

print('Writing',filename)

writer=tf.python_io.TFRecordWriter(filename)

for index in range(num_examples):

image_raw=images[index].tostring() #将图像矩阵化为一个字符串

#写入协议缓冲区,height、width、depth、label编码成int 64类型,image——raw编码成二进制

example=tf.train.Example(features=tf.train.Features(feature={

'height':_int64_feature(rows),

'width':_int64_feature(cols),

'depth':_int64_feature(depth),

'label':_int64_feature(int(labels[index])),

'image_raw':_bytes_feature(image_raw)}))

writer.write(example.SerializeToString()) #序列化字符串

writer.close()

def inference(images, hidden1_units, hidden2_units):

with tf.name_scope('hidden1'):

weights = tf.Variable(

tf.truncated_normal([IMAGE_PIXELS, hidden1_units],

stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),name='weights')

biases = tf.Variable(tf.zeros([hidden1_units]),name='biases')

hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)

with tf.name_scope('hidden2'):

weights = tf.Variable(

tf.truncated_normal([hidden1_units, hidden2_units],

stddev=1.0 / math.sqrt(float(hidden1_units))),

name='weights')

biases = tf.Variable(tf.zeros([hidden2_units]),

name='biases')

hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)

with tf.name_scope('softmax_linear'):

weights = tf.Variable(

tf.truncated_normal([hidden2_units,FLAGS.num_epochs],

stddev=1.0 / math.sqrt(float(hidden2_units))),name='weights')

biases = tf.Variable(tf.zeros([FLAGS.num_epochs]),name='biases')

logits = tf.matmul(hidden2, weights) + biases

return logits

def lossFunction(logits, labels):

labels = tf.to_int64(labels)

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(

logits=logits, labels=labels, name='xentropy')

loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

return loss

def training(loss, learning_rate):

tf.summary.scalar(loss.op.name, loss)

optimizer = tf.train.GradientDescentOptimizer(learning_rate)

global_step = tf.Variable(0, name='global_step', trainable=False)

train_op = optimizer.minimize(loss, global_step=global_step)

return train_op

def read_and_decode(filename_queue): #输入文件名队列

reader=tf.TFRecordReader()

_,serialized_example=reader.read(filename_queue)

#解析一个example,如果需要解析多个样例,使用parse_example函数

features=tf.parse_single_example(

serialized_example,

#必须写明feature里面的key的名称

features={

#TensorFlow提供两种不同的属性解析方法,一种方法是tf.FixedLenFeature,

#这种方法解析的结果为一个Tensor。另一个方法是tf.VarLenFeature,

#这种方法得到的解析结果为SparseTensor,用于处理稀疏数据。

#这里解析数据的格式需要和上面程序写入数据的格式一致

'image_raw':tf.FixedLenFeature([],tf.string),#图片是string类型

'label':tf.FixedLenFeature([],tf.int64), #标记是int64类型

})

#对于BytesList,要重新进行编码,把string类型的0维Tensor变成uint8类型的一维Tensor

image = tf.decode_raw(features['image_raw'], tf.uint8)

image.set_shape([IMAGE_PIXELS])

#tensor("input/DecodeRaw:0",shape=(784,),dtype=uint8)

#image张量的形状为:tensor("input/sub:0",shape=(784,),dtype=float32)

image = tf.cast(image, tf.float32) * (1. / 255) - 0.5

#把标记从uint8类型转换为int32类性

#label张量的形状为tensor(“input/cast_1:0",shape=(),dtype=int32)

label = tf.cast(features['label'], tf.int32)

return image,label

def inputs(train,batch_size,num_epochs):

#输入参数:

#train:选择输入训练数据/验证数据

#batch_size:训练的每一批有多少个样本

#num_epochs:过几遍数据,设置为0/None表示永远训练下去

'''

返回结果: A tuple (images,labels)

*images:类型为float,形状为【batch_size,mnist.IMAGE_PIXELS],范围【-0.5,0.5】。

*label:类型为int32,形状为【batch_size],范围【0,mnist.NUM_CLASSES]

注意tf.train.QueueRunner必须用tf.train.start_queue_runners()来启动线程

'''

if not num_epochs:num_epochs=None

#获取文件路径,即./MNIST_data/train.tfrecords,./MNIST_data/validation.records

filename=os.path.join(FLAGS.train_dir,TRAIN_FILE if train else VALIDATION_FILE)

with tf.name_scope('input'):

#tf.train.string_input_producer返回一个QueueRunner,里面有一个FIFOQueue

filename_queue=tf.train.string_input_producer(#如果样本量很大,可以分成若干文件,把文件名列表传入

[filename],num_epochs=num_epochs)

image,label=read_and_decode(filename_queue)

#随机化example,并把它们整合成batch_size大小

#tf.train.shuffle_batch生成了RandomShuffleQueue,并开启两个线程

images,sparse_labels=tf.train.shuffle_batch(

[image,label],batch_size=batch_size,num_threads=2,

capacity=1000+3*batch_size,

min_after_dequeue=1000) #留下一部分队列,来保证每次有足够的数据做随机打乱

return images,sparse_labels

def run_training():

with tf.Graph().as_default():

#输入images和labels

images,labels=inputs(train=True,batch_size=FLAGS.batch_size,

num_epochs=3) #num_epochs就是训练的轮数

#构建一个从推理模型来预测数据的图

logits=inference(images,FLAGS.hidden1,FLAGS.hidden2)

loss=lossFunction(logits,labels) #定义损失函数

#Add to the Graph operations that train the model

train_op=training(loss,FLAGS.learning_rate)

#初始化参数,特别注意:string——input_producer内部创建了一个epoch计数变量

#归入tf.graphkey.local_variables集合中,必须单独用initialize_local_variables()初始化

init_op=tf.group(tf.global_variables_initializer(),

tf.local_variables_initializer())

sess=tf.Session()

sess.run(init_op)

#Start input enqueue threads

coord =tf.train.Coordinator()

threads=tf.train.start_queue_runners(sess=sess,coord=coord)

try:

step=0

while not coord.should_stop(): #进入永久循环

start_time=time.time()

_,loss_value=sess.run([train_op,loss])

#每100次训练输出一次结果

if step % 100 ==0:

duration=time.time()-start_time

print('Step %d: loss=%.2f (%.3f sec)'%(step,loss_value,duration))

step+=1

except tf.errors.OutOfRangeError:

print('Done training for %d epochs,%d steps.'%(FLAGS.num_epochs,step))

finally:

coord.request_stop()#通知其他线程关闭

coord.join(threads)

sess.close()

def main(unused_argv):

#获取数据

data_sets=input_data.read_data_sets(FLAGS.directory,dtype=tf.uint8,reshape=False,

validation_size=FLAGS.validation_size)

#将数据转换成tf.train.Example类型,并写入TFRecords文件

convert_to(data_sets.train,'train')

convert_to(data_sets.validation,'validation')

convert_to(data_sets.test,'test')

print('convert finished')

run_training()

if __name__ == '__main__':

tf.app.run()

运行结果如图

以上这篇TFRecord格式存储数据与队列读取实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

python如何读取tfrecord_TFRecord格式存储数据与队列读取实例相关推荐

  1. 在python中使用json格式存储数据

    在python中使用json格式存储数据 代码如下: import jsonlist1 = [{'A': [1, 2, 3, 4, 5, 6], 'B': [3, 4, 5, 6, 7]},{'C': ...

  2. livechart 只显示 y 值_基于Python语言的SEGY格式地震数据读取与显示编程

    敬请关注<地学新视野> 摘要:本文简单介绍了SEG-Y地震数据文件格式,以及如何用Python语言编写读写SEG-Y格式的地震数据并绘制地震剖面,其中用到了Segyio和matplotli ...

  3. 读取BIL格式高光谱数据——C/C++

    在上一篇博客中,我们提到如何读取头文件.这篇博客将要谈到如何进一步读取高光谱数据本身,这里以BIL格式为例. 什么是BIL呢?BIL的全称为Band Interleave by Line,连续保存的是 ...

  4. Mysql中使用json格式存储数据好吗?

    在最近的一次项目开发过程中,在数据表设计阶段,对是否用json格式存储某些数据我们产生了分歧.以往项目中对此点比较随意,导致数据表中有些json格式数据体积很大,层次很深,我担心这会降低数据查询和解析 ...

  5. python怎么读取csv的一部分数据_python批量读取csv文件 如何用python将csv文件中的数据读取成数组...

    如何用python把多个csv文件数据处理后汇总到新csv文件你看这月光多温柔,小编转头还能看见你,一切从未坍塌. 可以用pandas读取数据,首先把文件方同一个文件价里,然后对当前文件价的所有内容循 ...

  6. Python:处理yaml格式的数据太简单了,真怕你学会了

    一.思考 1.什么是yaml? 不是标记语言 对用户极其友好 数据序列化标准 跨语言 所有编程语言都支持 跨平台 所有平台都支持 Windows.linux.Mac 格式简单 比json小姐姐穿得更少 ...

  7. python数据分析报告的格式_Python数据报表之Excel操作模块用法分析

    本文实例讲述了Python数据报表之Excel操作模块用法.分享给大家供大家参考,具体如下: 一 点睛 Excel是当今最流行的电子表格处理软件,支持丰富的计算函数及图表,在系统运营方面广泛用于运营数 ...

  8. 读取csv格式的数据

    1.直接上代码,关键是会用 2.代码如下: <?php #添加推荐到英文站 $file = fopen('code.csv','r'); while ($data = fgetcsv($file ...

  9. Spark读取Parquet格式的数据为Dataframe

    SaveMode指定文件保存时的模式: OverWrite 覆盖 Append 追加 ErrorIfExists 如果存在就报错 Ignore 如果存在就忽略 val spark = SparkSes ...

最新文章

  1. 7个小技巧,老板再也无法留我加班了...
  2. 统计学习方法第五章作业:ID3/C4.5算法分类决策树、平方误差二叉回归树代码实现
  3. 为Visual studio 2008 添加汇编工程模板
  4. 06列表的常用基本操作
  5. asp.net面试的题目
  6. SpringSecurity认证流程分析
  7. 第十七章 特殊类成员
  8. 怎么利用计算机计算潮流计算,电力系统潮流计算的目的是什么
  9. linux gst-launch 播放视频旋转,【视频开发】Gstreamer中一些gst-launch常用命令
  10. jQuery插件开发之windowScroll
  11. 打印连续数字 java_java多线程连续打印字母数字问题
  12. 20211102:数字滤波器按照实现结构的分类及其优缺点总结
  13. OpenCV中的reshape
  14. unity3d的uGUI基本操作
  15. SAP各模块常用数据库表大全--->常用表
  16. java程序怎么混淆,使用混淆器,保护你的java程序,混淆java
  17. 通过气象站API获取天气信息
  18. PHP中?是什么意思,有什么用?
  19. 统计一组名字中每个姓出现的次数
  20. 18.10.20日报

热门文章

  1. 洛谷P4956 [COCI2017-2018#6] Davor
  2. 【LaTex】2.1 分式、乘方与开方
  3. 【WY】数据分析 — Matplotlib 阶段一 :基础语法 六 —— 图表保存
  4. React获取url参数
  5. 【单例模式】JAVA懒汉式和饿汉式简单实现
  6. Redis的基本数据类型
  7. POJ9273:PKU2506Tiling_递推+高精度
  8. 2019年9月-最新2000个国内高匿代理ip
  9. 论文阅读——WaveNet: A Generative Model for Raw Audio
  10. ubuntu18.04试玩openproject