Dataset的用法简析
之前的文章,稍微讲了一下Estimator的用法,也提到Estimator的数据处理使用的是tf.data这两个模块是Tensorflow初学者必须掌握的内容。现在,就让我们从大的概念入手,来慢慢理解tf.data的用法
转载请注明出处
推荐官方文档:https://tensorflow.google.cn/programmers_guide/datasets
tf.data的作用
在机器学习过程中,对数据的获取、过滤、使用、存储是很重要的一个内容,因为数据可能是不完整的、有杂质的、来源不同的。面对海量数据,我们当然不可能每次都手动整合。Tensorflow框架下,对数据的处理使用的是tf.data,它可以帮助我们以多种方式获取数据、灵活的处理数据和保存数据,使我们能够把更多的精力专注在算法的逻辑上。下面就让我们一起来学习。
tf.data获取数据的方式
这里着着重理解Dataset的概念
Dataset是存储Tensor结构的类,它可以保存一批Tensor结构,以供模型来训练或者测试。这里,Tensor结构是自己定义的,可以有多种格式。
Dataset获取数据的方式有多种,可以从Tensor获取,也可以从另一个Dataset转换而来,我们暂时只讲从Tensor获取。
用到的接口为:
tf.data.Dataset.from_tensor_slices()
这个接口允许我们传递一个或多个Tensor结构给Dataset,因为默认把Tensor的第一个维度作为数据数目的标识,所以要保持数据结构中第一维的一致性,用代码说明一下:
dataset = tf.data.Dataset.from_tensor_slices({"a": tf.random_uniform([4]),"b": tf.random_uniform([4, 100], maxval=100, dtype=tf.int32)})
print(dataset.output_types) # ==> "{'a': tf.float32, 'b': tf.int32}"
print(dataset.output_shapes) # ==> "{'a': (), 'b': (100,)}"
这里包含如下信息:
1、该接口可以接受一个字典变量。实际上,该接口接受任何Iterator
2、第一个维度被认为是数据的数量,可以看到,观察数据的shapes的时候,只显示第一维以后的,为什么呢,因为第一维被认为是数据的数量,所以不参与构成shapes
Dataset输出数据的方式
make_one_shot_iterator迭代器
有进就有出,那么数据怎么从Dataset出来呢,代码如下:
dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(10,3))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()with tf.Session() as sess:for i in range(10):value = sess.run(next_element)print(i, value)output:
0 [ 0.78891609 0.31016679 -2.22261044]
1 [ 3.06918115 0.14014906 0.86654045]
2 [ 2.08348332 0.57866576 -0.66946627]
3 [-1.28344434 1.96287407 0.70896466]
4 [-1.28056116 -0.65352575 0.39975416]
5 [-0.70007014 -0.94034185 1.02308444]
6 [ 0.70819506 -0.56918389 0.75509558]
7 [ 0.26925763 -0.18980865 -0.90350774]
8 [ 1.45644465 -1.13308881 -0.37863782]
9 [ 0.4485246 -0.48737583 -0.40142893]
这里,我们先用numpy生成随机数据并储存在Dataset,之后,是用了dataset.make_one_shot_iterator()
迭代器来读取数据。one_shot迭代器人如其名,意思就是数据输出一次后就丢弃了。
这就构成了数据进出的一种方式,下面,我们多了解几种数据输出的迭代器
make_initializable_iterator 迭代器
可初始化迭代器允许Dataset中存在占位符,这样可以在数据需要输出的时候,再进行feed操作。实验代码如下:
max_value = tf.placeholder(tf.int64, shape=[])
dataset = tf.data.Dataset.range(max_value)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()with tf.Session() as sess:
# Initialize an iterator over a dataset with 10 elements.sess.run(iterator.initializer, feed_dict={max_value: 10})#需要取数据的时候才将需要的参数feed进去for i in range(10):value = sess.run(next_element)assert i == value# Initialize the same iterator over a dataset with 100 elements.sess.run(iterator.initializer, feed_dict={max_value: 100})#feed了不同的参数for i in range(100):value = sess.run(next_element)assert i == value
reinitializable 迭代器
这个迭代器构造方式是根据数据的shapes和type,所以只要shapes和type相同,就可以接受不同的数据源来进行初始化,且可以反复初始化,见以下代码:
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64))
validation_dataset = tf.data.Dataset.range(50)# A reinitializable iterator is defined by its structure. We could use the
# `output_types` and `output_shapes` properties of either `training_dataset`
# or `validation_dataset` here, because they are compatible.
iterator = tf.data.Iterator.from_structure(training_dataset.output_types,training_dataset.output_shapes)
next_element = iterator.get_next()training_init_op = iterator.make_initializer(training_dataset)
validation_init_op = iterator.make_initializer(validation_dataset)with tf.Session() as sess:# Run 20 epochs in which the training dataset is traversed, followed by the# validation dataset.for _ in range(20):#可以重复初始化# Initialize an iterator over the training dataset.sess.run(training_init_op)#每次初始化的数据可以不同for _ in range(100):sess.run(next_element)# Initialize an iterator over the validation dataset.sess.run(validation_init_op)#初始化了另一组数据for _ in range(50):sess.run(next_element)
Iterator.from_string_handle 迭代器
可以看到,reinitializable 已经具有较强的灵活性了,但是它还是每次加载数据都需要重新初始化,有没有可能省掉这一步呢,是可以的,Iterator.from_string_handle通过feed初始化句柄的方式,取得了更高的灵活性,代码如下
# Define training and validation datasets with the same structure.
training_dataset = tf.data.Dataset.range(100).map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64)).repeat()
validation_dataset = tf.data.Dataset.range(50)# A feedable iterator is defined by a handle placeholder and its structure. We
# could use the `output_types` and `output_shapes` properties of either
# `training_dataset` or `validation_dataset` here, because they have
# identical structure.
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()# You can use feedable iterators with a variety of different kinds of iterator
# (such as one-shot and initializable iterators).
training_iterator = training_dataset.make_one_shot_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())# Loop forever, alternating between training and validation.
while True:# Run 200 steps using the training dataset. Note that the training dataset is# infinite, and we resume from where we left off in the previous `while` loop# iteration.for _ in range(200):sess.run(next_element, feed_dict={handle: training_handle})# Run one pass over the validation dataset.sess.run(validation_iterator.initializer)for _ in range(50):sess.run(next_element, feed_dict={handle: validation_handle})
从Dataset初始化Dataset
讲完了如何读取数据,我们再回过头来讲获取数据的另一种方法:从Dataset获取。
之所以这么安排,是因为要结合输出才能理解这种获取方法的意义。
为了从dataset中初始化,这里有三个接口:
Dataset.map
Dataset.flat_map
Dataset.filter
这三个接口从字面上就很好理解,map就是对于给定Dataset中的每一个元素,都执行一次map操作,而flat_map就是既执行了map,还对数据进行了一次扁平化,也就是降维,而filter就是进行了一次过滤, 我们直接从代码的角度看以看这三个接口怎么用。
代码虽然很大一段,需要理解的东西却很少, 首先,我们定义了一个3*2*3的随机多维数组,可以看到
正常的输出就是输出3个2*3的数组
而map的作用我这里写的是各数加一,所以输出的是3个2*3的各数加一的数组
flat_map的作用是降维,所以输出的是6个1*3的数组
filter的作用是过滤,所以输出的是我的过滤内容:[0][0]元素大于0.8的,由于最有一个数组其元素为0.76103773,被过滤掉了
with tf.Session() as sess:np.random.seed(0)#持有种子,使得每次随机出来的数组是一样的normal_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3))np.random.seed(0)map_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).map(map_func=lambda x:x+1)#各数加一np.random.seed(0)flat_map_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).flat_map(map_func=lambda x:tf.data.Dataset.from_tensor_slices(x))#输出的还是原来的x,但是降维了np.random.seed(0)filter_dataset = tf.data.Dataset.from_tensor_slices(np.random.randn(3,2,3)).filter(lambda x:x[0][0] > 0.8)#进行了一次过滤iterator1 = tf.data.Iterator.from_structure(normal_dataset.output_types,normal_dataset.output_shapes)iterator2 = tf.data.Iterator.from_structure(map_dataset.output_types,map_dataset.output_shapes)iterator3 = tf.data.Iterator.from_structure(flat_map_dataset.output_types,flat_map_dataset.output_shapes)iterator4 = tf.data.Iterator.from_structure(filter_dataset.output_types,filter_dataset.output_shapes)next_element1 = iterator1.get_next()next_element2 = iterator2.get_next()next_element3 = iterator3.get_next()next_element4 = iterator4.get_next()training_init_op1 = iterator1.make_initializer(normal_dataset)training_init_op2 = iterator2.make_initializer(map_dataset)training_init_op3 = iterator3.make_initializer(flat_map_dataset)training_init_op4 = iterator4.make_initializer(filter_dataset)print("normal:")sess.run(training_init_op1)for _ in range(3):print(sess.run(next_element1))print("map:")sess.run(training_init_op2)for _ in range(3):print(sess.run(next_element2))print("falt_map:")sess.run(training_init_op3)for _ in range(6):print(sess.run(next_element3))print("filter:")sess.run(training_init_op4)for _ in range(2):print(sess.run(next_element4))output:normal:
[[ 1.76405235 0.40015721 0.97873798][ 2.2408932 1.86755799 -0.97727788]]
[[ 0.95008842 -0.15135721 -0.10321885][ 0.4105985 0.14404357 1.45427351]]
[[ 0.76103773 0.12167502 0.44386323][ 0.33367433 1.49407907 -0.20515826]]
map:
[[ 2.76405235 1.40015721 1.97873798][ 3.2408932 2.86755799 0.02272212]]
[[ 1.95008842 0.84864279 0.89678115][ 1.4105985 1.14404357 2.45427351]]
[[ 1.76103773 1.12167502 1.44386323][ 1.33367433 2.49407907 0.79484174]]#各数加一了
falt_map:
[ 1.76405235 0.40015721 0.97873798]
[ 2.2408932 1.86755799 -0.97727788]
[ 0.95008842 -0.15135721 -0.10321885]
[ 0.4105985 0.14404357 1.45427351]
[ 0.76103773 0.12167502 0.44386323]
[ 0.33367433 1.49407907 -0.20515826]
#这里降维成一维数组了
filter:
[[ 1.76405235 0.40015721 0.97873798][ 2.2408932 1.86755799 -0.97727788]]
[[ 0.95008842 -0.15135721 -0.10321885][ 0.4105985 0.14404357 1.45427351]]#最后一个被过滤掉了
到这里,就基本讲述了一下Dataset的输入输出方法,篇幅有限,这篇博文就到这里,之后会另开一篇,写一写数据的消费等更高级的操作!
Dataset的用法简析相关推荐
- Python中的基本函数及其常用用法简析
分享Python中的基本函数及其常用用法简析,首先关于函数的解释函数是为了达到某种目的而采取的行为,函数是可重复使用的,用来实现某个单一功能或者功能片段的代码块,简单来说就是由一系列的程序语句组成的程 ...
- fuser 用法简析
fuser 用法简析 fuser [功能] fuser 可以显示出当前哪个程序在使用磁盘上的某个文件.挂载点.甚至网络端口,并给出程序进程的详细信息. [描述] fuser显示使用指定文件或者文件 ...
- PrintWriter用法简析
public class PrintWriterextends Writer 向文本输出流打印对象的格式化表示形式.此类实现在 PrintStream 中的所有 print 方法.它不包含用于写入原始 ...
- 【C语言】typedef的用法简析
前言 C语言中typedef可以为一个数据类型定义别名(可以理解为人类的绰号),用来替代基本数据类型,数组类型,指针类型,自定义的结构体类型,枚举类型等等:这样使用可以让我们编码方便.下面我来看它在结 ...
- java入门学习笔记(二)—— Eclipse入门学习之快捷键、java语言基础知识之各类关键字及其用法简析
一.Eclipse入门学习 1. 快捷键 对于一个编辑器,快捷键必不可少,是十分好用且有效的工具. 对于一个初学者,首先掌握了如下快捷键. (很多通用的快捷键不多说) Ctrl + / -- 注释当前 ...
- android之descendantFocusability用法简析
2019独角兽企业重金招聘Python工程师标准>>> listView的Item被抢焦点,这是开发中很常见的一个问题,项目中的listview不仅仅是简单的文字,常常需要自己定义l ...
- 坑爹的属性,android:descendantFocusability用法简析
开发中很常见的一个问题,项目中的listview不仅仅是简单的文字,常常需要自己定义listview,自己的Adapter去继承 BaseAdapter,在adapter中按照需求进行编写,问题就出现 ...
- android:descendantFocusability用法简析
开发中很常见的一个问题,项目中的listview不仅仅是简单的文字,常常需要自己定义listview,自己的Adapter去继承BaseAdapter,在adapter中按照需求进行编写,问题就出现了 ...
- hog函数的用法 python_Python中的基本函数及常用用法简析
函数 解释 函数是为了达到某种目的而采取的行为,函数是可重复使用的,用来实现某个单一功能或者功能片段的代码块,简单来说就是由一系列的程序语句组成的程序段落. 函数存在的意义: 1. 提高代码的复用性 ...
最新文章
- 习题8-6 删除字符 (20 分)
- Promise 简介
- linux 下启动jar小程序
- Angular依赖注入的一个例子和注入原理单步调试
- hdu 5094 Maze
- 多元线性回归分析matlab实验报告,利用MATLAB进行多元线性回归.ppt
- Linux创建逻辑卷
- SPAN Switched Port Analyzer 单臂路由
- junit单元测试报错Failed to load ApplicationContext,但是项目发布到tomcat浏览器访问没问题...
- 模块化方案esl以及amd的依赖方式
- SSH项目搭建-02-配置文件
- Atitti v5住宿服务部后勤部建设指引指南.docx
- 使用小丸工具箱进行极限视频压缩
- EMACS 使用手册
- 通常网站当中的关键词密度如何控制呢
- 基于站长之家(CNZZ)的网站流量统计分析 (以vue代码为例)
- Github开始强制使用PAT(Personal Access Token)了
- 塞规公差带图_塞规公差计算表
- 实时音视频数据传输协议介绍
- 系统架构图编写(概要设计)
热门文章
- Java程序员应该如何提升自己呢
- Linux - samba实现Linux与windows文件共享——共享文件夹目标文件访问权限被拒绝解决方案(超详细,看不懂你怪我)
- KML转geojson在线工具和数据抽稀工具
- 数据驱动 - ddt
- 叶胜超:V SYSTEMS(VSYS)--人人可发币的公链项目!
- 机柜风扇 的组成及如何正确安装 机柜散热风扇
- Web自动化测试-Protractor基础(二)
- 中集集团人工智能企业中集飞瞳,拿产品说话的全球航运港口人工智能高科技独角兽,全球第一家完成200万次人工智能集装箱验箱的AI企业
- 艾美智能影库服务器ip,家庭影院播放器;影库 篇一:艾美影库MS-300 到底怎么样?...
- 会考计算机考试查询成绩查询,会考成绩(学业水平考试成绩查询系统)