TensorFlow数据读取方式:Dataset API

Datasets:一种为TensorFlow 模型创建输入管道的新方式。把数组、元组、张量等转换成DatasetV1Adapter格式

Dataset api有方法加载和操作数据,并将其输入到您的模型中。Dataset api与Estimators api很匹配。
下图是tensorflow API的完整架构图:

Datasets API是由以下图中所示的类组成:

其中:
**Dataset:**基类,包含创建和转换数据集的方法。还允许您从内存中的数据或Python中初始化数据集
生产者。
TextLineDataset: 从文本文件中读取行 (txt,csv…)。
TFRecordDataset: 从TFRecord文件 读取记录。
FixedLengthRecordDataset: 从二进制文件中读取固定大小的记录 。
Iterator: 提供一次访问一个数据集元素的方法
总之,Datasets API实现了从内存或者硬盘文件中加载数据组成数据集,同时对数据集进行一系列变换操作,最终将数据集提供给其他API使用的一系列功能。下面,本文就将从这三个方面对Datasets API进行介绍。

加载数据形成数据集

dataset = tf.data.Dataset.from_tensor_slices(data #数据)#按第一维度进行切分,返回dataset形式数据#shapes:切分后的维度#tf.float64:最小元素的类型

从内存或迭代器中加载数据:

数据集的单个元素包含一个或多个tf.Tensor对象,叫做分量。可能是单个张量,
张量的元组,或张量的嵌套元组。除了元组,您还可以使用collections.namedtuple或dictionary
将字符串映射到张量以表示数据集的单个元素。

# 一个张量 :
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10]))
# 张量的元组:
dataset2 =    tf.data.Dataset.from_tensor_slices((tf.random_uniform([4]),tf.random_uniform([4,100])))# 张量的元组, mnist_data是一个objection
images = mnist_data.train.images.reshape([-1, 28, 28, 1])
labels = mnist_data.train.labels
dataset = tf.contrib.data.Dataset.from_tensor_slices((images, labels))# 张量的嵌套元组:
dataset3 = tf.data.Dataset.zip((dataset1, dataset2))
# 一个集合.可命名元组或者一个字典映射字符串成张量
dataset = tf.data.Dataset.from_tensor_slices({"a": tf.random_uniform([4]),"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})

读取数组

import tensorflow as tf
import numpy as np
# 创建dataset
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))#<DatasetV1Adapter shapes: (), types: tf.float64>
# 实例化了一个Iterator
iterator = dataset.make_one_shot_iterator()#<tensorflow.python.data.ops.iterator_ops.Iterator object at 0x000002016B501CC0>
# 从iterator里取出一个元素
one_element = iterator.get_next()# <tf.Tensor 'IteratorGetNext:0' shape=() dtype=float64>
with tf.Session() as sess:for i in range(5):print(sess.run(one_element))# 1.0# 2.0# 3.0# 4.0# 5.0

读取矩阵

import tensorflow.contrib.eager as tfe
import tensorflow as tf
import numpy as np
tfe.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(np.array([[[1, 2, 3,],[4, 5, 6]],[[7, 8, 9,],[10,11,12]]]))
print("dataset:",dataset)#<DatasetV1Adapter shapes: (2, 3), types: tf.int32>
for one_element in tfe.Iterator(dataset):print(one_element)#tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)#tf.Tensor([6  7  8  9 10], shape=(5,), dtype=int32)#tf.Tensor([11 12 13 14 15], shape=(5,), dtype=int32)

读取字典

import tensorflow.contrib.eager as tfe
import tensorflow as tf
import numpy as np
tfe.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices({"a": np.array([[1, 2, 3, 4, 5],[6, 7, 8, 9, 10],[11, 12, 13, 14, 15]]),  # 维度3x5"b": np.array([1.0, 2.0, 3.0])  # 维度3x0,注意:第一维度3要与上面相同})
print("dataset:",dataset)#<DatasetV1Adapter shapes: {a: (5,), b: ()}, types: {a: tf.int32, b: tf.float64}>
for one_element in tfe.Iterator(dataset):print(one_element)#{'a': <tf.Tensor: id=9, shape=(5,), dtype=int32, numpy=array([1, 2, 3, 4, 5])>, 'b': <tf.Tensor: id=10, shape=(), dtype=float64, numpy=1.0>}# {'a': <tf.Tensor: id=13, shape=(5,), dtype=int32, numpy=array([ 6,  7,  8,  9, 10])>, 'b': <tf.Tensor: id=14, shape=(), dtype=float64, numpy=2.0>}# {'a': <tf.Tensor: id=17, shape=(5,), dtype=int32, numpy=array([11, 12, 13, 14, 15])>, 'b': <tf.Tensor: id=18, shape=(), dtype=float64, numpy=3.0>}

文本文件:

filepaths = ["/var/data/file1.txt", "/var/data/file2.txt"]
dataset = tf.data.TextLineDataset(filepaths)

读取图片image

import tensorflow as tf
import glob
import os
# 函数的功能时将filename对应的图片文件读进来,并缩放到统一的大小
def _parse_function(filename,):image_string = tf.read_file(filename)image_decoded = tf.image.decode_bmp(image_string,)# 根据图片类型,选择以下函数# image_decoded = tf.image.decode_png(image_string)# image_decoded = tf.image.decode_jepg(image_string)# image_decoded = tf.image.decode_git(image_string)image_resized = tf.image.resize_images(image_decoded,[64, 64]#图像维度变化)return image_resized
def walk_type(path, file_type):paths = glob.glob(os.path.join(path,#存放图片的文件夹路径file_type # 文件类型))# path下所有file_type类型的文件的路径列表return paths
paths = walk_type('dataset/*/','*.bmp')#图片路径列表
filenames = tf.constant(paths)
dataset = tf.data.Dataset.from_tensor_slices((filenames))
dataset = dataset.map(_parse_function)
dataset = dataset.shuffle(buffer_size=1000).batch(1).repeat(10)

tfrecords文件:

filepaths = ["/data/file1.tfrecord", "/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filepaths)

二进制文件:

filepaths = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in range(1, 6)]
image_bytes = image.height * image.width * image.depth
record_bytes = label_bytes + image_bytes
dataset = tf.data.FixedLengthRecordDataset(filepaths,record_bytes)

Datasets API支持一系列的变换操作

Datasets API支持 repeat、map、shuffle、batch等变换。

(1)repeat是将整个数据集重复多次,相当于一个或多个epoch,接受的参数数字代表repeat次数,若为空则无限重复

# Repeat infinitely.
dataset = tf.data.TFRecordDataset(filenames).repeat()

(2)map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset。通常用于数据变换或者解析与编码文件数据。

def parser(self, serialized_example):"""解析单个tf.Example变成图片和标签张量."""features = tf.parse_single_example(serialized_example,features={'image': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64),})image = tf.decode_raw(features['image'], tf.uint8)image.set_shape([DEPTH * HEIGHT * WIDTH])# 从[depth * height * width]变换维度到[depth, height, width].image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),tf.float32)label = tf.cast(features['label'], tf.int32)# 自定义预处理image = self.preprocess(image)return image, label
dataset = dataset.map(self.parser, num_threads=batch_size, output_buffer_size=2 * batch_size)
def decode_csv(line):parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])label = parsed_line[-1:] # 最后一个元素是labeldel parsed_line[-1] # 删除最后一个元素features = parsed_line # 所有的(除了最后一个元素)都是特性d = dict(zip(feature_names, features)), labelreturn d
dataset = (tf.data.TextLineDataset(file_path) # 读文本文件.skip(1) # 跳过标题行.map(decode_csv)) # 通过应用decode_csv fn转换每个元素

(3)shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小。单位是以图片(张量)为单位,而不是byte;

# 可能的洗牌记录
if subset == 'train' or shuffle:min_queue_examples = int(Cifar10DataSet.num_examples_per_epoch(subset) * 0.4)# 确保容量足够大,以提供良好的随机洗牌。dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)

(4)batch就是将多个元素组合成batch,接受一个batch_size的参数。

# Batch 组合
dataset = dataset.batch(batch_size)

迭代器

构建了表示输入数据的数据集之后,下一步是创建一个迭代器来访问来自该数据集的元素。Dataset API目前支持以下迭代器,以提高复杂程度:

one-shot :

一次性迭代器是迭代器的最简单形式,它只支持遍历数据集一次,不需要显式初始化。一次性迭代器 处理现有基于队列的输入管道支持的几乎所有情况 ,但是它们不支持参数化。

iterator = dataset.make_one_shot_iterator()
image_batch, label_batch = iterator.get_next()
#返回image_batch label_batch
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
for i in range(100):value = sess.run(next_element)assert i == value

initializable

可初始化迭代器要求您运行显式迭代器。在使用它之前的初始化操作。作为对这种不便的交换,它允许您使用一个或多个tf.placeholder()张量对数据集的定义进行参数化,当您初始化迭代器时可以提供这些张量。

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# 在具有10个元素的数据集上初始化一个迭代器。
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):value = sess.run(next_element)assert i == value
# 在具有100个元素的数据集上初始化相同的迭代器。
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):value = sess.run(next_element)assert i == value
# 将训练数据加载到两个NumPy数组中
features = data["features"]
labels = data["labels"]
# 假设“features”的每一行都与“labels”对应
assert features.shape[0] == labels.shape[0]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# ['dataset'上的其他转换…...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})

reinitializable

可重新初始化的迭代器可以从多个不同的数据集对象初始化 。 例如,您可能有一个训练输入管道,它使用对输入图像的随机扰动来改进泛化,还有一个验证输入管道,它评估未修改数据的预测。这些管道通常使用具有相同结构的不同数据集对象(即,对于每个组件具有相同的类型和兼容的形状)。
最后看两个比较完整的例子:
要在tf.estimator.Estimator的input_fn中使用数据集,我们建议使用Dataset.make_one_shot_iterator()。例如:

#解析数据
def parser(self,serialized_example):"""解析单个tf.Example变成图片和标签张量."""features = tf.parse_single_example(serialized_example,features={'image': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64),})image = tf.decode_raw(features['image'], tf.uint8)image.set_shape([DEPTH * HEIGHT * WIDTH])# 从[深度*高度*宽度]到[深度、高度、宽度].image = tf.cast(tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),tf.float32)label = tf.cast(features['label'], tf.int32)# 自定义预处理.image = self.preprocess(image)return image, label
#变换数据
def preprocess(self,image):"""在[高度,宽度,深度]布局中对单个图像进行预处理."""if subset == 'train' and shuffle:# Pad 4像素在每个尺寸的特征图,在小批量完成image = tf.image.resize_image_with_crop_or_pad(image, 40, 40)image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])image = tf.image.random_flip_left_right(image)# 因为这些操作不是可交换的,所以可以考虑将它们的操作顺序随机化。#注意:由于per_image_standardization会使平均值为零并使stddev单元为零,所以这可能不会有任何影响,请参阅tensorflow #1458。distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)# 减去平均值,除以像素的方差float_image = tf.image.per_image_standardization(distorted_image)return image
def input_fn(self,data_dir,batch_size,subset):if subset in ['train', 'validation', 'eval']:filepaths = [os.path.join(data_dir, subset + '.tfrecords')]else:raise ValueError('Invalid data subset "%s"' % subset)dataset = tf.contrib.data.TFRecordDataset(filepaths).repeat()# 解析记录.dataset = dataset.map(self.parser, num_threads=batch_size, output_buffer_size=2 * batch_size)# 可能的洗牌记录.if subset == 'train':min_queue_examples = int(Cifar10DataSet.num_examples_per_epoch(subset) * 0.4)# 确保容量足够大,以提供良好的随机变换。dataset = dataset.shuffle(buffer_size=min_queue_examples + 3 * batch_size)# Batch it up.dataset = dataset.batch(batch_size)iterator = dataset.make_one_shot_iterator()image_batch, label_batch = iterator.get_next()return image_batch, label_batch

tf.train.MonitoredTrainingSession API简化了在分布式设置中运行TensorFlow的许多方面 , MonitoredTrainingSession使用tf.errors.OutOfRangeError表示训练已经完成,因此要与tf.dataAPI一起使用,我们建议使用Dataset.make_one_shot_iterator()。例如:

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(32)
dataset = dataset.repeat(num_epochs)
iterator = dataset.make_one_shot_iterator()
next_example, next_label = iterator.get_next()
loss = model_function(next_example, next_label)
training_op = tf.train.AdagradOptimizer(...).minimize(loss)
with tf.train.MonitoredTrainingSession(...) as sess:while not sess.should_stop():sess.run(training_op)

查看dataset:DatasetV1Adapter的方法

import tensorflow as tf
import numpy as np
# 把数组转化为dataset模式:DatasetV1Adapter
data = tf.data.Dataset.from_tensor_slices(np.array([1, 2, 3, 4, 5]))
print("data:",data) #<DatasetV1Adapter shapes: (), types: tf.int32>
# 建立迭代器,并进行迭代操作(把data从DatasetV1Adapter格式转化为tensor格式)
element = data.make_one_shot_iterator().get_next()
print('element:',element) #Tensor("IteratorGetNext:0", shape=(), dtype=int32)
with tf.Session() as sess:try:while True:print(sess.run(element))except tf.errors.OutOfRangeError:print("Out range !")

TensorFlow数据读取方式:Dataset API,以及如何查看dataset:DatasetV1Adapter的方法相关推荐

  1. TensorFlow全新的数据读取方式:Dataset API入门教程

    Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline. 此前,在TensorFlow中读取数据一般有两种方法: 1.使用pl ...

  2. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  3. XML数据读取方式性能比较(一)

    几个月来,疑被SOA,一直在和XML操作打交道,SQL差不多又忘光了.现在已经知道,至少有四种常用人XML数据操作方式(好像Java差不多),不过还没有实际比较过这些方式各有哪些特点或优劣.正好看到网 ...

  4. OJ平台常用数据读取方式

    OJ平台常用数据读取方式 1.C格式 读取判断: scanf函数返回值: 1.大于0时,表示正确接收的参数个数. 2.等于0时,表示输入不匹配,无法正确输入任何值. 3.等于-1时,表示EOF(end ...

  5. TensorFlow学习笔记(二十二) tensorflow数据读取方法总结

    本文PDF文件下载地址:http://download.csdn.net/download/qq_36330643/9938438 Tensorflow的数据读取有三种方式: 1. Preloaded ...

  6. linux 读取大量图片 内存,10 张图帮你搞定 TensorFlow 数据读取机制

    导读 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解 ...

  7. 十图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  8. tensorflow数据读取机制

    原博客地址:https://zhuanlan.zhihu.com/p/27238630 代码地址:https://github.com/hzy46/Deep-Learning-21-Examples/ ...

  9. Tensorflow数据读取之tfrecord

    文章目录 tfrecord tfrecord的使用流程 写入tfrecord文件 读取tfrecord文件 tfrecord中的数据格式 tfrecord中对于变长数据和定长数据的处理 tfrecor ...

最新文章

  1. C# 运算符的优先级
  2. 汇编call指令详解_我也能写出雷军的的代码吗?最好的汇编语言入门教程在这里!...
  3. 基于linux操作系统Mysql的基本操作(一)
  4. 【Java 网络编程】TCP 服务器端 客户端 简单示例
  5. [ATF]-TEE/REE系统切换时ATF的寄存器的保存和恢复
  6. word文档怎么限制编辑(禁止编辑、只读)?
  7. python中csv模块读写文件
  8. 代码编辑器Sublime Text 3 免费使用方法与简体中文汉化包下载
  9. kali下生成web端后门
  10. ef core mysql 生成迁移失败_EFCore + MySql codeFirst 迁移 Migration出现的问题
  11. Oracle数据库导入csv文件(sqlldr命令行)
  12. 神奇_程序cmd命令窗口运行会自动停止_加解决方案---Linux工作笔记045
  13. SQL卸载重装实例名重复问题
  14. 有关正则的知识点梳理
  15. antd vue 位置变动 固钉消失_使用vue封装固钉Affix组件,滚动条到底部时自动吸附,离开底部时自动相对窗口固定...
  16. java进制转换所有方法_Java进制转换方法整理
  17. vue组件eleme 时间选择器问题
  18. 调试华为MML协议备忘
  19. 墨刀怎么注册_墨刀的使用
  20. 分享学JavaScript的第七天

热门文章

  1. 使用Linux服务器搭建个人深度学习环境
  2. glob.glob 函数读取文件
  3. 局部敏感哈希(Locality Sensitive Hashing)二三问[2]
  4. linux平台的实验描述,基于LINUX的操作系统实验平台的设计与实现
  5. supervisor 守护php,laravel队列之Supervisor守护进程(centos篇)
  6. 在 RedHat 使用 gdc-client 下载 TCGA 数据
  7. 如何在 Linux 中查看目录大小?
  8. 钱海丰:农药污染下的土壤微生态响应与风险预测​(今晚7点半)
  9. QIIME 2教程. 14数据评估和质控q2-quality-control(2021.2)
  10. 最后2周 | 高级转录组分析和R语言数据可视化第十一期 (报名线上课还可免费参加线下课)...