一、TFRecord文件书写效率对比(单线程和多线程对比)

1、准备工作

# Author : Hellcat
# Time   : 18-1-15'''
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
'''import os
import glob
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as pltnp.set_printoptions(threshold=np.inf)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)def _int64_feature(value):"""生成整数数据属性"""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):"""生成字符型数据属性"""return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

2、单线程TFR文件写入

def image2TFR_single_thread(path='./Data_Set/cartoon_faces',with_label=False):# 获取图片名称以及数量# 等价于image_names = glob.glob(path+'/*')# 使用next可以直接取出迭代器中的元素image_names = next(os.walk(path))[2]num_file = len(image_names)# 定义每个文件中放入多少数据instances_per_shard = 10000# 定义写多少个文件(数据量大时可以写入多个文件加速)num_shards = num_file // instances_per_shard + 1for file_i in range(num_shards):# 文件名命名规则file_name = './TFRecord_Output/{0}.tfrecords_{1}_of_{2}_st'\.format(path.split('/')[-1], file_i+1, num_shards)# 书写器初始化writer = tf.python_io.TFRecordWriter(file_name)for index, image_name in enumerate(image_names[file_i*instances_per_shard:(file_i+1)*instances_per_shard]):image_data = plt.imread(os.path.join(path, image_name))if with_label == True:pass# TODO# 如果有标签,则在这里添加确定标签的规则,注意非one_hot# label = ……image_raw = image_data.tostring()example = tf.train.Example(features=tf.train.Features(feature={'image': _bytes_feature(image_raw),# 'label': _int64_feature(label)}))writer.write(example.SerializeToString())# 书写器关闭writer.close()

3、多线程TFR文件写入

def image2TFR_multiple_threads(path='./Data_Set/cartoon_faces',with_label=False):# 获取图片名称以及数量# 等价于image_names = glob.glob(path+'/*')# 使用next可以直接取出迭代器中的元素image_names = next(os.walk(path))[2]num_file = len(image_names)# 定义每个文件中放入多少数据instances_per_shard = 10000# 定义写多少个文件(数据量大时可以写入多个文件加速)num_shards = num_file // instances_per_shard + 1file_names = ['./TFRecord_Output/{0}.tfrecords_{1}_of_{2}_mt'.format(path.split('/')[-1], file_i+1, num_shards) for file_i in range(num_shards)]def _TFR_write():for file_name in file_names:file_names.remove(file_name)writer = tf.python_io.TFRecordWriter(file_name)num = 0for image_name in image_names:num += 1if num > instances_per_shard:breakimage_names.remove(image_name)image_data = plt.imread(os.path.join(path, image_name))if with_label == True:pass# TODO# 如果有标签,则在这里添加确定标签的规则,注意非one_hot# label = ……image_raw = image_data.tostring()example = tf.train.Example(features=tf.train.Features(feature={'image': _bytes_feature(image_raw),# 'label': _int64_feature(label)}))writer.write(example.SerializeToString())writer.close()threads = []t1 = threading.Thread(target=_TFR_write, name='resize_img_thread:0')threads.append(t1)t2 = threading.Thread(target=_TFR_write, name='resize_img_thread:1')threads.append(t2)for t in threads:t.start()for t in threads:t.join()

4、测试部分

if __name__=='__main__':import datetimeimport threadingfor i in range(15):time1 = datetime.datetime.now()image2TFR_multiple_threads()time2 = datetime.datetime.now()image2TFR_single_thread()time3 = datetime.datetime.now()print('mul:', time2-time1)print('sin:', time3-time2)print('_*_'*10)

5、部分输出

mul: 0:00:25.779139
sin: 0:00:26.312438
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.203649
sin: 0:00:27.982487
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:31.193418
sin: 0:00:28.735610
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.414592
sin: 0:00:30.207631
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.999488
sin: 0:00:29.683136
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.659919
sin: 0:00:28.534984
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:30.366691
sin: 0:00:31.014559
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.288918
sin: 0:00:29.142247
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:29.861579
sin: 0:00:29.329732
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.854213
sin: 0:00:33.794422
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.010327
sin: 0:00:29.163616
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.773299
sin: 0:00:29.312738
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.815851
sin: 0:00:28.715579
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.889409
sin: 0:00:28.157235
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.143782
sin: 0:00:28.988136
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.533430
sin: 0:00:30.000925
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:28.158601
sin: 0:00:29.448665
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.839638
sin: 0:00:28.908899
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:27.922513
sin: 0:00:28.757721
_*__*__*__*__*__*__*__*__*__*_
mul: 0:00:31.227687
sin: 0:00:29.576041
_*__*__*__*__*__*__*__*__*__*_

可能是数据量不够大的原因,多线程没有明显的优势,可能写入文件数增加会更好,但个人感觉由于涉及到写入文件句柄操作这不是个适合使用多线程加速的任务。

二、TFRecord实际使用框架

总的原则,把可以修改的超参数啊、路径啊什么的单独提出来,不要放在程序中,那样使用时想要修改会及其繁琐,且易出错

1、包导入以及超参数设定

# Author : Hellcat
# Time   : 18-1-15"""
import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"
"""import os
import glob
import numpy as np
import tensorflow as tf
from scipy.misc import imread, imresizenp.set_printoptions(threshold=np.inf)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)# 读取数据文件的轮数
NUM_EPOCHS = 1
# TFR保存图像尺寸
IMAGE_HEIGHT = 227
IMAGE_WIDTH = 227
IMAGE_DEPTH = 3
# 训练batch尺寸
BATCH_SIZE = 2
# 定义每个TFR文件中放入多少条数据
INSTANCES_PER_SHARD = 10000
# 图片文件存放路径
IMAGE_PATH = './Data_Set/cartoon_faces'
# 图片文件和标签清单保存文件
IMAGE_LABEL_LIST = 'images_&_labels.txt'
# TFR文件保存路径
TFR_PATH = './TFRecord_Output'

2、文件清单生成

def filename_list(path=IMAGE_PATH):"""文件清单生成:param path:图像路径,path下直接是图片 :return: txt文件,每一行内容是:路径图片名+若干空格+类别标签数字+\n"""# 获取图片名称以及数量# 等价于image_names = glob.glob(path+'/*')# 使用next可以直接取出迭代器中的元素file_names = next(os.walk(path))[2]with open(IMAGE_LABEL_LIST, 'w') as f:for file_name in file_names:f.write(path+'/'+file_name+' '+'1'+'\n')

3、TFR文件生成

def image_to_TFR(image_and_label=IMAGE_LABEL_LIST,image_height=IMAGE_HEIGHT,image_width=IMAGE_WIDTH):"""从清单读取图片并生成TFR文件:param image_and_label: txt图片清单:param image_height: 保存如TFR文件的图片高度:param image_width: 保存TFR文件的图片宽度"""def _int64_feature(value):"""生成整数数据属性"""return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def _bytes_feature(value):"""生成字符型数据属性"""return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))with open(image_and_label, 'r') as f:lines = f.readlines()image_paths = [image_path.strip('\n').split(' ')[0] for image_path in lines]labels = [image_path.strip('\n').split(' ')[-1] for image_path in lines]# 如下操作会报错,因为忽略了指针问题,第一次readlines后指针到达文件末尾,第二次readlines什么都read不到# image_paths = [image_path.strip('\n').split(' ')[0] for image_path in f.readlines()]# labels = [image_path.strip('\n').split(' ')[-1] for image_path in f.readlines()]num_file = len(image_paths)# 定义写多少个文件(数据量大时可以写入多个文件加速)num_shards = num_file // INSTANCES_PER_SHARD + 1for file_i in range(num_shards):# 文件名命名规则file_name = os.path.join(TFR_PATH, '{0}.tfrecords_{1}_of_{2}')\.format(image_paths[0].split('/')[-2], file_i+1, num_shards)print('正在生成文件: ', file_name)# 书写器初始化writer = tf.python_io.TFRecordWriter(file_name)for index, image_path in enumerate(image_paths[file_i*INSTANCES_PER_SHARD:(file_i+1)*INSTANCES_PER_SHARD]):image_data = imread(os.path.join(image_path))image_data = imresize(image_data, (image_height, image_width))image_raw = image_data.tostring()example = tf.train.Example(features=tf.train.Features(feature={'image': _bytes_feature(image_raw),'label': _int64_feature(int(labels[index]))}))writer.write(example.SerializeToString())# 书写器关闭writer.close()

4、读取TFR文件并生成batch数据

本函数最后的images和labels可以作为return,直接送入网络参与训练

def batch_from_TFR(image_height=IMAGE_HEIGHT,image_width=IMAGE_WIDTH,image_depth=IMAGE_DEPTH):"""从TFR文件读取batch数据"""if not os.path.exists(TFR_PATH):os.makedirs(TFR_PATH)'''读取TFR数据并还原为uint8的图片'''file_names = glob.glob(os.path.join(TFR_PATH, '{0}.tfrecords_*_of_*').format(IMAGE_PATH.split('/')[-1]))filename_queue = tf.train.string_input_producer(file_names, num_epochs=NUM_EPOCHS, shuffle=True)reader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)features = tf.parse_single_example(serialized_example,features={'image': tf.FixedLenFeature([], tf.string),'label': tf.FixedLenFeature([], tf.int64)})image = features['image']image_decode = tf.decode_raw(image, tf.uint8)# 解码会变为一维数组,所以这里设定shape时需要设定为一维数组image_decode.set_shape([image_height*image_width*image_depth])image_decode = tf.reshape(image_decode, [image_height, image_width, image_depth])label = tf.cast(features['label'], tf.int32)'''图像预处理''''''生成batch图像'''# 随机获得batch_size大小的图像和labelimages, labels = tf.train.shuffle_batch([image_decode, label],batch_size=BATCH_SIZE,num_threads=1,capacity=1000 + 3 * BATCH_SIZE,  # 队列最大容量min_after_dequeue=1000)

5、包含在上面batch函数中的测试模块

    # 测试部分print(images)sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)img = sess.run(images)[0]import matplotlib.pyplot as pltplt.imshow(img)coord.request_stop()coord.join(threads)

测试结果,

6、启动部分

if __name__ == '__main__':import datetimetime1 = datetime.datetime.now()# filename_list()# image_to_TFR()batch_from_TFR()time2 = datetime.datetime.now()print(time2-time1)

从测试部分的运行注意到设计tf的队列操作时,局部变量初始化sess.run(tf.global_variables_initializer())是必须的,否则会报错(『TensorFlow』问题整理)。

转载于:https://www.cnblogs.com/hellcat/p/8287831.html

『TensorFlow』TFR数据预处理探究以及框架搭建相关推荐

  1. python 动漫卡通人物图片大全,『TensorFlow』DCGAN生成动漫人物头像_下

    一.计算图效果以及实际代码实现 计算图效果 实际模型实现 相关介绍移步我的github项目. 二.生成器与判别器设计 生成器 相关参量, 噪声向量z维度:100 标签向量y维度:10(如果有的话) 生 ...

  2. 『TensorFlow』专题汇总

    TensorFlow函数查询 『TensorFlow』0.x_&_1.x版本框架改动汇总 『TensorFlow』函数查询列表_数值计算 『TensorFlow』函数查询列表_张量属性调整 『 ...

  3. 『TensorFlow』函数查询列表_张量属性调整

    博客园 首页 新随笔 新文章 联系 订阅 管理 『TensorFlow』函数查询列表_张量属性调整 数据类型转换Casting 操作 描述 tf.string_to_number (string_te ...

  4. 『TensorFlow』模型保存和载入方法汇总

    一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 参数名称 功能说明 默认值 var_list Saver中存储变 ...

  5. 『TensorFlow』模型载入方法汇总

    『TensorFlow』第七弹_保存&载入会话_霸王回马 一.TensorFlow常规模型加载方法 保存模型 tf.train.Saver()类,.save(sess, ckpt文件目录)方法 ...

  6. 『TensorFlow』命令行参数解析

    argparse很强大,但是我们未必需要使用这么繁杂的东西,TensorFlow自己封装了一个简化版本的解析方式,实际上是对argparse的封装 脚本化调用tensorflow的标准范式: impo ...

  7. Tensorflow nmt的数据预处理过程

    tensorflow nmt的数据预处理过程 在tensorflow/nmt项目中,训练数据和推断数据的输入使用了新的Dataset API,应该是tensorflow 1.2之后引入的API,方便数 ...

  8. 『TensorFlow』第七弹_保存载入会话_霸王回马

    首更: 由于TensorFlow的奇怪形式,所以载入保存的是sess,把会话中当前激活的变量保存下来,所以必须保证(其他网络也要求这个)保存网络和载入网络的结构一致,且变量名称必须一致,这是caffe ...

  9. 『TensorFlow』DCGAN生成动漫人物头像_下

    『TensorFlow』以GAN为例的神经网络类范式 『cs231n』通过代码理解gan网络&tensorflow共享变量机制_上 『TensorFlow』通过代码理解gan网络_中 一.计算 ...

  10. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

最新文章

  1. 数据结构与算法:算法简介
  2. 【JVM】Java对象创建的流程步骤
  3. 为了节省能量,人类演化出了“不合规律”的大脑神经元 | Nature
  4. 拼装机器人感想_学习制作机器人的感想作文500字6篇
  5. 局域网读取文件_教你windows局域网如何设置共享文件
  6. xml突然变成空白_“侏罗纪中期”出现了型增转变填补食肉性恐龙体型发展当中的空白...
  7. 转:修改Content Server管理员密码 - [Documentum 实施开发记录]
  8. 年发5篇Science,现入选中国榜“35岁以下科技创新35人”!
  9. 方德系统装exe文件_国产x86处理器+中科方德定制Linux 完美运行exe
  10. oracle 日期转换成毫秒数,ORACLE:毫秒与日期的相互转换,获取某天的信息
  11. 启动mysql时显示:/tmp/mysql.sock 不存在的解决方法
  12. Jsrender初体验
  13. BMFont制作美术字体包教包会
  14. lisp型材库_STMX 1.3.2 发布,高性能的 Common Lisp 库
  15. 算术平均数、几何平均数、调和平均数的、标准差、方差、正态分布、异常值噪声处理
  16. python中怎么取小数点后两位函数_python中round函数保留两位小数的方法
  17. 动态为Spring Boot项目中所有自定义的Controller添加过滤器的两种方法
  18. 【ESP32】VSCode+Arduino+Platformio 如何使用ESP32上的PSRAM
  19. 20190422每周精品之认知
  20. 初学入门YOLOv5手势识别之制作并训练自己的数据集

热门文章

  1. 应用测试一(烤地瓜)——> 隐藏数据
  2. ioi 赛制_《Produce48》现坑爹赛制 网友称这波操作令人窒息
  3. 系统地学习打字(个人见解)
  4. 【win 10】win 10:远程连接 Windows 服务器工具下载、安装和使用:PowerShell server Putty —— win 10 之间通过 ssh 登录
  5. 硬件电路设计笔记-电平转换电路
  6. Because You Loved Me歌词
  7. 音声合成:音高、泛音、谐波、基频 到底是什么概念?
  8. 洛谷刷题C语言:第一次,第二次,成交!、Bessie‘s Secret Pasture S、金币、Bookshelf B、东南西北
  9. 资源屋分享两款导航网站源码 支持自动收录、自动审核、自动检测友链功能
  10. sklearn.neighbors