tensorflow 训练数据载入

  • 1. tf.data.Dataset
  • 2. dataset 创建数据集的方式
    • 2.1 tf.data.Dataset.from_tensor_slices()
    • 2.2 tf.data.TextLineDataset()
    • 2.3 tf.data.FixedLengthRecordDataset()
    • 2.4 tf.data.TFRecordDataset()
  • 3. dateset 迭代操作iterator
    • 3.1 make_one_shot_iterator()
    • 3.2 make_initializable_iterator()
    • 3.3 reinitializable iterator()
    • 3.4 feedable iterator()
  • 4. dataset的map、batch、shuffle、repeat操作
  • 5. 非eager/eager 模式
    • 5.1 非eager模式demo
    • 5.2 eager模式demo

1. tf.data.Dataset

参考Google官方给出的Dataset API中的类图,Dataset 务于数据读取,构建输入数据的pipeline。
Dataset可以看作是相同类型“元素”的有序列表,可使用Iterator迭代获取Dataset中的元素。

2. dataset 创建数据集的方式

2.1 tf.data.Dataset.from_tensor_slices()

从tensor中创建数据集,数据集元素以tensor第一维度为划分。

import tensorflow as tf
import numpy as np
# 切分传入Tensor的第一个维度,生成相应的dataset。
dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
# 如果传入字典,那切分结果就是字典按值切分,元素型如{"a":[1],"b":[x,x]}
dataset2 = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                       "b": np.random.uniform(size=(5, 2))}
)

2.2 tf.data.TextLineDataset()

读取文件数据创建数据集,数据集元素为文件的每一行

2.3 tf.data.FixedLengthRecordDataset()

从一个文件列表和record_bytes中创建数据集,数据集元素是文件中固定字节数record_bytes的内容。

2.4 tf.data.TFRecordDataset()

读TFRecord文件创建数据集,数据集中的一条数据是一个TFExample。

dataset = tf.data.TFRecordDataset(filenames = [tfrecord_file_name]) # [tfrecord_file_name] tfrecord 文件列表

frecord 文件中的特征一般都经过tf.train.Example 序列化,在使用前需要先解码tf.train.Example.FromString()

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

3. dateset 迭代操作iterator

iterator是从Dataset对象中创建出来的,用于迭代取数据集中的元素。

3.1 make_one_shot_iterator()

dataset.make_one_shot_iterator()–只能从头到尾读取一次dataset。如果一个dataset中元素被读取完了再sess.run()的话,会抛出tf.errors.OutOfRangeError异常。因此可以在外界捕捉这个异常以判断数据是否读取完。

import tensorflow as tf
import numpy as np
# 切分传入Tensor的第一个维度,生成相应的dataset。如果传入字典,那切分结果就是字典按值切分
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
iterator = dataset.make_one_shot_iterator()    # 只能从头到尾读取一次
one_element = iterator.get_next()              # 从iterator里取出一个元素。
# 处于非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")

3.2 make_initializable_iterator()

dataset.make_initializable_iterator()–支持placeholder dataset 的迭代操作,这可以方便通过参数快速定义新的Iterator。

 # limit相当于一个参数,它规定了Dataset中数的上限, 使用make_initializable_iterator
limit = tf.placeholder(dtype=tf.int32, shape=[])
dataset = tf.data.Dataset.from_tensor_slices(tf.range(start=0, limit=limit))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:sess.run(iterator.initializer, feed_dict={limit: 10})for i in range(10):value = sess.run(next_element)assert i == value

sess.run(next_element) 每run一次, 数据迭代器指针就会往下移动一个。TF官网学习(9)–使用iterator注意事项

如果在dataset的构建时,一次性读入了所有的数据,会导致计算图变得很大,给传输、保存带来不便。make_initializable_iterator()支持placeholder 操作,仅在需要传输数据时再取数据。

# 从硬盘中读入两个Numpy数组
with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})

3.3 reinitializable iterator()

dataset.reinitializable iterator() --待补

3.4 feedable iterator()

dataset.feedable iterator()–待补

4. dataset的map、batch、shuffle、repeat操作

map–接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset。

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0

batch–将多个元素组合成一个batch

dataset = dataset.batch(16)    # 将数据集划分为batch size为16的小批次

shuffle– 打乱dataset中的元素,参数buffersize。打乱的实现机理:从buffer_size 大小的部buffer中随机抽取元素,组成打乱后的数据集。buffer中被抽走的元素由原数据集中的后续元素补位置。 重复‘抽取-补充’这个过程,直至buffer为空。
会在batch之间打乱数据–疑问多tfrecord 文件是一次性构建数据集还是一条一条的构建

buffer_size 的大小详见tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

dataset = dataset.shuffle(buffer_size=10000)

repeat– 将整个序列重复多次,用来处理机器学习中的epoch,假设原始数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

dataset = dataset.repeat(5)

5. 非eager/eager 模式

5.1 非eager模式demo

在非Eager模式下,Dataset中读出的一个元素一般对应一个batch的Tensor,我们可以使用这个Tensor在计算图中构建模型。

import tensorflow as tf
import numpy as np
# 切分传入Tensor的第一个维度,生成相应的dataset。如果传入字典,那切分结果就是字典按值切分
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
iterator = dataset.make_one_shot_iterator()    # 只能从头到尾读取一次
one_element = iterator.get_next()              # 从iterator里取出一个元素。
# 处于非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")

5.2 eager模式demo

在Eager模式下,Dataset建立Iterator的方式有所不同,此时通过读出的数据就是含有值的Tensor,方便调试。

import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
for one_element in tfe.Iterator(dataset):print(one_element)             # 可直接读取数据

参考文献:TensorFlow全新的数据读取方式:Dataset API入门教程


TensorFlow(2)-训练数据载入相关推荐

  1. 数据载入过慢?这里有一份TensorFlow加速指南

    王小新 编译自 Towards Data Science 量子位 出品 | 公众号 QbitAI 机器学习算法烂熟于心,网络结构顺手拈来,但是如果数据集载入时耗费大量时间,那整个训练时间就会大大增加. ...

  2. tensorflow 读取cifar_浅入浅出TensorFlow 4 - 训练CIFAR数据

    #coding=utf-8 import cifar10,cifar10_input import tensorflow as tf import numpy as np import time # ...

  3. 关于使用tensorflow object detection API训练自己的模型-补充部分(代码,数据标注工具,训练数据,测试数据)

    之前分享过关于tensorflow object detection API训练自己的模型的几篇博客,后面有人陆续碰到一些问题,问到了我解决方法.所以在这里补充点大家可能用到的东西.声明一下,本人专业 ...

  4. 【两个例子】Tensorflow+Inception-v3训练自己的数据+分析retrain源码

    [学习笔记]Tensorflow+Inception-v3训练自己的数据 https://www.cnblogs.com/EstherLjy/p/9861034.html TensorFlow学习笔记 ...

  5. tensorflow sigmoid 如何计算训练数据的正确率_量化训练:Quantization Aware Training in Tensorflow(一)...

    本文的内容包括对神经网络模型量化的基本介绍.对Tensorflow量化训练的理解与上手实操. 此外,后续系列还对量化训练中的by pass和batch norm两种情况进行补充解释,欢迎点击浏览,量化 ...

  6. 谷歌BERT预训练源码解析(一):训练数据生成

    目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...

  7. 仅50张图片训练数据的AI分类技术PK​,阿里拿下ECCV 2020竞赛冠军

    出品 | AI科技大本营(ID:rgznai100) 近日,两年一度的世界计算机视觉领域顶会ECCV 2020的各项挑战赛结果出炉,在图像分类赛中,阿里安全的高效AI分类技术超越三星.深兰科技.同济大 ...

  8. 动动手,用TensorFlow API训练出自己的目标检测模型

    TensorFlow内包含了一个强大的物体检测API,我们可以利用这API来训练自己的数据集实现特殊的目标检测. Dat Tran就分享了自己实现可爱的浣熊检测器的经历,在文章中作者把检测器的训练流程 ...

  9. 使用Tensorflow操作MNIST数据

    MNIST是一个非常有名的手写体数字识别数据集,在很多资料中,这个数据集都会被用作深度学习的入门样例.而TensorFlow的封装让使用MNIST数据集变得更加方便.MNIST数据集是NIST数据集的 ...

最新文章

  1. yolov5训练自己的数据集
  2. vsftpd的配置文件路径,是在哪里指定的?
  3. poj1988(判断一个结点下面有多少个结点,推荐)
  4. python下selenium测试报告整合
  5. uni-app使用input框 v-model双向绑定不起作用解决方案
  6. 有无关通配符的相等操作符
  7. C++的安全类型转换的讨论
  8. CVPR 2022 | 天大本科生论文入选!深度学习长尾分类新SOTA
  9. SpringBoot笔记整理(四)
  10. 云原生安全助力在线教育三分钟搞定安全防护
  11. 智能关机软件 c语言,智能关机软件
  12. 计算机病毒的危害主要体现于对计算机系统的信息破坏和,2014年中央电大专科信息技术应用理论题.doc...
  13. python中or关键字在变量赋值时的用法
  14. 人口吸引力超宁波、南京,这座背靠上海的小城开挂了?
  15. SCI各领域国际顶尖学术期刊一览
  16. 21天早睡早起习惯计划
  17. OSPF多实例路由防环检测功能介绍
  18. Windows系列系统 修改键盘默认对应键值(修改ctrl与fn位置, 解决键盘重要键损坏问题)
  19. 图书管理系统 C语言链表实现 学校大作业功能齐全(书籍信息以及用户信息保存在附带的txt文件中)
  20. can only accept this command while in the powered on state(iOS蓝牙 打开app后的第一次扫描要扫描两次)

热门文章

  1. 编译Mysql 5.5时报do_abi_check错误
  2. 基于C8051F040单片机的CAN总线测试模式研究
  3. Platform Builder实践之配置文件
  4. 服务器部署 配置jetty运行参数_Zookeeper+websocket实现对分布式服务器的实时监控...
  5. linux 网络相关,Linux系统管理员必备的21个网络相关监控
  6. cobol to java_cobol to java
  7. 【转】RocketMQ的一些特性(生产者消费者配置参数的含义)
  8. SharePoint学习札记[2] — MOSS2007体系结构概述
  9. 关于C#程序的单元测试
  10. linux磁盘永久挂载教程,linux 永久磁盘挂载