转自:https://www.cnblogs.com/hellcat/p/8146748.html#_label0_3

目录

  • 程序介绍

    • 包导入
    • TFRecord录入格式转换
    • TFRecord文件写入测试
    • TFRecord文件读取测试
    • TFRecord文件批量生成
    • TFRecord文件读取测试
  • batch和batch_join的说明
    • 文件准备
    • 单个Reader,单个样本
    • 单个Reader,多个样本
    • 多Reader,多个样本

回到顶部

程序介绍

包导入

1

2

3

4

5

6

7

8

9

10

11

12

# Author : Hellcat

# Time   : 17-12-29

import os

import numpy as np

np.set_printoptions(threshold=np.inf)

import tensorflow as tf

config = tf.ConfigProto()

config.gpu_options.allow_growth = True

sess = tf.Session(config=config)

from tensorflow.examples.tutorials.mnist import input_data

TFRecord录入格式转换

TFRecord的录入格式是确定的,整数或者二进制,在train函数中能查看所有可以接受类型

1

2

3

4

5

6

7

8

def _int64_feature(value):

    """生成整数数据属性"""

    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):

    """生成字符型数据属性"""

    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

TFRecord文件写入测试

将mnist数据以每张图片为单位写入同一个TFR文件,

实际上就是每次把一个图片相关信息都写入,注意文件类型,二级制数据需要以string的格式保存

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

def TFRecord_write():

    """将mnist数据集写入TFR文件"""

    mnist = input_data.read_data_sets('./Data_Set/Mnist_data',

                                      dtype=tf.uint8,one_hot=True)

    images = mnist.train.images

    labels = mnist.train.labels

    pixels = images.shape[1]                     # 784

    num_examples = mnist.train.num_examples      # 55000

    # TFRecords文件地址

    filename = './TFRecord_Output/mnist_train.tfrecords'

    if not os.path.exists('./TFRecord_Output/'):

        os.makedirs('./TFRecord_Output/')

    # 创建一个writer书写文件

    writer = tf.python_io.TFRecordWriter(filename)

    for index in range(num_examples):

        # 提取单张图像矩阵并转换为字符串

        image_raw = images[index].tostring()

        # 将单张图片相关数据写入TFR文件

        example = tf.train.Example(features=tf.train.Features(feature={

            'pixels':  _int64_feature(pixels),

            'label':   _int64_feature(np.argmax(labels[index])),

            'img_raw': _bytes_feature(image_raw)

        }))

        writer.write(example.SerializeToString())  # 序列化为字符串

    writer.close()

调用,

1

2

if __name__=='__main__':

    TFRecord_write()

输出如下,

TFRecord文件读取测试

实际的读取基本单位和存入的基本单位是一一对应的,当然也可以复数读取,但是由于tf后续有batch拼接的函数,所以意义不大

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

def TFRecord_read():

    """从TFR文件读取mnist数据集合"""

    # 创建一个reader读取文件

    reader = tf.TFRecordReader()

    # 创建读取文件队列维护文件列表

    filename_queue = tf.train.string_input_producer(['./TFRecord_Output/mnist_train.tfrecords'])

    # 读取数据

    # 每次读取一个

    # _, serialized_example = reader.read(filename_queue)

    # 每次读取多个

    _, serialized_example = reader.read_up_to(filename_queue,10)

    # 解析样例

    # 解析函数选择必须和上面读取函数选择相一致

    # 解析单个样例

    # features = tf.parse_single_example(

    # 同时解析所有样例

    features = tf.parse_example(

        serialized_example,

        features={

            'img_raw': tf.FixedLenFeature([],tf.string),

            'pixels':    tf.FixedLenFeature([],tf.int64),

            'label':    tf.FixedLenFeature([],tf.int64),

        })

    # 解析二进制数据格式,将之按照uint8格式解析

    images = tf.decode_raw(features['img_raw'],tf.uint8)

    labels = tf.cast(features['label'],tf.int32)

    pixels = tf.cast(features['pixels'],tf.int32)

    batch_size = 2

    capacity = 1000 + 3 * batch_size

    images.set_shape([10,784])

    labels.set_shape(10)

    pixels.set_shape(10)

    image_batch, label_batch, pixel_batch = tf.train.batch(

        [images, labels, pixels], batch_size=batch_size, capacity=capacity)

    # 线程控制器

    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 这里指代的是读取数据的线程,如果不加的话队列一直挂起

    for i in range(10):

        # print(images, labels, pixels)

        # print(sess.run(images))

        image, label, pixel = sess.run([image_batch,label_batch,pixel_batch])

        # image, label, pixel = sess.run([images,labels,pixels])

        print(image.shape,label,pixel)

输出,

拼接batch尺寸为2,每次读取10个数据

可以看到,这里batch尺寸指定的实际上是读取次数

(2, 10, 784)

[[7 3 4 6 1 8 1 0 9 8]
 [0 3 1 2 7 0 2 9 6 0]]

[[784 784 784 784 784 784 784 784 784 784]
 [784 784 784 784 784 784 784 784 784 784]]
……

注意读取数目和解析数目选择的函数是要对应的,

1

2

3

4

5

6

7

8

9

10

11

12

# 读取数据

# 每次读取一个

# _, serialized_example = reader.read(filename_queue)

# 每次读取多个,这里指定10个

_, serialized_example = reader.read_up_to(filename_queue,10)

# 解析样例

# 解析函数选择必须和上面读取函数选择相一致

# 解析单个样例

# features = tf.parse_single_example()

# 同时解析所有样例

features = tf.parse_example()

值得注意的是这句,

1

threads = tf.train.start_queue_runners(sess=sess,coord=coord)

虽然后续未必会调用(coord实际上还是会调用用于协调停止),但实际上控制着队列的数据读取部分的启动,注释掉后会导致队列有出无进进而挂起。

TFRecord文件批量生成

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

def TFR_gen():

    """TFR样例数据生成"""

    # 定义写多少个文件(数据量大时可以写入多个文件加速)

    num_shards = 2

    # 定义每个文件中放入多少数据

    instances_per_shard = 2

    for i in range(num_shards):

        file_name = './TFRecord_Output/data.tfrecords-{}-of-{}'.format(i,num_shards)

        writer = tf.python_io.TFRecordWriter(file_name)

        for j in range(instances_per_shard):

            example = tf.train.Example(features=tf.train.Features(feature={

                'i':_int64_feature(i),

                'j':_int64_feature(j),

                'list':_bytes_feature(bytes([1,2,3]))

            }))

            writer.write(example.SerializeToString())  # 序列化为字符串

        writer.close()

输出如下,

TFRecord文件读取测试

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

def TFR_load():

    """批量载入TFR数据"""

    # 匹配文件名

    files = tf.train.match_filenames_once('./TFRecord_Output/data.tfrecords-*')

    import glob

    # files = glob.glob('./TFRecord_Output/data.tfrecords-*')

    # 载入文件名

    filename_queue = tf.train.string_input_producer(files,shuffle=True)

    reader = tf.TFRecordReader()

    _,serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(

        serialized_example,

        features={

            'i':tf.FixedLenFeature([],tf.int64),

            'j':tf.FixedLenFeature([],tf.int64),

            'list':tf.FixedLenFeature([],tf.string)

        })

    '''

    # tf.train.match_filenames_once操作中产生了变量

    # 值得注意的是局部变量,需要用下面的初始化函数初始化

    sess.run(tf.local_variables_initializer())

    print(sess.run(files))

    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess,coord=coord)

    for i in range(6):

        print(sess.run([features['i'],features['j']]))

    coord.request_stop()

    coord.join(threads)

    '''

    example, label, array = features['i'], features['j'], features['list']

    # 每个batch的中样例的个数

    batch_size = 3

    # 队列中样例的个数

    capacity = 1000 + 3 * batch_size

    suffer = False

    # batch操作实际代指的就是数据读取和预处理操作

    if suffer is not True:

        example_batch, label_batch, array_batch = tf.train.batch(

            [example, label, array], batch_size=batch_size, capacity=capacity)

    else:

        # 不同线程处理各自的文件

        # 随机包含各个线程选择文件名的随机和文件内部数据读取的随机

        example_batch, label_batch, array_batch = tf.train.shuffle_batch(

            [example, label, array], batch_size=batch_size, capacity=capacity,

            min_after_dequeue=30)

    sess.run(tf.local_variables_initializer())

    coord = tf.train.Coordinator()

    threads = tf.train.start_queue_runners(sess=sess, coord=coord)  # 这里指代的是读取数据的线程,如果不加的话队列一直挂起

    for i in range(2):

        cur_example_batch, cur_label_batch, cur_array_batch = sess.run([example_batch, label_batch, array_batch])

        print(cur_example_batch, cur_label_batch, cur_array_batch)

    coord.request_stop()

    coord.join(threads)

注意下面介绍,

1

2

3

# tf.train.match_filenames_once操作中产生了变量

# 值得注意的是局部变量,需要用下面的初始化函数初始化

sess.run(tf.local_variables_initializer())

batch生成的两个函数如下,

1

2

3

4

5

6

7

8

9

10

11

suffer = False

# batch操作实际代指的就是数据读取和预处理操作

if suffer is not True:

   example_batch, label_batch, array_batch = tf.train.batch(

       [example, label, array], batch_size=batch_size, capacity=capacity)

else:

    # 不同线程处理各自的文件

    # 随机包含各个线程选择文件名的随机和文件内部数据读取的随机

    example_batch, label_batch, array_batch = tf.train.shuffle_batch(

        [example, label, array], batch_size=batch_size, capacity=capacity,

        min_after_dequeue=30)

  • 单一文件多线程,那么选用tf.train.batch(需要打乱样本,有对应的tf.train.shuffle_batch)
  • 多线程多文件的情况,一般选用tf.train.batch_join来获取样本(打乱样本同样也有对应的tf.train.shuffle_batch_join使用)

回到顶部

batch和batch_join的说明

文件准备

1

2

3

4

5

6

7

$ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv 

$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv 

$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv 

$ cat A.csv 

Alpha1,A1 

Alpha2,A2 

Alpha3,A3 

单个Reader,单个样本

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

import tensorflow as tf 

# 生成一个先入先出队列和一个QueueRunner 

filenames = ['A.csv', 'B.csv', 'C.csv'

filename_queue = tf.train.string_input_producer(filenames, shuffle=False

# 定义Reader 

reader = tf.TextLineReader() 

key, value = reader.read(filename_queue) 

# 定义Decoder 

example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']]) 

# 运行Graph 

with tf.Session() as sess: 

    coord = tf.train.Coordinator()  #创建一个协调器,管理线程 

    threads = tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。 

    for i in range(10): 

        print example.eval()   #取样本的时候,一个Reader先从文件名队列中取出文件名,读出数据,Decoder解析后进入样本队列。 

    coord.request_stop() 

    coord.join(threads) 

# outpt 

# Alpha1 

# Alpha2 

# Alpha3 

# Bee1 

# Bee2 

# Bee3 

# Sea1 

# Sea2 

# Sea3 

# Alpha1 

单个Reader,多个样本

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

import tensorflow as tf 

filenames = ['A.csv', 'B.csv', 'C.csv']

## filenames = tf.train.match_filenames_once('.\data\*.csv')

filename_queue = tf.train.string_input_producer(filenames, shuffle=False

reader = tf.TextLineReader() 

key, value = reader.read(filename_queue) 

example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']]) 

# 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。Decoder解后数据会进入这个队列,再批量出队。 

# 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。 

example_batch, label_batch = tf.train.batch( 

      [example, label], batch_size=5

with tf.Session() as sess: 

    coord = tf.train.Coordinator() 

    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(10): 

        print example_batch.eval() 

    coord.request_stop() 

    coord.join(threads) 

# output 

# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2'] 

# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1'] 

# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3'] 

# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2'] 

# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1'] 

# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3'] 

# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2'] 

# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1'] 

# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3'] 

# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']

多Reader,多个样本

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

import tensorflow as tf 

filenames = ['A.csv', 'B.csv', 'C.csv'

filename_queue = tf.train.string_input_producer(filenames, shuffle=False

reader = tf.TextLineReader() 

key, value = reader.read(filename_queue) 

record_defaults = [['null'], ['null']] 

example_list = [tf.decode_csv(value, record_defaults=record_defaults) 

                  for _ in range(2)]  # Reader设置为2 

# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。 

example_batch, label_batch = tf.train.batch_join( 

      example_list, batch_size=5

with tf.Session() as sess: 

    coord = tf.train.Coordinator() 

    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(10): 

        print example_batch.eval() 

    coord.request_stop() 

    coord.join(threads) 

# output 

# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2'] 

# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1'] 

# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3'] 

# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2'] 

# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1'] 

# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3'] 

# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2'] 

# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1'] 

# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3'] 

# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2'] 

tf.train.batchtf.train.shuffle_batch'数是单个Reader读取,但是可以多线程。tf.train.batch_join'tf.train.shuffle_batch_join可设置多Reader读取,每个Reader使用一个线程。至于两种方法的效率,单Reader时,2个线程就达到了速度的极限。多Reader时,2个Reader就达到了极限。所以并不是线程越多越快,甚至更多的线程反而会使效率下降。

在这个例子中, 虽然只使用了一个文件名队列, 但是TensorFlow依然能保证多个文件阅读器从同一次迭代(epoch)的不同文件中读取数据,知道这次迭代的所有文件都被开始读取为止。(通常来说一个线程来对文件名队列进行填充的效率是足够的)

另一种替代方案是: 使用tf.train.shuffle_batch 函数,设置num_threads的值大于1。 这种方案可以保证同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件。这种方案的优点是:

  • 避免了两个不同的线程从同一个文件中读取同一个样本。
  • 避免了过多的磁盘搜索操作。

简单来说,

单一文件多线程,那么选用tf.train.batch(需要打乱样本,有对应的tf.train.shuffle_batch)

多线程多文件的情况,一般选用tf.train.batch_join来获取样本(打乱样本同样也有tf.train.shuffle_batch_join)

[转]『TensorFlow』读书笔记_TFRecord学习相关推荐

  1. 『TensorFlow』SSD源码学习_其二:基于VGG的SSD网络前向架构

    Fork版本项目地址:SSD 参考自集智专栏 一.SSD基础 在分类器基础之上想要识别物体,实质就是 用分类器扫描整张图像,定位特征位置 .这里的关键就是用什么算法扫描,比如可以将图片分成若干网格,用 ...

  2. 『算法』读书笔记 1.4算法分析 Part1

    Chapter 1 本章结构 1.1Java语法 1.2数据抽象 1.3集合类抽象数据类型:背包 (Bags) .队列 (Queues) .栈 (Stacks) 1.4算法分析 1.5连通性问题-Ca ...

  3. python 动漫卡通人物图片大全,『TensorFlow』DCGAN生成动漫人物头像_下

    一.计算图效果以及实际代码实现 计算图效果 实际模型实现 相关介绍移步我的github项目. 二.生成器与判别器设计 生成器 相关参量, 噪声向量z维度:100 标签向量y维度:10(如果有的话) 生 ...

  4. 『重构--改善既有代码的设计』读书笔记----序

    作为C++的程序员,我从大学就开始不间断的看书,看到如今上班,也始终坚持每天多多少少阅读技术文章,书看的很多,但很难有一本书,能让我去反复的翻阅.但唯独『重构--改善既有代码的设计』这本书让我重复看了 ...

  5. 『TensorFlow』模型保存和载入方法汇总

    一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称 功能说明 默认值 var_list Saver中存储变 ...

  6. 『TensorFlow』模型载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  7. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  8. 『TensorFlow』函数查询列表_张量属性调整

    博客园 首页 新随笔 新文章 联系 订阅 管理 『TensorFlow』函数查询列表_张量属性调整 数据类型转换Casting 操作 描述 tf.string_to_number (string_te ...

  9. 『TensorFlow』专题汇总

    TensorFlow函数查询 『TensorFlow』0.x_&_1.x版本框架改动汇总 『TensorFlow』函数查询列表_数值计算 『TensorFlow』函数查询列表_张量属性调整 『 ...

最新文章

  1. ES6 -Set 和 Map 数据结构
  2. python 生成html表的报告_pytest文档7-pytest-html生成html报告
  3. echat 图表动态数据生成,渲染,自定义Y轴坐标值
  4. Jerry在2020 SAP全球技术大会的分享:SAP Spartacus技术介绍的文字版
  5. 一个实际的sonar代码检查的配置文件
  6. 【渝粤教育】国家开放大学2019年春季 289法理学 参考试题
  7. Xshell 无法连接虚拟机中的ubuntu的问题
  8. c++ sleep函数_《PHP扩展开发》-hook-(hook原来的sleep)
  9. 集群系统服务器,Web集群服务器及管理系统
  10. (42)Verilog HDL 打两拍设计
  11. JavaEE 13个核心技术规范
  12. 挣多少钱让你觉得生存无忧,有底气做感兴趣的事?
  13. Ubuntu下备份系统的方法
  14. logistic回归分析优点_二元Logistic回归
  15. 电路串联和并联图解_串联和并联的电路图怎么画
  16. 20211003:数字滤波器前置知识,sinc函数与Sa函数
  17. HTML5吃豆豆游戏开发实战(三)2d碰撞检测、重构
  18. 世界疫情实时动态 + pyecharts可视化
  19. Knativa 基于流量的灰度发布和自动弹性实践
  20. vue路由嵌套无法渲染 页面空白

热门文章

  1. python 给文本加下划线_untiy3dUGUI实现text文本下划线
  2. mysql5.5 vsftpd_vsftpd-2.0.5+mysql-5.5+pam_mysql构建虚拟用户访问
  3. php 图片保存到本地文件,php 远程图片保存到本地的函数类
  4. java功能模块_Java 13功能
  5. 无法使用struts2注释_带有注释且没有struts.xml文件的Struts 2 Hello World示例
  6. 熊猫DataFrame from_dict()–字典到DataFrame
  7. sql exists_SQL Exists运算符–终极指南
  8. sql面试题问答题_SQL面试问答
  9. 转:C++中STL用法总结
  10. 找工作?最容易遇到的Java面试题