这是本人关于tf.data的第二篇博文,第一篇基于TF-v1详细介绍了tf.data,但是v1和v2很多地方不兼容,所以替大家瞧瞧v2的tf.data模块有什么新奇之处。

TensorFlow版本:2.1.0

首先贴上TF v1版本的tf.data博文地址:《TensorFlow tf.data 导入数据(tf.data官方教程)》

文章目录

  • 使用 tf.data 构建数据输入通道
    • 1. 基础知识 ¶
      • 1.1 Dataset 结构介绍 ¶
    • 2. 读取输入数据 ¶
      • 2.1 读取Numpy数组 ¶
      • 2.2 读取Python生成器中的数据 ¶
      • 2.3 读取TFRecord数据 ¶
      • 2.4 读取text数据 ¶
      • 2.5 读取CSV数据 ¶
      • 2.5 从文件读取数据 ¶
    • 3. 数据集元素 batching ¶
      • 3.1 最简单的 batching(直接 stack) ¶
      • 4.2 将 Tensor 填充成统一大小,然后 batching ¶
    • 4. 训练工作流程 ¶
      • 4.1. 数据repeat多个epoch ¶
      • 4.2. 随机shuffle输入数据 ¶
    • 5. 数据预处理 ¶
      • 5.1 使用Dataset.map()进行数据预处理 ¶
      • 5.2 使用非TF函数进行数据预处理 ¶
      • 5.3 解析tf.Exampleprotocol buffer messages ¶
      • 5.4 时间序列windowing ¶
        • 5.4.1 使用batch ¶
        • 5.4.2 使用window ¶
      • 5.5 重采样 ¶
        • 5.5.1 Datasets.sampling
        • 5.5.2 experimental.rejection_resample
    • 6. 在高阶API中使用tf.data
      • 6.1 在 tf.keras 中使用 tf.data
      • 6.2 在 tf.estimator 中使用 tf.data

使用 tf.data 构建数据输入通道

tf.data API编写的数据输入通道简单、并且可重用度高。tf.data能够实现非常复杂的数据输入通道。例如:图像模型的数据输入管道可能会聚集来自分布式文件系统中文件的数据,对每个图像应用随机扰动,然后将随机选择的图像合并为一批进行训练。文本模型的数据输入管道可能涉及从原始文本数据中提取符号,将其转换为带有查找表的嵌入标识符,以及将不同长度的序列分批处理。tf.dataAPI使得处理大量数据,从不同数据格式读取数据以及执行复杂的转换成为可能。

tf.data API引入了tf.data.Dataset 这个抽象概念。它是一个元素组成的序列,每个元素可以由一个或多个部分组成。例如,图像的数据输入通道中,一个元素可以是由数据和标签组成的一个训练样本。

创建dataset的方法有两种:

  • 基于内存中的数据 或 硬盘中的一个或多个文件 建立Dataset
  • 通过对Dataset进行 transform 得到一个新的Dataset
from __future__ import absolute_import, division, print_function, unicode_literalsimport tensorflow as tfimport pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as npnp.set_printoptions(precision=4)

1. 基础知识 ¶

建立一个数据输入通道,一般需要从数据源开始。如果你的数据储存在内存中,你可以使用tf.data.Dataset.from_tensor()tf.data.Dataset.from_tensor_slices()创建Dataset。如果你的数据是TFRecord格式,你可以使用tf.data.TFRecordDataset()创建Dataset

一旦你有了一个Dataset对象,你可以通过调用它的方法对其进行变换产生一个新的 Dataset对象。

Dataset是一个Python可迭代对象。所以可以使用 for 循环来消耗它的元素:

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
dataset

<TensorSliceDataset shapes: (), types: tf.int32>

for elem in dataset:print(elem.numpy())

8
3
0
8
2
1

或者显式使用iter创建一个Python迭代器,并使用next来消耗其的元素:

it = iter(dataset)print(next(it).numpy())

8

另外,也可以使用reduce()变换来消耗数据集的元素,根据所有元素产生单个结果。下面的示例说明如何使用reduce变换来计算整数数据集的总和。

print(dataset.reduce(0, lambda state, value: state + value).numpy())

22

1.1 Dataset 结构介绍 ¶

一个Dataset由多个相同结构的(嵌套)元素组成,每个元素又由多个可由tf.TypeSpec表示的部分组成(常见的有Tensor, SparseTensor, RaggedTensor, TensorArray, Dataset)。

利用Dataset.element_spec属性可以检查每个元素的组成部分的类型。该属性返回一个由tf.TypeSpec对象组成的嵌套结构,这个结构与Dataset中元素的结构是对应的。例如:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))dataset1.element_spec

TensorSpec(shape=(10,), dtype=tf.float32, name=None)

dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([4]),tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))dataset2.element_spec

(TensorSpec(shape=(), dtype=tf.float32, name=None),
\;TensorSpec(shape=(100,), dtype=tf.int32, name=None))

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))dataset3.element_spec

(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
\;(TensorSpec(shape=(), dtype=tf.float32, name=None),
\;\;TensorSpec(shape=(100,), dtype=tf.int32, name=None)))

# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))dataset4.element_spec

SparseTensorSpec(TensorShape([3, 4]), tf.int32)

# Use value_type to see the type of value represented by the element spec
dataset4.element_spec.value_type

tensorflow.python.framework.sparse_tensor.SparseTensor

Dataset 的变换支持任何结构的数据集。在使用 Dataset.map()Dataset.flat_map()Dataset.filter() 函数时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))dataset1

<TensorSliceDataset shapes: (10,), types: tf.int32>

for z in dataset1:print(z.numpy())

[6 7 1 1 5 6 7 8 7 6]
[8 3 3 7 9 3 8 4 8 4]
[2 3 6 9 4 2 1 8 1 6]
[6 7 1 9 6 2 4 7 9 1]

dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([4]),tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))dataset2

<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))dataset3

<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>

for a, (b,c) in dataset3:print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))

shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

:为 Dataset 中的元素的各个组件命名通常会带来便利性(例如,元素的各个组件表示不同特征时)。除了元组之外,还可以使用 命名元组(collections.namedtuple) 或 字典 来表示 Dataset 的单个元素。

dataset = tf.data.Dataset.from_tensor_slices({"a": tf.random.uniform([4]),"b": tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)})dataset..element_spec

{‘a’: TensorSpec(shape=(), dtype=tf.float32, name=None), ‘b’: TensorSpec(shape=(100,), dtype=tf.int32, name=None)}

2. 读取输入数据 ¶

2.1 读取Numpy数组 ¶

See Loading NumPy arrays for more examples.

如果您的数据存储在内存中,则创建 Dataset 的最简单方法是使用Dataset.from_tensor_slices()创建dataset。

train, test = tf.keras.datasets.fashion_mnist.load_data() # out is np array

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

images, labels = train
images = images/255dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # auto convert np array to constant tensor
dataset

<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

注意:上面的代码段会将 features 和 labels 数组作为 tf.constant() 嵌入 TensorFlow 图中。这非常适合小型数据集,但会浪费内存,因为这会多次复制数组的内容,并可能会达到 tf.GraphDef 协议缓冲区的 2GB 上限。

2.2 读取Python生成器中的数据 ¶

另一个常见的数据源是Python生成器。

注意:虽然使用Python生成器很简单,但这种方法的移植性、可扩展性较差。它必须与生成器运行在同一个Python进程中,并且它仍然受Python GIL的制约。

def count(stop):i = 0while i<stop:yield ii += 1
for n in count(5):print(n)

0
1
2
3
4

Dataset.from_generator可以将生成器转化为tf.data.Dataset.from_generator函数将可调用对象作为输入,从而在到达生成器末尾时可重新启动生成器。它带有一个可选args参数,利用该参数可向可调用对象传递传递参数。

output_types参数是必需的,因为tf.data会在后台构建一个tf.Graph(图的边界需要tf.type)。

ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):print(count_batch.numpy())

[0\, \, 1 \, \, 2 \,\, 3 \,\, 4 \,\, 5 \, \, 6 \,\, 7 \,\, 8 \, \, 9 \,]
[10\, 11\, 12\, 13\, 14\, 15\, 16\, 17\, 18\, 19]
[20 \, 21 \, 22 \, 23 \, 24 \, 0 \, 1 \, 2 \, 3 \, 4]
[ 5 \, 6 \, 7 \, 8 \, 9 \, 10 \, 11 \, 12 \, 13 \, 14]
[15\, 16\, 17\, 18\, 19\, 20\, 21\, 22\, 23\, 24]
[0\, \, 1 \, \, 2 \,\, 3 \,\, 4 \,\, 5 \, \, 6 \,\, 7 \,\, 8 \, \, 9 \,]
[10\, 11\, 12\, 13\, 14\, 15\, 16\, 17\, 18\, 19]
[20\, 21\, 22\, 23\, 24\, 0 \, 1 \, 2 \, 3 \, 4]
[ 5 6 7 8 9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes参数不是必须的,但是极力推荐指定该参数。因为很多TensorFlow operations不支持unknown rank的Tensor。如果某一个axis的长度是未知或者可变的,可以在output_shapes参数中将其置为None。

值得注意的是,dataset的其他方法也有output_shapes、output_types类似的规则。

下面是一个实例,它返回一个array元组,第二个array是一个长度不确定的向量:

def gen_series(): # 生成器i = 0while True:size = np.random.randint(0, 10)yield i, np.random.normal(size=(size,)) # array形状为(-1,)i += 1
for i, series in gen_series():print(i, ":", str(series))if i > 5:break

0 : [ 1.9201 0.2124 -0.3383 -0.1141 0.7749 -0.1499]
1 : []
2 : [ 0.5885 -1.1092 0.4577 2.2978 -1.1854]
3 : [-1.7452 1.0516]
4 : []
5 : []
6 : [-0.8563 -1.2055 -0.291 1.0448 0.1486 1.0402 1.8017]

第一个array是int32型,shape为**();第二个array是一个float32型,shape为(None,)**。

ds_series = tf.data.Dataset.from_generator(gen_series, output_types=(tf.int32, tf.float32),  # 必选参数output_shapes=((), (None,))) # 可选参数,但最好选上,原因前面已经提过ds_series

<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

现在,tf.data.Dataset建好了。但请注意:将形状可变的数据集进行 batching 时,您需要使用Dataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10, padded_shapes=([], [None]))ids, sequence_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequence_batch.numpy())

[ 6 1 10 0 3 17 12 9 5 23]

[[ 0.5812 -0.825 0.6075 -1.3856 -0.8151 -1.1908 0. 0. ]
[-0.7208 0.0611 0.0084 0.6592 0.8364 0.8327 -0.7164 0.8826]
[ 0.0391 -2.0019 0.4077 0.9304 0. 0. 0. 0. ]
[ 0.4397 -0.0901 -0.4993 0.3485 0.2481 0. 0. 0. ]
[ 0.0346 0. 0. 0. 0. 0. 0. 0. ]
[-1.0478 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. ]
[ 0.3163 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. ]]

注意:TensorFlow 2.2版本中,padded_shapes参数已经不需要了,The default behavior is to pad all axes to the longest in the batch.

ds_series_batch = ds_series.shuffle(20).padded_batch(10)

对于更实际的示例,可以尝试用preprocessing.image.ImageDataGenerator将其包装为tf.data.Dataset

首先下载数据:

flowers = tf.keras.utils.get_file('flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',untar=True)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 5s 0us/step

创建 image.ImageDataGenerator

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))

Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)

float32 (32, 256, 256, 3)
float32 (32, 5)

ds = tf.data.Dataset.from_generator(img_gen.flow_from_directory, args=[flowers], output_types=(tf.float32, tf.float32), output_shapes=([32,256,256,3], [32,5])
)ds

<FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, tf.float32)>

2.3 读取TFRecord数据 ¶

See Loading TFRecords for an end-to-end example.

tf.data API支持多种文件格式,因此您可以处理超出内存大小的大型数据集。例如,TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序都支持该格式的训练数据。通过 tf.data.TFRecordDataset 类,您可以将一个或多个 TFRecord 文件的内容作为数据管道的输入。

下面以French Street Name Signs(FSNS)为例:

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 0s 0us/step

TFRecordDataset的filenames 参数可以是字符串、字符串列表,也可以是字符串 tf.Tensor。因此,如果您有两组分别用于训练和验证的文件,你可以创建一个工厂方法来产生dataset(以filenames作为输入参数)。

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

<TFRecordDatasetV2 shapes: (), types: tf.string>

很多TensorFlow项目在它们的TFRecords文件中,使用了序列化的tf.train.Example记录。查看这种数据需要解码:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())parsed.features.feature['image/text']

bytes_list {
value: “Rue Perreyon”
}

2.4 读取text数据 ¶

See Loading Text for an end to end example.

很多数据集都是作为一个或多个文本文件存储的。tf.data.TextLineDataset 可以从一个或多个文本文件中提取行。给定一个或多个文件名,TextLineDataset 会为这些文件的每行生成一个字符串值元素。

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']file_paths = [tf.keras.utils.get_file(file_name, directory_url + file_name)for file_name in file_names
]

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step

dataset = tf.data.TextLineDataset(file_paths)

查看第一个文件的前几行:

for line in dataset.take(5):print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus’ son;"
b’His wrath pernicious, who ten thousand woes’
b"Caused to Achaia’s host, sent many a soul"
b’Illustrious into Ades premature,’
b’And Heroes gave (so stood the will of Jove)’

使用Dataset.interleave可以交替读取各个文件。这样可以更轻松地将文件混在一起。

files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)for i, line in enumerate(lines_ds.take(9)):if i % 3 == 0:print()print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus’ son;"
b"\xef\xbb\xbfOf Peleus’ son, Achilles, sing, O Muse,"
b’\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought’

b’His wrath pernicious, who ten thousand woes’
b’The vengeance, deep and deadly; whence to Greece’
b’countless ills upon the Achaeans. Many a brave soul did it send’

b"Caused to Achaia’s host, sent many a soul"
b’Unnumbered ills arose; which many a soul’
b’hurrying down to Hades, and many a hero did it yield a prey to dogs and’

默认情况下,TextLineDataset 会读取每个文件的每一行,这可能是不是我们想要的。例如,如果文件以标题行开头或包含评论。可以使用 Dataset.skip()Dataset.filter() 方法来移除这些行。

这里以Titanic数据集为例,演示去除标题行,过滤以查找幸存者:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)

Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step

for line in titanic_lines.take(10):print(line.numpy())

b’survived,sex,age,n_siblings_spouses,parch,fare,class,deck,embark_town,alone’
b’0,male,22.0,1,0,7.25,Third,unknown,Southampton,n’
b’1,female,38.0,1,0,71.2833,First,C,Cherbourg,n’
b’1,female,26.0,0,0,7.925,Third,unknown,Southampton,y’
b’1,female,35.0,1,0,53.1,First,C,Southampton,n’
b’0,male,28.0,0,0,8.4583,Third,unknown,Queenstown,y’
b’0,male,2.0,3,1,21.075,Third,unknown,Southampton,n’
b’1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n’
b’1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n’
b’1,female,4.0,1,1,16.7,Third,G,Southampton,n’

def survived(line):return tf.not_equal(tf.strings.substr(line, 0, 1), "0")survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):print(line.numpy())

b’1,female,38.0,1,0,71.2833,First,C,Cherbourg,n’
b’1,female,26.0,0,0,7.925,Third,unknown,Southampton,y’
b’1,female,35.0,1,0,53.1,First,C,Southampton,n’
b’1,female,27.0,0,2,11.1333,Third,unknown,Southampton,n’
b’1,female,14.0,1,0,30.0708,Second,unknown,Cherbourg,n’
b’1,female,4.0,1,1,16.7,Third,G,Southampton,n’
b’1,male,28.0,0,0,13.0,Second,unknown,Southampton,y’
b’1,female,28.0,0,0,7.225,Third,unknown,Cherbourg,y’
b’1,male,28.0,0,0,35.5,First,A,Southampton,y’
b’1,female,38.0,1,5,31.3875,Third,unknown,Southampton,n’

2.5 读取CSV数据 ¶

See Loading CSV Files, and Loading Pandas DataFrames for more examples.

CSV是一种常见的文件格式,它以纯文本方式储存表格数据。

例如:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")  # 下载数据
df = pd.read_csv(titanic_file, index_col=None)
df.head()
survived sex age n_siblings_spouses parch fare class deck embark_town alone
0 0 male 22.0 1 0 7.2500 Third unknown Southampton n
1 1 female 38.0 1 0 71.2833 First C Cherbourg n
2 1 female 26.0 0 0 7.9250 Third unknown Southampton y
3 1 female 35.0 1 0 53.1000 First C Southampton n
4 0 male 28.0 0 0 8.4583 Third unknown Queenstown y

如果你的数据规模不大,能直接读入内存,Dataset.from_tensor_slices方法可以以字典为输入,从而大大方便数据的导入:

titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))for feature_batch in titanic_slices.take(1):for key, value in feature_batch.items():print("  {!r:20s}: {}".format(key, value))

‘survived’ : 0
‘sex’ : b’male’
‘age’ : 22.0
‘n_siblings_spouses’: 1
‘parch’ : 0
‘fare’ : 7.25
‘class’ : b’Third’
‘deck’ : b’unknown’
‘embark_town’ : b’Southampton’
‘alone’ : b’n’

相比之下,直接从硬盘中读取数据是一个更灵活的方案。

tf.data模块提供了从一个或多个符合RFC 4180的 CSV文件中提取记录的方法。

experimental.make_csv_dataset函数是一个读取csv文件的高阶API。它支持列类型推断和许多其他功能(例如batching、shuffling等),以简化用法。

titanic_batches = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=4,label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):print("'survived': {}".format(label_batch))print("features:")for key, value in feature_batch.items():print("  {!r:20s}: {}".format(key, value))

‘survived’: [1 1 0 0]
features:
\; ‘sex’ : [b’female’ b’female’ b’male’ b’male’]
\; ‘age’ : [28. 24. 29. 28.]
\; ‘n_siblings_spouses’: [0 1 0 0]
\; ‘parch’ : [0 0 0 0]
\; ‘fare’ : [ 7.2292 26. 30. 7.725 ]
\; ‘class’ : [b’Third’ b’Second’ b’First’ b’Third’]
\; ‘deck’ : [b’unknown’ b’unknown’ b’D’ b’unknown’]
\; ‘embark_town’ : [b’Cherbourg’ b’Southampton’ b’Southampton’ b’Queenstown’]
\; ‘alone’ : [b’y’ b’n’ b’y’ b’y’]

如果只需要列的子集,则可以使用select_columns参数

titanic_batches = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=4,label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):print("'survived': {}".format(label_batch))for key, value in feature_batch.items():print("  {!r:20s}: {}".format(key, value))

‘survived’: [1 0 1 0]
\; ‘fare’ : [ 10.5 7.25 23. 106.425]
\; ‘class’ : [b’Second’ b’Third’ b’Second’ b’First’]

还有一个低阶experimental.CsvDataset类API,它可提供更精细的控制,但不支持列类型推断。相反,您必须指定每列的类型。

titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)for line in dataset.take(10):print([item.numpy() for item in line])

[0,b’male’,22.0,1,0,7.25,b’Third’,b’unknown’,b’Southampton’,b’n’]
[1,b’female’,38.0,1,0, 71.2833,b’First’,b’C’,b’Cherbourg’,b’n’]
[1,b’female’,26.0,0,0,7.925,b’Third’,b’unknown’,b’南安普敦’,b’y’]
[1,b’女性’,35.0,1,0,53.1,b’First’,b’C’,b’Southampton’,b’n’]
[0,b’male ‘,28.0,0,0,8.4583,b’Third’,b’unknown’,b’Queenstown’,b’y’]
[0,b’male’,2.0,3,1,21.075,b’Third’ ,b’unknown’,b’Southampton’,b’n’]
[1,b’female’,27.0、0、2、11.1333,b’Third’,b’unknown’,b’南安普敦’,b’n’]
[1,b’female’,14.0,1,0,30.0708,b’Second’,b’unknown’,b’Cherbourg’,b’n’]
[1,b’female ‘,4.0,1,1,16.7,b’Third’,b’G’,b’Southampton’,b’n’]
[0,b’male’,20.0,0,0,8.05,b’Third’,b’unknown’,b’Southampton’,b’y’]

如果某些列为空,则此低级界面允许您提供默认值而不是列类型。

%%writefile missing.csv  # Ipython魔法命令,只在Ipython中又用
1,2,3,4
,2,3,4
1,,3,4
1,2,,4
1,2,3,
,,,

Writing missing.csv

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))
dataset

<MapDataset shapes: (4,), types: tf.int32>

for line in dataset:print(line.numpy())

[1 2 3 4 ]
[999 2 3 4 ]
[1 999 3 4 ]
[1 2 999 4 ]
[1 2 3 999]
[999 999 999 999]

默认情况下,CsvDataset会生成(yield)文件所有列的每一行,这可能是不希望的。例如,如果要忽略文件开头的标题行,或者希望去除掉某些列,可以使用headerselect_cols这两个参数。

# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))
dataset

<MapDataset shapes: (2,), types: tf.int32>

for line in dataset:print(line.numpy())

[2 4]
[2 4]
[999 4]
[2 4]
[2 999]
[999 999]

2.5 从文件读取数据 ¶

很多数据集是由很多的文件构成,每个文件存储单个example。

flowers_root = tf.keras.utils.get_file('flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',untar=True)
flowers_root = pathlib.Path(flowers_root)

根目录包含每个类的文件夹:

for item in flowers_root.glob("*"):print(item.name)

sunflowers
daisy
LICENSE.txt
roses
tulips
dandelion

每个类的文件夹中存储的是该类样本:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))for f in list_ds.take(5):print(f.numpy())

b’/home/kbuilder/.keras/datasets/flower_photos/roses/2980099495_cf272e90ca_m.jpg’
b’/home/kbuilder/.keras/datasets/flower_photos/sunflowers/14678298676_6db8831ee6_m.jpg’
b’/home/kbuilder/.keras/datasets/flower_photos/tulips/485266837_671def8627.jpg’
b’/home/kbuilder/.keras/datasets/flower_photos/daisy/7377004908_5bc0cde347_n.jpg’
b’/home/kbuilder/.keras/datasets/flower_photos/dandelion/9726260379_4e8ee66875_m.jpg’

利用tf.io.read_file函数读取数据并从路径中提取标签,并返回(image, label)对:

def process_path(file_path):label = tf.strings.split(file_path, '/')[-2]return tf.io.read_file(file_path), labellabeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):print(repr(image_raw.numpy()[:100]))print()print(label_text.numpy())

b’\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xfe\x00\x1ccmp3.10.3.2Lq3 0xad6b4f35\x00\xff\xdb\x00C\x00\x03\x02\x02\x03\x02\x02\x03\x03\x03\x03\x04\x03\x03\x04\x05\x08\x05\x05\x04\x04\x05\n\x07\x07\x06\x08\x0c\n\x0c\x0c\x0b\n\x0b\x0b\r\x0e\x12\x10\r\x0e\x11\x0e\x0b\x0b\x10’

b’roses’

3. 数据集元素 batching ¶

3.1 最简单的 batching(直接 stack) ¶

最简单的 batching 方法是将数据集中的 n 个连续元素堆叠为单个元素。Dataset.batch() 转换正是这么做的,它与 tf.stack() 运算符具有相同的限制(被应用于元素的每个组成部分):即对于每个组成部分 i,所有元素的shape必须完全相同。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)for batch in batched_dataset.take(4):print([arr.numpy() for arr in batch])

[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8, 9, 10, 11]), array([ -8, -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]

Dataset.batch容易导致数量未知错误,因为最后一个batch可能未满。注意shape中的None:

batched_dataset

<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>

使用drop_remainder参数忽略最后一批,并获得完整的形状传播:

batched_dataset = dataset.batch(7, drop_remainder=True)
batched_dataset

<BatchDataset shapes: ((7,), (7,)), types: (tf.int64, tf.int64)>

4.2 将 Tensor 填充成统一大小,然后 batching ¶

上述方法适用于具有相同大小的张量。不过,很多模型(例如序列模型)处理的输入数据可能具有变化的size(例如序列的长度不同)。为了解决这个问题,可以通过 Dataset.padded_batch() 来指定一个或多个可能被填充的维度,从而批处理不同形状的张量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))for batch in dataset.take(2):print(batch.numpy())print()

[[0 0 0]
\; [1 0 0]
\; [2 2 0]
\; [3 3 3]]

[[4 4 4 4 0 0 0]
\; [5 5 5 5 5 0 0]
\; [6 6 6 6 6 6 0]
\; [7 7 7 7 7 7 7]]

Dataset.padded_batch() 允许你为各部分的各维度设置不同的填充,并且可以采用可变长度(用 None 表示)或恒定长度。也可以更改用于填充值,默认为 0。

4. 训练工作流程 ¶

4.1. 数据repeat多个epoch ¶

tf.dataAPI提供了两种主要的方式来实现数据的epoch repeat。

  • 最简单的方式是使用Dataset.repeat()

下面实例演示:

titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):batch_sizes = [batch.shape[0] for batch in ds]plt.bar(range(len(batch_sizes)), batch_sizes)plt.xlabel('Batch number')plt.ylabel('Batch size')

如果不给Dataset.repeat()传递参数,数据集将无限重复输入。

Dataset.repeat无参数输入时会自动开始无缝地切换到下一次迭代。因此,先Dataset.repeatDataset.batch将产生跨越时期边界的批次:

titanic_batches = titanic_lines.repeat(3).batch(128)plot_batch_sizes(titanic_batches)


如果需要清晰的epoch边界,请先Dataset.batchDataset.repeat

titanic_batches = titanic_lines.batch(128).repeat(3)plot_batch_sizes(titanic_batches)


如果您想在每个epoch结束时执行自定义计算(例如收集统计信息),那么最简单的方法是在每个epoch重新开始数据集迭代:

epochs = 3
dataset = titanic_lines.batch(128)for epoch in range(epochs):for batch in dataset:print(batch.shape)print("End of epoch: ", epoch)

(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch: 0
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch: 1
(128,)
(128,)
(128,)
(128,)
(116,)
End of epoch: 2

4.2. 随机shuffle输入数据 ¶

Dataset.shuffle()有一个固定大小的buffer,每次按均匀分布从buffer中取出下一个元素。

注意:越大的buffer_sizes,shuffle的越均匀,但这会占用很多内存,并且需要大量的时间来填充满该buffer(填满后,才会输出元素)。如果这导致了一些问题,可以考虑使用Dataset.interleave代替。

向数据集添加索引,以便可以看到效果:

lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)
dataset

<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.string)>

由于的buffer_size值为100,并且批的大小为20,因此第一批不包含索引大于120的元素。

n,line_batch = next(iter(dataset))
print(n.numpy())

[ 92 84 52 3 27 100 44 26 2 63 54 93 69 97 10 101 32 65
109 40]

batch 与 shuffle 先后顺序的问题:

Dataset.shuffle在缓冲区为空之前,不会向epoch发出结束信号。

先shuffle后repeat

因此先shuffle后repeat,会把每一个epoch的数据完全用光之后,才会开始下一个epoch(将下一个epoch的数据放入shuffle buffer):

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):print(n.numpy())

Here are the item ID’s near the epoch boundary:

[523 318 510 467 627 433 514 594 454 560]
[596 566 205 613 493 570 615 411 556 496]
[598 528 623 559 299 473 391 536]
[41 14 51 3 97 70 34 99 63 52]
[ 49 69 104 0 112 90 38 88 11 83]

shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")
plt.legend()


先repeat后shuffle

先repeat后shuffle,在当前epoch结束时,会把下一个epoch开始的数据加入shuffle buffer,与上一个epoch末尾的数据放在一起shuffle。

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):print(n.numpy())

Here are the item ID’s near the epoch boundary:

[545 576 610 588 0 595 582 10 597 495]
[353 540 7 490 440 563 559 27 600 504]
[624 476 25 519 608 525 477 30 560 363]
[468 34 3 32 47 22 609 449 627 20 ]
[611 599 577 541 62 13 601 606 15 18 ]
[26 43 607 434 73 616 55 552 57 6 ]
[587 544 584 1 16 51 596 614 21 50 ]
[39 46 76 40 78 71 37 28 2 69 ]
[574 24 88 12 543 100 89 68 445 83 ]
[441 619 557 97 113 96 38 79 613 92 ]
[29 414 65 462 537 232 126 118 75 11 ]
[87 121 80 585 114 72 99 112 102 589]
[77 61 542 369 8 133 129 567 136 344]
[81 91 139 128 49 66 565 64 152 90 ]
[538 494 154 547 131 147 166 158 111 165]

repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")
plt.legend()

5. 数据预处理 ¶

Dataset.map(f)函数的作用是将函数f应用到数据集的每一个元素,并返回处理后的数据集。这个函数为我们数据预处理提供了极大便利。

注意:
\;\;\;\;\; f函数的参数和返回值都必须是tf.Tensor

5.1 使用Dataset.map()进行数据预处理 ¶

使用真实数据训练神经网络时,常常需要将图像的尺寸改为一致,从而可以将多个图像组成一个batch。因此这里将演示如何使用Dataset.map()进行图像的解码、改变尺寸。

同样以花分类数据集为例:

list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))

编写一个函数来解析list_ds中的每个元素

# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):parts = tf.strings.split(filename, '/')label = parts[-2]image = tf.io.read_file(filename)image = tf.image.decode_jpeg(image)image = tf.image.convert_image_dtype(image, tf.float32)image = tf.image.resize(image, [128, 128])return image, label

测试下上面函数的效果:

file_path = next(iter(list_ds))
image, label = parse_image(file_path)def show(image, label):plt.figure()plt.imshow(image)plt.title(label.numpy().decode('utf-8'))plt.axis('off')show(image, label)


parse_image函数应用到整个数据集list_ds上:

images_ds = list_ds.map(parse_image)for image, label in images_ds.take(2): # 查看2个example以验证正确性show(image, label)

5.2 使用非TF函数进行数据预处理 ¶

使用非TF内置函数进行数据预处理的性能不如内置TF函数(Python当然没有C++跑的快,另外语言间的通讯也是个瓶颈),所以尽可能多地使用TF内置函数进行数据预处理。但是有的时候,使用Python库函数进行数据处理也是很方便的。你可以在Dataset.map()函数内部使用tf.py_function()来调用Python函数。

例如,你想对图像进行一个任意旋转,但是tf.image里只有tf.image.tot90,这对于数据增强来说不是很有效。

注意: tensorflow_addons 的 tensorflow_addons.image.rotate 中有一个TF兼容的 rotate函数。

为了实现我们上面提到的随机旋转,我们可以使用scipy.ndimage.rotate函数:

import scipy.ndimage as ndimagedef random_rotate_image(image):image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)

Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).


为了在Dataset.map中使用上面写好的random_rotate_image,我们需要描述返回的shape及type:

def tf_random_rotate_image(image, label):im_shape = image.shape[image,] = tf.py_function(random_rotate_image, [image], [tf.float32])image.set_shape(im_shape)return image, label
rot_ds = images_ds.map(tf_random_rotate_image)for image, label in rot_ds.take(2):show(image, label)

Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).

5.3 解析tf.Exampleprotocol buffer messages ¶

许多输入管道都是从TFRecord文件中提取tf.train.Example协议缓冲区消息。每条tf.train.Example记录包含一个或多个“特征”,并且输入管道通常会将这些特征转换为张量。

fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])
dataset

<TFRecordDatasetV2 shapes: (), types: tf.string>

你可以在td.data.Dataset外,使用tf.train.Example protos来了解数据:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
plt.imshow(img)
plt.axis('off')
_ = plt.title(feature["image/text"].bytes_list.value[0])

raw_example = next(iter(dataset))
def tf_parse(eg):example = tf.io.parse_example(eg[tf.newaxis], {'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)})return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(txt.numpy())
print(repr(img.numpy()[:20]), "...")

b’Rue Perreyon’b’
\ x89PNG \ r \ n \ x1a \ n \ x00 \ x00 \ x00 \ rIHDR \ x00 \ x00 \ x02X’…

decoded = dataset.map(tf_parse)
decoded

<MapDataset shapes: ((), ()), types: (tf.string, tf.string)>

image_batch, text_batch = next(iter(decoded.batch(10)))
image_batch.shape

TensorShape([10])

5.4 时间序列windowing ¶

For an end to end time series example see: Time series forecasting.

时间序列数据通常以完整的时间轴进行组织。

下面用Dataset.range模拟一个时间序列:

range_ds = tf.data.Dataset.range(100000)

通常,基于此类数据的模型需要连续的时间切片。

最简单的方法是数据进行batch:

5.4.1 使用batch ¶

batches = range_ds.batch(10, drop_remainder=True)for batch in batches.take(5):print(batch.numpy())

[0 1 2 3 4 5 6 7 8 9 ]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

如果要对未来进行one step密集预测,您可以相对于彼此移动特征和标签 one step:

def dense_1_step(batch):# Shift features and labels one step relative to each other.return batch[:-1], batch[1:]predict_dense_1_step = batches.map(dense_1_step)for features, label in predict_dense_1_step.take(3):print(features.numpy(), " => ", label.numpy())

[0 1 2 3 4 5 6 7 8 ] => [1 2 3 4 5 6 7 8 9 ]
[10 11 12 13 14 15 16 17 18] => [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28] => [21 22 23 24 25 26 27 28 29]

如果要预测整个窗口而不是固定的偏移量,可以将batches分为两部分:

batches = range_ds.batch(15, drop_remainder=True)def label_next_5_steps(batch):return (batch[:-5],   # Take the first 5 stepsbatch[-5:])   # take the remainderpredict_5_steps = batches.map(label_next_5_steps)for features, label in predict_5_steps.take(3):print(features.numpy(), " => ", label.numpy())

[0 1 2 3 4 5 6 7 8 9 ] => [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24] => [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39] => [40 41 42 43 44]

如果要使一个batch的特征 与 另一个batch的标签有重合,请使用Dataset.zip

feature_length = 10
label_length = 5features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])predict_5_steps = tf.data.Dataset.zip((features, labels))for features, label in predict_5_steps.take(3):print(features.numpy(), " => ", label.numpy())

[0 1 2 3 4 5 6 7 8 9 ] => [10 11 12 13 14]
[10 11 12 13 14 15 16 17 18 19] => [20 21 22 23 24]
[20 21 22 23 24 25 26 27 28 29] => [30 31 32 33 34]

5.4.2 使用window ¶

在使用Dataset.batch时,有些情况下可能需要你精细化的控制。该Dataset.window方法让你进行完全的控制,但需格外小心:它返回Dataset的Datasets。有关详细信息,参见1.1节。

window_size = 5windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):print(sub_ds)

<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>

Dataset.flat_map方法可以获取数据集的数据集并将其展平为单个数据集:can take a dataset of datasets and flatten it into a single dataset:

 for x in windows.flat_map(lambda x: x).take(30):print(x.numpy(), end=' ')

WARNING:tensorflow:AutoGraph could not transform <function at 0x7f973007e6a8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function at 0x7f973007e6a8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9

几乎所有情况下,你需要先对数据集进行batch:

def sub_to_batch(sub):return sub.batch(window_size, drop_remainder=True)for example in windows.flat_map(sub_to_batch).take(5):print(example.numpy())

[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

现在,您可以看到shift参数控制着每个窗口的移动量。

将所有的代码放在一起,构建下面的函数:

def make_window_dataset(ds, window_size=5, shift=1, stride=1):windows = ds.window(window_size, shift=shift, stride=stride)def sub_to_batch(sub):return sub.batch(window_size, drop_remainder=True)windows = windows.flat_map(sub_to_batch)return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)for example in ds.take(10):print(example.numpy())

[0 3 6 9 12 15 18 21 24 27]
[5 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]

然后,像以前一样很容易提取标签:

dense_labels_ds = ds.map(dense_1_step)for inputs,labels in dense_labels_ds.take(3):print(inputs.numpy(), "=>", labels.numpy())

[0 3 6 9 12 15 18 21 24] => [3 6 9 12 15 18 21 24 27]
[5 8 11 14 17 20 23 26 29] => [8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

5.5 重采样 ¶

当使用类别非常不平衡的数据集时,您可能需要对数据集重新采样。tf.data提供了两种方法来执行此操作。信用卡欺诈数据集就是此类问题的一个很好的例子。

注意:有关完整教程,请参见不平衡数据。

zip_path = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',fname='creditcard.zip',extract=True)csv_path = zip_path.replace('.zip', '.csv')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69156864/69155632 [==============================] - 2s 0us/step

creditcard_ds = tf.data.experimental.make_csv_dataset(csv_path, batch_size=1024, label_name="Class",# Set the column types: 30 floats and an int.column_defaults=[float()]*30+[int()])

检查类别分布情况,类别是高度不均衡的:

def count(counts, batch):features, labels = batchclass_1 = labels == 1class_1 = tf.cast(class_1, tf.int32)class_0 = labels == 0class_0 = tf.cast(class_0, tf.int32)counts['class_0'] += tf.reduce_sum(class_0)counts['class_1'] += tf.reduce_sum(class_1)return counts
counts = creditcard_ds.take(10).reduce(initial_state={'class_0': 0, 'class_1': 0},reduce_func = count)counts = np.array([counts['class_0'].numpy(),counts['class_1'].numpy()]).astype(np.float32)fractions = counts/counts.sum()
print(fractions)

[0.995 0.005]

训练不平衡数据集的一种常见方法是平衡它。tf.data包括了一些进行数据平衡的方法:

5.5.1 Datasets.sampling

一种重采样数据集的方法是使用sample_from_datasets。如果每个类别都有一个独立的data.Dataset,这种方法很适用。

在这里,只需使用过滤器从信用卡欺诈数据中生成各个类别的dataset:

negative_ds = (creditcard_ds.unbatch().filter(lambda features, label: label==0).repeat())
positive_ds = (creditcard_ds.unbatch().filter(lambda features, label: label==1).repeat())

WARNING:tensorflow:AutoGraph could not transform <function at 0x7f9730114598> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function at 0x7f9730114598> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function at 0x7f97301149d8> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function at 0x7f97301149d8> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

for features, label in positive_ds.batch(10).take(1):print(label.numpy())

[1 1 1 1 1 1 1 1 1 1 1]

使用tf.data.experimental.sample_from_datasets进行数据集均衡,请执行以下操作:

balanced_ds = tf.data.experimental.sample_from_datasets([negative_ds, positive_ds], [0.5, 0.5]).batch(10)

现在,数据集以50/50的概率生成每个类的示例:

for features, labels in balanced_ds.take(10):print(labels.numpy())

[1 0 1 1 0 1 0 0 0 0]
[1 1 0 0 0 0 0 1 0 1]
[1 1 0 0 0 1 0 0 1 1]
[1 0 0 1 0 1 1 0 0 0]
[0 0 1 0 0 0 0 1 1 1]
[0 1 1 1 1 0 0 1 0 1]
[0 0 0 0 1 0 1 1 1 1]
[0 0 0 1 1 1 0 0 0 1]
[1 1 0 1 1 1 1 1 1 0]
[1 1 1 1 0 1 0 0 1 1]

5.5.2 experimental.rejection_resample

使用experimental.sample_from_datasets的一个问题是:它需要每一类有一个独立的tf.data.Dataset。这可以使用Dataset.filter实现,但是会导致数据被加载两次。

data.experimental.rejection_resample函数可以被用于数据集的平衡,并且数据只加载一次。元素将从数据集中删除以实现平衡。

data.experimental.rejection_resample有一个class_func参数。该class_func被用于数据集的每个元素,并用于确定示例出于平衡目的所属的类。

creditcard_ds的元素已经是(features, label)对。因此,class_func只需要返回这些标签:

def class_func(features, label):return label

重采样器还需要目标分布,以及可选的初始分布估计:

resampler = tf.data.experimental.rejection_resample(class_func, target_dist=[0.5, 0.5], initial_dist=fractions)

重采样器处理单个示例,因此您必须在应用重采样器之前先unbatch数据集:

resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/experimental/ops/resampling.py:156: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:

重采样器返回(class, example)从的输出创建对class_func。在这种情况下,example已经是(feature, label)一对,因此可map用于删除标签的多余副本:

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)

现在,数据集以50/50的概率生成每个类的示例:

for features, labels in balanced_ds.take(10):print(labels.numpy())

[1 1 1 1 0 0 1 0 0 0]
[1 1 1 1 1 1 1 1 0 1]
[1 1 0 1 0 0 0 1 0 0]
[0 0 0 1 1 0 0 1 1 0]
[1 0 1 0 0 1 1 0 1 0]
[1 1 0 1 0 1 0 0 1 0]
[0 1 1 1 0 1 1 1 1 1]
[0 0 1 0 1 0 0 1 0 1]
[1 1 0 1 1 0 0 1 1 1]
[1 1 0 0 0 1 0 1 1 0]

6. 在高阶API中使用tf.data

6.1 在 tf.keras 中使用 tf.data

tf.keras API 极大地降低了创建、使用机器学习模型的难度。它的.fit().evaluate().predict() API支持tf.data作为输入。下面是一个简单的示例:

train, test = tf.keras.datasets.fashion_mnist.load_data()images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)model = tf.keras.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(10)
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

Model.fitModel.evaluate都需要 数据+标签:

model.fit(fmnist_train_ds, epochs=2)

Epoch 1/2
WARNING:tensorflow:Layer flatten is casting an input tensor from dtype float64 to the layer’s dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call tf.keras.backend.set_floatx('float64'). To change just this layer, pass dtype=‘float64’ to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1875/1875 [==============================] - 3s 2ms/step - loss: 0.6013 - accuracy: 0.7970
Epoch 2/2
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4617 - accuracy: 0.8418

<tensorflow.python.keras.callbacks.History at 0x7f97801f1588>

从上面可以看出,tf.kerastf.data的支持还是很好的。

如果你传给.fit()方法的数据输入管道在构建过程中调用了Dataset.repeat()方法,你需要给.fit()额外传递steps_per_epoch这个参数。

model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)

Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4650 - accuracy: 0.8422
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.3897 - accuracy: 0.8797

<tensorflow.python.keras.callbacks.History at 0x7f97801f1908>

评估时,你可以指定评估step数:

loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)

1875/1875 [==============================] - 3s 2ms/step - loss: 0.4423 - accuracy: 0.8473
Loss : 0.44227170944213867
Accuracy : 0.847266674041748

对于大数据集,可以设置评估step数:

loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)

10/10 [==============================] - 0s 2ms/step - loss: 0.4557 - accuracy: 0.8188
Loss : 0.45573288202285767
Accuracy : 0.8187500238418579

调用Model.predict时,不需要标签。

predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)
print(result.shape)

(320, 10)

如果你的dataset包含标签,predict会自动忽略标签。

result = model.predict(fmnist_train_ds, steps = 10)
print(result.shape)

(320, 10)

6.2 在 tf.estimator 中使用 tf.data

要在tf.estimator.Estimatorinput_fn中使用Dataset,只需要保证input_fn返回的是Dataset即可

官方教程对于这块的介绍有点不足,推荐大家阅读《TensorFlow Estimator 官方文档之----Dataset for Estimator》,里面比较详细地介绍了怎么在tf.estimator中使用tf.data

import tensorflow_datasets as tfdsdef train_input_fn():titanic = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=32,label_name="survived")titanic_batches = (titanic.cache().repeat().shuffle(500).prefetch(tf.data.experimental.AUTOTUNE))return titanic_batches
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
age = tf.feature_column.numeric_column('age')
import tempfile
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(model_dir=model_dir,feature_columns=[embark, cls, age],n_classes=2
)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {’_model_dir’: ‘/tmp/tmp7xfmvz5w’, ‘_tf_random_seed’: None, ‘_save_summary_steps’: 100, ‘_save_checkpoints_steps’: None, ‘_save_checkpoints_secs’: 600, ‘_session_config’: allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
}
}
, ‘_keep_checkpoint_max’: 5, ‘_keep_checkpoint_every_n_hours’: 10000, ‘_log_step_count_steps’: 100, ‘_train_distribute’: None, ‘_device_fn’: None, ‘_protocol’: None, ‘_eval_distribute’: None, ‘_experimental_distribute’: None, ‘_experimental_max_worker_delay_secs’: None, ‘_session_creation_timeout_secs’: 7200, ‘_service’: None, ‘_cluster_spec’: ClusterSpec({}), ‘_task_type’: ‘worker’, ‘_task_id’: 0, ‘_global_id_in_cluster’: 0, ‘_master’: ‘’, ‘_evaluation_master’: ‘’, ‘_is_chief’: True, ‘_num_ps_replicas’: 0, ‘_num_worker_replicas’: 1}

model = model.train(input_fn=train_input_fn, steps=100)

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:560: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use layer.add_weight method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:143: calling Constant.init (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0…
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp7xfmvz5w/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0…
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100…
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmp7xfmvz5w/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100…
INFO:tensorflow:Loss for final step: 0.5968354.

result = model.evaluate(train_input_fn, steps=10)for key, value in result.items():print(key, ":", value)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-03-28T01:27:11Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp7xfmvz5w/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.65018s
INFO:tensorflow:Finished evaluation at 2020-03-28-01:27:11
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.684375, accuracy_baseline = 0.603125, auc = 0.73216105, auc_precision_recall = 0.6447562, average_loss = 0.60841894, global_step = 100, label/mean = 0.396875, loss = 0.60841894, precision = 0.76, prediction/mean = 0.31196585, recall = 0.2992126
INFO:tensorflow:Saving ‘checkpoint_path’ summary for global step 100: /tmp/tmp7xfmvz5w/model.ckpt-100
accuracy : 0.684375
accuracy_baseline : 0.603125
auc : 0.73216105
auc_precision_recall : 0.6447562
average_loss : 0.60841894
label/mean : 0.396875
loss : 0.60841894
precision : 0.76
prediction/mean : 0.31196585
recall : 0.2992126
global_step : 100

for pred in model.predict(train_input_fn):for key, value in pred.items():print(key, ":", value)break

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp7xfmvz5w/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.1131]
logistic : [0.4717]
probabilities : [0.5283 0.4717]
class_ids : [0]
classes : [b’0’]
all_class_ids : [0 1]
all_classes : [b’0’ b’1’]


注:本文来自于TenosrFlow官方使用tf.data导入数据的 Learn > Guide > tf.data

2020年3月29号更新

tf.data官方教程 - - 基于TF-v2相关推荐

  1. TensorFlow tf.data 导入数据(tf.data官方教程) * * * * *

    原文链接:https://blog.csdn.net/u014061630/article/details/80728694 TensorFlow版本:1.10.0 > Guide > I ...

  2. TF之p2p:基于TF利用p2p模型部分代码实现提高图像的分辨率

    TF之p2p:基于TF利用p2p模型部分代码实现提高图像的分辨率 目录 一.tfimage.py文件功能解释 二.process.py添加一个新操作 一.tfimage.py文件功能解释 1.此处的c ...

  3. TF之pix2pix:基于TF利用Facades数据集训练pix2pix模型、测试并进行生成过程全记录

    TF之pix2pix:基于TF利用Facades数据集训练pix2pix模型.测试并进行生成过程全记录 目录 TB监控 1.SCALARS 2.IMAGES 3.GRAPHS 4.DISTRIBUTI ...

  4. TF之DCGAN:基于TF利用DCGAN测试自己的数据集并进行生成过程全记录

    TF之DCGAN:基于TF利用DCGAN测试自己的数据集并进行生成过程全记录 目录 训练的数据集部分图片 输出结果 1.默认参数输出结果 训练过程全记录 训练的数据集部分图片 以从网上收集了许多日式动 ...

  5. TF之DCGAN:基于TF利用DCGAN测试MNIST数据集并进行生成过程全记录

    TF之DCGAN:基于TF利用DCGAN测试MNIST数据集并进行生成 目录 测试结果 测试过程全记录 测试结果 train_00_0099 train_00_0799 train_00_0899 t ...

  6. tensorflow基础:tf.data.Dataset.from_tensor_slices() 与 tf.data.Dataset.from_generator()的异同

    tf.data.Dataset.from_tensor_slices(tensor): -->将tensor沿其第一个维度切片,返回一个含有N个样本的数据集(假设tensor的第一个维度为N). ...

  7. TensorFlow :tf.data 高性能数据输入管道设计指南

    TensorFlow版本:1.12.0 本篇主要介绍怎么使用 tf.data API 来构建高性能的输入 pipeline. tf.data官方教程详见前面的博客<<<<< ...

  8. tensorflow教程 开始——数据集:快速了解 tf.data

    参考文章:数据集:快速了解 数据集:快速了解 tf.data 从 numpy 数组读取内存数据. 逐行读取 csv 文件. 基本输入 学习如何获取数组的片段,是开始学习 tf.data 最简单的方式. ...

  9. Tensorflow读取数据-tf.data.TFRecordDataset

    tensorflow TFRecords文件的生成和读取方法 文章目录 tensorflow TFRecords文件的生成和读取方法 1. TFRecords说明 2.关键API 2.1 tf.io. ...

最新文章

  1. 变换判断滤波器类型_7.4 低通IIR滤波器的频率变换
  2. 一天中每个小时段我都起来过,都睡过。
  3. ESX4.1 “USB设备支持”实测
  4. HP11.31安装11.2.0.3实施手册
  5. 期货结算 期货算法公式
  6. JS开发之Factory(工厂)模式解析
  7. 如何开始使用 Java 机器学习
  8. 知弥深度清理大师隐私政策
  9. office图标显示异常和新建时图标没有显示等问题解决
  10. python- 机器人抓取谷歌地图数据
  11. oneway的定义和使用
  12. js发布订阅原理,代码解析
  13. 国内外镜像下载合集(详细最终版)
  14. 功能室计算机宣言,教室布置标语(精选多篇)
  15. java saf_java – 从SAF内容URI中提取文件名
  16. R语言学习笔记 07 Probit、Logistic回归
  17. 文档结构图 字体大小调节
  18. 老蜗牛写采集:网络爬虫(二)
  19. 教育孩儿子要懂的心理学
  20. creo JAVA_JAVA(Creo 后处理的修改方法)

热门文章

  1. mx linux默认字体,MX Linux 17.1初体验
  2. 在Shell里面判断字符串是否为空
  3. 系统检测效果html,系统检测(MonitorTest)
  4. 计算机网络技术 李晓峰,计算机网络技术3(吉林大学李晓峰).ppt
  5. HTML页面 加载播放RTMP协议流和HLS协议流直播视频
  6. python-pptx学习总结
  7. Cadence (Allegro) 转 Altium Designer
  8. STM32——GPIO简介
  9. XSSFWorkbook自由合并单元格并设置单元格样式
  10. 信吗?20年后人类将长生不老!