tensorflow教程 开始——数据集:快速了解 tf.data
参考文章:数据集:快速了解
数据集:快速了解
tf.data
- 从 numpy 数组读取内存数据。
- 逐行读取 csv 文件。
基本输入
学习如何获取数组的片段,是开始学习 tf.data 最简单的方式。
Premade Estimators
def train_input_fn(features, labels, batch_size):"""一个用来训练的输入函数"""# 将输入值转化为数据集。dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))# 混排、重复、批处理样本。dataset = dataset.shuffle(1000).repeat().batch(batch_size)# 返回数据集return dataset
下面我们来对这个函数做更仔细的分析。
参数
这个函数一共需要三个参数。如果一个参数的期望类型是 “array” (数组),那么它将可以接受几乎所有可以用 numpy.array 来转化为数组的值。我们可以看到只有一个例外:tuple,它对 Datasets 有特殊的含义。
- features:一个形如 {‘feature_name’:array} 的数据字典(或者是 DataFrame),它包含了原始的输入特征。
- labels:一个包含每个样本的 label 的数组。
- batch_size:一个指示所需批量大小的整数。
在 premade_estimator.py 中,我们使用 iris_data.load_data() 函数来检索虹膜数据。
你可以运行该函数,并按如下方式解压结果:
import iris_data# 获取数据
train, test = iris_data.load_data()
features, labels = train
然后用像下面这样的一行代码,将数据传递给 input 函数:
batch_size=100
iris_data.train_input_fn(features, labels, batch_size)
让我们来具体看看 train_input_fn() 函数。
(数组)片段
TF Layers 教程:构建卷积神经网络
返回这个 Dataset 的代码如下所示:
train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = trainmnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)
print(mnist_ds)
张量
<TensorSliceDataset shapes: (28,28), types: tf.uint8>
上述的 Dataset 表示数组的简单集合,但数据集比这更复杂。Dataset 可以透明地处理任何嵌套的字典或元组组合(或者 namedtuple)。
例如,将 irls 的 features 转换为标准 python 字典之后,你可以将数组字典转换为字典的 Dataset,如下所示:
dataset = tf.data.Dataset.from_tensor_slices(dict(features))
print(dataset)
<TensorSliceDatasetshapes: {SepalLength: (), PetalWidth: (),PetalLength: (), SepalWidth: ()},types: {SepalLength: tf.float64, PetalWidth: tf.float64,PetalLength: tf.float64, SepalWidth: tf.float64}
>
张量
iris 的第一行 train_input_fn 使用相同的功能,但是增加了一层结构。它创建了一个包含 (features_dict, label) 数据对的数据集。
以下代码表明,标签是类型为 int64 的标量:
# 将输入转化为数据集。
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
print(dataset)
<TensorSliceDatasetshapes: ({SepalLength: (), PetalWidth: (),PetalLength: (), SepalWidth: ()},()),types: ({SepalLength: tf.float64, PetalWidth: tf.float64,PetalLength: tf.float64, SepalWidth: tf.float64},tf.int64)>
操作
目前,Dataset 会按照固定顺序遍历数据一次,且一次只能生成一个元素。在可以用于训练之前,它需要进一步的处理。幸运的是,tf.data.Dataset 类提供了方法来让数据为训练作出更好的准备。train_input_fn 的下一行代码就利用了几个这样的方法:
# 样本的混排、重复、批处理。
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
tf.data.Dataset.shuffle
tf.data.Dataset.repeat
tf.data.Dataset.batch
print(mnist_ds.batch(100))
<BatchDatasetshapes: (?, 28, 28),types: tf.uint8>
注意,因为最后一个批次将会有比较少的元素,因此数据集的批量大小是不确定的。
在 train_input_fn 中,批处理之后,数据集 包含元素们的一维向量,这些一维向量的前面部分是:
print(dataset)
<TensorSliceDatasetshapes: ({SepalLength: (?,), PetalWidth: (?,),PetalLength: (?,), SepalWidth: (?,)},(?,)),types: ({SepalLength: tf.float64, PetalWidth: tf.float64,PetalLength: tf.float64, SepalWidth: tf.float64},tf.int64)>
返回
此时,Dataset 包含 (features_dict, labels) 对。这是 train 和 evaluate 方法所期望的格式,因此 input_fn 将返回数据集。
在使用 predict 方法时,可以/应该省略 labels。
读取 CSV 文件
tf.data
如下对 iris_data.maybe_download 函数的调用,将会在必要的时候下载数据,并返回结果文件的路径:
import iris_data
train_path, test_path = iris_data.maybe_download()
iris_data.csv_input_fn 函数包括了一个用 Dataset 解析 csv 文件的替代方案。
让我们来看看如何构建一个兼容 Estimator 的、可以读取本地文件的输入函数。
建立 Dataset
tf.data.Dataset.skip
ds = tf.data.TextLineDataset(train_path).skip(1)
建立一个 csv 行解析器
我们从建立一个可以解析一行的函数开始。
tf.decode_csv
tf.decode_csv
# 描述文本列的元数据
COLUMNS = ['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth','label']
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):# 将行解码到 fields 中fields = tf.decode_csv(line, FIELD_DEFAULTS)# 将结果打包成字典features = dict(zip(COLUMNS,fields))# 将标签从特征中分离label = features.pop('label')return features, label
解析多行
tf.data.Dataset.map
这个 map 方法接受一个 map_func 参数,这个参数描述了 Dataset 中的每一个元素应该如何被转化。
tf.data.Dataset.map
因此,为了在多行数据被从 csv 文件中读取出来的时候解析它们,我们为 map 方法提供 _parse_line 函数:
ds = ds.map(_parse_line)
print(ds)
<MapDataset
shapes: ({SepalLength: (), PetalWidth: (), ...},()),
types: ({SepalLength: tf.float32, PetalWidth: tf.float32, ...},tf.int32)>
现在,数据集中包含的是 (features, label) 数据对,而不是简单的字符串标量了。
iris_data.csv_input_fn 函数的余下部分和 Basic input 中介绍的 iris_data.train_input_fn 函数相同。
实践
这个函数可以作为 iris_data.train_input_fn 的替代。它可以像如下这样,来给 estimator 提供数据:
train_path, test_path = iris_data.maybe_download()# 所有的输入都是数字
feature_columns = [tf.feature_column.numeric_column(name)for name in iris_data.CSV_COLUMN_NAMES[:-1]]# 构建 estimator
est = tf.estimator.LinearClassifier(feature_columns,n_classes=3)
# 训练 estimator
batch_size = 100
est.train(steps=1000,input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))
Estimator 期望 input_fn 没有任何参数。要解除这个限制,我们使用 lambda 来捕获参数并提供预期的接口。
总结
为了从不同的数据源中便捷的读取数据,tf.data 模块提供了类和函数的集合。除此之外,tf.data 有简单并且强大的方法,来应用各种标准和自定义转换。
现在你已经基本了解了如何为 Estimator 高效的获取数据。(作为扩展)接下来可以思考如下的文档:
- 创建定制化 Estimator
- 底层 API 编程介绍
- 数据导入
tensorflow教程 开始——数据集:快速了解 tf.data相关推荐
- 【Tensorflow教程笔记】常用模块 tf.function :图执行模式
基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...
- 【Tensorflow教程笔记】常用模块 tf.train.Checkpoint :变量的保存与恢复
基础 TensorFlow 基础 TensorFlow 模型建立与训练 基础示例:多层感知机(MLP) 卷积神经网络(CNN) 循环神经网络(RNN) 深度强化学习(DRL) Keras Pipeli ...
- TensorFlow tf.data 导入数据(tf.data官方教程) * * * * *
原文链接:https://blog.csdn.net/u014061630/article/details/80728694 TensorFlow版本:1.10.0 > Guide > I ...
- 【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法
Tensorflow 2.0中提供了专门用于数据输入的接口tf.data.Dataset,可以简洁高效的实现数据的读入.打乱(shuffle).增强(augment)等功能.下面以一个简单的实例讲解该 ...
- Tensorflow2.* 加载和预处理数据之用 tf.data 加载磁盘图片数据(4)
Tensorflow2.* 机器学习基础知识篇: 对服装图像进行分类 使用Tensorflow Hub对未处理的电影评论数据集IMDB进行分类 Keras 机器学习基础知识之对预处理的电影评论文本分类 ...
- openlayers map获取全部feature_tf2.0基础-tf.data与tf.feature_column
7.2.1 tf.data 使用 tf.data API 可以轻松处理大量数据.不同的数据格式以及复杂的转换.tf.data API 在 TensorFlow 中引入了两个新的抽象类: tf.data ...
- tensorflow tf.data.Dataset.from_tensor_slices() (创建一个“数据集”,其元素是给定张量的切片)
from tensorflow\python\data\ops\dataset_ops.py @staticmethoddef from_tensor_slices(tensors):"&q ...
- 【小白学习tensorflow教程】一、tensorflow基本操作、快速构建线性回归和分类模型
@Author:Runsen TF 目前发布2.5 版本,之前阅读1.X官方文档,最近查看2.X的文档.tensorflow是非常强的工具,生态庞大. tensorflow提供了Keras的分支,这里 ...
- Tensorflow读取数据-tf.data.TFRecordDataset
tensorflow TFRecords文件的生成和读取方法 文章目录 tensorflow TFRecords文件的生成和读取方法 1. TFRecords说明 2.关键API 2.1 tf.io. ...
最新文章
- [汇编] 001基础知识-什么是汇编
- (转)淘淘商城系列——实现添加商品功能
- 图网络笔记-知识补充与node2vec代码注解
- valuable的用法_词汇精选:valuable的用法和辨析
- poj 1151 hdu 1542 Atlantis 线段树扫描线(详细讲解)
- mac如何删除用户或者群组
- GPS测量中所涉及的时间系统
- 计算机显卡驱动不匹配,显卡驱动与系统不兼容?尤其是老电脑
- 常用颜色中英文名称及RGB数值表
- zookeeper之ZkClient使用,java电子书阅读器开发
- win10总显示打印机未连接服务器,win10安装打印机一直未响应。。。
- CreateProcess的用法
- DXP画图快捷键等资料
- JVM调优专题-JVM调优参数
- 海信电视power android,海信85U9E评测:基于Android 9.0深度优化的电视系统
- SimpleDateFormat.parse()方法中的时区设置缺陷
- base64编码转码
- datatable invalid json format
- 新版掩日免杀——搭配CS使用测试
- 计算机专业毕业设计选题
热门文章
- 医学院计算机社发展,医学院计算机教学创新思路.docx
- pip install pygame_使用 Python 和 Pygame 模块构建一个游戏框架!
- 超级直播sop直播源.zip_双11首场虚拟直播,天猫超级直播开创直播新玩法
- 【MM模块】Basic Invoice Verification 发票校验
- 【PP主数据】工作中心介绍
- 【重复制造精讲】定义重复制造参数文件
- Function实现ALV Table六:页眉页脚
- think in java i o_5.[Think in Java笔记]Java IO系统
- mysql转化为GaussDB,GaussDB(openGauss)宣布开源,性能超越 MySQL 与 PostgreSQL
- python导出csv不带引号的句子_不带双引号写入CSV文件