文章目录

  • 数据并行读取
    • 创建文件名列表
    • 创建文件名队列
    • 创建Reader & Decoder
      • CSV file
      • TFRecords file
      • Any format file
    • 创建样例队列
  • 创建批样例数据
  • 填充数据节点
  • CIFAR-10数据集示例

用户处理数据集的典型流程是:

  1. 输入数据集从文件系统读取到内存中;
  2. 将其转换为模型所需要的格式;
  3. 以某种形式传入到数据流图中,开始模型训练。

一般采用的数据读取方式有两种:

  • Large-scale Dataset:一般由由大量数据文件组成。因为数据规模太大,所以无法一次性全部加载到内存。但是,如果每进行一步模型训练就加载一次所需的batch data,这将阻塞模型的训练过程。为了减小数据读取对模型训练效率的影响,常用的方法是通过多线程并行读取数据。
  • Tiny Dataset:由较少数据文件组成,能够在数据模型开始前一次性将所有data load到memory中。

数据并行读取

large scale dataset一般无法一次性加载到内存中进行处理,如ImageNet数据集。当处理如此规模的数据集时,TF提供了以输入流水线方式从多个文件中并行读取数据的方法,这使得模型训练时能够有充足的数据能够feed to graph。
它的主要步骤如下:

  1. 创建文件名列表;
  2. 创建文件名队列;
  3. 创建Reader & Decoder;
  4. 创建样例队列;

下图展示了完整的并行数据读取流水线:

上图样例队列和批样例队列中的元素名,指示哪个文件中的第几个数据记录(即第几行)。

理解并行流水线数据读取过程的关键是掌握文件名队列和样例队列。文件名队列为读取数据文件提供了一个缓冲区,样例队列为数据feed to graph提供了缓冲区。

创建文件名列表

文件名列表是指输入数据集中所有文件的名称所构成的列表。列表中的元素可能是在本地文件系统上的文件位置,也可能是共享文件系统或分布式文件系统上的统一资源标识符(URI)。

两种创建文件名列表的方法:

  • python list:如果文件名的个数不多,或者文件命名遵循rules,那么用户可以直接使用list存储文件名。
  • tf.train.match_filenames_once():该方法在graph创建了一个获取文件名列表的操作,它的输入是文件名列表的匹配模式,输出是一个存储了符合该匹配模式的文件名列表variable。在初始化全局变量是,该文件名列变量也会被初始化。

创建文件名队列

一般使用tf.train.string_input_producer()创建文件名队列,它的输入是之前创建的文件名列表,输出是一个先入先出的queue。

epoch:完整遍历一次输入数据集即为一个epoch。
训练模型需要反复遍历整个输入数据集,以不断更新模型参数。

用户可以通过tf.train.string_input_producer()的输入参数num_epoches设置模型的最大训练周期数。但每次新的epoch,我们希望模型数据顺序是变化的,以防止模型过拟合。因此,可以设置tf.train.string_input_producer()的输入参数 shuffle = True,此时程序就可以打乱每个epoch内的文件名顺序。

tf.train.string_input_producer()所有输入参数:

  • string_tensor:存储文件名列表的字符串张量
  • num_epochs:最大训练周期
  • shuffle:是否打乱文件顺序
  • seed:随机化种子,用于文件打乱
  • capacity:filename queue的容量(长度
  • shareed_name:多个sessions见共享的文件名队列
  • name:创建文件名队列操作的名称
  • cancel_op:取消队列操作

创建Reader & Decoder

Reader的功能就是读取数据,Decoder的功能是将数据转换为张量格式。两者都与数据文件格式有关。下表给出了TF推荐的三种数据文件格式及其对应Reader&Decoder。

文件格式 Reader类型 Decoder类型
CSV file tf.TextLineReader() tf.decode_csv
TFRecords file tf.TFRecordReader() tf.parse_single_example
自由格式文件 tf.FixedLengthRecordReader() tf.decode_raw

一般流程:首先创建数据文件对应的Reader,然后从文件名队列中取出文件名,并传入Reader.read()方法,最后使用对应的Decoder将数据记录中的每一列数据都转换为张量格式。

CSV file

字符分隔值(Comma-Seperated Values, CSV)文件是以纯文本形式存储表格数据。CSV的一般标准是:

CSV由多条数据记录组成,数据记录之间以某种换行符进行分隔。每条记录由多个字段组成,字段间通常以制表符或逗号分隔。所有记录拥有相同的字段序列格式。

以读取多个记录收入支出表(file1.csv & file2.csv)为例,展示TF如何读取CSV file。其中部分表格内容如下所示:

year month income outgo
2021 1 40000 20000
2021 2 42000 19000

每条数据记录包含四个字段:year, month, income, outgo。示例代码如下:

# create filename queue
filename_list = ['file1.csv', 'file2.csv']
filename_queue = tf.train.string_input_producer(filename_list)
# create Reader for csv file
reader = tf.TextLineReader()
# read one row from csv file
_, value = reader.read(filename_queue)
# setting default value
record_defaults = [[2021], [0], [0.0], [0.0]]
# transfer data to Tensor with tf.decode_csv
year, month, income, outgo = tf.decode_csv(value, record_defaults)features = tf.stack([year, month, income, outgo])

注,Reader.read()方法只能读取一行数据记录,Reader.read_up_to()方法能够一次读取多条数据,通过设置其num_records参数,可以显式地制定一次读取的数据记录数量。

tf.decode_csv()方法中的record_defaults参数,是为了给数据记录中的某些不合法或不存在的字段填充默认值,以确保程序正常执行。注意,以上的Reader.read()方法还有tf.decode_csv()方法返回的都是graph ops,而不是real data,用户需要通过session.run()才能获得data。

TFRecords file

TFRecords文件存储的是有结构的序列化字符块,他是TF推荐的standard file format。但是一般数据集的annotations都没有采用该类数据格式,因此我们在这里不做更多介绍。

Any format file

自由格式文件是用户自定义的二进制文件。它的存储对象是字符串,每条数据记录都是一个固定长度的字节块。因此如果要想正确识别和转换二进制文件中的数据记录,必须使用tf.FixedLengthRecordReader()读取二进制文件中固定长度的字节块,然后使用tf.decode_raw()方法将读取的字符串转换为张量。tf.FixedLengthRecordReader()tf.TextLineReader()均继承自ReaderBase类,都支持一次读取多条记录的方法。

tf.decode_raw()方法的功能是将字符串转换为张量,其prototype如下所示:

tf.decode_raw(bytes, out_type, little_endian=None, name=None)

创建样例队列

在“CSV file”小结,我们得到了year, month, income, outgo四个特征张量。在会话执行时,为了使计算机任务顺利获得输入数据,我们需要使用tf.train.start_queue_runners()方法启动执行入队操作的所有线程,具体包括文件名入队到filename_queue的操作,样例入队到样例队列的操作。

示例代码如下:

init_op = tf.global_variables_initializer()
with tf.Session() as sess:sess.run(init_op)# 启动所有执行入队操作的后台线程tf.train.start_queue_runners(sess=sess)for i in range(2):example = sess.run(features)print(example)

上述代码中的features是之前创建的数据读取、解析的操作。同时上述代码并不适合生产环境,因为其较差的容错性:无人管理队列操作后台线程的生命周期,任何线程出现异常都会导致程序崩溃。为了解决该问题,可以使用tf.train.Coordinator()方法构建管理多线程生命周期的协调器。它会监控TF所有后台线程,但其中某个线程出现异常时,Coordinator.should_stop()将返回True,使for循环结束。然后执行finally中Coordinator.request_stop()方法,请求所有线程安全退出。

需要注意的是,使用Coordinator管理multi-threads之前,需要先执行tf.local_variables_initializer()方法对其进行初始化。所以使用tf.group()方法将tf.local_variables_initializer()tf.global_variables_initializer()聚合生成整个程序的初始化操作init_op。

示例代码如下:

import tensorflow as tf
# create filename_queue and setting epochs = 5
filename_list = ['file1.csv', 'file2.csv']
filename_queue = tf.train.string_input_producer(filename_list , num_epochs=5)...# aggregate local and global initialization ops
init_op = tf.group(tf.local_variables_initializer(),tf.global_variables_initializer())with tf.Session() as sess:init_op.run()coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess, coord=coord)print("=> Threads: ", threads)try:for i in range(10):if not coord.should_stop():example = sess.run(features)print(example)except tf.errors.OutofRangeError:print('=> Catch OutofRangeError')finally:# request to stop all the threads in backgroundcoord.request_stop()print('=> Finish reading ...')coord.join(threads)

创建批样例数据

通过上节内容,我们成功获得了数据样例,但是需要将这些样例聚合成批数据才能用于模型训练、评估和推理使用。TF提供的tf.train.shuffle_batch()方法不仅能够使用样例创建批数据,而且能够在打包过程中打乱样例顺序。增加随机性。

示例代码:

filename_queue = ...
examples = ...# batch queue settings
batch_size = 16
min_after_dequeque = 10000 # 样例队列中出队的样例个数
capacity = min_after_dequeque + 3 * batch_size # 批数据队列容量
# create batch queue
batch_queue = tf.train.shuffle_batch([examples], batch_size=batch_size, capacity=capacity, min_after_dequeque=min_after_deque)

tf.train.shuffle_batch()除了上面使用的参数外,常用的还有设置线程个数的num_threads参数,设置随机化种子的seed参数,以及设置多条样例入队的enqueue_many参数。

填充数据节点

使用批数据训练的模型基本上都是用feed数据节点的方法,他不需要读取完整的数据集,有效减少了内存开销。同时,基于并行输入流水线的数据读取方法保证了实时性,与将全部数据预加载到内存中,对训练结果没有明显差距。

CIFAR-10数据集示例

CIFAR-10数据集总共包含60000张32x32x3的图像,图片总共有10类,每一类6000张图片, 下载地址。整个数据集被分为6个批数据,每一批数据包含1W张图片,其中5W张用于模型训练,1W张用于模型测试。

在CIFAR-10数据集中,一条数据记录由类别标签和图像数据两部分组成。单张图片需要3072个字节,类别标签1个字节。因此,CIFAR-10数据集中单条记录占用3073个字节,他们以二进制数据文件格式存储。

示例代码:

import tensorflow as tf# label length
LABEL_BYTES = 1
# image size
IMAGE_SIZE = 32
# image channel
IMAGE_CHANNEL = 3
# image data length
IMAGE_BYTES = IMAGE_SIZE * IMAGE_SIZE * IMAGE_CHANNEL
# classes num
NUM_CLASSES = 10def read_cifar10(data_file, batch_size):"""input params:data_file: CIFAR-10 data filebatch_size: the size of batchreturns:images: images batch following format [batch_size, IMAGE_SIZE, IMAGE_SIZE]labels: labels batch following format [batch_size, NUM_CLASSES]"""record_bytes = LABEL_BYTES + IMAGE_BYTES# create filename listdata_files = tf.gfile.Glob(data_file)

《深入了解TensorFlow》笔记——Chapter 4.1 输入数据集相关推荐

  1. 中国大学MOOC-人工智能实践:Tensorflow笔记-课程笔记 Chapter6

    本篇博客为学习中国大学MOOC-人工智能实践:Tensorflow笔记课程时的个人笔记记录.具体课程情况可以点击链接查看.(这里推一波中国大学MOOC,很好的学习平台,质量高,种类全,想要学习的话很有 ...

  2. tensorflow笔记:多层CNN代码分析

    tensorflow笔记系列:  (一) tensorflow笔记:流程,概念和简单代码注释  (二) tensorflow笔记:多层CNN代码分析  (三) tensorflow笔记:多层LSTM代 ...

  3. tensorflow笔记:流程,概念和简单代码注释

    tensorflow是google在2015年开源的深度学习框架,可以很方便的检验算法效果.这两天看了看官方的tutorial,极客学院的文档,以及综合tensorflow的源码,把自己的心得整理了一 ...

  4. 人工智能实践:TensorFlow笔记学习(八)—— 卷积神经网络实践

    大纲 7.1  复现已有的卷积神经网络 7.2  用vgg16实现图片识别 目标 掌握复现已有网络,用vgg16实现图片识别 7.1  复现已有的卷积神经网络 VGGNet是Karen simonya ...

  5. 人工智能实践:TensorFlow笔记学习(六)—— 全连接网络实践

    输入手写数字输出识别结果 大纲 6.1 输入手写数字图片输出识别结果 6.2 制作数据集 目标 1.实现断点续训 2.输入真实图片,输出预测结果 3.制作数据集,实现特定应用 6.1  输入手写数字图 ...

  6. 人工智能实践:TensorFlow笔记学习(五)—— 全连接网络基础

    MNIST数据集输出手写数字识别准确率 大纲 5.1 MNIST数据集 5.2 模块化搭建神经网络 5.3 手写数字识别准确率输出 目标 利用MNIST数据集巩固模块化搭建神经网路的八股,实践前向传播 ...

  7. 人工智能实践:TensorFlow笔记学习(四)—— 神经网络优化

    神经网络优化  大纲 4.1 损失函数 4.2 学习率 4.3 滑动平均 4.4 正则化 4.5 神经网络搭建八股 目标 掌握神经网络优化方法 4.1 损失函数 神经元模型:用数学公式表示为:,f为激 ...

  8. 人工智能实践:TensorFlow笔记学习(三)——TensorFlow框架

    搭建神经网络 大纲 3.1 张量.计算图.会话 3.2 前向传播 3.3 反向传播 目标 搭建神经网络,总结搭建八股 3.1 张量.计算图.会话 一.基本概念 基于Tensorflow的NN:用张量表 ...

  9. 人工智能实践:TensorFlow笔记学习(二)—— Python语法串讲

    Python语法串讲  大纲 2.1 Linux指令.HelloWorld 2.2 列表.元组.字典 2.3 条件语句 2.4 循环语句 2.5 turtle模块 2.6 函数.模块.包 2.7 类. ...

最新文章

  1. Algorithm之RS:RS常用的一些库
  2. IQueryable 和 IEnumerable
  3. UVa11205 The Broken Pedometer
  4. 上传SVN丢失.a文件的问题
  5. jacoco入门_Android jacoco 代码覆盖率测试入门
  6. 【蓝桥杯单片机】NE555在CT107D上的使用
  7. C语言 gcc 动态库
  8. 画对数幅频曲线_耳机频响曲线如何看(中)--耳机和音箱对频响和失真的要求的差异...
  9. prepared statement mysql_MySQL之 Statement实现及PreparedStatement实现
  10. 重磅丨继人工智能大会后《AIOS链上人工智能白皮书》正式发布,核心应用即将开启
  11. 国内交易平台关闭了,教你如何把Zcash(zec)兑换成人民币
  12. a标签去掉下划线以及字体颜色
  13. 计算机网络技术广告,屏蔽QQ广告和迷你首页广告
  14. [Office] 公务员WPS Excel常用的一些技巧方法
  15. 《算法笔记》——基础篇习题选择结构
  16. uniapp授权登录微信支付宝小程序获取code和基础信息
  17. c语言中,exit(1)是什么意思?
  18. CVPR 2022|精准高效估计多人3D姿态,美图北航联合提出分布感知式单阶段模型...
  19. 五证 两书 三表
  20. 入坑esp-01s 1.3寸OLED时钟及天气显示(二)

热门文章

  1. 18. Ubuntu弹出移动硬盘
  2. 28推精英会专访IT博主卢松松
  3. Bitly缩短网址服务 - Blog透视镜
  4. ffmpeg-日志输出av_log()
  5. 在.NET中使用DirectShow
  6. 安卓获取摄像头帧率_一种基于android终端提高USB2.0摄像头高分辨率高帧率的方法_2015109767475_说明书_专利查询_专利网_钻瓜专利网...
  7. 连接器类型vh_连接器中的类型如PH、XH、SM等都是什么意思?
  8. 机器学习2022笔记(一)—— 机器学习相关规定
  9. pygame 飞机大战子弹的编写(三)自定义子弹位置、速度、角度
  10. BDG邦德外汇:财富的时间载体