原文链接: tf data 常用操作

上一篇: es2018 常用 特性

下一篇: WOMic 使用wifi 将手机作为电脑麦克风音频输入

使用zip 组合数据集

import tensorflow as tf
import numpy as npa = np.arange(10)
b = np.arange(10)
a = tf.data.Dataset().from_tensor_slices(a)
b = tf.data.Dataset().from_tensor_slices(b)# 如果通过下面的方式 输出 ((array([7, 2, 6]),), (array([7, 2, 6]),))
# a = tf.data.Dataset().from_tensor_slices((a,))
# b = tf.data.Dataset().from_tensor_slices((b,))
data = tf.data.Dataset.zip((a, b))
data = data.repeat(-1).shuffle(16).batch(3)iterator = data.make_initializable_iterator()
batch = iterator.get_next()
with tf.Session() as sess:sess.run(iterator.initializer)for i in range(10):# (array([3, 8, 7]), array([3, 8, 7]))print(sess.run(batch))

flat_map 与 map 的区别 map返回处理后的数据即可,flat_map 需要返回一个dataset

flat_map最终输出是单个数据对

下面这种写法会报错

data2 = data2.repeat(-1).flat_map(flat_fun).batch(3)def flat_map(self, map_func):"""Maps `map_func` across this dataset and flattens the result.Use `flat_map` if you want to make sure that the order of your datasetstays the same. For example, to flatten a dataset of batches into adataset of their elements:def map(self, map_func, num_parallel_calls=None):"""Maps `map_func` across the elements of this dataset.This transformation applies `map_func` to each element of this dataset, andreturns a new dataset containing the transformed elements, in the sameorder as they appeared in the input.

但是map可以采用上面的写法

import tensorflow as tf
import numpy as npdef map_fun(x, y):print('map ', x, y)  # map  Tensor("arg0:0", shape=(), dtype=int32) Tensor("arg1:0", shape=(), dtype=int32)return x, ydef flat_fun(x, y):print('flat ', x, y)  # flat  Tensor("arg0:0", shape=(?,), dtype=int32) Tensor("arg1:0", shape=(?,), dtype=int32)return tf.data.Dataset().from_tensor_slices((x, y))a = np.arange(10)
b = np.arange(10)data1 = tf.data.Dataset().from_tensor_slices((a, b))
data1 = data1.repeat(-1).batch(3).map(map_fun)
iterator1 = data1.make_initializable_iterator()
batch1 = iterator1.get_next()data2 = tf.data.Dataset().from_tensor_slices((a, b))
data2 = data2.repeat(-1).batch(3).flat_map(flat_fun)
iterator2 = data2.make_initializable_iterator()
batch2 = iterator2.get_next()with tf.Session() as sess:sess.run([iterator1.initializer, iterator2.initializer])for i in range(10):print(sess.run(batch1))for i in range(10):print(sess.run(batch2))(array([0, 1, 2]), array([0, 1, 2]))
(array([3, 4, 5]), array([3, 4, 5]))
(array([6, 7, 8]), array([6, 7, 8]))
(array([9, 0, 1]), array([9, 0, 1]))
(array([2, 3, 4]), array([2, 3, 4]))
(array([5, 6, 7]), array([5, 6, 7]))
(array([8, 9, 0]), array([8, 9, 0]))
(array([1, 2, 3]), array([1, 2, 3]))
(array([4, 5, 6]), array([4, 5, 6]))
(array([7, 8, 9]), array([7, 8, 9]))
(0, 0)
(1, 1)
(2, 2)
(3, 3)
(4, 4)
(5, 5)
(6, 6)
(7, 7)
(8, 8)
(9, 9)

消耗 NumPy 数组

如果您的所有输入数据都适合存储在内存中,则根据输入数据创建 Dataset 的最简单方法是将它们转换为 tf.Tensor对象,并使用 Dataset.from_tensor_slices()。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]dataset = tf.data.Dataset.from_tensor_slices((features, labels))

请注意,上面的代码段会将 features 和 labels 数组作为 tf.constant() 指令嵌入在 TensorFlow 图中。这样非常适合小型数据集,但会浪费内存,因为会多次复制数组的内容,并可能会达到 tf.GraphDef 协议缓冲区的 2GB 上限。

作为替代方案,您可以根据 tf.placeholder() 张量定义 Dataset,并在对数据集初始化 Iterator 时馈送 NumPy 数组。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]# Assume that each row of `features` corresponds to the same row as `labels`.
assert features.shape[0] == labels.shape[0]features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})

消耗 TFRecord 数据

tf.data API 支持多种文件格式,因此您可以处理那些不适合存储在内存中的大型数据集。例如,TFRecord 文件格式是一种面向记录的简单二进制格式,很多 TensorFlow 应用采用此格式来训练数据。通过 tf.data.TFRecordDataset 类,您可以将一个或多个 TFRecord 文件的内容作为输入管道的一部分进行流式传输。

# Creates a dataset that reads all of the examples from two files.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)

TFRecordDataset 初始化程序的 filenames 参数可以是字符串、字符串列表,也可以是字符串 tf.Tensor。因此,如果您有两组分别用于训练和验证的文件,则可以使用 tf.placeholder(tf.string) 来表示文件名,并使用适当的文件名初始化迭代器:

filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...)  # Parse the record into tensors.
dataset = dataset.repeat()  # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()# You can feed the initializer with the appropriate filenames for the current
# phase of execution, e.g. training vs. validation.# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames: validation_filenames})

消耗 CSV 数据

CSV 文件格式是用于以纯文本格式存储表格数据的常用格式。tf.contrib.data.CsvDataset 类提供了一种从符合 RFC 4180 的一个或多个 CSV 文件中提取记录的方法。给定一个或多个文件名以及默认值列表后,CsvDataset 将生成一个元素元组,元素类型对应于为每个 CSV 记录提供的默认元素类型。像 TFRecordDataset 和 TextLineDataset 一样,CsvDataset 将接受 filenames(作为 tf.Tensor),因此您可以通过传递 tf.placeholder(tf.string) 进行参数化。

# Creates a dataset that reads all of the records from two CSV files, each with
# eight float columns
filenames = ["/var/data/file1.csv", "/var/data/file2.csv"]
record_defaults = [tf.float32] * 8  # Eight required float columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)

如果某些列为空,则可以提供默认值而不是类型。

默认情况下,CsvDataset 生成文件的每一列或每一行,这可能是不可取的;例如,如果文件以应忽略的标题行开头,或如果输入中不需要某些列。可以分别使用 header 和 select_cols 参数移除这些行和字段。

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values
record_defaults = [[0.0]] * 8
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults)# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [[0.0]] * 2  # Only provide defaults for the selected columns
dataset = tf.contrib.data.CsvDataset(filenames, record_defaults, header=True, select_cols=[2, 4])

解析 tf.Example 协议缓冲区消息

许多输入管道都从 TFRecord 格式的文件中提取 tf.train.Example 协议缓冲区消息(例如这种文件使用 tf.python_io.TFRecordWriter 编写而成)。每个 tf.train.Example 记录都包含一个或多个 “特征”,输入管道通常会将这些特征转换为张量。

# Transforms a scalar string `example_proto` into a pair of a scalar string and
# a scalar integer, representing an image and its label, respectively.
def _parse_function(example_proto):features = {"image": tf.FixedLenFeature((), tf.string, default_value=""),"label": tf.FixedLenFeature((), tf.int64, default_value=0)}parsed_features = tf.parse_single_example(example_proto, features)return parsed_features["image"], parsed_features["label"]# Creates a dataset that reads all of the examples from two files, and extracts
# the image and label features.
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_function)

解码图片数据并调整其大小

在用真实的图片数据训练神经网络时,通常需要将不同大小的图片转换为通用大小,这样就可以将它们批处理为具有固定大小的数据。

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def _parse_function(filename, label):image_string = tf.read_file(filename)image_decoded = tf.image.decode_jpeg(image_string)image_resized = tf.image.resize_images(image_decoded, [28, 28])return image_resized, label# A vector of filenames.
filenames = tf.constant(["/var/data/image1.jpg", "/var/data/image2.jpg", ...])# `labels[i]` is the label for the image in `filenames[i].
labels = tf.constant([0, 37, ...])dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)

使用 tf.py_func() 应用任意 Python 逻辑

为了确保性能,我们建议您尽可能使用 TensorFlow 指令预处理数据。不过,在解析输入数据时,调用外部 Python 库有时很有用。为此,请在 Dataset.map() 转换中调用 tf.py_func() 指令。

import cv2# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):image_decoded = cv2.imread(filename.decode(), cv2.IMREAD_GRAYSCALE)return image_decoded, label# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):image_decoded.set_shape([None, None, None])image_resized = tf.image.resize_images(image_decoded, [28, 28])return image_resized, labelfilenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
labels = [0, 37, 29, 1, ...]dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(lambda filename, label: tuple(tf.py_func(_read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

简单的批处理

最简单的批处理形式是将数据集中的 n 个连续元素堆叠为一个元素。Dataset.batch() 转换正是这么做的,它与 tf.stack() 运算符具有相同的限制(被应用于元素的每个组件):即对于每个组件 i,所有元素的张量形状都必须完全相同。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)iterator = batched_dataset.make_one_shot_iterator()
next_element = iterator.get_next()print(sess.run(next_element))  # ==> ([0, 1, 2,   3],   [ 0, -1,  -2,  -3])
print(sess.run(next_element))  # ==> ([4, 5, 6,   7],   [-4, -5,  -6,  -7])
print(sess.run(next_element))  # ==> ([8, 9, 10, 11],   [-8, -9, -10, -11])

使用填充批处理张量

上述方法适用于具有相同大小的张量。不过,很多模型(例如序列模型)处理的输入数据可能具有不同的大小(例如序列的长度不同)。为了解决这种情况,可以通过 Dataset.padded_batch() 转换来指定一个或多个会被填充的维度,从而批处理不同形状的张量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=[None])iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()print(sess.run(next_element))  # ==> [[0, 0, 0], [1, 0, 0], [2, 2, 0], [3, 3, 3]]
print(sess.run(next_element))  # ==> [[4, 4, 4, 4, 0, 0, 0],#      [5, 5, 5, 5, 5, 0, 0],#      [6, 6, 6, 6, 6, 6, 0],#      [7, 7, 7, 7, 7, 7, 7]]

您可以通过 Dataset.padded_batch() 转换为每个组件的每个维度设置不同的填充,并且可以采用可变长度(在上面的示例中用 None 表示)或恒定长度。也可以替换填充值,默认设置为 0。

tf data 常用操作相关推荐

  1. 记录 之 tf.data进行数据集处理常用的几个函数介绍

    1.tf.data.Dataset.from_tensor_slices(),前面的文章我们有介绍过. 2.tf.data.Dataset.shuffle(buffer_size = n) 这里的sh ...

  2. dataframe 添加一行_R语言Data Frame数据框常用操作

    来源 | R友舍 Data Frame一般被翻译为数据框,感觉就像是R中的表,由行和列组成,与Matrix不同的是,每个列可以是不同的数据类型,而Matrix是必须相同的. Data Frame每一列 ...

  3. Tensorflow读取数据-tf.data.TFRecordDataset

    tensorflow TFRecords文件的生成和读取方法 文章目录 tensorflow TFRecords文件的生成和读取方法 1. TFRecords说明 2.关键API 2.1 tf.io. ...

  4. TensorFlow常用操作:代码示例

    1,定义矩阵代码示例: import tensorflow as tftf.zeros([3,4]) #定义3行4列元素均为0的矩阵tensor=tf.constant([1,2,3,4])#定义一维 ...

  5. openlayers map获取全部feature_tf2.0基础-tf.data与tf.feature_column

    7.2.1 tf.data 使用 tf.data API 可以轻松处理大量数据.不同的数据格式以及复杂的转换.tf.data API 在 TensorFlow 中引入了两个新的抽象类: tf.data ...

  6. IOS沙盒Files目录说明和常用操作

    - (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launc ...

  7. 用Python轻松搞定Excel中的20个常用操作

    来源 |早起Python(ID: zaoqi-python) Excel与Python都是数据分析中常用的工具,本文将使用动态图(Excel)+代码(Python)的方式来演示这两种工具是如何实现数据 ...

  8. python和R对dataframe的常用操作的差异:head、tail、shape、dim、nrow、ncol、descirbe、summary、str

    python和R对dataframe的常用操作的差异:head.tail.shape.dim.nrow.ncol.descirbe.summary.str # python df.head() # R ...

  9. UITableView取消选中颜色、常用操作

    UITableView取消选中颜色.常用操作 使用空白view取代cell - (UITableViewCell *)tableView:(UITableView *)tableView cellFo ...

  10. phoenix的元数据一般存在哪里_Phoenix常用操作记录-阿里云开发者社区

    Apache Phoenix 常用操作 基础知识 1****.****Phoenix 主要技术点 a.将SQL转化为HBase Scan,将结果封装为JDBC Result Set. b.表的元数据保 ...

最新文章

  1. Makefile和Shell学习笔记
  2. 21天养成习惯?不一定
  3. 3-08. 栈模拟队列(25)(ZJU_PAT 模拟)
  4. Qt Creator指定运行设置
  5. python 更改输出的颜色_python 输出指定的颜色
  6. 前端学习(1514):vue-router使用步骤
  7. npm run dev/build/serve
  8. Python标准库socketserver实现UDP协议时间服务器
  9. hibernate 里面 mysql dialect 配置
  10. 新建UE4 c++类
  11. 基于ABBYY SDK 实现java版本 Hello 功能!
  12. Maven整合SSH项目(六)
  13. NSIS安装包制做软件的常用小技巧
  14. Part Ⅵ Transportation 交通??
  15. kafka 消费机制
  16. 学mei私聊问我【DISTINCT】关键字有什么作用?查一个字段和多个字段的区别?
  17. 译:Convolutional Two-Stream Network Fusion for Video Action Recognition
  18. CSAPP Lab:attacklab
  19. 数据库分库分表中间件对比(很全)
  20. 【JavaWeb学习】CSS(样式和布局)

热门文章

  1. xbox360使用_适用于Xbox 360的HD-DVD Player
  2. Java基础面试题50题
  3. 白话区块链 之4: 区块链分类与架构
  4. HTTP网络协议四:HTTP报文及报文字段说明
  5. 汉高软件服务器安装系统,如何安装了如指掌眼镜管理系统的服务器和客户端,还需要安装什么软件的?...
  6. Bean 工厂和Application contexts有什么区别?
  7. MySQL数据库锁机制
  8. 计算机ping使用的端口,Windows7系统中怎么Ping端口?利用telnet命令Ping 端口的方法...
  9. 使用git进行word版本管理
  10. windows10一键修改开机动画