tensorflow TFRecords文件的生成和读取方法

文章目录

  • tensorflow TFRecords文件的生成和读取方法
    • 1. TFRecords说明
    • 2.关键API
      • 2.1 tf.io.TFRecordWriter类
      • 2.2 tf.train.Example
      • `tf.Example` 的数据类型
      • 2.3 tf.train.Features
      • 2.4 tf.train.Feature
      • 2.5 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList
      • 2.6 tf.io.TFRecordReader类
      • 2.7 tf.data.TFRecordDataset类
        • 对Dataset中的元素做变换:Transformation
          • **(1)map**
          • **(2)batch**
          • **(3)shuffle**
          • **(4)repeat**
    • 3.写操作
      • 例子1:鸢尾花数据集-写操作
      • 例子2:
    • 4.读操作
      • 例子1 鸢尾花数据集-读操作
    • 5.一些问题

备注:本文参考https://zhuanlan.zhihu.com/p/31992460( 公众号《人工智能技术干货》,专注深度学习与计算机视觉!)

​ 本文参考https://www.tensorflow.org/api_docs/python/tf/io/TFRecordWriter (Tensorflow官方教程)

​ 本文参考https://zhuanlan.zhihu.com/p/43356309 【0.1】Tensorflow踩坑记之tf.data

​ 本文参考https://zhuanlan.zhihu.com/p/30751039 TensorFlow全新的数据读取方式:Dataset API入门教程

https://blog.csdn.net/qq_36556054/article/details/102872885?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0.pc_relevant_default&spm=1001.2101.3001.4242.1&utm_relevant_index=3

1. TFRecords说明

TFRecords是一种tensorflow的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议,其后缀一般为tfrecord。TFRecords文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是tensorflow“从文件里读取数据”的一种官方推荐方法!本篇文章,我将整理tensorflow TFRecords文件生成和读取的方法, 分为两个部分,每个部分分别介绍并附带例程!

TFrecord 是TensorFlow使用的一种数据格式,他可以把多个训练的图片许多信息压缩在一个文件中,用特殊的方式存储和读取,通过tf.dataset 这个API进行快速的读取和写入。具体使用官方教材,参考官方文档,里面有具体的使用方法,最近又出了高阶的使用方法,等流程跑通了再继续优化,Tensflow公众号-tf.data API,构建高性能 TensorFlow 输入管道。

将数据处理成TFRecord的形式,是tensorflow官方推荐的一种文件可以。使用这种文件格式是官方推荐,具体原因如下:

  • 该文件格式方便复制和移动,能够很好的利用内存;
  • 支持String,Float,Int类型的数据,方便存储结构化的标注数据;

2.关键API

2.1 tf.io.TFRecordWriter类

把记录写入到TFRecords文件的类.

__init__(path,options=None)

作用:创建一个TFRecordWriter对象,这个对象就负责写记录到指定的文件中去了.
参数:
path: TFRecords 文件路径
options: (可选) TFRecordOptions对象

close()

作用:关闭对象.

write(record)

作用:把字符串形式的记录写到文件中去.
参数:
record: 字符串,待写入的记录

2.2 tf.train.Example

tf.Example 的数据类型

这个类是非常重要的,在Tensorflow中样本数据的序列化保存一般采用tfrecord的文件格式,其根本原因在文章开头就已经描述。在tfrecord文件中,实质上是一堆tf.train.Example的集合。

tf.train.Example的数据集包含三个类:tf.train.Feature -> tf.train.Features -> tf.train.Example三个类。具体的关系如下:

example = tf.train.Example(features=tf.train.Features(feature={...}))

具体例子如下:

record_bytes = tf.train.Example(features=tf.train.Features(feature={"x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),"y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),}))

上述描述,大家对Example这个类有一个基本的了解。

属性:

features : 是一个tf.train.Features

函数:

__init__(**kwargs)

这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个tf.train.Features对象进去.

SerializeToString()

作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.

2.3 tf.train.Features

作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.

tf.train.Features(feature ={"x":tf.train.Feature(bytes_list=tf.train.BytesList=(value=[])),"y":tf.train.Feature(int64_list=tf.train.Int64List(value = []))
})

2.4 tf.train.Feature

包含属性:

bytes_list 对应的对象是:tf.train.BytesList(value=[])
float_list 对应的对象是:tf.train.FloatList(value=[])
int64_list 对应的对象是:tf.train.Int64List(value=[])

2.5 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList

这三个数据类型

  1. tf.train.BytesList(可强制转换自以下类型)
  • string
  • byte
  1. tf.train.FloatList(可强制转换自以下类型)
  • float (float32)
  • double (float64)
  1. tf.train.Int64List(可强制转换自以下类型)
  • bool
  • enum
  • int32
  • uint32
  • int64
  • uint6

2.6 tf.io.TFRecordReader类

To create an input pipeline, you must start with a data source. For example, to construct a Dataset from data in memory, you can use tf.data.Dataset.from_tensors() or tf.data.Dataset.from_tensor_slices(). Alternatively, if your input data is stored in a file in the recommended TFRecord format, you can use tf.data.TFRecordDataset().

将原始的特征数据处理成结构化的tfrecord数据集。

feature -> Features -> Example, 三者按顺序为包含关系

example = tf.train.Example(features=tf.train.Features(feature={...}))
tf_writer.write(example.SerializeToString())  # 序列化写入tfrecord

2.7 tf.data.TFRecordDataset类

Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline。

对Dataset中的元素做变换:Transformation

**Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。**通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。

常用的Transformation有:

  • map
  • batch
  • shuffle
  • repeat

下面就分别进行介绍。

(1)map

map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
(2)batch

batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:

dataset = dataset.batch(32)
(3)shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:

dataset = dataset.shuffle(buffer_size=10000)
(4)repeat

repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:

dataset = dataset.repeat(5)

如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:

dataset = dataset.repeat()

3.写操作

下面以两个例子来说明一下写操作。

例子1:鸢尾花数据集-写操作

第一个例子是将鸢尾花的数据集iris.csv,将这部分数据写入到iris.tfrecords中。

import tensorflow as tfflags = tf.flags
FLAGS = flags.FLAGSflags.DEFINE_string("input","../data/iris.csv","输入")
flags.DEFINE_string("output","../data/iris.bin","输入")def main(_):inputfile = FLAGS.inputoutputfile = FLAGS.outputwrite = tf.io.TFRecordWriter(outputfile)idx = -1with tf.io.gfile.GFile(inputfile,'r') as reader:for line in reader:idx += 1   #为了跳过第一行的descriptionif idx < 1 :continuesplits = line.split(",")x = [float(i) for i in splits[:-1]]y = [int(splits[-1])]example = tf.train.Example(features=tf.train.Features(feature={'x': tf.train.Feature(float_list=tf.train.FloatList(value=x)  # 方括号表示输入为list,一般tf.train.FloatList被用来处理浮点数),'y': tf.train.Feature(int64_list=tf.train.Int64List(value=y)  # B_data本身就是列表,一般tf.train.Int64List被用来处理整数)}))print(example.SerializeToString())write.write(example.SerializeToString())write.close()if __name__ == "__main__":tf.app.run()

例子2:


4.读操作

在TensorFlow 1.3中,Dataset API是放在contrib包中的:

tf.contrib.data.Dataset

而在TensorFlow 1.4中,Dataset API已经从contrib包中移除,变成了核心API的一员:

tf.data.Dataset

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RmKWUi95-1643547854392)(…/…/…/Images/TF-DataSet.png)]

例子1 鸢尾花数据集-读操作

读取数据中,

import tensorflow as tfflags = tf.flagsFLAGS = flags.FLAGSflags.DEFINE_string("input" , "../data/iris.bin" , "input file")def decord_fn(dataset):features = tf.io.parse_single_example(dataset ,{'x':tf.io.FixedLenFeature([4] , tf.float32 ) ,'y':tf.io.FixedLenFeature([1] , tf.int64 )})return featuresdef main(_):dataset = tf.data.TFRecordDataset(FLAGS.input).map(decord_fn)dataset = dataset.batch(10)NUM_EPOCHS = 2dataset = dataset.repeat(NUM_EPOCHS)iterator = dataset.make_one_shot_iterator()next_element = iterator.get_next()with tf.Session() as sess:sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))#sess.run(iterator.initializer)while True:try:x_train = sess.run(next_element)  # 取出设定的print(x_train)# print("x = {x:.4f},  y = {y:.4f}".format(x_train['x'],x_train['y']))except tf.errors.OutOfRangeError:break
if __name__ == "__main__":#flags.mark_flag_as_required("input")#flags.mark_flag_as_required("output")tf.app.run()
import tensorflow as tfflags = tf.flagsFLAGS = flags.FLAGSflags.DEFINE_string("input" , "../data/iris.bin" , "input file")def decord_fn(dataset):features_dict  = {"x":tf.io.FixedLenFeature([4] , tf.float32 ),"y":tf.io.FixedLenFeature([1] , tf.int64)}results = tf.io.parse_single_example(dataset , features=features_dict)return  results
def main(_):dataset = tf.data.TFRecordDataset(FLAGS.input)dataset = dataset.map(decord_fn)dataset = dataset.batch(10)NUM_EPOCHS = 2dataset = dataset.repeat(NUM_EPOCHS)iterator = dataset.make_initializable_iterator()next_element = iterator.get_next()global_init = tf.global_variables_initializer()local_init = tf.local_variables_initializer()with tf.Session() as sess:sess.run([global_init , local_init])sess.run(iterator.initializer)while True:try:out = sess.run(next_element)print(out)except tf.errors.OutOfRangeError:break
if __name__ == "__main__":#flags.mark_flag_as_required("input")#flags.mark_flag_as_required("output")tf.app.run()

5.一些问题

在使用tfrecords的时候主要遇到一些问题,主要包括:

一是处理大量数据的的时候文件太大(一万张图片有700M),并且生成的太慢,

第二就是训练的时候将图片value和label打印出来,发现不对应。 但写入的时候label和image并没有错乱,也排查了读取的代码以及数据类型的问题,但是的确会出现image和label不对应的问题,我想问一下您在制作且用于训练的时候显示过吗,出现过这种情况吗?能够确定自己的image和label是对应的吗

Tensorflow读取数据-tf.data.TFRecordDataset相关推荐

  1. TensorFlow 2.0 - tf.data.Dataset 数据预处理 猫狗分类

    文章目录 1 tf.data.Dataset.from_tensor_slices() 数据集建立 2. Dataset.map(f) 数据集预处理 3. Dataset.prefetch() 并行处 ...

  2. java套接字客户端_使用Java从客户端套接字读取数据(Read data from a client socket in Java)...

    使用Java从客户端套接字读取数据(Read data from a client socket in Java) 我编写了从客户端套接字发送/接收数据的代码. 发送数据步骤已成功完成,但是当我想从套 ...

  3. tensorflow2读取数据P4: tf.data.TFRecordDataset创建Dataset

    为啥要用tfrecord 使用tfrecord是为了更高效地读取数据,这种方式比较适合数据量大且数据集相对稳定的情况.tfrecord将数据存储成了二进制记录序列的格式,这格式比较对计算机的胃口,读取 ...

  4. tensorflow随笔-读写数据tf.data

    模块:tf.data 定义在tensorflow/_api/v1/data/init.py 输入管道的tf.data.Dataset API 模块 实验模块:构建输入管道的实验API. 类 class ...

  5. tensorflow2数据读取P3: tf.data.Dataset.from_generator通过preprocessing.image.ImageDataGenerator构造Dataset

    tf.data.Dataset.from_generator通过preprocessing.image.ImageDataGenerator构造Dataset 虽然自己定义生成器就可以构建datase ...

  6. 云端TensorFlow读取数据IO的高效方式

    低效的IO方式 最近通过观察PAI平台上TensoFlow用户的运行情况,发现大家在数据IO这方面还是有比较大的困惑,主要是因为很多同学没有很好的理解本地执行TensorFlow代码和分布式云端执行T ...

  7. tensorflow基础:tf.data.Dataset.from_tensor_slices()

    tf.data.Dataset.from_tensor_slices() 语义解释:from_tensor_slices,从张量的切片读取数据. 工作原理:将输入的张量的第一个维度看做样本的个数,沿其 ...

  8. 记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices

    tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple).列表(list).张量(tensor)等特征进行特征切片.切片的 ...

  9. tensorflow基础:tf.data.Dataset.from_tensor_slices() 与 tf.data.Dataset.from_generator()的异同

    tf.data.Dataset.from_tensor_slices(tensor): -->将tensor沿其第一个维度切片,返回一个含有N个样本的数据集(假设tensor的第一个维度为N). ...

最新文章

  1. Nancy in .Net Core学习笔记 - 初识Nancy
  2. it程序员刷题 面试 中文网站收集
  3. linux 协议栈之socket,Linux协议栈之BSD和INET socket层(一)
  4. [JSOI2007]文本生成器
  5. intellij-IDE运行Java程序报错:java: -source 1.5 中不支持 lambda 表达式 有用
  6. PLSQL 安装+配置( Oracle数据库连接工具 )
  7. 大数据及hadooop简介
  8. 前端 鼠标一次移动半个像素_Web前端(二):CSS3
  9. 亚马逊警用刷脸计划小小受阻,但原因并不是贝佐斯妥协
  10. NRF52832学习笔记
  11. 《考勤信息管理系统》数据库 课程设计
  12. 计算机求和公式IFEROR,IFERROR函数详解_Excel公式教程
  13. 计算机组装实验老毛桃u盘报告,老毛桃winpe硬盘安装版制作教程
  14. mysql 重建索引,mysql优化之索引重建
  15. emqx配置ssl/tsl实现双向认证
  16. Facebook的新算法可以预测出你的贫富阶级
  17. Caesar加密与解密
  18. mysql查询删除的数据历史记录_查询数据库各种历史记录
  19. 获取上个月的第一天和最后一天和当前月最后一天
  20. Appium JAVA ios 设备 AUT not install

热门文章

  1. mybatis逆向工程和批量插入
  2. python 生成动态库_Python 项目转.so动态库
  3. 破解boson netsim for ccnp 7.06(测试内链接有效)
  4. 如何通过nodejs快速搭建一个服务器
  5. 单纯形法和对偶单纯形法
  6. 人在新加坡,刚下飞机,原地失业!上交大佬刚到新加坡,就被虾皮取消了offer,作者发声了......
  7. 可惜,离职在家“苦修”一年半最终还是与字节offer擦肩而过
  8. Solidworks_ Flexnet_Server怎么删除?
  9. 密码学在信息安全领域的应用
  10. d3.js——图形缩放平移操作