tfds.load()和tf.data.Dataset的简介
tfds.load()和tf.data.Dataset的简介
tfds.load()有以下参数
tfds.load(name, split=None, data_dir=None, batch_size=None, shuffle_files=False,download=True, as_supervised=False, decoders=None, read_config=None,with_info=False, builder_kwargs=None, download_and_prepare_kwargs=None,as_dataset_kwargs=None, try_gcs=False
)
重要参数如下:
- name 数据集的名字
- split 对数据集的切分
- data_dir 数据的位置或者数据下载的位置
- batch_size 批道数
- shuffle_files 打乱
- as_supervised 返回元组(默认返回时字典的形式的)
1.数据的切分
# 拿数据集中训练集(数据集默认划分为train,test)
train_ds = tfds.load('mnist', split='train')# 两部分都拿出来
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'])# 两部分都拿出来,并合成一个
train_test_ds = tfds.load('mnist', split='train+test')# 从训练集的10(含)到20(不含)
train_10_20_ds = tfds.load('mnist', split='train[10:20]')# 训练集的前10%
train_10pct_ds = tfds.load('mnist', split='train[:10%]')# 训练集的前10%和后80%
train_10_80pct_ds = tfds.load('mnist', split='train[:10%]+train[-80%:]')#---------------------------------------------------
# 10%的交错验证集:
# 没批验证集拿训练集的10%:
# [0%:10%], [10%:20%], ..., [90%:100%].
vals_ds = tfds.load('mnist', split=[f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
# 训练集拿90%:
# [10%:100%] (验证集为 [0%:10%]),
# [0%:10%] + [20%:100%] (验证集为 [10%:20%]), ...,
# [0%:90%] (验证集为 [90%:100%]).
trains_ds = tfds.load('mnist', split=[f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])
还有使用ReadInstruction
API 切分的,效果跟上面一样
# The full `train` split.
train_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train'))# The full `train` split and the full `test` split as two distinct datasets.
train_ds, test_ds = tfds.load('mnist', split=[tfds.core.ReadInstruction('train'),tfds.core.ReadInstruction('test'),
])# The full `train` and `test` splits, interleaved together.
ri = tfds.core.ReadInstruction('train') + tfds.core.ReadInstruction('test')
train_test_ds = tfds.load('mnist', split=ri)# From record 10 (included) to record 20 (excluded) of `train` split.
train_10_20_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', from_=10, to=20, unit='abs'))# The first 10% of train split.
train_10_20_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', to=10, unit='%'))# The first 10% of train + the last 80% of train.
ri = (tfds.core.ReadInstruction('train', to=10, unit='%') +tfds.core.ReadInstruction('train', from_=-80, unit='%'))
train_10_80pct_ds = tfds.load('mnist', split=ri)# 10-fold cross-validation (see also next section on rounding behavior):
# The validation datasets are each going to be 10%:
# [0%:10%], [10%:20%], ..., [90%:100%].
# And the training datasets are each going to be the complementary 90%:
# [10%:100%] (for a corresponding validation set of [0%:10%]),
# [0%:10%] + [20%:100%] (for a validation set of [10%:20%]), ...,
# [0%:90%] (for a validation set of [90%:100%]).
vals_ds = tfds.load('mnist', [tfds.core.ReadInstruction('train', from_=k, to=k+10, unit='%')for k in range(0, 100, 10)])
trains_ds = tfds.load('mnist', [(tfds.core.ReadInstruction('train', to=k, unit='%') +tfds.core.ReadInstruction('train', from_=k+10, unit='%'))for k in range(0, 100, 10)])
2.返回的对象
返回的对象是一个tf.data.Dataset或者和一个tfds.core.DatasetInfo(如果有的话)
3.指定目录
指定目录十分简单(默认会放到用户目录下面)
train_ds = tfds.load('mnist', split='train',data_dir='~/user')
4.获取img和label
因为返回的是一个tf.data.Dataset对象,我们可以在对其进行迭代之前对数据集进行操作,以此来获取符合我们要求的数据。
tf.data.Dataset有以下几个重要的方法:
4.1 shuffle
数据的打乱
shuffle(buffer_size, seed=None, reshuffle_each_iteration=None
)
#随机重新排列此数据集的元素。
#该数据集用buffer_size元素填充缓冲区,然后从该缓冲区中随机采样元素,将所选元素替换为新元素。为了实现完美
#的改组,需要缓冲区大小大于或等于数据集的完整大小。
#例如,如果您的数据集包含10,000个元素但buffer_size设置为1,000个,则shuffle最初将仅从缓冲区的前1,000
#个元素中选择一个随机元素。选择一个元素后,其缓冲区中的空间将由下一个(即1,001个)元素替换,并保留1,000个#元素缓冲区。
#reshuffle_each_iteration控制随机播放顺序对于每个时期是否应该不同。
4.2 batch
批道大小(一批多少个数据),迭代的是时候根据批道数放回对应的数据量
batch(batch_size, drop_remainder=False
)dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
返回的是一个Dataset
4.3 map
用跟普通的map方法差不多,目的是对数据集操作
map(map_func, num_parallel_calls=None, deterministic=None
)dataset = Dataset.range(1, 6) # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())
返回的是一个Dataset
4.4 as_numpy_iterator
返回一个迭代器,该迭代器将数据集的所有元素转换为numpy。
使用as_numpy_iterator
检查你的数据集的内容。要查看元素的形状和类型,请直接打印数据集元素,而不要使用 as_numpy_iterator
。
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:print(element)
#tf.Tensor( 1 , shape = ( ) , dtype = int32 )
#tf.Tensor ( 2 , shape = ( ) . dtype = int32 )
#tf.Tensor ( 3 , shape = ( ) , dtype = int32 )dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():print(element)
#1
#2
#3
4.5 对数据集操作示例
通过下面的写法可以获取符合格式的数据:
#先用map()将img进行resize,然后进打乱,然后设定迭代的放回的batch_size
dataset_train = dataset_train.map(lambda img, label: (tf.image.resize(img, (224, 224)) / 255.0, label)).shuffle(1024).batch(batch_size)#因为是测试集,所以不打乱,只是把img进行resize
dataset_test = dataset_test.map(lambda img, label: (tf.image.resize(img, (224, 224)) / 255.0, label)).batch(batch_size)
对数据进行迭代:
for images, labels in dataset_train:labels_pred = model(images, training=True)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=labels, y_pred=labels_pred)loss = tf.reduce_mean(loss)········
tfds.load()和tf.data.Dataset的简介相关推荐
- 使用tf.data.Dataset加载numpy数据
Mnist数据集 0~9的手写体图片,该数据默认已经将数据分成训练集和测试集.训练集有60000张图片,测试集有10000张图片. 导入必要库 import tensorflow as tf from ...
- tf.data.Dataset.interleave
本文对tf.data.Dataset.interleave该方法做点笔记. 在tensorflow中数据处理通道中,有一个方法interleave,tf定义如下: interleave(map_fun ...
- tf.data.Dataset.from_tensor_slices 的用法
将python列表和numpy数组转换成tensorflow的dataset 只有dataset才能被model.fit函数训练 import tensorflow as tf import nump ...
- tf.data.Dataset 用法
tf.data.DatasetAPI支持写入的描述性和高效的输入管线.Dataset用法遵循一个常见模式: 从输入数据创建源数据集. 应用数据集转换来预处理数据. 迭代数据集并处理元素. 迭代以流式方 ...
- tf.data.Dataset.from_tensor_slices() 详解
函数原型: tf.data.Dataset.from_tensor_slices(tensors, name=None ) 官网地址:https://www.tensorflow.org/api_do ...
- tensorflow tf.data.Dataset.from_tensor_slices() (创建一个“数据集”,其元素是给定张量的切片)
from tensorflow\python\data\ops\dataset_ops.py @staticmethoddef from_tensor_slices(tensors):"&q ...
- 记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices
tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple).列表(list).张量(tensor)等特征进行特征切片.切片的 ...
- TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制
TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...
- tensorflow学习笔记:tf.data.Dataset,from_tensor_slices(),shuffle(),batch()的用法
tf.data.Dataset.from_tensor_slices: 它的作用是切分传入Tensor的第一个维度,生成相应的dataset. 例1: dataset = tf.data.Datase ...
最新文章
- 程序员创业前要做哪些准备?
- Java多线程中join方法详解
- Spring MVC-08循序渐进之国际化(AcceptHeaderLocaleResolver)
- 电商的折扣体系是如何搭建的?
- 深度学习问题解决:Check failed: stream-parent()-GetConvolveAlgorithms( conv_parameters.ShouldIncludeWinogra
- 关于数据库插入中文乱码问题
- 基于node.js+MongoDB+elementui的分页接口以及页面实现
- Eclipse导入项目:No projects are found to import
- html总结:文本框填满表格
- Spring JSF集成
- MySQL追加注释或者大量修改注释
- vue-router组件重用 路由切换时的问题
- C++ template —— 模板中的名称(三)
- 查看Linux版本命令
- 【信息系统项目管理师】2018年下半年信息系统项目管理师上午综合知识真题
- 财务软件服务器装什么系统,财务软件用哪种云服务器
- C语言:Fibonacci数列打印
- python开发一个PC屏幕监控软件(2000块的道德底线)
- 功率单位mw和dbm的换算总结
- eclipse 中用svn共享项目