

首先贴上TF v1版本的tf.data博文地址:《TensorFlow tf.data 导入数据(tf.data官方教程)》


  • 使用 tf.data 构建数据输入通道
    • 1. 基础知识 ¶
      • 1.1 Dataset 结构介绍 ¶
    • 2. 读取输入数据 ¶
      • 2.1 读取Numpy数组 ¶
      • 2.2 读取Python生成器中的数据 ¶
      • 2.3 读取TFRecord数据 ¶
      • 2.4 读取text数据 ¶
      • 2.5 读取CSV数据 ¶
      • 2.5 从文件读取数据 ¶
    • 3. 数据集元素 batching ¶
      • 3.1 最简单的 batching(直接 stack) ¶
      • 4.2 将 Tensor 填充成统一大小,然后 batching ¶
    • 4. 训练工作流程 ¶
      • 4.1. 数据repeat多个epoch ¶
      • 4.2. 随机shuffle输入数据 ¶
    • 5. 数据预处理 ¶
      • 5.1 使用Dataset.map()进行数据预处理 ¶
      • 5.2 使用非TF函数进行数据预处理 ¶
      • 5.3 解析tf.Exampleprotocol buffer messages ¶
      • 5.4 时间序列windowing ¶
        • 5.4.1 使用batch ¶
        • 5.4.2 使用window ¶
      • 5.5 重采样 ¶
        • 5.5.1 Datasets.sampling
        • 5.5.2 experimental.rejection_resample
    • 6. 在高阶API中使用tf.data
      • 6.1 在 tf.keras 中使用 tf.data
      • 6.2 在 tf.estimator 中使用 tf.data

使用 tf.data 构建数据输入通道

tf.data API编写的数据输入通道简单、并且可重用度高。tf.data能够实现非常复杂的数据输入通道。例如:图像模型的数据输入管道可能会聚集来自分布式文件系统中文件的数据,对每个图像应用随机扰动,然后将随机选择的图像合并为一批进行训练。文本模型的数据输入管道可能涉及从原始文本数据中提取符号,将其转换为带有查找表的嵌入标识符,以及将不同长度的序列分批处理。tf.dataAPI使得处理大量数据,从不同数据格式读取数据以及执行复杂的转换成为可能。

tf.data API引入了tf.data.Dataset 这个抽象概念。它是一个元素组成的序列,每个元素可以由一个或多个部分组成。例如,图像的数据输入通道中,一个元素可以是由数据和标签组成的一个训练样本。


  • 基于内存中的数据 或 硬盘中的一个或多个文件 建立Dataset
  • 通过对Dataset进行 transform 得到一个新的Dataset
from __future__ import absolute_import, division, print_function, unicode_literalsimport tensorflow as tfimport pathlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as npnp.set_printoptions(precision=4)

1. 基础知识 ¶


一旦你有了一个Dataset对象,你可以通过调用它的方法对其进行变换产生一个新的 Dataset对象。

Dataset是一个Python可迭代对象。所以可以使用 for 循环来消耗它的元素:

dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])

<TensorSliceDataset shapes: (), types: tf.int32>

for elem in dataset:print(elem.numpy())



it = iter(dataset)print(next(it).numpy())



print(dataset.reduce(0, lambda state, value: state + value).numpy())


1.1 Dataset 结构介绍 ¶

一个Dataset由多个相同结构的(嵌套)元素组成,每个元素又由多个可由tf.TypeSpec表示的部分组成(常见的有Tensor, SparseTensor, RaggedTensor, TensorArray, Dataset)。


dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))dataset1.element_spec

TensorSpec(shape=(10,), dtype=tf.float32, name=None)

dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([4]),tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))dataset2.element_spec

(TensorSpec(shape=(), dtype=tf.float32, name=None),
\;TensorSpec(shape=(100,), dtype=tf.int32, name=None))

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))dataset3.element_spec

(TensorSpec(shape=(10,), dtype=tf.float32, name=None),
\;(TensorSpec(shape=(), dtype=tf.float32, name=None),
\;\;TensorSpec(shape=(100,), dtype=tf.int32, name=None)))

# Dataset containing a sparse tensor.
dataset4 = tf.data.Dataset.from_tensors(tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]))dataset4.element_spec

SparseTensorSpec(TensorShape([3, 4]), tf.int32)

# Use value_type to see the type of value represented by the element spec


Dataset 的变换支持任何结构的数据集。在使用 Dataset.map()Dataset.flat_map()Dataset.filter() 函数时(这些转换会对每个元素应用一个函数),元素结构决定了函数的参数:

dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10], minval=1, maxval=10, dtype=tf.int32))dataset1

<TensorSliceDataset shapes: (10,), types: tf.int32>

for z in dataset1:print(z.numpy())

[6 7 1 1 5 6 7 8 7 6]
[8 3 3 7 9 3 8 4 8 4]
[2 3 6 9 4 2 1 8 1 6]
[6 7 1 9 6 2 4 7 9 1]

dataset2 = tf.data.Dataset.from_tensor_slices((tf.random.uniform([4]),tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))dataset2

<TensorSliceDataset shapes: ((), (100,)), types: (tf.float32, tf.int32)>

dataset3 = tf.data.Dataset.zip((dataset1, dataset2))dataset3

<ZipDataset shapes: ((10,), ((), (100,))), types: (tf.int32, (tf.float32, tf.int32))>

for a, (b,c) in dataset3:print('shapes: {a.shape}, {b.shape}, {c.shape}'.format(a=a, b=b, c=c))

shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)
shapes: (10,), (), (100,)

:为 Dataset 中的元素的各个组件命名通常会带来便利性(例如,元素的各个组件表示不同特征时)。除了元组之外,还可以使用 命名元组(collections.namedtuple) 或 字典 来表示 Dataset 的单个元素。

dataset = tf.data.Dataset.from_tensor_slices({"a": tf.random.uniform([4]),"b": tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)})dataset..element_spec

{‘a’: TensorSpec(shape=(), dtype=tf.float32, name=None), ‘b’: TensorSpec(shape=(100,), dtype=tf.int32, name=None)}

2. 读取输入数据 ¶

2.1 读取Numpy数组 ¶

See Loading NumPy arrays for more examples.

如果您的数据存储在内存中,则创建 Dataset 的最简单方法是使用Dataset.from_tensor_slices()创建dataset。

train, test = tf.keras.datasets.fashion_mnist.load_data() # out is np array

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

images, labels = train
images = images/255dataset = tf.data.Dataset.from_tensor_slices((images, labels)) # auto convert np array to constant tensor

<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

注意:上面的代码段会将 features 和 labels 数组作为 tf.constant() 嵌入 TensorFlow 图中。这非常适合小型数据集,但会浪费内存,因为这会多次复制数组的内容,并可能会达到 tf.GraphDef 协议缓冲区的 2GB 上限。

2.2 读取Python生成器中的数据 ¶


注意:虽然使用Python生成器很简单,但这种方法的移植性、可扩展性较差。它必须与生成器运行在同一个Python进程中,并且它仍然受Python GIL的制约。

def count(stop):i = 0while i<stop:yield ii += 1
for n in count(5):print(n)




ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )
for count_batch in ds_counter.repeat().batch(10).take(10):print(count_batch.numpy())

[0\, \, 1 \, \, 2 \,\, 3 \,\, 4 \,\, 5 \, \, 6 \,\, 7 \,\, 8 \, \, 9 \,]
[10\, 11\, 12\, 13\, 14\, 15\, 16\, 17\, 18\, 19]
[20 \, 21 \, 22 \, 23 \, 24 \, 0 \, 1 \, 2 \, 3 \, 4]
[ 5 \, 6 \, 7 \, 8 \, 9 \, 10 \, 11 \, 12 \, 13 \, 14]
[15\, 16\, 17\, 18\, 19\, 20\, 21\, 22\, 23\, 24]
[0\, \, 1 \, \, 2 \,\, 3 \,\, 4 \,\, 5 \, \, 6 \,\, 7 \,\, 8 \, \, 9 \,]
[10\, 11\, 12\, 13\, 14\, 15\, 16\, 17\, 18\, 19]
[20\, 21\, 22\, 23\, 24\, 0 \, 1 \, 2 \, 3 \, 4]
[ 5 6 7 8 9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]

output_shapes参数不是必须的,但是极力推荐指定该参数。因为很多TensorFlow operations不支持unknown rank的Tensor。如果某一个axis的长度是未知或者可变的,可以在output_shapes参数中将其置为None。



def gen_series(): # 生成器i = 0while True:size = np.random.randint(0, 10)yield i, np.random.normal(size=(size,)) # array形状为(-1,)i += 1
for i, series in gen_series():print(i, ":", str(series))if i > 5:break

0 : [ 1.9201 0.2124 -0.3383 -0.1141 0.7749 -0.1499]
1 : []
2 : [ 0.5885 -1.1092 0.4577 2.2978 -1.1854]
3 : [-1.7452 1.0516]
4 : []
5 : []
6 : [-0.8563 -1.2055 -0.291 1.0448 0.1486 1.0402 1.8017]


ds_series = tf.data.Dataset.from_generator(gen_series, output_types=(tf.int32, tf.float32),  # 必选参数output_shapes=((), (None,))) # 可选参数,但最好选上,原因前面已经提过ds_series

<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float32)>

现在,tf.data.Dataset建好了。但请注意:将形状可变的数据集进行 batching 时,您需要使用Dataset.padded_batch

ds_series_batch = ds_series.shuffle(20).padded_batch(10, padded_shapes=([], [None]))ids, sequence_batch = next(iter(ds_series_batch))

[ 6 1 10 0 3 17 12 9 5 23]

[[ 0.5812 -0.825 0.6075 -1.3856 -0.8151 -1.1908 0. 0. ]
[-0.7208 0.0611 0.0084 0.6592 0.8364 0.8327 -0.7164 0.8826]
[ 0.0391 -2.0019 0.4077 0.9304 0. 0. 0. 0. ]
[ 0.4397 -0.0901 -0.4993 0.3485 0.2481 0. 0. 0. ]
[ 0.0346 0. 0. 0. 0. 0. 0. 0. ]
[-1.0478 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. ]
[ 0.3163 0. 0. 0. 0. 0. 0. 0. ]
[ 0. 0. 0. 0. 0. 0. 0. 0. ]]

注意:TensorFlow 2.2版本中,padded_shapes参数已经不需要了,The default behavior is to pad all axes to the longest in the batch.

ds_series_batch = ds_series.shuffle(20).padded_batch(10)



flowers = tf.keras.utils.get_file('flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',untar=True)

Downloading data from https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
228818944/228813984 [==============================] - 5s 0us/step

创建 image.ImageDataGenerator

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)
images, labels = next(img_gen.flow_from_directory(flowers))

Found 3670 images belonging to 5 classes.

print(images.dtype, images.shape)
print(labels.dtype, labels.shape)

float32 (32, 256, 256, 3)
float32 (32, 5)

ds = tf.data.Dataset.from_generator(img_gen.flow_from_directory, args=[flowers], output_types=(tf.float32, tf.float32), output_shapes=([32,256,256,3], [32,5])

<FlatMapDataset shapes: ((32, 256, 256, 3), (32, 5)), types: (tf.float32, tf.float32)>

2.3 读取TFRecord数据 ¶

See Loading TFRecords for an end-to-end example.

tf.data API支持多种文件格式,因此您可以处理超出内存大小的大型数据集。例如,TFRecord文件格式是一种简单的面向记录的二进制格式,许多TensorFlow应用程序都支持该格式的训练数据。通过 tf.data.TFRecordDataset 类,您可以将一个或多个 TFRecord 文件的内容作为数据管道的输入。

下面以French Street Name Signs(FSNS)为例:

# Creates a dataset that reads all of the examples from two files.
fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001
7905280/7904079 [==============================] - 0s 0us/step

TFRecordDataset的filenames 参数可以是字符串、字符串列表,也可以是字符串 tf.Tensor。因此,如果您有两组分别用于训练和验证的文件,你可以创建一个工厂方法来产生dataset(以filenames作为输入参数)。

dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])

<TFRecordDatasetV2 shapes: (), types: tf.string>


raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())parsed.features.feature['image/text']

bytes_list {
value: “Rue Perreyon”

2.4 读取text数据 ¶

See Loading Text for an end to end example.

很多数据集都是作为一个或多个文本文件存储的。tf.data.TextLineDataset 可以从一个或多个文本文件中提取行。给定一个或多个文件名,TextLineDataset 会为这些文件的每行生成一个字符串值元素。

directory_url = 'https://storage.googleapis.com/download.tensorflow.org/data/illiad/'
file_names = ['cowper.txt', 'derby.txt', 'butler.txt']file_paths = [tf.keras.utils.get_file(file_name, directory_url + file_name)for file_name in file_names

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/cowper.txt
819200/815980 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/derby.txt
811008/809730 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/illiad/butler.txt
811008/807992 [==============================] - 0s 0us/step

dataset = tf.data.TextLineDataset(file_paths)


for line in dataset.take(5):print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus’ son;"
b’His wrath pernicious, who ten thousand woes’
b"Caused to Achaia’s host, sent many a soul"
b’Illustrious into Ades premature,’
b’And Heroes gave (so stood the will of Jove)’


files_ds = tf.data.Dataset.from_tensor_slices(file_paths)
lines_ds = files_ds.interleave(tf.data.TextLineDataset, cycle_length=3)for i, line in enumerate(lines_ds.take(9)):if i % 3 == 0:print()print(line.numpy())

b"\xef\xbb\xbfAchilles sing, O Goddess! Peleus’ son;"
b"\xef\xbb\xbfOf Peleus’ son, Achilles, sing, O Muse,"
b’\xef\xbb\xbfSing, O goddess, the anger of Achilles son of Peleus, that brought’

b’His wrath pernicious, who ten thousand woes’
b’The vengeance, deep and deadly; whence to Greece’
b’countless ills upon the Achaeans. Many a brave soul did it send’

b"Caused to Achaia’s host, sent many a soul"
b’Unnumbered ills arose; which many a soul’
b’hurrying down to Hades, and many a hero did it yield a prey to dogs and’

默认情况下,TextLineDataset 会读取每个文件的每一行,这可能是不是我们想要的。例如,如果文件以标题行开头或包含评论。可以使用 Dataset.skip()Dataset.filter() 方法来移除这些行。


titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)

Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv
32768/30874 [===============================] - 0s 0us/step

for line in titanic_lines.take(10):print(line.numpy())


def survived(line):return tf.not_equal(tf.strings.substr(line, 0, 1), "0")survivors = titanic_lines.skip(1).filter(survived)
for line in survivors.take(10):print(line.numpy())


2.5 读取CSV数据 ¶

See Loading CSV Files, and Loading Pandas DataFrames for more examples.



titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")  # 下载数据
df = pd.read_csv(titanic_file, index_col=None)
survived sex age n_siblings_spouses parch fare class deck embark_town alone
0 0 male 22.0 1 0 7.2500 Third unknown Southampton n
1 1 female 38.0 1 0 71.2833 First C Cherbourg n
2 1 female 26.0 0 0 7.9250 Third unknown Southampton y
3 1 female 35.0 1 0 53.1000 First C Southampton n
4 0 male 28.0 0 0 8.4583 Third unknown Queenstown y


titanic_slices = tf.data.Dataset.from_tensor_slices(dict(df))for feature_batch in titanic_slices.take(1):for key, value in feature_batch.items():print("  {!r:20s}: {}".format(key, value))

‘survived’ : 0
‘sex’ : b’male’
‘age’ : 22.0
‘n_siblings_spouses’: 1
‘parch’ : 0
‘fare’ : 7.25
‘class’ : b’Third’
‘deck’ : b’unknown’
‘embark_town’ : b’Southampton’
‘alone’ : b’n’


tf.data模块提供了从一个或多个符合RFC 4180的 CSV文件中提取记录的方法。


titanic_batches = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=4,label_name="survived")
for feature_batch, label_batch in titanic_batches.take(1):print("'survived': {}".format(label_batch))print("features:")for key, value in feature_batch.items():print("  {!r:20s}: {}".format(key, value))

‘survived’: [1 1 0 0]
\; ‘sex’ : [b’female’ b’female’ b’male’ b’male’]
\; ‘age’ : [28. 24. 29. 28.]
\; ‘n_siblings_spouses’: [0 1 0 0]
\; ‘parch’ : [0 0 0 0]
\; ‘fare’ : [ 7.2292 26. 30. 7.725 ]
\; ‘class’ : [b’Third’ b’Second’ b’First’ b’Third’]
\; ‘deck’ : [b’unknown’ b’unknown’ b’D’ b’unknown’]
\; ‘embark_town’ : [b’Cherbourg’ b’Southampton’ b’Southampton’ b’Queenstown’]
\; ‘alone’ : [b’y’ b’n’ b’y’ b’y’]


titanic_batches = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=4,label_name="survived", select_columns=['class', 'fare', 'survived'])
for feature_batch, label_batch in titanic_batches.take(1):print("'survived': {}".format(label_batch))for key, value in feature_batch.items():print("  {!r:20s}: {}".format(key, value))

‘survived’: [1 0 1 0]
\; ‘fare’ : [ 10.5 7.25 23. 106.425]
\; ‘class’ : [b’Second’ b’Third’ b’Second’ b’First’]


titanic_types  = [tf.int32, tf.string, tf.float32, tf.int32, tf.int32, tf.float32, tf.string, tf.string, tf.string, tf.string]
dataset = tf.data.experimental.CsvDataset(titanic_file, titanic_types , header=True)for line in dataset.take(10):print([item.numpy() for item in line])

[1,b’female’,38.0,1,0, 71.2833,b’First’,b’C’,b’Cherbourg’,b’n’]
[0,b’male ‘,28.0,0,0,8.4583,b’Third’,b’unknown’,b’Queenstown’,b’y’]
[0,b’male’,2.0,3,1,21.075,b’Third’ ,b’unknown’,b’Southampton’,b’n’]
[1,b’female ‘,4.0,1,1,16.7,b’Third’,b’G’,b’Southampton’,b’n’]


%%writefile missing.csv  # Ipython魔法命令,只在Ipython中又用

Writing missing.csv

# Creates a dataset that reads all of the records from two CSV files, each with
# four float columns which may have missing values.record_defaults = [999,999,999,999]
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults)
dataset = dataset.map(lambda *items: tf.stack(items))

<MapDataset shapes: (4,), types: tf.int32>

for line in dataset:print(line.numpy())

[1 2 3 4 ]
[999 2 3 4 ]
[1 999 3 4 ]
[1 2 999 4 ]
[1 2 3 999]
[999 999 999 999]


# Creates a dataset that reads all of the records from two CSV files with
# headers, extracting float data from columns 2 and 4.
record_defaults = [999, 999] # Only provide defaults for the selected columns
dataset = tf.data.experimental.CsvDataset("missing.csv", record_defaults, select_cols=[1, 3])
dataset = dataset.map(lambda *items: tf.stack(items))

<MapDataset shapes: (2,), types: tf.int32>

for line in dataset:print(line.numpy())

[2 4]
[2 4]
[999 4]
[2 4]
[2 999]
[999 999]

2.5 从文件读取数据 ¶


flowers_root = tf.keras.utils.get_file('flower_photos','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',untar=True)
flowers_root = pathlib.Path(flowers_root)


for item in flowers_root.glob("*"):print(item.name)



list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))for f in list_ds.take(5):print(f.numpy())


利用tf.io.read_file函数读取数据并从路径中提取标签,并返回(image, label)对:

def process_path(file_path):label = tf.strings.split(file_path, '/')[-2]return tf.io.read_file(file_path), labellabeled_ds = list_ds.map(process_path)
for image_raw, label_text in labeled_ds.take(1):print(repr(image_raw.numpy()[:100]))print()print(label_text.numpy())

b’\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xfe\x00\x1ccmp3.10.3.2Lq3 0xad6b4f35\x00\xff\xdb\x00C\x00\x03\x02\x02\x03\x02\x02\x03\x03\x03\x03\x04\x03\x03\x04\x05\x08\x05\x05\x04\x04\x05\n\x07\x07\x06\x08\x0c\n\x0c\x0c\x0b\n\x0b\x0b\r\x0e\x12\x10\r\x0e\x11\x0e\x0b\x0b\x10’


3. 数据集元素 batching ¶

3.1 最简单的 batching(直接 stack) ¶

最简单的 batching 方法是将数据集中的 n 个连续元素堆叠为单个元素。Dataset.batch() 转换正是这么做的,它与 tf.stack() 运算符具有相同的限制(被应用于元素的每个组成部分):即对于每个组成部分 i,所有元素的shape必须完全相同。

inc_dataset = tf.data.Dataset.range(100)
dec_dataset = tf.data.Dataset.range(0, -100, -1)
dataset = tf.data.Dataset.zip((inc_dataset, dec_dataset))
batched_dataset = dataset.batch(4)for batch in batched_dataset.take(4):print([arr.numpy() for arr in batch])

[array([0, 1, 2, 3]), array([ 0, -1, -2, -3])]
[array([4, 5, 6, 7]), array([-4, -5, -6, -7])]
[array([ 8, 9, 10, 11]), array([ -8, -9, -10, -11])]
[array([12, 13, 14, 15]), array([-12, -13, -14, -15])]



<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.int64)>


batched_dataset = dataset.batch(7, drop_remainder=True)

<BatchDataset shapes: ((7,), (7,)), types: (tf.int64, tf.int64)>

4.2 将 Tensor 填充成统一大小,然后 batching ¶

上述方法适用于具有相同大小的张量。不过,很多模型(例如序列模型)处理的输入数据可能具有变化的size(例如序列的长度不同)。为了解决这个问题,可以通过 Dataset.padded_batch() 来指定一个或多个可能被填充的维度,从而批处理不同形状的张量。

dataset = tf.data.Dataset.range(100)
dataset = dataset.map(lambda x: tf.fill([tf.cast(x, tf.int32)], x))
dataset = dataset.padded_batch(4, padded_shapes=(None,))for batch in dataset.take(2):print(batch.numpy())print()

[[0 0 0]
\; [1 0 0]
\; [2 2 0]
\; [3 3 3]]

[[4 4 4 4 0 0 0]
\; [5 5 5 5 5 0 0]
\; [6 6 6 6 6 6 0]
\; [7 7 7 7 7 7 7]]

Dataset.padded_batch() 允许你为各部分的各维度设置不同的填充,并且可以采用可变长度(用 None 表示)或恒定长度。也可以更改用于填充值,默认为 0。

4. 训练工作流程 ¶

4.1. 数据repeat多个epoch ¶

tf.dataAPI提供了两种主要的方式来实现数据的epoch repeat。

  • 最简单的方式是使用Dataset.repeat()


titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
titanic_lines = tf.data.TextLineDataset(titanic_file)
def plot_batch_sizes(ds):batch_sizes = [batch.shape[0] for batch in ds]plt.bar(range(len(batch_sizes)), batch_sizes)plt.xlabel('Batch number')plt.ylabel('Batch size')



titanic_batches = titanic_lines.repeat(3).batch(128)plot_batch_sizes(titanic_batches)


titanic_batches = titanic_lines.batch(128).repeat(3)plot_batch_sizes(titanic_batches)


epochs = 3
dataset = titanic_lines.batch(128)for epoch in range(epochs):for batch in dataset:print(batch.shape)print("End of epoch: ", epoch)

End of epoch: 0
End of epoch: 1
End of epoch: 2

4.2. 随机shuffle输入数据 ¶




lines = tf.data.TextLineDataset(titanic_file)
counter = tf.data.experimental.Counter()dataset = tf.data.Dataset.zip((counter, lines))
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(20)

<BatchDataset shapes: ((None,), (None,)), types: (tf.int64, tf.string)>


n,line_batch = next(iter(dataset))

[ 92 84 52 3 27 100 44 26 2 63 54 93 69 97 10 101 32 65
109 40]

batch 与 shuffle 先后顺序的问题:



因此先shuffle后repeat,会把每一个epoch的数据完全用光之后,才会开始下一个epoch(将下一个epoch的数据放入shuffle buffer):

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.shuffle(buffer_size=100).batch(10).repeat(2)print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(60).take(5):print(n.numpy())

Here are the item ID’s near the epoch boundary:

[523 318 510 467 627 433 514 594 454 560]
[596 566 205 613 493 570 615 411 556 496]
[598 528 623 559 299 473 391 536]
[41 14 51 3 97 70 34 99 63 52]
[ 49 69 104 0 112 90 38 88 11 83]

shuffle_repeat = [n.numpy().mean() for n, line_batch in shuffled]
plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.ylabel("Mean item ID")


先repeat后shuffle,在当前epoch结束时,会把下一个epoch开始的数据加入shuffle buffer,与上一个epoch末尾的数据放在一起shuffle。

dataset = tf.data.Dataset.zip((counter, lines))
shuffled = dataset.repeat(2).shuffle(buffer_size=100).batch(10)print("Here are the item ID's near the epoch boundary:\n")
for n, line_batch in shuffled.skip(55).take(15):print(n.numpy())

Here are the item ID’s near the epoch boundary:

[545 576 610 588 0 595 582 10 597 495]
[353 540 7 490 440 563 559 27 600 504]
[624 476 25 519 608 525 477 30 560 363]
[468 34 3 32 47 22 609 449 627 20 ]
[611 599 577 541 62 13 601 606 15 18 ]
[26 43 607 434 73 616 55 552 57 6 ]
[587 544 584 1 16 51 596 614 21 50 ]
[39 46 76 40 78 71 37 28 2 69 ]
[574 24 88 12 543 100 89 68 445 83 ]
[441 619 557 97 113 96 38 79 613 92 ]
[29 414 65 462 537 232 126 118 75 11 ]
[87 121 80 585 114 72 99 112 102 589]
[77 61 542 369 8 133 129 567 136 344]
[81 91 139 128 49 66 565 64 152 90 ]
[538 494 154 547 131 147 166 158 111 165]

repeat_shuffle = [n.numpy().mean() for n, line_batch in shuffled]plt.plot(shuffle_repeat, label="shuffle().repeat()")
plt.plot(repeat_shuffle, label="repeat().shuffle()")
plt.ylabel("Mean item ID")

5. 数据预处理 ¶


\;\;\;\;\; f函数的参数和返回值都必须是tf.Tensor

5.1 使用Dataset.map()进行数据预处理 ¶



list_ds = tf.data.Dataset.list_files(str(flowers_root/'*/*'))


# Reads an image from a file, decodes it into a dense tensor, and resizes it
# to a fixed shape.
def parse_image(filename):parts = tf.strings.split(filename, '/')label = parts[-2]image = tf.io.read_file(filename)image = tf.image.decode_jpeg(image)image = tf.image.convert_image_dtype(image, tf.float32)image = tf.image.resize(image, [128, 128])return image, label


file_path = next(iter(list_ds))
image, label = parse_image(file_path)def show(image, label):plt.figure()plt.imshow(image)plt.title(label.numpy().decode('utf-8'))plt.axis('off')show(image, label)


images_ds = list_ds.map(parse_image)for image, label in images_ds.take(2): # 查看2个example以验证正确性show(image, label)

5.2 使用非TF函数进行数据预处理 ¶



注意: tensorflow_addons 的 tensorflow_addons.image.rotate 中有一个TF兼容的 rotate函数。


import scipy.ndimage as ndimagedef random_rotate_image(image):image = ndimage.rotate(image, np.random.uniform(-30, 30), reshape=False)return image
image, label = next(iter(images_ds))
image = random_rotate_image(image)
show(image, label)

Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).


def tf_random_rotate_image(image, label):im_shape = image.shape[image,] = tf.py_function(random_rotate_image, [image], [tf.float32])image.set_shape(im_shape)return image, label
rot_ds = images_ds.map(tf_random_rotate_image)for image, label in rot_ds.take(2):show(image, label)

Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).

5.3 解析tf.Exampleprotocol buffer messages ¶


fsns_test_file = tf.keras.utils.get_file("fsns.tfrec", "https://storage.googleapis.com/download.tensorflow.org/data/fsns-20160927/testdata/fsns-00000-of-00001")
dataset = tf.data.TFRecordDataset(filenames = [fsns_test_file])

<TFRecordDatasetV2 shapes: (), types: tf.string>

你可以在td.data.Dataset外,使用tf.train.Example protos来了解数据:

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())feature = parsed.features.feature
raw_img = feature['image/encoded'].bytes_list.value[0]
img = tf.image.decode_png(raw_img)
_ = plt.title(feature["image/text"].bytes_list.value[0])

raw_example = next(iter(dataset))
def tf_parse(eg):example = tf.io.parse_example(eg[tf.newaxis], {'image/encoded': tf.io.FixedLenFeature(shape=(), dtype=tf.string),'image/text': tf.io.FixedLenFeature(shape=(), dtype=tf.string)})return example['image/encoded'][0], example['image/text'][0]
img, txt = tf_parse(raw_example)
print(repr(img.numpy()[:20]), "...")

b’Rue Perreyon’b’
\ x89PNG \ r \ n \ x1a \ n \ x00 \ x00 \ x00 \ rIHDR \ x00 \ x00 \ x02X’…

decoded = dataset.map(tf_parse)

<MapDataset shapes: ((), ()), types: (tf.string, tf.string)>

image_batch, text_batch = next(iter(decoded.batch(10)))


5.4 时间序列windowing ¶

For an end to end time series example see: Time series forecasting.



range_ds = tf.data.Dataset.range(100000)



5.4.1 使用batch ¶

batches = range_ds.batch(10, drop_remainder=True)for batch in batches.take(5):print(batch.numpy())

[0 1 2 3 4 5 6 7 8 9 ]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]

如果要对未来进行one step密集预测,您可以相对于彼此移动特征和标签 one step:

def dense_1_step(batch):# Shift features and labels one step relative to each other.return batch[:-1], batch[1:]predict_dense_1_step = batches.map(dense_1_step)for features, label in predict_dense_1_step.take(3):print(features.numpy(), " => ", label.numpy())

[0 1 2 3 4 5 6 7 8 ] => [1 2 3 4 5 6 7 8 9 ]
[10 11 12 13 14 15 16 17 18] => [11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28] => [21 22 23 24 25 26 27 28 29]


batches = range_ds.batch(15, drop_remainder=True)def label_next_5_steps(batch):return (batch[:-5],   # Take the first 5 stepsbatch[-5:])   # take the remainderpredict_5_steps = batches.map(label_next_5_steps)for features, label in predict_5_steps.take(3):print(features.numpy(), " => ", label.numpy())

[0 1 2 3 4 5 6 7 8 9 ] => [10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24] => [25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39] => [40 41 42 43 44]

如果要使一个batch的特征 与 另一个batch的标签有重合,请使用Dataset.zip

feature_length = 10
label_length = 5features = range_ds.batch(feature_length, drop_remainder=True)
labels = range_ds.batch(feature_length).skip(1).map(lambda labels: labels[:-5])predict_5_steps = tf.data.Dataset.zip((features, labels))for features, label in predict_5_steps.take(3):print(features.numpy(), " => ", label.numpy())

[0 1 2 3 4 5 6 7 8 9 ] => [10 11 12 13 14]
[10 11 12 13 14 15 16 17 18 19] => [20 21 22 23 24]
[20 21 22 23 24 25 26 27 28 29] => [30 31 32 33 34]

5.4.2 使用window ¶


window_size = 5windows = range_ds.window(window_size, shift=1)
for sub_ds in windows.take(5):print(sub_ds)

<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>
<_VariantDataset shapes: (), types: tf.int64>

Dataset.flat_map方法可以获取数据集的数据集并将其展平为单个数据集:can take a dataset of datasets and flatten it into a single dataset:

 for x in windows.flat_map(lambda x: x).take(30):print(x.numpy(), end=' ')

WARNING:tensorflow:AutoGraph could not transform <function at 0x7f973007e6a8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function at 0x7f973007e6a8> and will run it as-is.
Cause: could not parse the source code:

for x in windows.flat_map(lambda x: x).take(30):

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 3 4 5 6 7 4 5 6 7 8 5 6 7 8 9


def sub_to_batch(sub):return sub.batch(window_size, drop_remainder=True)for example in windows.flat_map(sub_to_batch).take(5):print(example.numpy())

[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]



def make_window_dataset(ds, window_size=5, shift=1, stride=1):windows = ds.window(window_size, shift=shift, stride=stride)def sub_to_batch(sub):return sub.batch(window_size, drop_remainder=True)windows = windows.flat_map(sub_to_batch)return windows
ds = make_window_dataset(range_ds, window_size=10, shift = 5, stride=3)for example in ds.take(10):print(example.numpy())

[0 3 6 9 12 15 18 21 24 27]
[5 8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34 37]
[15 18 21 24 27 30 33 36 39 42]
[20 23 26 29 32 35 38 41 44 47]
[25 28 31 34 37 40 43 46 49 52]
[30 33 36 39 42 45 48 51 54 57]
[35 38 41 44 47 50 53 56 59 62]
[40 43 46 49 52 55 58 61 64 67]
[45 48 51 54 57 60 63 66 69 72]


dense_labels_ds = ds.map(dense_1_step)for inputs,labels in dense_labels_ds.take(3):print(inputs.numpy(), "=>", labels.numpy())

[0 3 6 9 12 15 18 21 24] => [3 6 9 12 15 18 21 24 27]
[5 8 11 14 17 20 23 26 29] => [8 11 14 17 20 23 26 29 32]
[10 13 16 19 22 25 28 31 34] => [13 16 19 22 25 28 31 34 37]

5.5 重采样 ¶



zip_path = tf.keras.utils.get_file(origin='https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip',fname='creditcard.zip',extract=True)csv_path = zip_path.replace('.zip', '.csv')

Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/creditcard.zip
69156864/69155632 [==============================] - 2s 0us/step

creditcard_ds = tf.data.experimental.make_csv_dataset(csv_path, batch_size=1024, label_name="Class",# Set the column types: 30 floats and an int.column_defaults=[float()]*30+[int()])


def count(counts, batch):features, labels = batchclass_1 = labels == 1class_1 = tf.cast(class_1, tf.int32)class_0 = labels == 0class_0 = tf.cast(class_0, tf.int32)counts['class_0'] += tf.reduce_sum(class_0)counts['class_1'] += tf.reduce_sum(class_1)return counts
counts = creditcard_ds.take(10).reduce(initial_state={'class_0': 0, 'class_1': 0},reduce_func = count)counts = np.array([counts['class_0'].numpy(),counts['class_1'].numpy()]).astype(np.float32)fractions = counts/counts.sum()

[0.995 0.005]


5.5.1 Datasets.sampling



negative_ds = (creditcard_ds.unbatch().filter(lambda features, label: label==0).repeat())
positive_ds = (creditcard_ds.unbatch().filter(lambda features, label: label==1).repeat())

WARNING:tensorflow:AutoGraph could not transform <function at 0x7f9730114598> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function at 0x7f9730114598> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==0)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function at 0x7f97301149d8> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function at 0x7f97301149d8> and will run it as-is.
Cause: could not parse the source code:

.filter(lambda features, label: label==1)

This error may be avoided by creating the lambda in a standalone statement.

To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert

for features, label in positive_ds.batch(10).take(1):print(label.numpy())

[1 1 1 1 1 1 1 1 1 1 1]


balanced_ds = tf.data.experimental.sample_from_datasets([negative_ds, positive_ds], [0.5, 0.5]).batch(10)


for features, labels in balanced_ds.take(10):print(labels.numpy())

[1 0 1 1 0 1 0 0 0 0]
[1 1 0 0 0 0 0 1 0 1]
[1 1 0 0 0 1 0 0 1 1]
[1 0 0 1 0 1 1 0 0 0]
[0 0 1 0 0 0 0 1 1 1]
[0 1 1 1 1 0 0 1 0 1]
[0 0 0 0 1 0 1 1 1 1]
[0 0 0 1 1 1 0 0 0 1]
[1 1 0 1 1 1 1 1 1 0]
[1 1 1 1 0 1 0 0 1 1]

5.5.2 experimental.rejection_resample




creditcard_ds的元素已经是(features, label)对。因此,class_func只需要返回这些标签:

def class_func(features, label):return label


resampler = tf.data.experimental.rejection_resample(class_func, target_dist=[0.5, 0.5], initial_dist=fractions)


resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/data/experimental/ops/resampling.py:156: Print (from tensorflow.python.ops.logging_ops) is deprecated and will be removed after 2018-08-20.
Instructions for updating:
Use tf.print instead of tf.Print. Note that tf.print returns a no-output operator that directly prints the output. Outside of defuns or eager mode, this operator will not be executed unless it is directly specified in session.run or used as a control dependency for other operators. This is only a concern in graph mode. Below is an example of how to ensure tf.print executes in graph mode:

重采样器返回(class, example)从的输出创建对class_func。在这种情况下,example已经是(feature, label)一对,因此可map用于删除标签的多余副本:

balanced_ds = resample_ds.map(lambda extra_label, features_and_label: features_and_label)


for features, labels in balanced_ds.take(10):print(labels.numpy())

[1 1 1 1 0 0 1 0 0 0]
[1 1 1 1 1 1 1 1 0 1]
[1 1 0 1 0 0 0 1 0 0]
[0 0 0 1 1 0 0 1 1 0]
[1 0 1 0 0 1 1 0 1 0]
[1 1 0 1 0 1 0 0 1 0]
[0 1 1 1 0 1 1 1 1 1]
[0 0 1 0 1 0 0 1 0 1]
[1 1 0 1 1 0 0 1 1 1]
[1 1 0 0 0 1 0 1 1 0]

6. 在高阶API中使用tf.data

6.1 在 tf.keras 中使用 tf.data

tf.keras API 极大地降低了创建、使用机器学习模型的难度。它的.fit().evaluate().predict() API支持tf.data作为输入。下面是一个简单的示例:

train, test = tf.keras.datasets.fashion_mnist.load_data()images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)model = tf.keras.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(10)
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

Model.fitModel.evaluate都需要 数据+标签:

model.fit(fmnist_train_ds, epochs=2)

Epoch 1/2
WARNING:tensorflow:Layer flatten is casting an input tensor from dtype float64 to the layer’s dtype of float32, which is new behavior in TensorFlow 2. The layer has dtype float32 because its dtype defaults to floatx.

If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.

To change all layers to have dtype float64 by default, call tf.keras.backend.set_floatx('float64'). To change just this layer, pass dtype=‘float64’ to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

1875/1875 [==============================] - 3s 2ms/step - loss: 0.6013 - accuracy: 0.7970
Epoch 2/2
1875/1875 [==============================] - 3s 2ms/step - loss: 0.4617 - accuracy: 0.8418

<tensorflow.python.keras.callbacks.History at 0x7f97801f1588>



model.fit(fmnist_train_ds.repeat(), epochs=2, steps_per_epoch=20)

Epoch 1/2
20/20 [==============================] - 0s 2ms/step - loss: 0.4650 - accuracy: 0.8422
Epoch 2/2
20/20 [==============================] - 0s 2ms/step - loss: 0.3897 - accuracy: 0.8797

<tensorflow.python.keras.callbacks.History at 0x7f97801f1908>


loss, accuracy = model.evaluate(fmnist_train_ds)
print("Loss :", loss)
print("Accuracy :", accuracy)

1875/1875 [==============================] - 3s 2ms/step - loss: 0.4423 - accuracy: 0.8473
Loss : 0.44227170944213867
Accuracy : 0.847266674041748


loss, accuracy = model.evaluate(fmnist_train_ds.repeat(), steps=10)
print("Loss :", loss)
print("Accuracy :", accuracy)

10/10 [==============================] - 0s 2ms/step - loss: 0.4557 - accuracy: 0.8188
Loss : 0.45573288202285767
Accuracy : 0.8187500238418579


predict_ds = tf.data.Dataset.from_tensor_slices(images).batch(32)
result = model.predict(predict_ds, steps = 10)

(320, 10)


result = model.predict(fmnist_train_ds, steps = 10)

(320, 10)

6.2 在 tf.estimator 中使用 tf.data


官方教程对于这块的介绍有点不足,推荐大家阅读《TensorFlow Estimator 官方文档之----Dataset for Estimator》,里面比较详细地介绍了怎么在tf.estimator中使用tf.data

import tensorflow_datasets as tfdsdef train_input_fn():titanic = tf.data.experimental.make_csv_dataset(titanic_file, batch_size=32,label_name="survived")titanic_batches = (titanic.cache().repeat().shuffle(500).prefetch(tf.data.experimental.AUTOTUNE))return titanic_batches
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third'])
age = tf.feature_column.numeric_column('age')
import tempfile
model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(model_dir=model_dir,feature_columns=[embark, cls, age],n_classes=2

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {’_model_dir’: ‘/tmp/tmp7xfmvz5w’, ‘_tf_random_seed’: None, ‘_save_summary_steps’: 100, ‘_save_checkpoints_steps’: None, ‘_save_checkpoints_secs’: 600, ‘_session_config’: allow_soft_placement: true
graph_options {
rewrite_options {
meta_optimizer_iterations: ONE
, ‘_keep_checkpoint_max’: 5, ‘_keep_checkpoint_every_n_hours’: 10000, ‘_log_step_count_steps’: 100, ‘_train_distribute’: None, ‘_device_fn’: None, ‘_protocol’: None, ‘_eval_distribute’: None, ‘_experimental_distribute’: None, ‘_experimental_max_worker_delay_secs’: None, ‘_session_creation_timeout_secs’: 7200, ‘_service’: None, ‘_cluster_spec’: ClusterSpec({}), ‘_task_type’: ‘worker’, ‘_task_id’: 0, ‘_global_id_in_cluster’: 0, ‘_master’: ‘’, ‘_evaluation_master’: ‘’, ‘_is_chief’: True, ‘_num_ps_replicas’: 0, ‘_num_worker_replicas’: 1}

model = model.train(input_fn=train_input_fn, steps=100)

WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/feature_column/feature_column_v2.py:560: Layer.add_variable (from tensorflow.python.keras.engine.base_layer_v1) is deprecated and will be removed in a future version.
Instructions for updating:
Please use layer.add_weight method instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/ftrl.py:143: calling Constant.init (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0…
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp7xfmvz5w/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0…
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100…
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmp7xfmvz5w/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100…
INFO:tensorflow:Loss for final step: 0.5968354.

result = model.evaluate(train_input_fn, steps=10)for key, value in result.items():print(key, ":", value)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2020-03-28T01:27:11Z
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp7xfmvz5w/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.65018s
INFO:tensorflow:Finished evaluation at 2020-03-28-01:27:11
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.684375, accuracy_baseline = 0.603125, auc = 0.73216105, auc_precision_recall = 0.6447562, average_loss = 0.60841894, global_step = 100, label/mean = 0.396875, loss = 0.60841894, precision = 0.76, prediction/mean = 0.31196585, recall = 0.2992126
INFO:tensorflow:Saving ‘checkpoint_path’ summary for global step 100: /tmp/tmp7xfmvz5w/model.ckpt-100
accuracy : 0.684375
accuracy_baseline : 0.603125
auc : 0.73216105
auc_precision_recall : 0.6447562
average_loss : 0.60841894
label/mean : 0.396875
loss : 0.60841894
precision : 0.76
prediction/mean : 0.31196585
recall : 0.2992126
global_step : 100

for pred in model.predict(train_input_fn):for key, value in pred.items():print(key, ":", value)break

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmp7xfmvz5w/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-0.1131]
logistic : [0.4717]
probabilities : [0.5283 0.4717]
class_ids : [0]
classes : [b’0’]
all_class_ids : [0 1]
all_classes : [b’0’ b’1’]

注:本文来自于TenosrFlow官方使用tf.data导入数据的 Learn > Guide > tf.data


