Tensorflow读取数据-tf.data.TFRecordDataset
tensorflow TFRecords文件的生成和读取方法
文章目录
- tensorflow TFRecords文件的生成和读取方法
- 1. TFRecords说明
- 2.关键API
- 2.1 tf.io.TFRecordWriter类
- 2.2 tf.train.Example
- `tf.Example` 的数据类型
- 2.3 tf.train.Features
- 2.4 tf.train.Feature
- 2.5 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList
- 2.6 tf.io.TFRecordReader类
- 2.7 tf.data.TFRecordDataset类
- 对Dataset中的元素做变换:Transformation
- **(1)map**
- **(2)batch**
- **(3)shuffle**
- **(4)repeat**
- 3.写操作
- 例子1:鸢尾花数据集-写操作
- 例子2:
- 4.读操作
- 例子1 鸢尾花数据集-读操作
- 5.一些问题
备注:本文参考https://zhuanlan.zhihu.com/p/31992460( 公众号《人工智能技术干货》,专注深度学习与计算机视觉!)
本文参考https://www.tensorflow.org/api_docs/python/tf/io/TFRecordWriter (Tensorflow官方教程)
本文参考https://zhuanlan.zhihu.com/p/43356309 【0.1】Tensorflow踩坑记之tf.data
本文参考https://zhuanlan.zhihu.com/p/30751039 TensorFlow全新的数据读取方式:Dataset API入门教程
https://blog.csdn.net/qq_36556054/article/details/102872885?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_baidulandingword~default-0.pc_relevant_default&spm=1001.2101.3001.4242.1&utm_relevant_index=3
1. TFRecords说明
TFRecords是一种tensorflow的内定标准文件格式,其实质是二进制文件,遵循protocol buffer(PB)协议,其后缀一般为tfrecord。TFRecords文件方便复制和移动,能够很好的利用内存,无需单独标记文件,适用于大量数据的顺序读取,是tensorflow“从文件里读取数据”的一种官方推荐方法!本篇文章,我将整理tensorflow TFRecords文件生成和读取的方法, 分为两个部分,每个部分分别介绍并附带例程!
TFrecord 是TensorFlow使用的一种数据格式,他可以把多个训练的图片许多信息压缩在一个文件中,用特殊的方式存储和读取,通过tf.dataset 这个API进行快速的读取和写入。具体使用官方教材,参考官方文档,里面有具体的使用方法,最近又出了高阶的使用方法,等流程跑通了再继续优化,Tensflow公众号-tf.data API,构建高性能 TensorFlow 输入管道。
将数据处理成TFRecord的形式,是tensorflow官方推荐的一种文件可以。使用这种文件格式是官方推荐,具体原因如下:
- 该文件格式方便复制和移动,能够很好的利用内存;
- 支持String,Float,Int类型的数据,方便存储结构化的标注数据;
2.关键API
2.1 tf.io.TFRecordWriter类
把记录写入到TFRecords文件的类.
__init__(path,options=None)
作用:创建一个TFRecordWriter对象,这个对象就负责写记录到指定的文件中去了.
参数:
path: TFRecords 文件路径
options: (可选) TFRecordOptions对象
close()
作用:关闭对象.
write(record)
作用:把字符串形式的记录写到文件中去.
参数:
record: 字符串,待写入的记录
2.2 tf.train.Example
tf.Example
的数据类型
这个类是非常重要的,在Tensorflow中样本数据的序列化保存一般采用tfrecord的文件格式,其根本原因在文章开头就已经描述。在tfrecord文件中,实质上是一堆tf.train.Example的集合。
tf.train.Example的数据集包含三个类:tf.train.Feature -> tf.train.Features -> tf.train.Example三个类。具体的关系如下:
example = tf.train.Example(features=tf.train.Features(feature={...}))
具体例子如下:
record_bytes = tf.train.Example(features=tf.train.Features(feature={"x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),"y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),}))
上述描述,大家对Example这个类有一个基本的了解。
属性:
features : 是一个tf.train.Features
函数:
__init__(**kwargs)
这个函数是初始化函数,会生成一个Example对象,一般我们使用的时候,是传入一个tf.train.Features对象进去.
SerializeToString()
作用:把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串的.
2.3 tf.train.Features
作用:初始化Features对象,一般我们是传入一个字典,字典的键是一个字符串,表示名字,字典的值是一个tf.train.Feature对象.
tf.train.Features(feature ={"x":tf.train.Feature(bytes_list=tf.train.BytesList=(value=[])),"y":tf.train.Feature(int64_list=tf.train.Int64List(value = []))
})
2.4 tf.train.Feature
包含属性:
bytes_list 对应的对象是:tf.train.BytesList(value=[])
float_list 对应的对象是:tf.train.FloatList(value=[])
int64_list 对应的对象是:tf.train.Int64List(value=[])
2.5 tf.train.Int64List, tf.train.BytesList, tf.train.FloatList
这三个数据类型
tf.train.BytesList
(可强制转换自以下类型)
string
byte
tf.train.FloatList
(可强制转换自以下类型)
float
(float32
)double
(float64
)
tf.train.Int64List
(可强制转换自以下类型)
bool
enum
int32
uint32
int64
uint6
2.6 tf.io.TFRecordReader类
To create an input pipeline, you must start with a data source. For example, to construct a Dataset
from data in memory, you can use tf.data.Dataset.from_tensors()
or tf.data.Dataset.from_tensor_slices()
. Alternatively, if your input data is stored in a file in the recommended TFRecord format, you can use tf.data.TFRecordDataset()
.
将原始的特征数据处理成结构化的tfrecord数据集。
feature -> Features -> Example, 三者按顺序为包含关系
example = tf.train.Example(features=tf.train.Features(feature={...}))
tf_writer.write(example.SerializeToString()) # 序列化写入tfrecord
2.7 tf.data.TFRecordDataset类
Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline。
对Dataset中的元素做变换:Transformation
**Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。**通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。
常用的Transformation有:
- map
- batch
- shuffle
- repeat
下面就分别进行介绍。
(1)map
map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset,如我们可以对dataset中每个元素的值加1:
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0
(2)batch
batch就是将多个元素组合成batch,如下面的程序将dataset中的每个元素组成了大小为32的batch:
dataset = dataset.batch(32)
(3)shuffle
shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小:
dataset = dataset.shuffle(buffer_size=10000)
(4)repeat
repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(5)就可以将之变成5个epoch:
dataset = dataset.repeat(5)
如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常:
dataset = dataset.repeat()
3.写操作
下面以两个例子来说明一下写操作。
例子1:鸢尾花数据集-写操作
第一个例子是将鸢尾花的数据集iris.csv,将这部分数据写入到iris.tfrecords中。
import tensorflow as tfflags = tf.flags
FLAGS = flags.FLAGSflags.DEFINE_string("input","../data/iris.csv","输入")
flags.DEFINE_string("output","../data/iris.bin","输入")def main(_):inputfile = FLAGS.inputoutputfile = FLAGS.outputwrite = tf.io.TFRecordWriter(outputfile)idx = -1with tf.io.gfile.GFile(inputfile,'r') as reader:for line in reader:idx += 1 #为了跳过第一行的descriptionif idx < 1 :continuesplits = line.split(",")x = [float(i) for i in splits[:-1]]y = [int(splits[-1])]example = tf.train.Example(features=tf.train.Features(feature={'x': tf.train.Feature(float_list=tf.train.FloatList(value=x) # 方括号表示输入为list,一般tf.train.FloatList被用来处理浮点数),'y': tf.train.Feature(int64_list=tf.train.Int64List(value=y) # B_data本身就是列表,一般tf.train.Int64List被用来处理整数)}))print(example.SerializeToString())write.write(example.SerializeToString())write.close()if __name__ == "__main__":tf.app.run()
例子2:
4.读操作
在TensorFlow 1.3中,Dataset API是放在contrib包中的:
tf.contrib.data.Dataset
而在TensorFlow 1.4中,Dataset API已经从contrib包中移除,变成了核心API的一员:
tf.data.Dataset
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-RmKWUi95-1643547854392)(…/…/…/Images/TF-DataSet.png)]
例子1 鸢尾花数据集-读操作
读取数据中,
import tensorflow as tfflags = tf.flagsFLAGS = flags.FLAGSflags.DEFINE_string("input" , "../data/iris.bin" , "input file")def decord_fn(dataset):features = tf.io.parse_single_example(dataset ,{'x':tf.io.FixedLenFeature([4] , tf.float32 ) ,'y':tf.io.FixedLenFeature([1] , tf.int64 )})return featuresdef main(_):dataset = tf.data.TFRecordDataset(FLAGS.input).map(decord_fn)dataset = dataset.batch(10)NUM_EPOCHS = 2dataset = dataset.repeat(NUM_EPOCHS)iterator = dataset.make_one_shot_iterator()next_element = iterator.get_next()with tf.Session() as sess:sess.run((tf.global_variables_initializer(), tf.local_variables_initializer()))#sess.run(iterator.initializer)while True:try:x_train = sess.run(next_element) # 取出设定的print(x_train)# print("x = {x:.4f}, y = {y:.4f}".format(x_train['x'],x_train['y']))except tf.errors.OutOfRangeError:break
if __name__ == "__main__":#flags.mark_flag_as_required("input")#flags.mark_flag_as_required("output")tf.app.run()
import tensorflow as tfflags = tf.flagsFLAGS = flags.FLAGSflags.DEFINE_string("input" , "../data/iris.bin" , "input file")def decord_fn(dataset):features_dict = {"x":tf.io.FixedLenFeature([4] , tf.float32 ),"y":tf.io.FixedLenFeature([1] , tf.int64)}results = tf.io.parse_single_example(dataset , features=features_dict)return results
def main(_):dataset = tf.data.TFRecordDataset(FLAGS.input)dataset = dataset.map(decord_fn)dataset = dataset.batch(10)NUM_EPOCHS = 2dataset = dataset.repeat(NUM_EPOCHS)iterator = dataset.make_initializable_iterator()next_element = iterator.get_next()global_init = tf.global_variables_initializer()local_init = tf.local_variables_initializer()with tf.Session() as sess:sess.run([global_init , local_init])sess.run(iterator.initializer)while True:try:out = sess.run(next_element)print(out)except tf.errors.OutOfRangeError:break
if __name__ == "__main__":#flags.mark_flag_as_required("input")#flags.mark_flag_as_required("output")tf.app.run()
5.一些问题
在使用tfrecords的时候主要遇到一些问题,主要包括:
一是处理大量数据的的时候文件太大(一万张图片有700M),并且生成的太慢,
第二就是训练的时候将图片value和label打印出来,发现不对应。 但写入的时候label和image并没有错乱,也排查了读取的代码以及数据类型的问题,但是的确会出现image和label不对应的问题,我想问一下您在制作且用于训练的时候显示过吗,出现过这种情况吗?能够确定自己的image和label是对应的吗
Tensorflow读取数据-tf.data.TFRecordDataset相关推荐
- TensorFlow 2.0 - tf.data.Dataset 数据预处理 猫狗分类
文章目录 1 tf.data.Dataset.from_tensor_slices() 数据集建立 2. Dataset.map(f) 数据集预处理 3. Dataset.prefetch() 并行处 ...
- java套接字客户端_使用Java从客户端套接字读取数据(Read data from a client socket in Java)...
使用Java从客户端套接字读取数据(Read data from a client socket in Java) 我编写了从客户端套接字发送/接收数据的代码. 发送数据步骤已成功完成,但是当我想从套 ...
- tensorflow2读取数据P4: tf.data.TFRecordDataset创建Dataset
为啥要用tfrecord 使用tfrecord是为了更高效地读取数据,这种方式比较适合数据量大且数据集相对稳定的情况.tfrecord将数据存储成了二进制记录序列的格式,这格式比较对计算机的胃口,读取 ...
- tensorflow随笔-读写数据tf.data
模块:tf.data 定义在tensorflow/_api/v1/data/init.py 输入管道的tf.data.Dataset API 模块 实验模块:构建输入管道的实验API. 类 class ...
- tensorflow2数据读取P3: tf.data.Dataset.from_generator通过preprocessing.image.ImageDataGenerator构造Dataset
tf.data.Dataset.from_generator通过preprocessing.image.ImageDataGenerator构造Dataset 虽然自己定义生成器就可以构建datase ...
- 云端TensorFlow读取数据IO的高效方式
低效的IO方式 最近通过观察PAI平台上TensoFlow用户的运行情况,发现大家在数据IO这方面还是有比较大的困惑,主要是因为很多同学没有很好的理解本地执行TensorFlow代码和分布式云端执行T ...
- tensorflow基础:tf.data.Dataset.from_tensor_slices()
tf.data.Dataset.from_tensor_slices() 语义解释:from_tensor_slices,从张量的切片读取数据. 工作原理:将输入的张量的第一个维度看做样本的个数,沿其 ...
- 记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices
tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple).列表(list).张量(tensor)等特征进行特征切片.切片的 ...
- tensorflow基础:tf.data.Dataset.from_tensor_slices() 与 tf.data.Dataset.from_generator()的异同
tf.data.Dataset.from_tensor_slices(tensor): -->将tensor沿其第一个维度切片,返回一个含有N个样本的数据集(假设tensor的第一个维度为N). ...
最新文章
- Nancy in .Net Core学习笔记 - 初识Nancy
- it程序员刷题 面试 中文网站收集
- linux 协议栈之socket,Linux协议栈之BSD和INET socket层(一)
- [JSOI2007]文本生成器
- intellij-IDE运行Java程序报错:java: -source 1.5 中不支持 lambda 表达式 有用
- PLSQL 安装+配置( Oracle数据库连接工具 )
- 大数据及hadooop简介
- 前端 鼠标一次移动半个像素_Web前端(二):CSS3
- 亚马逊警用刷脸计划小小受阻,但原因并不是贝佐斯妥协
- NRF52832学习笔记
- 《考勤信息管理系统》数据库 课程设计
- 计算机求和公式IFEROR,IFERROR函数详解_Excel公式教程
- 计算机组装实验老毛桃u盘报告,老毛桃winpe硬盘安装版制作教程
- mysql 重建索引,mysql优化之索引重建
- emqx配置ssl/tsl实现双向认证
- Facebook的新算法可以预测出你的贫富阶级
- Caesar加密与解密
- mysql查询删除的数据历史记录_查询数据库各种历史记录
- 获取上个月的第一天和最后一天和当前月最后一天
- Appium JAVA ios 设备 AUT not install
热门文章
- mybatis逆向工程和批量插入
- python 生成动态库_Python 项目转.so动态库
- 破解boson netsim for ccnp 7.06(测试内链接有效)
- 如何通过nodejs快速搭建一个服务器
- 单纯形法和对偶单纯形法
- 人在新加坡,刚下飞机,原地失业!上交大佬刚到新加坡,就被虾皮取消了offer,作者发声了......
- 可惜,离职在家“苦修”一年半最终还是与字节offer擦肩而过
- Solidworks_ Flexnet_Server怎么删除?
- 密码学在信息安全领域的应用
- d3.js——图形缩放平移操作