tfds.load()和tf.data.Dataset的简介

tfds.load()有以下参数

tfds.load(name, split=None, data_dir=None, batch_size=None, shuffle_files=False,download=True, as_supervised=False, decoders=None, read_config=None,with_info=False, builder_kwargs=None, download_and_prepare_kwargs=None,as_dataset_kwargs=None, try_gcs=False
)

重要参数如下:

  • name 数据集的名字
  • split 对数据集的切分
  • data_dir 数据的位置或者数据下载的位置
  • batch_size 批道数
  • shuffle_files 打乱
  • as_supervised 返回元组(默认返回时字典的形式的)

1.数据的切分

# 拿数据集中训练集(数据集默认划分为train,test)
train_ds = tfds.load('mnist', split='train')# 两部分都拿出来
train_ds, test_ds = tfds.load('mnist', split=['train', 'test'])# 两部分都拿出来,并合成一个
train_test_ds = tfds.load('mnist', split='train+test')# 从训练集的10(含)到20(不含)
train_10_20_ds = tfds.load('mnist', split='train[10:20]')# 训练集的前10%
train_10pct_ds = tfds.load('mnist', split='train[:10%]')# 训练集的前10%和后80%
train_10_80pct_ds = tfds.load('mnist', split='train[:10%]+train[-80%:]')#---------------------------------------------------
# 10%的交错验证集:
# 没批验证集拿训练集的10%:
# [0%:10%], [10%:20%], ..., [90%:100%].
vals_ds = tfds.load('mnist', split=[f'train[{k}%:{k+10}%]' for k in range(0, 100, 10)
])
# 训练集拿90%:
# [10%:100%] (验证集为 [0%:10%]),
# [0%:10%] + [20%:100%] (验证集为 [10%:20%]), ...,
# [0%:90%] (验证集为 [90%:100%]).
trains_ds = tfds.load('mnist', split=[f'train[:{k}%]+train[{k+10}%:]' for k in range(0, 100, 10)
])

还有使用ReadInstruction API 切分的,效果跟上面一样

# The full `train` split.
train_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train'))# The full `train` split and the full `test` split as two distinct datasets.
train_ds, test_ds = tfds.load('mnist', split=[tfds.core.ReadInstruction('train'),tfds.core.ReadInstruction('test'),
])# The full `train` and `test` splits, interleaved together.
ri = tfds.core.ReadInstruction('train') + tfds.core.ReadInstruction('test')
train_test_ds = tfds.load('mnist', split=ri)# From record 10 (included) to record 20 (excluded) of `train` split.
train_10_20_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', from_=10, to=20, unit='abs'))# The first 10% of train split.
train_10_20_ds = tfds.load('mnist', split=tfds.core.ReadInstruction('train', to=10, unit='%'))# The first 10% of train + the last 80% of train.
ri = (tfds.core.ReadInstruction('train', to=10, unit='%') +tfds.core.ReadInstruction('train', from_=-80, unit='%'))
train_10_80pct_ds = tfds.load('mnist', split=ri)# 10-fold cross-validation (see also next section on rounding behavior):
# The validation datasets are each going to be 10%:
# [0%:10%], [10%:20%], ..., [90%:100%].
# And the training datasets are each going to be the complementary 90%:
# [10%:100%] (for a corresponding validation set of [0%:10%]),
# [0%:10%] + [20%:100%] (for a validation set of [10%:20%]), ...,
# [0%:90%] (for a validation set of [90%:100%]).
vals_ds = tfds.load('mnist', [tfds.core.ReadInstruction('train', from_=k, to=k+10, unit='%')for k in range(0, 100, 10)])
trains_ds = tfds.load('mnist', [(tfds.core.ReadInstruction('train', to=k, unit='%') +tfds.core.ReadInstruction('train', from_=k+10, unit='%'))for k in range(0, 100, 10)])

2.返回的对象

返回的对象是一个tf.data.Dataset或者和一个tfds.core.DatasetInfo(如果有的话)

3.指定目录

指定目录十分简单(默认会放到用户目录下面)

train_ds = tfds.load('mnist', split='train',data_dir='~/user')

4.获取img和label

因为返回的是一个tf.data.Dataset对象,我们可以在对其进行迭代之前对数据集进行操作,以此来获取符合我们要求的数据。

tf.data.Dataset有以下几个重要的方法:

4.1 shuffle

数据的打乱

shuffle(buffer_size, seed=None, reshuffle_each_iteration=None
)
#随机重新排列此数据集的元素。
#该数据集用buffer_size元素填充缓冲区,然后从该缓冲区中随机采样元素,将所选元素替换为新元素。为了实现完美
#的改组,需要缓冲区大小大于或等于数据集的完整大小。
#例如,如果您的数据集包含10,000个元素但buffer_size设置为1,000个,则shuffle最初将仅从缓冲区的前1,000
#个元素中选择一个随机元素。选择一个元素后,其缓冲区中的空间将由下一个(即1,001个)元素替换,并保留1,000个#元素缓冲区。
#reshuffle_each_iteration控制随机播放顺序对于每个时期是否应该不同。

4.2 batch

批道大小(一批多少个数据),迭代的是时候根据批道数放回对应的数据量

batch(batch_size, drop_remainder=False
)dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())

返回的是一个Dataset

4.3 map

用跟普通的map方法差不多,目的是对数据集操作

map(map_func, num_parallel_calls=None, deterministic=None
)dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())

返回的是一个Dataset

4.4 as_numpy_iterator

返回一个迭代器,该迭代器将数据集的所有元素转换为numpy。

使用as_numpy_iterator检查你的数据集的内容。要查看元素的形状和类型,请直接打印数据集元素,而不要使用 as_numpy_iterator

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset:print(element)
#tf.Tensor( 1 , shape = ( ) , dtype = int32 )
#tf.Tensor ( 2 , shape = ( ) . dtype = int32 )
#tf.Tensor ( 3 , shape = ( ) , dtype = int32 )dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():print(element)
#1
#2
#3

4.5 对数据集操作示例

通过下面的写法可以获取符合格式的数据:

#先用map()将img进行resize,然后进打乱,然后设定迭代的放回的batch_size
dataset_train = dataset_train.map(lambda img, label: (tf.image.resize(img, (224, 224)) / 255.0, label)).shuffle(1024).batch(batch_size)#因为是测试集,所以不打乱,只是把img进行resize
dataset_test = dataset_test.map(lambda img, label: (tf.image.resize(img, (224, 224)) / 255.0, label)).batch(batch_size)

对数据进行迭代:

for images, labels in dataset_train:labels_pred = model(images, training=True)loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=labels, y_pred=labels_pred)loss = tf.reduce_mean(loss)········

tfds.load()和tf.data.Dataset的简介相关推荐

  1. 使用tf.data.Dataset加载numpy数据

    Mnist数据集 0~9的手写体图片,该数据默认已经将数据分成训练集和测试集.训练集有60000张图片,测试集有10000张图片. 导入必要库 import tensorflow as tf from ...

  2. tf.data.Dataset.interleave

    本文对tf.data.Dataset.interleave该方法做点笔记. 在tensorflow中数据处理通道中,有一个方法interleave,tf定义如下: interleave(map_fun ...

  3. tf.data.Dataset.from_tensor_slices 的用法

    将python列表和numpy数组转换成tensorflow的dataset 只有dataset才能被model.fit函数训练 import tensorflow as tf import nump ...

  4. tf.data.Dataset 用法

    tf.data.DatasetAPI支持写入的描述性和高效的输入管线.Dataset用法遵循一个常见模式: 从输入数据创建源数据集. 应用数据集转换来预处理数据. 迭代数据集并处理元素. 迭代以流式方 ...

  5. tf.data.Dataset.from_tensor_slices() 详解

    函数原型: tf.data.Dataset.from_tensor_slices(tensors, name=None ) 官网地址:https://www.tensorflow.org/api_do ...

  6. tensorflow tf.data.Dataset.from_tensor_slices() (创建一个“数据集”,其元素是给定张量的切片)

    from tensorflow\python\data\ops\dataset_ops.py @staticmethoddef from_tensor_slices(tensors):"&q ...

  7. 记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices

    tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple).列表(list).张量(tensor)等特征进行特征切片.切片的 ...

  8. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  9. tensorflow学习笔记:tf.data.Dataset,from_tensor_slices(),shuffle(),batch()的用法

    tf.data.Dataset.from_tensor_slices: 它的作用是切分传入Tensor的第一个维度,生成相应的dataset. 例1: dataset = tf.data.Datase ...

最新文章

  1. 程序员创业前要做哪些准备?
  2. Java多线程中join方法详解
  3. Spring MVC-08循序渐进之国际化(AcceptHeaderLocaleResolver)
  4. 电商的折扣体系是如何搭建的?
  5. 深度学习问题解决:Check failed: stream-parent()-GetConvolveAlgorithms( conv_parameters.ShouldIncludeWinogra
  6. 关于数据库插入中文乱码问题
  7. 基于node.js+MongoDB+elementui的分页接口以及页面实现
  8. Eclipse导入项目:No projects are found to import
  9. html总结:文本框填满表格
  10. Spring JSF集成
  11. MySQL追加注释或者大量修改注释
  12. vue-router组件重用 路由切换时的问题
  13. C++ template —— 模板中的名称(三)
  14. 查看Linux版本命令
  15. 【信息系统项目管理师】2018年下半年信息系统项目管理师上午综合知识真题
  16. 财务软件服务器装什么系统,财务软件用哪种云服务器
  17. C语言:Fibonacci数列打印
  18. python开发一个PC屏幕监控软件(2000块的道德底线)
  19. 功率单位mw和dbm的换算总结
  20. eclipse 中用svn共享项目

热门文章

  1. 用Log Parser Studio分析IIS日志
  2. KMP算法的JavaScript实现
  3. 一个通过数据库镜像实现SPS 2003门户快速备份与恢复的操作手册
  4. 《星际争霸 / StarCraft 》超经典游戏!
  5. 作为一名程序员,谁没跳过槽,“6”招让你“空降”大厂
  6. 全排列的生成算法:字典序法
  7. PHP操作使用Redis
  8. Bug邮件队列插入不了
  9. 怎么排号_春节将至,那些过年不回家的人们都是怎么过年的?
  10. 批量删除txt文档内容命令_Linux@实用操作命令