Tensorflow 程序读取数据一共有3种方法:

  • 供给数据(feeding):在程序运行的每一步,让Python代码来供给数据
  • 从文件读取数据: 让一个输入管线从文件中读取数据
  • 预加载数据:在tensorflow图中定义常量或变量来保存所有数据(适用于数据量小的时候)

一个典型的文件读取管线会包含下面这些步骤:

  1. 文件名列表
  2. 可配置的 文件名乱序(shuffling)
  3. 可配置的 最大训练迭代数(epoch limit)
  4. 文件名队列
  5. 针对输入文件格式的阅读器
  6. 纪录解析器
  7. 可配置的预处理器
  8. 样本队列
以下以tensorflow/models/image/cifar10/cifar10_input.py 为例分步说明
1.得到文件名列表
[python] view plain copy
  1. filenames=[os.path.join(data_dir,'data_batch_%d.bin'%i) for i in range(1,6)] #得到一个文件名列表
  2. for  f in filenames:
  3. if not tf.gfile.Exists(f):
  4. raise ValueError('Failed to find file: '+ f)

此处用list表示文件名列表,然后依次检验文件是否存在,以抛出异常

2.将文件名列表交给tf.train.string_input_producer函数,得到一个先入先出的队列(Queue),文件阅读器会需要它来读取数据

其中可配置参数中有shuffle,是bool值,判断要不要用乱序操作

[python] view plain copy
  1. filename_queue=tf.train.string_input_producer(filenames)#生成一个先入先出队列,需要用文件阅读器来读取其数据

3.得到文件名队列后,针对输入文件格式,创建阅读器进行读取

例如:若从CSV文件中读取数据,需要使用TextLineReader和decode_csv来进行读取和解码

若是CIFAR-10 dataset文件,因为每条记录的长度固定,一个字节的标签+3072像素数据

所以此处采用FixedLengthRecordReader()和decode_raw来进行读取和解码

每次read的执行都会从文件中读取一行内容, decode_csv 操作会解析这一行内容并将其转为张量列表。如果输入的参数有缺失,record_default参数可以根据张量的类型来设置默认值。

在调用run或者eval去执行read之前, 你必须调用tf.train.start_queue_runners来将文件名填充到队列。否则read操作会被阻塞到文件名队列中有值为止。(程序挂起就死掉了,只能强制结束进程)

tf.transpose 将张量的维度变化 ,

1.tensorflow 里面的 tensor是先从高维向低维算起的 
比如:’x’ is [[[1 2 3] 
[4 5 6]] 
[[7 8 9] 
[10 11 12]]] 
x的维数就是[2,2,3] 
而perm的意思就是将tensor对应的不同的维数之间变换 
比如perm = [2,1,0],则原tensor = [3, 2,2],全部倒过来 
perm = [0,2,1],则原tensor = [2,3,2], 后两维置换

tf.transpose(x, perm=[0, 2, 1]) ==> 
[[[1 4] 
[2 5] 
[3 6]] 
[[7 10] 
[8 11] 
[9 12]]]

而不是
[[[1 3 2] 
[4 6 5]] 
[[7 9 8] 
[10 12 11]]]

本例中将读取数据单独写到一个函数中

[python] view plain copy
  1. <strong>def read_cifar10(filename_queue):</strong>
  2. """Reads and parses(解析) examples from CIFAR10 data files
  3. Args:
  4. filename_queue:A queue of strings with the filenames to read from
  5. Returns:
  6. An object representing a single example, with the following fields:
  7. height:行数32
  8. width:列数32
  9. depth:颜色通道数3
  10. key:a scalar string Tensor describing the filename & record number for this example
  11. label: an int32 Tensor with the label in the range 0~9
  12. uint8 image: a [height, width, depth] uint8 Tensor with the image data
  13. """
  14. class CIFAR10Record:
  15. pass
  16. result=CIFAR10Record()
  17. #CIFAR10数据库中图片的维度
  18. label_bytes=1 #2 for CIFAR-100
  19. result.height=32
  20. result.width=32
  21. result.depth=3
  22. image_bytes=result.height*result.width*result.depth
  23. #每个记录都由一个字节的标签和3072字节的图像数据组成,长度固定
  24. record_bytes=label_bytes+image_bytes
  25. #read a record, getting filenames from the filename_queue
  26. reader=tf.FixedLengthRecordReader(record_bytes=record_bytes)
  27. result.key,value=reader.read(filename_queue)#注意这里read每次只读取一行!
  28. <span style="white-space:pre">      </span>
  29. #Convert from a string to a vector of uint8 that is record_bytes long
  30. record_bytes=tf.decode_raw(value,tf.uint8)#decode_raw可以将一个字符串转换为一个uint8的张量
  31. #The first bytes represent the label, which we convert from uint8->int32
  32. result.label=tf.cast(tf.strided_slice(record_bytes,[0],[label_bytes]),tf.int32)
  33. #将剩下的图像数据部分reshape为【depth,height,width】的形式
  34. depth_major=tf.reshape(tf.strided_slice(record_bytes,[label_bytes],[label_bytes+image_bytes]),[result.depth,result.height,result.width])
  35. #from【depth,height,width】to【height,width,depth】
  36. result.uint8image=tf.transpose(depth_major,[1,2,0])
  37. return result #返回的是一个类的对象!

read_cifar10返回了一个训练样本,包括result.label和reaule.uint8image两个数据成员

4.预处理

针对输入进来的一个样本,进行的预处理可以使加噪,添加失真,翻转等

[python] view plain copy
  1. read_input=read_cifar10(filename_queue)
  2. reshaped_image=tf.cast(read_input.uint8image,tf.float32)
  3. height=IMAGE_SIZE
  4. width=IMAGE_SIZE
  5. #Image processing for training the network. Note the many random
  6. #distrotions applied to the image  预处理图片,添加噪声,失真等。
  7. #Randomly crop(裁剪) a [height,width]section of the image
  8. distorted_image=tf.random_crop(reshaped_image,[height,width,3])
  9. #随机水平翻转图片,50%的概率
  10. distorted_image=tf.image.random_flip_left_right(distorted_image)
  11. #在图片的RGB加上一个随机的delta,delta<=max_dalta
  12. distorted_image=tf.image.random_brightness(distorted_image,max_delta=63)
  13. #contrast就是对图片的RGB三个channel,分别算出整张图片的mean,然后分别对相应channel的每一个像素做
  14. #x = (x-mean) * contrast_factor + mean 其中对于random_contrast函数,contrast_factor随机取自
  15. #[lower, upper]
  16. distorted_image=tf.image.random_contrast(distorted_image,lower=0.2,upper=1.8)
  17. #减去均值像素,并除以像素方差 (图片标准化)
  18. float_image=tf.image.per_image_standardization(distorted_image)
  19. #set the shapes of tensors
  20. float_image.set_shape([height,width,3])
  21. read_input.label.set_shape([1])
  22. #确保随机乱序有好的混合效果
  23. min_fraction_of_examples_in_queue=0.4
  24. min_queue_examples= int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN* min_fraction_of_examples_in_queue)
  25. print('Filling queue with %d CIFAR images before starting to train.'% min_queue_examples+'This will take a few minutes.')
  26. #Generate a batch of images and labels by building up a queue of examples
  27. return _generate_image_and_label_batch(float_image,read_input.label,min_queue_examples,batch_size,shuffle=True)

5.得到样本队列

在数据输入管线的末端,我们需要有另一个队列来执行输入样本的training, evaluation, inference,要对样本进行批处理

所以使用tf.train.shuffle_batch函数用16个不同的reader读到的样本组成batch返回

[python] view plain copy
  1. def _generate_image_and_label_batch(image,label,min_queue_examples,batch_size,shuffle):
  2. """
  3. Construct a queued batch of images and labels.
  4. :param image: 3-D Tensor of[Height, width,3] of type.float32
  5. :param label:1-D Tensor of type.int32
  6. :param min_queue_examples: int32,minimum number of samples to retain in the queue that provides of batches of examples
  7. :param batch_size: Number of images per batch
  8. :param shuffle: boolean indicating whether to use shuffling queue (乱序)
  9. :return:
  10. images: Images. 4D tensor of [batch_size,height,width,3]size
  11. labels: Labels.1D tensor of [batch_size]size
  12. """
  13. #Create a queue that shuffles the examples, and then
  14. #read 'batch_size' images +labels from the example queue\
  15. <strong>  num_preprocess_threads=16</strong>
  16. #读取图片加上预处理要花费不少时间,所以我们在16个独自线程上运行它们,which fill a TensorFlow queue
  17. #这种方案可以保证同一时刻只在一个文件中进行读取操作(但读取速度依然优于单线程),而不是同时读取多个文件
  18. #优点是:
  19. #避免了两个不同的线程从同一个文件中读取同一个样本
  20. #避免了过多的磁盘搜索操作
  21. if shuffle:
  22. #创建bathces of ‘batch_size’个图片和'batch_size'个labels
  23. images,label_batch=tf.train.shuffle_batch(
  24. [image,label],
  25. batch_size=batch_size,
  26. num_threads=num_preprocess_threads,
  27. capacity=min_queue_examples+3*batch_size,#capacity必须比min_after_dequeue大
  28. min_after_dequeue=min_queue_examples) #min_after_dequeue 定义了我们会从多大的buffer中随机采样
  29. #大的值意味着更好的乱序但更慢的开始,和更多内存占用
  30. else: #不乱序
  31. images,label_batch=tf.train.batch(
  32. [image,label],
  33. batch_size=batch_size,
  34. num_threads=num_preprocess_threads,
  35. capacity=min_queue_examples+3*batch_size)
  36. #display the training images in the visualizer
  37. tf.summary.image('images',images)
  38. return images,tf.reshape(label_batch,[batch_size])

线程和队列

队列就是tensorFlow图中的节点,这是一种有状态的节点,就像变量一样,其他节点可以修改它的内容。

具体来说,其他节点可以把新元素插入到队列后端(rear),也可以把前端元素删除

队列的使用:

队列类型有先进先出(FIFO Queue),或者是随机的(RandomShuffleQueue)

FIFO Que

创建一个先进先出队列,以及一个“出队,+1,入队”操作:

[python] view plain copy
  1. import tensorflow as tf
  2. #创建的图:一个先入先出队列,以及初始化,出队,+1,入队操作
  3. q = tf.FIFOQueue(3, "float")
  4. init = q.enqueue_many(([0.1, 0.2, 0.3],))
  5. x = q.dequeue()
  6. y = x + 1
  7. q_inc = q.enqueue([y])
  8. #开启一个session,session是会话,会话的潜在含义是状态保持,各种tensor的状态保持
  9. with tf.Session() as sess:
  10. sess.run(init)
  11. for i in range(2):
  12. sess.run(q_inc)
  13. quelen =  sess.run(q.size())
  14. for i in range(quelen):
  15. print (sess.run(q.dequeue()))

输出结果:

0.3

1.1

1.2

注意先入先出的规则!

一个典型的输入结构:是使用一个RandomShuffleQueue来作为模型训练的输入,多个线程准备训练样本,并且把这些样本推入队列,一个训练线程执行一个训练操作,此操作会从队列中移出最小批次的样本(mini-batches)

之前的例子中,入队操作都在主线程中进行,Session中可以多个线程一起运行。 在数据输入的应用场景中,入队操作从硬盘上读取输入,放到内存当中,速度较慢。 使用QueueRunner可以创建一系列新的线程进行入队操作,让主线程继续使用数据。如果在训练神经网络的场景中,就是训练网络和读取数据是异步的,主线程在训练网络,另一个线程在将数据从硬盘读入内存。

再举一个例子:

[python] view plain copy
  1. import tensorflow as tf
  2. import sys
  3. q=tf.FIFOQueue(1000,"float")
  4. #计数器
  5. counter=tf.Variable(0.0)
  6. #操作:给计数器加一
  7. increment_op=tf.assign_add(counter,tf.constant(1.0))
  8. #操作:将计数器加入队列
  9. enqueue_op=q.enqueue(counter)
  10. #创建一个队列管理器QueueRunner,用这两个操作向q中添加元素,目前我们只使用一个线程:
  11. qr=tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*1)
  12. #主线程:
  13. sess=tf.Session()
  14. sess.run(tf.global_variables_initializer())
  15. enqueue_threads=qr.create_threads(sess,start=True) #启动入队线程
  16. #主线程:
  17. for i in range(0,5):
  18. print(sess.run(q.dequeue()))

结果是:2.0
172.0
225.0
272.0
367.0

并不是如普通计数器一样的1,2,3,4,5,原因就是+1操作和入队操作不同步!可能+1操作执行了很多次之后,才会进行一次入队操作,并且出队结束后,本应程序要结束,但是因为入队线程没有显示结束,所以,整个程序就跟挂起一样,也结束不了。

Tensorflow 的session对象是支持多线程的,因此多个线程可以很方便地使用同一个会话(session),并且并行地执行操作。

然而,在Python程序实现并行运算并不容易,所有线程都必须被同步终止,异常必须能被正常地捕获并报告,会话终止的时候,队列必须能被正确地关闭。

所幸TensorFlow提供了两个类来帮助多线程的实现:tf.Coordinator和 tf.QueueRunner。从设计上这两个类必须被一起使用。Coordinator类可以用来同时停止多个工作线程并且向那个在等待所有工作线程终止的程序报告异常。QueueRunner类用来协调多个工作线程同时将多个张量推入同一个队列中。

使用tf.train.Coordinator来终止其他线程,Coordinator类主要有如下几个方法:

  • should_stop():如果线程应该停止则返回True
  • request_stop(<exception>):请求该线程停止
  • join(<list of threads>):等待被指定的线程终止
加入Coordinator后上述例子变为:
[python] view plain copy
  1. import tensorflow as tf
  2. import sys
  3. q=tf.FIFOQueue(1000,"float")
  4. #计数器
  5. counter=tf.Variable(0.0)
  6. #操作:给计数器加一
  7. increment_op=tf.assign_add(counter,tf.constant(1.0))
  8. #操作:将计数器加入队列
  9. enqueue_op=q.enqueue(counter)
  10. #创建一个队列管理器QueueRunner,用这两个操作向q中添加元素,目前我们只使用一个线程:
  11. qr=tf.train.QueueRunner(q,enqueue_ops=[increment_op,enqueue_op]*1)
  12. #主线程:
  13. sess=tf.Session()
  14. sess.run(tf.global_variables_initializer())
  15. coord=tf.train.Coordinator()
  16. enqueue_threads=qr.create_threads(sess,coord=coord,start=True) #启动入队线程,Coordinator是线程的参数
  17. #主线程:
  18. for i in range(0,5):
  19. print(sess.run(q.dequeue()))
  20. coord.request_stop() #通知其他线程关闭
  21. coord.join(enqueue_threads)#其他所有线程关之后,这一函数才能返回

返回结果为:3.0
28.0
48.0
73.0
94.0

Tensorflow 从bin文件中读取数据并相关推荐

  1. python读文件和写文件-python开发--从文件中读取数据和写入文件

    #! /usr/bin/env python -*- coding:utf-8 -*- """ @Author:gcan @Email:1528667112@qq.com ...

  2. TF从文件中读取数据

    从文件中读取数据 在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步: 把样本数据写入TFRecords二进制文件 从队列 ...

  3. mfc从文件中读取数据_Python 中的 bytes、str 以及 unicode 区别

    从Python发展历史谈起 Python3和Python2表示字符序列的方式有所不同. Python3字符序列的两种表示为byte和str.前者的实例包含原始的8位值,即原始的字节:后者的实例包括Un ...

  4. 向HBase中导入数据3:使用MapReduce从HDFS或本地文件中读取数据并写入HBase(增加使用Reduce批量插入)

    前面我们介绍了: 向HBase中导入数据1:查询Hive后写入 向HBase中导入数据2:使用MapReduce从HDFS或本地文件中读取数据并写入HBase(只使用Map逐条查询) 为了提高插入效率 ...

  5. vc++从txt文件中读取数据

    数值分析课上老师说要将数据写在txt文件上,然后让程序从txt文件中读取数据.让本来C++已经遗忘了很久的我们无从下手,在网上也查看了很多,发现大多都是扯淡,放在VC++编辑器上发现并不能运行,不知道 ...

  6. 【Python】从文件中读取数据

    从文件中读取数据 1.1 读取整个文件 要读取文件,需要一个包含几行文本的文件(文件PI_DESC.txt与file_reader.py在同一目录下) PI_DESC.txt 3.1415926535 ...

  7. python3中的zip_Python3实现将文件归档到zip文件及从zip文件中读取数据的方法

    ''''' Created on Dec 24, 2012 将文件归档到zip文件,并从zip文件中读取数据 @author: liury_lab ''' # 压缩成zip文件 from zipfil ...

  8. 如何用c语言从txt文件中读取数据

    用C语言从txt文件中读取数据,可以使用C标准库文件自带的文件接口函数进行操作. 一.打开文件: FILE *fopen(const char *filename, const char *mode) ...

  9. xlswriter-用于Excel 2007+XLSX文件中读取数据

    介绍: xlswriter是一个 ,可用于在 Excel 2007+ XLSX 文件中读取数据,插入多个工作表,写入文本.数字.公式.日期.图表.图片和超链接. 特点: 一.写入 100%兼容的Exc ...

最新文章

  1. VScode操作记录
  2. python插件安装包_python 离线安装插件包
  3. 第k大的数 java_一道算法题:第K大的数
  4. R语言与正态总体均值的区间估计
  5. java设计模式 订阅模式_Java中的外观设计模式
  6. ipv6怎么进行邻居学习_事实证明,我的邻居也想学习编码。 他们只是没有意识到这是可能的。
  7. dataframe中多列除以不同列_Python之DataFrame切片与索引实验
  8. AngularJS开发指南
  9. NSString (NSStringPathExtensions) 类的相关方法和属性梳理
  10. cocos2d 屏幕適配_cocos2dx 3.2 屏幕适配的理解
  11. ps数位板绘画遇到问题总结
  12. 单容水箱液位pid控制实验报告_过程控制实验-单容水箱液位控制系统
  13. 联想android手机驱动,Lenovo联想手机驱动
  14. 第三十一章 与昔一何殊勇怯(一之全)
  15. 牛顿迭代法实现平方根函数
  16. 数字化转型|利用现代技术和通信手段|改变企业为客户创造价值的方式
  17. CTF-代码审计(2)
  18. aardio部署_aardio学习笔记-变量与常量
  19. Docker之通过阿里云的镜像加速器快速拉取镜像到本地
  20. SAP 玻璃原片单位问题处理

热门文章

  1. 《jsp程序设计》智多星手机销售网后台设计
  2. 关系数据库-关系代数-数据库习题
  3. ubuntu / linuxmint 搜狗输入法安装后 fcitx configure找不到的解决办法
  4. linux mac地址远程开机,用MAC地址远程开机的开机棒你见过吗?
  5. 空间计量经济学(4)---空间滞后与空间杜宾误差模型
  6. windows10小技巧: 将手机投影到windows10上
  7. [学习SLAM] 3D可视化 只viz模块和pangolin
  8. 猿创征文| 我的开发者工具箱之数据分析师装备库
  9. 塞班java手机qq浏览器下载_手机QQ浏览器 for Symbian S60v3
  10. Linux从图形界面切换到文本界面快捷键不好用的解决方法