附加一个链接关于DatasetAPI:https://zhuanlan.zhihu.com/p/30751039

本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片、大量图片,和TFRecorder读取方式。并且还补充了功能相近的tf函数。

处理单张图片

我们训练完模型之后,常常要用图片测试,有的时候,我们并不需要对很多图像做测试,可能就是几张甚至一张。这种情况下没有必要用队列机制。

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. def read_image(file_name):
  4. img = tf.read_file(filename=file_name) #默认读取格式为uint8
  5. print("img 的类型是",type(img));
  6. img = tf.image.decode_jpeg(img,channels=0) # channels 为1得到的是灰度图,为0则按照图片格式来读
  7. return img
  8. def main( ):
  9. with tf.device("/cpu:0"):
  10. img_path='./1.jpg'
  11. img=read_image(img_path)
  12. with tf.Session() as sess:
  13. image_numpy=sess.run(img)
  14. print(image_numpy)
  15. print(image_numpy.dtype)
  16. print(image_numpy.shape)
  17. plt.imshow(image_numpy)
  18. plt.show()
  19. if __name__=="__main__":
  20. main()

img_path是文件所在地址包括文件名称,地址用相对地址或者绝对地址都行

输出结果为:

  1. img 的类型是 <class 'tensorflow.python.framework.ops.Tensor'>
  2. [[[196 219 209]
  3. [196 219 209]
  4. [196 219 209]
  5. ...
  6. [[ 71 106 42]
  7. [ 59 89 39]
  8. [ 34 63 19]
  9. ...
  10. [ 21 52 46]
  11. [ 15 45 43]
  12. [ 22 50 53]]]
  13. uint8
  14. (675, 1200, 3)

和tf.read_file用法相似的函数还有tf.gfile.FastGFile  tf.gfile.GFile,只是要指定读取方式是'r' 还是'rb' 。

需要读取大量图像用于训练

这种情况就需要使用Tensorflow队列机制。首先是获得每张图片的路径,把他们都放进一个list里面,然后用string_input_producer创建队列,再用tf.WholeFileReader读取。具体请看下例:

  1. def get_image_batch(data_file,batch_size):
  2. data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
  3. #这个num_epochs函数在整个Graph是local Variable,所以在sess.run全局变量的时候也要加上局部变量。 filenames_queue=tf.train.string_input_producer(data_names,num_epochs=50,shuffle=True,capacity=512)
  4. reader=tf.WholeFileReader()
  5. _,img_bytes=reader.read(filenames_queue)
  6. image=tf.image.decode_png(img_bytes,channels=1) #读取的是什么格式,就decode什么格式
  7. #解码成单通道的,并且获得的结果的shape是[?, ?,1],也就是Graph不知道图像的大小,需要set_shape
  8. image.set_shape([180,180,1]) #set到原本已知图像的大小。或者直接通过tf.image.resize_images
  9. image=tf.image.convert_image_dtype(image,tf.float32)
  10. #预处理 下面的一句代码可以换成自己想使用的预处理方式
  11. #image=tf.divide(image,255.0)
  12. return tf.train.batch([image],batch_size)

这里的date_file是指文件夹所在的路径,不包括文件名。第一句是遍历指定目录下的文件名称,存放到一个list中。当然这个做法有很多种方法,比如glob.glob,或者tf.train.match_filename_once

全部代码如下:

  1. import tensorflow as tf
  2. import os
  3. def read_image(data_file,batch_size):
  4. data_names=[os.path.join(data_file,k) for k in os.listdir(data_file)]
  5. filenames_queue=tf.train.string_input_producer(data_names,num_epochs=5,shuffle=True,capacity=30)
  6. reader=tf.WholeFileReader()
  7. _,img_bytes=reader.read(filenames_queue)
  8. image=tf.image.decode_jpeg(img_bytes,channels=1)
  9. image=tf.image.resize_images(image,(180,180))
  10. image=tf.image.convert_image_dtype(image,tf.float32)
  11. return tf.train.batch([image],batch_size)
  12. def main( ):
  13. img_path=r'F:\dataSet\WIDER\WIDER_train\images\6--Funeral' #本地的一个数据集目录,有足够的图像
  14. img=read_image(img_path,batch_size=10)
  15. image=img[0] #取出每个batch的第一个数据
  16. print(image)
  17. init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
  18. with tf.Session() as sess:
  19. sess.run(init)
  20. coord = tf.train.Coordinator()
  21. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  22. try:
  23. while not coord.should_stop():
  24. print(image.shape)
  25. except tf.errors.OutOfRangeError:
  26. print('read done')
  27. finally:
  28. coord.request_stop()
  29. coord.join(threads)
  30. if __name__=="__main__":
  31. main()

这段代码可以说写的很是规整了。注意到init里面有对local变量的初始化,并且因为用到了队列,当然要告诉电脑什么时候队列开始, tf.train.Coordinator 和 tf.train.start_queue_runners 就是两个管理队列的类,用法如程序所示。

输出如下:

  1. (180, 180, 1)
  2. (180, 180, 1)
  3. (180, 180, 1)
  4. (180, 180, 1)
  5. (180, 180, 1)

与 tf.train.string_input_producer相似的函数是 tf.train.slice_input_producer。 tf.train.slice_input_producer和tf.train.string_input_producer的第一个参数形式不一样。等有时间再做一个二者比较的博客

对TFRecorder解码获得图像数据

其实这块和上一种方式差不多的,更重要的是怎么生成TFRecorder文件,这一部分我会补充到另一篇博客上。

仍然使用 tf.train.string_input_producer。

  1. import tensorflow as tf
  2. import matplotlib.pyplot as plt
  3. import os
  4. import cv2
  5. import numpy as np
  6. import glob
  7. def read_image(data_file,batch_size):
  8. files_path=glob.glob(data_file)
  9. queue=tf.train.string_input_producer(files_path,num_epochs=None)
  10. reader = tf.TFRecordReader()
  11. print(queue)
  12. _, serialized_example = reader.read(queue)
  13. features = tf.parse_single_example(
  14. serialized_example,
  15. features={
  16. 'image_raw': tf.FixedLenFeature([], tf.string),
  17. 'label_raw': tf.FixedLenFeature([], tf.string),
  18. })
  19. image = tf.decode_raw(features['image_raw'], tf.uint8)
  20. image = tf.cast(image, tf.float32)
  21. image.set_shape((12*12*3))
  22. label = tf.decode_raw(features['label_raw'], tf.float32)
  23. label.set_shape((2))
  24. # 预处理部分省略,大家可以自己根据需要添加
  25. return tf.train.batch([image,label],batch_size=batch_size,num_threads=4,capacity=5*batch_size)
  26. def main( ):
  27. img_path=r'F:\python\MTCNN_by_myself\prepare_data\pnet*.tfrecords' #本地的几个tf文件
  28. img,label=read_image(img_path,batch_size=10)
  29. image=img[0]
  30. init=[tf.global_variables_initializer(),tf.local_variables_initializer()]
  31. with tf.Session() as sess:
  32. sess.run(init)
  33. coord = tf.train.Coordinator()
  34. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  35. try:
  36. while not coord.should_stop():
  37. print(image.shape)
  38. except tf.errors.OutOfRangeError:
  39. print('read done')
  40. finally:
  41. coord.request_stop()
  42. coord.join(threads)
  43. if __name__=="__main__":
  44. main()

在read_image函数中,先使用glob函数获得了存放tfrecord文件的列表,然后根据TFRecord文件是如何存的就如何parse,再set_shape

这里有必要提醒下parse的方式。我们看到这里用的是tf.decode_raw ,因为做TFRecord是将图像数据string化了,数据是串行的,丢失了空间结果。从features中取出image和label的数据,这时就要用 tf.decode_raw  解码,得到的结果当然也是串行的了,所以set_shape 成一个串行的,再reshape。这种方式是取决于你的编码TFRecord方式的。

再举一种例子:

  1. reader=tf.TFRecordReader()
  2. _,serialized_example=reader.read(file_name_queue)
  3. features = tf.parse_single_example(serialized_example, features={
  4. 'data': tf.FixedLenFeature([256,256], tf.float32),
  5. 'label': tf.FixedLenFeature([], tf.int64),
  6. 'id': tf.FixedLenFeature([], tf.int64)
  7. })
  8. img = features['data']
  9. label =features['label']
  10. id = features['id']

这个时候就不需要任何解码了。因为做TFRecord的方式就是直接把图像数据append进去了。

TensorFlow中读取图像数据的三种方式(转)相关推荐

  1. python读取图像数据流_浅谈TensorFlow中读取图像数据的三种方式

    本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片.大量图片,和TFRecorder读取方式.并且还补充了功能相近的tf函数. 1.处理单张图片 我们训练完模型之后,常常要用图片 ...

  2. python csv库,Python 中导入csv数据的三种方法

    Python 中导入csv数据的三种方法,具体内容如下所示: 1.通过标准的Python库导入CSV文件: Python提供了一个标准的类库CSV文件.这个类库中的reader()函数用来导入CSV文 ...

  3. ios网络学习------4 UIWebView的加载本地数据的三种方式

    ios网络学习------4 UIWebView的加载本地数据的三种方式 分类: IOS2014-06-27 12:56 959人阅读 评论(0) 收藏 举报 UIWebView是IOS内置的浏览器, ...

  4. C#读取Excel数据的几种方式(包含大量数据读取)

    C#读取Excel数据的几种方式(包含大量数据读取) C#读取Excel数据的几种方式(包含大量数据读取) OleDB方式 COM组件的方式 NPOI方式读取(此处未测试,参考其他博文) 常用的Exc ...

  5. discard connection丢失数据_python kafka 生产者发送数据的三种方式

    python kafka 生产者发送数据的三种方式 发送方式 同步发送 发送数据耗时最长 有发送数据的状态,不会丢失数据,数据可靠性高 以同步的方式发送消息时,一条一条的发送,对每条消息返回的结果判断 ...

  6. android sqlite使用之模糊查询数据库数据的三种方式

    android sqlite使用之模糊查询数据库数据的三种方式 android应用开发中常常需要记录一下数据,而在查询的时候如何实现模糊查询呢?很少有文章来做这样的介绍,所以这里简单的介绍下三种sql ...

  7. 【Matlab系列】MATLAB中显示输出数据的四种方式

    DATE: 2019-11-13 1.参考 MATLAB 显示输出数据的三种方式 Matlab之print,fprint,fscanf,disp函数 2.改变数据格式 当数据重复再命令行窗口时,整数以 ...

  8. Spark读取Hive数据的两种方式与保存数据到HDFS

    Spark读取Hive数据的两种方式与保存数据到HDFS Spark读取Hive数据的方式主要有两种 1. 通过访问hive metastore的方式,这种方式通过访问hive的metastore元数 ...

  9. mysql数据库删除数据的三种方式:

    mysql数据库删除数据的三种方式: delete from table where 直接删除表中的某一行数据,并且同时将该行的删除操作作为事务记录在日志中保存以便进行进行回滚操作.所以delete相 ...

最新文章

  1. Linux内核初始化阶段内存管理的几种阶段
  2. java aio事件模型_IO模型之AIO代码及其实践详解
  3. 改革以来计算机应用发展总结,计算机应用专业课程改革总结.doc
  4. java连接各数据库的语句
  5. mysql-front权限管理_mysql 用户及权限管理 小结
  6. Ubuntu中安装Eclipse的SVN插件——subclipse
  7. 精英赛上线|冠军万元奖金
  8. 算法-冒泡排序和快速排序
  9. liunx查询进程下的线程
  10. 15种TBProAudio音乐插件合集包 2021最新
  11. laravel使用artisan报错SQLSTATE[42S02]: Base table or view not found: 1146
  12. back to wuxi
  13. 添加mysql.h头文件
  14. 【路径大全】iphone所有文件路径。CYDIA
  15. 计算机动画的教育应用研究,计算机动画技术在高校CAI课件制作中的应用研究
  16. Linux登录mysql密码正确被拒绝访问
  17. 读懂K线,就能理解期货股票交易中人性的期望、猜疑、幻想、贪婪、恐惧...
  18. Smart Construction:工程机械正在加速进入智能建设时代
  19. 推荐上百本优质大数据书籍,附必读清单(大数据宝藏)
  20. 树莓派从零开始快速入门系列汇总

热门文章

  1. django xadmin ForeignKey display
  2. php判断直线相交,zoj 1158 判断2线段完全相交
  3. Google maps API开发(一)(转)
  4. react脚手架_react脚手架create-react-app安装与使用
  5. matlab 状态空间的波特图,MATLAB:对于状态空间方程的系统辨识
  6. R语言向matlab转化,我有一段MATLAB的程序,现在想转换成R语言代码
  7. php adodb使用,ADODB类使用_PHP教程
  8. Lua中实现类似C#的事件机制
  9. C#中实现byte[]与任意对象互换(服务端通讯专用)
  10. 修改linux端口22,修改LINUX 默认的22端口