之前的文章,稍微讲了一下Estimator的用法,也提到Estimator的数据处理使用的是tf.data这两个模块是Tensorflow初学者必须掌握的内容。现在,就让我们从大的概念入手,来慢慢理解tf.data的用法

转载请注明出处

推荐官方文档:https://tensorflow.google.cn/programmers_guide/datasets

tf.data的作用

在机器学习过程中,对数据的获取、过滤、使用、存储是很重要的一个内容,因为数据可能是不完整的、有杂质的、来源不同的。面对海量数据,我们当然不可能每次都手动整合。Tensorflow框架下,对数据的处理使用的是tf.data,它可以帮助我们以多种方式获取数据、灵活的处理数据和保存数据,使我们能够把更多的精力专注在算法的逻辑上。下面就让我们一起来学习。

tf.data获取数据的方式

这里着着重理解Dataset的概念
Dataset是存储Tensor结构的类,它可以保存一批Tensor结构,以供模型来训练或者测试。这里,Tensor结构是自己定义的,可以有多种格式。
Dataset获取数据的方式有多种,可以从Tensor获取,也可以从另一个Dataset转换而来,我们暂时只讲从Tensor获取。
用到的接口为:

tf.data.Dataset.from_tensor_slices()

这个接口允许我们传递一个或多个Tensor结构给Dataset,因为默认把Tensor的第一个维度作为数据数目的标识,所以要保持数据结构中第一维的一致性,用代码说明一下:

dataset = tf.data.Dataset.from_tensor_slices({"a": tf.random_uniform([4]),"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types)  # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes)  # ==> "{'a': (), 'b': (100,)}"

这里包含如下信息:
1、该接口可以接受一个字典变量。实际上,该接口接受任何Iterator
2、第一个维度被认为是数据的数量,可以看到,观察数据的shapes的时候,只显示第一维以后的,为什么呢,因为第一维被认为是数据的数量,所以不参与构成shapes

Dataset输出数据的方式

make_one_shot_iterator迭代器

有进就有出,那么数据怎么从Dataset出来呢,代码如下:

dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(10,3))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()with tf.Session() as sess:for i in range(10):value = sess.run(next_element)print(i, value)output:
0 [ 0.78891609  0.31016679 -2.22261044]
1 [ 3.06918115  0.14014906  0.86654045]
2 [ 2.08348332  0.57866576 -0.66946627]
3 [-1.28344434  1.96287407  0.70896466]
4 [-1.28056116 -0.65352575  0.39975416]
5 [-0.70007014 -0.94034185  1.02308444]
6 [ 0.70819506 -0.56918389  0.75509558]
7 [ 0.26925763 -0.18980865 -0.90350774]
8 [ 1.45644465 -1.13308881 -0.37863782]
9 [ 0.4485246  -0.48737583 -0.40142893]

这里,我们先用numpy生成随机数据并储存在Dataset,之后,是用了dataset.make_one_shot_iterator()迭代器来读取数据。one_shot迭代器人如其名,意思就是数据输出一次后就丢弃了。
这就构成了数据进出的一种方式,下面,我们多了解几种数据输出的迭代器

make_initializable_iterator 迭代器

可初始化迭代器允许Dataset中存在占位符,这样可以在数据需要输出的时候,再进行feed操作。实验代码如下:

max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()with tf.Session() as sess:
# Initialize an iterator over a dataset with 10 elements.sess.run(iterator.initializer, feed_dict={max_value: 10})#需要取数据的时候才将需要的参数feed进去for i in range(10):value = sess.run(next_element)assert i == value# Initialize the same iterator over a dataset with 100 elements.sess.run(iterator.initializer, feed_dict={max_value: 100})#feed了不同的参数for i in range(100):value = sess.run(next_element)assert i == value

reinitializable 迭代器

这个迭代器构造方式是根据数据的shapes和type,所以只要shapes和type相同,就可以接受不同的数据源来进行初始化,且可以反复初始化,见以下代码:

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
next_element = iterator.get_next()training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)with tf.Session() as sess:# Run 20 epochs in which the training dataset is traversed, followed by the# validation dataset.for _ in range(20):#可以重复初始化# Initialize an iterator over the training dataset.sess.run(training_init_op)#每次初始化的数据可以不同for _ in range(100):sess.run(next_element)# Initialize an iterator over the validation dataset.sess.run(validation_init_op)#初始化了另一组数据for _ in range(50):sess.run(next_element)

Iterator.from_string_handle 迭代器

可以看到,reinitializable 已经具有较强的灵活性了,但是它还是每次加载数据都需要重新初始化,有没有可能省掉这一步呢,是可以的,Iterator.from_string_handle通过feed初始化句柄的方式,取得了更高的灵活性,代码如下

# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())# Loop forever, alternating between training and validation.
while True:# Run 200 steps using the training dataset. Note that the training dataset is# infinite, and we resume from where we left off in the previous `while` loop# iteration.for _ in range(200):sess.run(next_element, feed_dict={handle: training_handle})# Run one pass over the validation dataset.sess.run(validation_iterator.initializer)for _ in range(50):sess.run(next_element, feed_dict={handle: validation_handle})

从Dataset初始化Dataset

讲完了如何读取数据,我们再回过头来讲获取数据的另一种方法:从Dataset获取。
之所以这么安排,是因为要结合输出才能理解这种获取方法的意义。
为了从dataset中初始化,这里有三个接口:

Dataset.map
Dataset.flat_map
Dataset.filter

这三个接口从字面上就很好理解,map就是对于给定Dataset中的每一个元素,都执行一次map操作,而flat_map就是既执行了map,还对数据进行了一次扁平化,也就是降维,而filter就是进行了一次过滤, 我们直接从代码的角度看以看这三个接口怎么用。

代码虽然很大一段,需要理解的东西却很少, 首先,我们定义了一个3*2*3的随机多维数组,可以看到
正常的输出就是输出3个2*3的数组
而map的作用我这里写的是各数加一,所以输出的是3个2*3的各数加一的数组
flat_map的作用是降维,所以输出的是6个1*3的数组
filter的作用是过滤,所以输出的是我的过滤内容:[0][0]元素大于0.8的,由于最有一个数组其元素为0.76103773,被过滤掉了

with tf.Session() as sess:np.random.seed(0)#持有种子,使得每次随机出来的数组是一样的normal_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3))np.random.seed(0)map_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).map(map_func=lambda x:x+1)#各数加一np.random.seed(0)flat_map_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).flat_map(map_func=lambda x:tf.data.Dataset.from_tensor_slices(x))#输出的还是原来的x,但是降维了np.random.seed(0)filter_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).filter(lambda x:x[0][0] > 0.8)#进行了一次过滤iterator1 = tf.data.Iterator.from_structure(normal_dataset.output_types,normal_dataset.output_shapes)iterator2 = tf.data.Iterator.from_structure(map_dataset.output_types,map_dataset.output_shapes)iterator3 = tf.data.Iterator.from_structure(flat_map_dataset.output_types,flat_map_dataset.output_shapes)iterator4 = tf.data.Iterator.from_structure(filter_dataset.output_types,filter_dataset.output_shapes)next_element1 = iterator1.get_next()next_element2 = iterator2.get_next()next_element3 = iterator3.get_next()next_element4 = iterator4.get_next()training_init_op1 = iterator1.make_initializer(normal_dataset)training_init_op2 = iterator2.make_initializer(map_dataset)training_init_op3 = iterator3.make_initializer(flat_map_dataset)training_init_op4 = iterator4.make_initializer(filter_dataset)print("normal:")sess.run(training_init_op1)for _ in range(3):print(sess.run(next_element1))print("map:")sess.run(training_init_op2)for _ in range(3):print(sess.run(next_element2))print("falt_map:")sess.run(training_init_op3)for _ in range(6):print(sess.run(next_element3))print("filter:")sess.run(training_init_op4)for _ in range(2):print(sess.run(next_element4))output:normal:
[[ 1.76405235  0.40015721  0.97873798][ 2.2408932   1.86755799 -0.97727788]]
[[ 0.95008842 -0.15135721 -0.10321885][ 0.4105985   0.14404357  1.45427351]]
[[ 0.76103773  0.12167502  0.44386323][ 0.33367433  1.49407907 -0.20515826]]
map:
[[ 2.76405235  1.40015721  1.97873798][ 3.2408932   2.86755799  0.02272212]]
[[ 1.95008842  0.84864279  0.89678115][ 1.4105985   1.14404357  2.45427351]]
[[ 1.76103773  1.12167502  1.44386323][ 1.33367433  2.49407907  0.79484174]]#各数加一了
falt_map:
[ 1.76405235  0.40015721  0.97873798]
[ 2.2408932   1.86755799 -0.97727788]
[ 0.95008842 -0.15135721 -0.10321885]
[ 0.4105985   0.14404357  1.45427351]
[ 0.76103773  0.12167502  0.44386323]
[ 0.33367433  1.49407907 -0.20515826]
#这里降维成一维数组了
filter:
[[ 1.76405235  0.40015721  0.97873798][ 2.2408932   1.86755799 -0.97727788]]
[[ 0.95008842 -0.15135721 -0.10321885][ 0.4105985   0.14404357  1.45427351]]#最后一个被过滤掉了

到这里,就基本讲述了一下Dataset的输入输出方法,篇幅有限,这篇博文就到这里,之后会另开一篇,写一写数据的消费等更高级的操作!

Dataset的用法简析相关推荐

  1. Python中的基本函数及其常用用法简析

    分享Python中的基本函数及其常用用法简析,首先关于函数的解释函数是为了达到某种目的而采取的行为,函数是可重复使用的,用来实现某个单一功能或者功能片段的代码块,简单来说就是由一系列的程序语句组成的程 ...

  2. fuser 用法简析

    fuser 用法简析 fuser [功能]  fuser 可以显示出当前哪个程序在使用磁盘上的某个文件.挂载点.甚至网络端口,并给出程序进程的详细信息. [描述]  fuser显示使用指定文件或者文件 ...

  3. PrintWriter用法简析

    public class PrintWriterextends Writer 向文本输出流打印对象的格式化表示形式.此类实现在 PrintStream 中的所有 print 方法.它不包含用于写入原始 ...

  4. 【C语言】typedef的用法简析

    前言 C语言中typedef可以为一个数据类型定义别名(可以理解为人类的绰号),用来替代基本数据类型,数组类型,指针类型,自定义的结构体类型,枚举类型等等:这样使用可以让我们编码方便.下面我来看它在结 ...

  5. java入门学习笔记(二)—— Eclipse入门学习之快捷键、java语言基础知识之各类关键字及其用法简析

    一.Eclipse入门学习 1. 快捷键 对于一个编辑器,快捷键必不可少,是十分好用且有效的工具. 对于一个初学者,首先掌握了如下快捷键. (很多通用的快捷键不多说) Ctrl + / -- 注释当前 ...

  6. android之descendantFocusability用法简析

    2019独角兽企业重金招聘Python工程师标准>>> listView的Item被抢焦点,这是开发中很常见的一个问题,项目中的listview不仅仅是简单的文字,常常需要自己定义l ...

  7. 坑爹的属性,android:descendantFocusability用法简析

    开发中很常见的一个问题,项目中的listview不仅仅是简单的文字,常常需要自己定义listview,自己的Adapter去继承 BaseAdapter,在adapter中按照需求进行编写,问题就出现 ...

  8. android:descendantFocusability用法简析

    开发中很常见的一个问题,项目中的listview不仅仅是简单的文字,常常需要自己定义listview,自己的Adapter去继承BaseAdapter,在adapter中按照需求进行编写,问题就出现了 ...

  9. hog函数的用法 python_Python中的基本函数及常用用法简析

    函数 解释 函数是为了达到某种目的而采取的行为,函数是可重复使用的,用来实现某个单一功能或者功能片段的代码块,简单来说就是由一系列的程序语句组成的程序段落. 函数存在的意义: 1. 提高代码的复用性 ...

最新文章

  1. 习题8-6 删除字符 (20 分)
  2. Promise 简介
  3. linux 下启动jar小程序
  4. Angular依赖注入的一个例子和注入原理单步调试
  5. hdu 5094 Maze
  6. 多元线性回归分析matlab实验报告,利用MATLAB进行多元线性回归.ppt
  7. Linux创建逻辑卷
  8. SPAN Switched Port Analyzer 单臂路由
  9. junit单元测试报错Failed to load ApplicationContext,但是项目发布到tomcat浏览器访问没问题...
  10. 模块化方案esl以及amd的依赖方式
  11. SSH项目搭建-02-配置文件
  12. Atitti v5住宿服务部后勤部建设指引指南.docx
  13. 使用小丸工具箱进行极限视频压缩
  14. EMACS 使用手册
  15. 通常网站当中的关键词密度如何控制呢
  16. 基于站长之家(CNZZ)的网站流量统计分析 (以vue代码为例)
  17. Github开始强制使用PAT(Personal Access Token)了
  18. 塞规公差带图_塞规公差计算表
  19. 实时音视频数据传输协议介绍
  20. 系统架构图编写(概要设计)

热门文章

  1. Java程序员应该如何提升自己呢
  2. Linux - samba实现Linux与windows文件共享——共享文件夹目标文件访问权限被拒绝解决方案(超详细,看不懂你怪我)
  3. KML转geojson在线工具和数据抽稀工具
  4. 数据驱动 - ddt
  5. 叶胜超:V SYSTEMS(VSYS)--人人可发币的公链项目!
  6. 机柜风扇 的组成及如何正确安装 机柜散热风扇
  7. Web自动化测试-Protractor基础(二)
  8. 中集集团人工智能企业中集飞瞳,拿产品说话的全球航运港口人工智能高科技独角兽,全球第一家完成200万次人工智能集装箱验箱的AI企业
  9. 艾美智能影库服务器ip,家庭影院播放器;影库 篇一:艾美影库MS-300 到底怎么样?...
  10. 会考计算机考试查询成绩查询,会考成绩(学业水平考试成绩查询系统)