文章首发于微信公众号《有三AI》

【从caffe到Tensorflow 1】io 操作

最近项目要频繁用到tensorflow,所以不得不认真研究下tensorflow而不是跟之前一样遇到了就搞一下了。

首先我觉得所有这些框架里面caffe是最清晰的,所以就算是学习tensorflow,我也会以caffe的思路去学习,这就是这个系列的用意。

今天是第1篇,咱们说io操作,也就是文件读取,载入内存。

01 Caffe的io操作

caffe的io,是通过在prototxt中定义数据输入,默认支持data,imagedata,hdf5data,window data等类型。Datalayer,输入是LMDB数据格式,image data 支持的是imagelist的数据格式。

对于LMDB来说,我们在caffe layer中配置准备好的二进制数据即可。

对于image data,我们准备一个data list,官方的imagedata是一个分类任务的list,格式为每行image,label,当然随着任务的不同我们可以自定义。比如分割任务image,mask。检测任务,image num of object, objectrect1,object rect2等。

典型的格式是这样:

具体的载入,就是在相关层的DataLayerSetUp函数中设置好输入大小,load_batch函数中,读取原始数据,再利用data_transform塞入内存。

当然caffe也可以自定义python层使用,不过我还是更习惯c++,何况这里比较的也是官方自带的layer。

从上面我们可以看出,caffe的io都是从文件中载入,只是文件的组织方式不同。

Tensorflow的io输入则要复杂,全面很多,我们参考tensorflow1.5的API。

http://link.zhihu.com/?target=https%3A//

www.tensorflow.org/api_docs/python/tf/data

02 Tensorflow的io操作

Tensorflow不止是读取文件这一种方法,它可以包含以下几种方式。

  • 预加载数据: 在TensorFlow图中定义常量或变量来保存所有数据(仅适用于数据量比较小的情况)

import tensorflow as tf
# 设计Graph
x1 = tf.constant([2, 3, 4])
x2 = tf.constant([4, 0, 1])
y = tf.add(x1, x2)
with tf.Session() as sess:print sess.run(y)

如上,x1,x2都是预加载好的数据。在设计Graph的时候,x1和x2就已经被定义成了两个有值的列表,在计算y的时候直接取x1和x2的值。这种方法的问题是将数据直接内嵌到Graph中,再把Graph传入Session中运行。当数据量比较大时,Graph的传输会遇到效率问题。

  • Feeding 它定义变量的时候用占位符替代数据,待运行的时候填充数据。

import tensorflow as tf
x1 = tf.placeholder(tf.int16)
x2 = tf.placeholder(tf.int16)
y = tf.add(x1, x2)
# 用Python产生数据
li1 = [2, 3, 4]
li2 = [4, 0, 1]
# 打开一个session --> 喂数据 --> 计算y
with tf.Session() as sess:print sess.run(y, feed_dict={x1: li1, x2: li2})

定义的时候,x1, x2只是占位符所以没有具体的值,运行的时候使用sess.run()中的feed_dict参数,将Python产生的数据喂给后端,并计算y。

  • Reading From File

前两种方法很方便,但是遇到大型数据的时候就会很吃力,即使是Feeding,中间环节的增加也是不小的开销,比如数据类型转换等等。而且,面对复杂类型的数据,也是处理不过来的。因此与caffe一样,tensorflow也是支持从文件中读取数据。

下面举一个利用队列读取硬盘中的数据到内存的例子:假如需要读取的数据存在一个list中。这篇博客举了一个很好的例子;

http://honggang.io/2016/08/19/tensorflow-data-reading/

在上图中,首先由一个单线程把文件名堆入队列,两个Reader同时从队列中取文件名并读取数据,Decoder将读出的数据解码后堆入样本队列。

利用了string_input_producer + tf.TextLineReader() +train.start_queue_runners来读取数据,string_input_producer的定义在

https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/python/training/input.py

string_input_producer(string_tensor,num_epochs=None,shuffle=True,seed=None,capacity=32,shared_name=None,name=None,cancel_op=None
)

从上面可见,可以指定num_epochs,是否shuffle等,这就是一个最简单的从文件中读取的例子了。

假设有文件A.csv如下:

Alpha1,A1Alpha2,A2Alpha3,A3

单个reader读取单个数据脚本如下;

import tensorflow as tf
filenames = ['A.csv'] 必须要以数组的形式
filename_queue =tf.train.string_input_producer(filenames,shuffle=False)
reader = tf.TextLineReader()# 定义Reader
key, value = reader.read(filename_queue)
# 定义Decoder
example, label = tf.decode_csv(value,record_defaults=[['null'], ['null']])
# 运行Graph
with tf.Session() as sess:coord = tf.train.Coordinator()  #创建一个协调器,管理线程threads =tf.train.start_queue_runners(coord=coord)  #启动QueueRunner, 此时文件名队列已经进队。for i in range(10):print example.eval()   #取样本的时候,一个Reader先从文件名队列中取出文件名,读出数据,Decoder解析后进入样本队列。coord.request_stop()coord.join(threads)

讲了上面的基础例子之后,我们开始看更复杂的例子。

上面的例子包含两类,一种是从placeholder读内存中的数据,一种是使用queue读硬盘中的数据,而1.3以后的Dataset API同时支持从内存和硬盘的读取。

它们支持多种类型的输入,分别是FixedLengthRecordDataset, TextLineDataset,TFRecordDataset类型的。

TextLineDataset:这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件,跟上面例子类似。

TFRecordDataset:这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample,这是很常用的。

FixedLengthRecordDataset:这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。

迭代器:提供了一种一次获取一个数据集元素的方法。

所有定义都在tensorflow/python/data/ops/readers.py中。

参考文章

https://zhuanlan.zhihu.com/p/30751039

我们先理解一下dataset是什么?

Dataset可以看作是相同类型“元素”的有序列表,而单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。先以最简单的,Dataset的每一个元素是一个数字为例:


import tensorflow as tf
import numpy as np
dataset =tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0,3.0, 4.0, 5.0]))

这样,我们就创建了一个dataset,这个dataset中含有5个元素,分别是1.0, 2.0, 3.0, 4.0, 5.0。

如何将这个dataset中的元素取出呢?方法是从Dataset中示例化一个Iterator,然后对Iterator进行迭代。

iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):print(sess.run(one_element))

对应的输出结果应该就是从1.0到5.0。语句iterator =dataset.make_one_shot_iterator()从dataset中实例化了一个Iterator,这个Iterator是一个“one shot iterator”,即只能从头到尾读取一次。one_element =iterator.get_next()表示从iterator里取出一个元素,调用sess.run(one_element)后,才能真正地取出一个值。

如果一个dataset中元素被读取完了,再尝试sess.run(one_element)的话,就会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据的行为是一致的。在实际程序中,可以在外界捕捉这个异常以判断数据是否读取完,请参考下面的代码:

dataset =tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0,3.0, 4.0, 5.0]))
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:
try:
while True:print(sess.run(one_element))
except tf.errors.OutOfRangeError:print("end!")

dataset还可以有一些基本的数据变换操作,即transform操作,常见的有map,batch,shuffle,repeat

把数据+1dataset = dataset.map(lambda x: x + 1)

组合成batch,dataset = dataset.batch(32)

进行shuffle,dataset = dataset.shuffle(buffer_size=10000)

repeat 组成多个epoch,dataset = dataset.repeat(5)

03 来一个实例

理解了dataset之后,我们再看如何从文件中读取数据。由于tfrecord是非常常用的格式,下面我们就以这个为例。

假如我们有两个文件夹,一个是整理好的固定大小的图片,一个是对应label图片,这是一个分割任务,下面我们开始做。

  • 处理成tfrecord格式

首先,我们要把数据处理成tfrecord格式。

我们先定义一下存储格式:

直接贴完整代码了

import tensorflow as tf
import os
import sysdef _bytes_feature(value):returntf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))def _convert_to_example(image_buffer, mask_buffer,filename, mask_filename):example =tf.train.Example(features=tf.train.Features(feature={'image': _bytes_feature(image_buffer),'mask': _bytes_feature(mask_buffer),"filename":_bytes_feature(bytes(filename.encode("UTF-8"))),"mask_filename":_bytes_feature(bytes(mask_filename.encode("UTF-8")))#"filename": _bytes_feature(bytes(filename,encoding="UTF-8")),#"mask_filename":_bytes_feature(bytes(mask_filename,encoding="UTF-8"))}))return examplefiles = os.listdir(sys.argv[1])
mask_dir = sys.argv[2]
writer = tf.python_io.TFRecordWriter(sys.argv[3])
for file in files:filename = filemask_filename =os.path.join(mask_dir,filename.split('.')[0] + ".png")filename = os.path.join(sys.argv[1],filename)try:image_buffer = tf.gfile.FastGFile(filename,'rb').read()mask_buffer = tf.gfile.FastGFile(mask_filename,'rb').read()print "filename=",filenameexample = _convert_to_example(image_buffer,mask_buffer, filename, mask_filename)writer.write(example.SerializeToString())except StopIteration as e:print "error"

_convert_to_example这个函数,就是定义存储的格式;tf.gfile.FastGFile就是读取图片原始文件格式且不编解码,writer = tf.python_io.TFRecordWriter(sys.argv[3])是定义writer,写起来其实挺简单。

tf.train.Example是一个protocol buffer,定义在

https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/core/example/example.proto

将数据填入到Example后就可以序列化为一个字符串。一个Example中包含Features,Features里包含Feature,每一个feature其实就是一个字典,如上面的一个字典包含4个字段。

  • 读取tf.records

读取数据就可以使用tf.TFRecordReader的tf.parse_single_example解析器。它将Example protocolbuffer解析为张量。

简单的利用队列读取,可以采用下面的方法

filename_queue =tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()_, serialized_example =reader.read(filename_queue)   #返回文件名和文件features =tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),‘img_raw' : tf.FixedLenFeature([],tf.string),})
img = tf.decode_raw(features['image'], tf.uint8)
label = tf.decode_raw(features['mask'], tf.uint8)

不过,我们这里利用新的API的dataset来读取,更加高效。直接贴上代码如下:

上面定义过_convert_to_example,我们这里先定义一个读取格式。

def _extract_features(example):features = {"image": tf.FixedLenFeature((), tf.string),"mask": tf.FixedLenFeature((), tf.string)
}

获取一个example
parsed_example = tf.parse_single_example(example,features)
得到原始图并转换格式,set_shape是必须的,因为没有存储尺寸信息。

  images =tf.cast(tf.image.decode_jpeg(parsed_example["image"]), dtype=tf.float32)images.set_shape([224, 224, 3])masks =tf.cast(tf.image.decode_jpeg(parsed_example["mask"]), dtype=tf.float32) / 255.masks.set_shape([224, 224, 1])return images, masks

下面这个函数就是create迭代器了,在这里我们使用最简单的iterator,one-shot iterator来迭代,当然它只支持在一个dataset上迭代一次,不需要显式初始化。这里不需要怀疑epoch的问题,因为dataset.repeat(num_epoch)就会设置epoch数目,所以虽然只在dataset上迭代一次,但是已经遍历过数据epoch次。

def create_one_shot_iterator(filenames, batch_size,num_epoch):dataset = tf.data.TFRecordDataset(filenames)dataset = dataset.map(_extract_features)dataset = dataset.shuffle(buffer_size=batch_size)dataset = dataset.batch(batch_size)dataset = dataset.repeat(num_epoch)
return dataset.make_one_shot_iterator()

用的时候,就是

train_iterator = create_one_shot_iterator(train_files,train_batch_size, num_epoch=num_epochs)
next_images, next_masks = train_iterator.get_next()

当然读取出来之后可以做一些数据增强的操作。

就这样完毕!

感谢各位看官的耐心阅读,不足之处希望多多指教。后续内容将会不定期奉上,欢迎大家关注有三公众号 有三AI

【Tensorflow】io 操作相关推荐

  1. 【从caffe到Tensorflow 1】io 操作

    最近项目要频繁用到tensorflow,所以不得不认真研究下tensorflow而不是跟之前一样遇到了就搞一下了. 首先我觉得所有这些框架里面caffe是最清晰的,所以就算是学习tensorflow, ...

  2. Tensorflow线程队列与IO操作

    目录 Tensorflow线程队列与IO操作 1 线程和队列 1.1 前言 1.2 队列 1.3 队列管理器 1.4 线程协调器 2 文件读取 2.1 流程 2.2 文件读取API: 3 图像读取 3 ...

  3. TensorFlow常用操作:代码示例

    1,定义矩阵代码示例: import tensorflow as tftf.zeros([3,4]) #定义3行4列元素均为0的矩阵tensor=tf.constant([1,2,3,4])#定义一维 ...

  4. [零基础学JAVA]Java SE应用部分-27.Java IO操作(01)

    JAVA IO操作目标 本季知识点 1.File类 2.RandomAccessFile类 File类 在JAVA中所有的IO操作全部存放在java.io包中. File是唯一一个与文件本身有关的操作 ...

  5. java中的IO操作总结(一)

    转载:http://www.cnblogs.com/nerxious/archive/2012/12/15/2818848.html    所谓IO,也就是Input与Output的缩写.在java中 ...

  6. python和R文件IO操作对比及dataframe创建方式对比:read_csv、to_csv、write.csv、 data.frame、pd.DataFrame

    python和R文件IO操作对比及dataframe创建方式对比 很多工程师可能刚开始的时候只熟悉python或者R其中的一个进行数据科学相关的任务. 那么如果我们对比这学习可以快速了解语言设计背后的 ...

  7. CPU被夺走的三种状态 执行时间久了 IO操作让cpu等待 被优先级高的抢占

    CPU被夺走的三种状态   .                           cpu的三种状态之              运行态      就绪态等待被调用             阻塞态 i ...

  8. python中io.textio_Python文件读写概述(IO操作、文件读写、stringiobytesio、序列化),python,的,小,总结,StringIOBytesIO...

    IO操作 在进行文件的读写之前,需要说明几点.首先,运行的程序和读取的数据都会在内存中缓存. 进入到 程序或数据 内存 其次,用python程序进行文件的读写,需要创建一个小工具–文件流,用来处理数据 ...

  9. .NET中的IO操作基础介绍

    关于IO简介 .NET中的IO操作,经常需要调用一下几个类. 1.FileStream类        文件流类,负责大文件的拷贝,读写. 2.Path类                   Path ...

最新文章

  1. 谈谈微服务设计中的API网关模式
  2. 励研(LY) CRC16算法
  3. [C#]泛型与非泛型集合类的区别及使用例程,包括ArrayList,Hashtable,ListT,DictionaryTkey,Tvalue,SortedListTkey,Tvalue,...
  4. java.util.Date转换成java.sql.Date
  5. 设计模式是什么鬼(原型)
  6. 接口安全怎么处理_压瓦机设备噪声大怎么处理及压瓦机的安全使用
  7. Python 爬虫的工具列表大全
  8. Android 使用库项目时的一个特殊tip
  9. 代码检查工具系列——CheckStyle
  10. 抖音那种一道光闪过转场效果是怎么做的?
  11. websocket握手失败_WebSocket通信之握手协议
  12. 性能测试工程师职业现状分析
  13. 计算机屏幕出现条纹w7,电脑屏幕出现条纹,教您电脑屏幕出现条纹闪烁怎么解决...
  14. MySQL数据库基础-----多表查询
  15. python导出dxf图_利用Dxfwrite/ezdxf操作CAD文件!
  16. 动力电池系统介绍(十)——电压采样
  17. SYU新人提高90题1(基础题)
  18. three.js BIM可视化练习
  19. 【转】Protel99Se制作PCB基本流程
  20. 惠普w ndOws8的配置是多少,莫博士:旧电脑升级Win8是自己找罪受

热门文章

  1. 类与接口(四)方法重载解析
  2. hibernate常见错误及解决方法总结
  3. 01--MySQL自学教程:数据库MySQL安装和校验
  4. python入门基础系列_Python3基础系列-基本入门语法
  5. 【转】解决:ORA-19602: cannot backup or copy active file in NOARCHIVELOG mode
  6. 将solr安装到tomcat里
  7. python在子类中添加新的属性_pycharm实现在子类中添加一个父类没有的属性
  8. electron 样式不加载_electron-vue,启动之后没了样式和图片是什么原因呢,请大佬帮忙康康...
  9. c语言指针指向字符串单个,C语言 有没有可能调用一个指向字符串的函数指针?...
  10. 「 活动 」连续 3 天,企业容器应用实战营上海站来啦!