Tensorflow知识整理(二)——数据持久化

  • 数据读取的三种方法
  • 数据预处理
    • TFRecord格式介绍
    • 多线程输入数据框架
      • 队列
      • 多线程辅助函数/类
    • 输入文件队列
    • 组合训练数据(Batching)
    • 数据集(Dataset)

数据读取的三种方法

  (1)供给数据(feeding)(2)从文件中读取数据(3)多管线输入

数据预处理

  TensorFlow支持以TFRecord格式存储数据。

TFRecord格式介绍

  TFRecord文件中的数据是通过tf.train.Example Protocol Buffer格式存储的。
【写数据】以tf.train.Example Protocol Buffer格式写入数据
(1)构造写入文件类型的属性
(2)创建一个writer来写TFRecord文件
(3)将TFTFRecord转换成一个example
(4)将一个个example写入TFRecord文件

# 生成整数型的属性
def _int64_feature(value):return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))# 生成字符串型的属性
def _bytes_feature(value):return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))# 输出TFRecord文件地址
filename = ""
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for i in range(train_size):# TFRecord文件中的数据在此存储example = tf.train.Example(features = tf.train.Features(feature = {"key1":_bytes_feature(value1),"key2":_int64_feature(value2)}))# 将一个example写入TFRecord文件writer.write(record = example.SerializeToString())
writer.close()

【读数据】
(1)创建一个reader来读取TFRecord文件中的example
(2)解析读入的一个example
(3)转换需要的数据类型

# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
filename = ""
filename_queue = tf.train.string_input_producer([filename])_, serialized_example = reader.read(filename_queue)
# 解析读入的一个样例
# 多个样例使用tf.parse_example
features = tf.parse_single_example(serialized_example,features = {# 解析结果为一个tensor# 解析数据的格式要跟写入数据的格式一致"key1":tf.FixedLenFeature([], tf.string),"key2":tf.FixedLenFeature([], tf.int64)})
# 转换数据类型
# 字符串解析成图像对应的像素数组
images = tf.decode_raw(features['key1'], tf.uint8)
# tf.int64 -> tf.int32
label = tf.cast(features['key2'], tf.int32)

[注意]
(1)tf.FixedLenFeature()得到的解析结果为Tensor,而tf.VarLenFeature()得到解析结果为SparseTensor
(2)tf.parse_single_example()解析时每次读入一个样例;而tf.parse_example()读入多个样例
(3)解析数据的格式要跟写入数据的格式一致

多线程输入数据框架

队列

  队列与多线程:tf中,队列不仅是一种数据结构,它提供了多线程机制,队列也是TF多线程输入数据处理框架的基础,队列是异步计算张量取值的一个重要机制。
  TF中,队列是计算图上有状态的节点
  修改队列状态的操作:Enqueue、EnqueueMany、Dequeue
【两种队列】tf.FIFOQueue()和tf.RandomShuffleQueue()

  1. tf.FIFOQueue():先入先出队列
init = q.enqueue_many(([0,10],)) # enqueue_many表示一次入队多个元素
#出队
x = q.dequeue()
y = x + 1
# 入队
q_inc = q.enqueue(y)with tf.Session() as sess:init.run()for _ in range(5):v, _ = sess.run([x, q_inc])print(v)
  1. tf.RandomShuffleQueue():会将队列中的元素打乱,每次出队列操作得到的是从当前队列所有元素中随机选择的一个【在训练NN时希望每次使用的训练数据尽量随机】

多线程辅助函数/类

  tf.Coordinator和tf.QueueRunner完成多线程协同功能
【tf.Coordinator类】
  主要用于协同多个线程一起停止,提供了should_stop,request_stop,join三个函数

函数 功能
should_stop 如果值为True,当前线程需要退出
request_stop 线程可以调用该函数,使得should_stop函数的返回值为True,来通知其他线程退出
join 等待线程退出,才能继续往下执行
# 线程中运行的程序
def MyLoop(coord, work_id):while not coord.should_stop():# 随机退出if np.random.rand() < 0.1:print("Stoping from id: %d\n" % work_id)coord.request_stop()else:print("Working on id\n" % work_id)time.sleep(1)
# 创建类协同多个线程
coord = tf.train.Coordinator()
# 产生5个线程
threads = [threading.Thread(target = MyLoop, args = (coord, i, )) for i in range(5)]
# 启动线程
for t in threads:t.start()
#等待线程退出
coord.join(threads)

【tf.QueueRunner类】
  主要用于启动多个线程来操作同一个队列

queue = tf.FIFOQueue(100, "float")enqueue_op = queue.enqueue([tf.random_normal([1])])# 表示需要启动5个线程,每个线程中运行的是enqueue_op操作, 没有指定集合则加入默认集合
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)# 将定义的QueueRunner加入TensorFlow上指定的集合。
# 将刚定义的qr加入默认的tf.GraphKeys.QUEUE_RUNNERS集合
tf.train.add_queue_runner(qr)#定义出队操作
out_tensor = queue.dequeue()
with tf.Session() as sess:coord = tf.train.Coordinator()# 启动线程threads = tf.train.start_queue_runners(sess = sess, coord = coord)# 获取队列中的值for _ in range(3): print(sess.run(out_tensor)[0])# 主线程通知各线程退出coord.request_stop()coord.join(threads)

输入文件队列

  一个TFRecord可以存储多个训练样例, 数据可以分成多个TFRecord文件来提高处理效率,TF提供tf.train.match_filenames_once函数来获取符合一个正则表达式的所有文件。得到的文件列表可以通过tf.train.string_input_producer([filename])管理。
tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, seed=None, capacity=32, shared_name=None, name=None, cancel_op=None)
  使用输入的文件创建一个输入队列,支持多线程操作
  每次调用文件读取函数时,这个函数会从输入队列中出队一个文件并从该文件中读取数据,输入队列会将队列中文件平均分给不同的线程。当一个输入队列中的所有文件都被处理完后,它会将初始化时提供的文件列表中的文件全部重新加入队列中。shuffle参数控制是否随机打乱文件列表中文件出队的顺序。

参数名 参数说明
string_tensor 一维字符串张量(文件名数组)
num_epochs 循环次数,如果指定,则在产生OutOfRange错误前,循环num_epochs次,如不指定,则无限次循环
shuffle 是否在每个epoch内随机打乱顺序seed随机种子
capacity 队列的容量
shared_name 如果设置,则此队列将在多个会话的给定名称下共享。对具有此队列的设备打开的所有会话都可以通过shared_name访问它。在分布式设置中使用它意味着只有能够访问此操作的其中一个会话才能看到每个名称。
name 操作的名称
cancel_op 取消队列的操作
输出:返回一个队列,该队列的一个QueueRunner加入到当前Graph的QUEUE_RUNNER集合中

组合训练数据(Batching)

【目的】将多个样例组织成batch
【函数】tf.train.batch 和tf.train.shuffle_batch
   tf.train.shuffle_batch:不同的线程会读取同一个文件,缺点:文件的样例比较相似
  tf.train.shuffle_batch_join:不同的线程会读取不同的文件,缺点:多线程读取多个文件可能导致过多的硬盘寻址,从而使得读取效率降低。

【功能】生成一个队列,队列的入队操作是生成单个样例的方法,而每次出队得到的是一个batch的样例,唯一的区别在于是否将数据顺序打乱。

import tensorflow as tf############# 完整的输入处理框架 ########################
# 正则匹配文件
files = tf.train.match_filenames_once("data/data.tfrecords-*")
filename_queue = tf.train.string_input_producer(files, shuffle = False)reader = tf.TFRecordReader()
_, Serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(Serialized_example,features = {'image':tf.FixedLenFeature([], tf.string),'label':tf.FixedLenFeature([], tf.int64),'height':tf.FixedLenFeature([], tf.int64),'width': tf.FixedLenFeature([], tf.int64),'channels': tf.FixedLenFeature([], tf.int64)})
image, label = features['image'], features['label']
height, width = features['height'], features['width']
channels = features['channels']
decoded_image = tf.decode_raw(image, tf.uint8)
decoded_image.set_shape = ([height, width, channels])image_size = 229
distorted_image = preprocess_for_train(decoded_image, image_size, image_size, None)# 最小的队列长度
min_after_dequeue = 10000
# 一个batch 中样例的个数
batch_size = 100
# 设置队列的容量
capacity = min_after_dequeue + 3 * batch_size
# 组合样例,不打乱顺序
image_batch, label_batch = tf.train.batch([distorted_image, label], batch_size = batch_size, capacity = capacity)
# 组合样例,打乱顺序
# min_after_dequeue参数限制了出队时队列中元素的最小个数。因为当队列中元素太少时,随机打乱样例顺序的作用就不大了。
image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size = batch_size, capacity = capacity,min_after_dequeue = min_after_dequeue)learning_rate = 0.01
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)with tf.Session() as sess:sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess =sess, coord = coord)TRAINING_ROUNDS = 5000for i in range(TRAINING_ROUNDS):sess.run(train_step)coord.request_stop()coord.join(threads)

数据集(Dataset)

  更高层的数据处理框架,每个数据集代表一个数据来源,数据集也是计算图上的一个节点。
  如果数据集很大,无法装入内存中,所以dataset提供迭代器功能,每次get_next读入下一个输入数据。
【dataset创建】

  1. 从张量中创建
    dataset = tf.data.Dataset.from_tensor_slices(input_data)
  2. 从文本文件中创建
    dataset = tf.data.TextLineDataset(input_files)
  3. 从TFRecord格式中创建数据集,需要传入一个解析器(自定义)
    input_files = tf.data.TFRecordDataset(input_files)
    dataset = dataset.map(parser) #用map使之作用在每个example上

【dataset迭代器】

迭代器 功能
dataset.make_one_shot_iterator() 数据集的参数必须已经确定,所以不需要特别的初始化过程
dataset.make_initializable_iterator() 如果需要用placeholder来初始化数据集
data.make_reinitializable_iterator() 多次initialize,用于遍历不同的数据来源
data.make_feedable_iterator() 可以用feed_dict方式动态指定运行哪个iterator

【shuffle】
dataset = dataset.shuffle(buffer_size) # 随机打乱顺序buffer_size 等效于tf.train.shuffle_batch的min_after_dequeue参数
【batch】
dataset = dataset.batch(batch_size) # 将数据组合成batch
【重复】
dataset = dataset.repeat(N) # 将数据集重复N份,每一份数据称为一个epoch

import tensorflow as tf# 创建数据集
# 从张量中创建数据集
input_data = [1,2,3,5,8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
# 从文本文件中创建数据集
# input_files = ["",""]
# dataset = tf.data.TextLineDataset(input_files)
# 从TFRecord格式中创建数据集,需要传入一个解析器
def parser(record):features = tf.parse_single_example(record, features = {'feat1': tf.FixedLenFeature([], tf.int64)'feat2': tf.FixedLenFeature([], tf.int64)})return features['feat1'], features['feat2']
input_files = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parser)
# iterator = dataset.make_one_shot_iterator()
# feat1, feat2 = iterator.get_next()# 创建一个迭代器用于遍历数据集
iterator = dataset.make_one_shot_iterator()
# get_next()返回一个输入数据的张量
x = iterator.get_next()
y = x * x
with tf.Session() as sess:for i in range(len(input_data)):print(sess.run(y))iterator = dataset.make_initializable_iterator()
with tf.Session() as sess:sess.run(iterator.initializer, feed_dict = {})

Tensorflow知识整理(二)——数据持久化相关推荐

  1. 【机器学习知识整理二】处理分类数据、处理文本、处理日期和时间

    系列文章目录 上一篇:机器学习知识整理一:数据加载.数据整理.数值型数据处理 文章目录 系列文章目录 前言 一.处理分类数据 1. 对nominal型分类特征编码 2. 对ordinal分类特征编码 ...

  2. pytorch基础知识整理(二)数据加载

    pytorch数据加载组件位于torch.utils.data中. from torch.utils.data import DataLoader, Dataset, Sampler 1, torch ...

  3. JavaScript入门知识整理二

    文章目录 一.基础知识 二.事件 示例一.鼠标点击按钮触发add2()函数的执行 示例二.文本框 示例三.页面加载事件(onload) 三.DOM对象,控制HTML元素 示例一.将HTML分解成DOM ...

  4. js事件(Event)知识整理

    鼠标事件 鼠标移动到目标元素上的那一刻,首先触发mouseover  之后如果光标继续在元素上移动,则不断触发mousemove  如果按下鼠标上的设备(左键,右键,滚轮--),则触发mousedow ...

  5. mysql不能持久存储数据的是_数据持久化存储

    一数据持久化存储-csv文件 1.作用 将爬取的数据存放到本地的csv文件中 2.使用流程 1.导入模块2.打开csv文件3.初始化写入对象4.写入数据(参数为列表)importcsv with op ...

  6. unity 一些有用的碎片知识整理 之 二 (之 四 持续更新中...)

    -- 系列文章链接 Unity 一些有用的碎片知识整理 之 一 点击可跳转链接 Unity 一些有用的碎片知识整理 之 三 点击可跳转链接 Unity 一些有用的碎片知识整理 之 四 点击可跳转链接 ...

  7. iOS开发面试知识整理 – OC基础 (二)

    iOS | 面试知识整理 – OC基础 (二) 1.C和 OC 如何混编 xcode可以识别一下几种扩展名文件: .m文件,可以编写 OC语言 和 C 语言代码 .cpp: 只能识别C++ 或者C语言 ...

  8. abap alv新增行数据_ABAP_ALV_最好教程 最全知识整理.doc

    ABAP_ALV_最好教程 最全知识整理 ABAP ALV 知识整理 亿力科技 ABAP开发组 目 录 一.ALV简介3 1.简介3 2.ALV_GRID介绍3 3.其它描述3 二.开发ALV的基本流 ...

  9. 数据库知识整理 - 关系数据库标准语言SQL(一)- SQL概述与数据定义

    主要内容 SQL概述 1. SQL的发展 2. SQL的优点 3. SQL的基本概念 SQL的常用数据类型 数据定义 1. 模式的定义与删除 2. 基本表的定义.删除和修改 3. 索引的建立和删除 S ...

最新文章

  1. ArduinoYun教程之通过网络为Arduino Yun编程
  2. 金山游戏研发改制新进展:计划本周发出正式公告
  3. python语言教程-Python 基础教程
  4. ListView属性设置
  5. 解决 Eclipse 项目有红感叹号的方法
  6. P1407-[国家集训队]稳定婚姻【tarjan,强连通分量】
  7. Modelarts与无感识别技术生态总结(浅出版)
  8. Angular-Observable和RxJS
  9. 2021-02-02 天地图图层类型
  10. Linux-eval命令
  11. Spring整合MyBatis之SqlSession对象的产生
  12. Oracle 系列 统计信息详解(Statistic)
  13. 威廉 哈特 史密斯《当你抚触》
  14. L1-049. 天梯赛座位分配
  15. 在头条号和西瓜视频发布视频,播放量20万,却是零收益?
  16. 如何调用外部webservice 接口来发送短信
  17. Q-M法(列表法)化简 C++ 实现
  18. Java程序 switch语句
  19. 计算机D盘无法读取,电脑d盘打不开怎么办_解决电脑d盘无法打开的方法
  20. 从985非科班到网易伏羲CV算法岗

热门文章

  1. vue项目搭建(二)
  2. 大智慧显示切换服务器,大智慧怎么改界面 大智慧改界面教程
  3. 一些简单的js技术 实现点击 的js隐藏显示
  4. 东方博宜OJ 1265 - 【入门】爱因斯坦的数学题
  5. 流体动力学控制方程(详细推导)
  6. 双十二有哪些数码好物值得入手、双十二必买数码好物清单
  7. 聊天记录:李维、左轻侯、周爱民谈Diamondback
  8. 2022大学生寒假社会实践活动稿件怎样向新闻媒体投稿?
  9. STM32的中断向量表是干什么的?到底有什么用?它放在哪里?
  10. RRDTool (比较全)