本文完整代码在 https://github.com/jiemojiemo/Tensorflow_Demo/blob/master/image_loader.ipynb

Tensorflow图像处理以及数据读取

本人研究的方向是图像处理,这个领域几乎被深度学习的方法给统治了,例如图像去燥、图像超分辨、图像增强等等。在模拟实现相关论文的方法时,我发现最难的部分并不是深度学习的网络,而是如何构建你的训练集。通常,为了构建一个可训练的数据库我需要做:

  1. 上网找到论文提到的图像数据库,或者自己收集图像
  2. 对图像进行处理,构建训练所需的数据库,处理的方式各种各样,包括

    • 图像颜色域的变化,例如RGB转HSV,RGB转Gray等等
    • 图像大小的缩放,例如将不同大小的图像缩放为大小一致的图像
    • 提前图像块(image patch),就是从整张图像中,扣下小块(例如32*32)的小图像,这么做主要是因为可以增加训练数据的量,另外小块的图像训练起来速度更快,image patch的合理性是基于卷积神经网络的感受野(Receptive field)通常不会太大
    • 数据增强(Data augmentation),例如将图像上下翻转,左右翻转,裁剪,旋转等等。这里有一篇Keras-5 基于 ImageDataGenerator 的 Data Augmentation实现可以让大家大致明白什么是Data augmentation
  3. 为了训练,给深度网络喂数据,我还需要写batch generator,就是用来生成一个batch的东西

一般,我们对整个训练过程有两种方案
1. 构建数据库的部分是独立,也就是说我们对找到的图像做预处理,将预处理的结果保存起来,这就算构建好训练的数据库了,然后训练时从这数据库里直接拿数据进行训练
2. 训练时实时地预处理一个batch的图像,将处理的结果作为训练的输入

第一种方法将训练集的构建和网络的训练分开,并且将预处理结果存在电脑中,这样做训练的代码会比较简单,且直接读入处理好的数据能让训练速度更快,当然,不足的地方就是不够灵活,如果预处理的方式改变了(例如,原本是RGB转HSV,现在我要RGB转Gray),那么需要重新构建一个数据库,造成硬盘空间的浪费

第二种方法虽然训练速度不如第一种,但是足够灵活,我们主要关注第二种方法。

在TensorFlow中,图像处理主要由tf.image模块支持,batch generator主要用tf.data.Dataset实现,下面我们来看看整个流程的具体实现

1 获取所有图片的路径

很明显,如果训练集很大,图片很多,我们无法一次读取所有图片进行训练,因此我们先找到所有图片的路径,在需要读取图片时再根据路径读取图片

import glob
# images_dir 下存放着需要预处理的图像
images_dir = '/home/public/butterfly/dataset_detection/JPEGImages/'# 查找图片文件, 根据具体数据集自由添加各种图片格式(jpg, jpeg, png, bmp等等)
images_paths = glob.glob(images_dir+'*.jpg')
images_paths += glob.glob(images_dir+'*.jpeg')
images_paths += glob.glob(images_dir+'*.png')
print('Find {} images, the first 10 image paths are:'.format(len(images_paths)))
for path in images_paths[:10]:print(path)
Find 717 images, the first 10 image paths are:
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001000.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000969.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000805.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000158.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001017.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001155.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_001404.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000202.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000568.jpg
/home/public/butterfly/dataset_detection/JPEGImages/IMG_000022.jpg
import numpy as np
# split training set and test data
test_split_factor = 0.2
n_test_path = int(len(images_paths)*test_split_factor)
# 转出numpy数据,方便使用
train_image_paths = np.asarray(images_paths[:-n_test_path])
test_image_paths = np.asarray(images_paths[-n_test_path:])
print('Number of train set is {}'.format(train_image_paths.shape[0]))
print('Number of test set is {}'.format(test_image_paths.shape[0]))
Number of train set is 574
Number of test set is 143

2. Batch Generator

我们将使用tf.data.Dataset来实现batch generator,这里借鉴了一篇博客 TensorFlow全新的数据读取方式:Dataset API入门教程。我们直接上代码,具体解释请看注释

def gaussian_noise_layer(input_image, std):noise = tf.random_normal(shape=tf.shape(input_image), mean=0.0, stddev=std, dtype=tf.float32)noise_image = tf.cast(input_image, tf.float32) + noisenoise_image = tf.clip_by_value(noise_image, 0, 1.0)return noise_imagedef parse_data(filename):'''导入数据,进行预处理,输出两张图像,分别是输入图像和目标图像(例如,在图像去噪中,输入的是一张带噪声图像,目标图像是无噪声图像)Args:filaneme, 图片的路径Returns:输入图像,目标图像'''# 读取图像image = tf.read_file(filename)# 解码图片image = tf.image.decode_image(image)# 数据预处理,或者数据增强,这一步根据需要自由发挥# 随机提取patchimage = tf.random_crop(image, size=(100,100, 3))# 数据增强,随机水平翻转图像image = tf.image.random_flip_left_right(image)# 图像归一化image = tf.cast(image, tf.float32) / 255.0# 加噪声n_image =gaussian_noise_layer(image, 0.5)return n_image, image
def train_generator(batchsize, shuffle=True):'''生成器,用于生产训练数据Args:batchsize,训练的batch sizeshuffle, 是否随机打乱batchReturns:训练需要的数据'''with tf.Session() as sess:# 创建数据库train_dataset = tf.data.Dataset().from_tensor_slices((train_image_paths))# 预处理数据train_dataset = train_dataset.map(parse_data)# 设置 batch sizetrain_dataset = train_dataset.batch(batchsize)# 无限重复数据train_dataset = train_dataset.repeat()# 洗牌,打乱if shuffle:train_dataset = train_dataset.shuffle(buffer_size=4)# 创建迭代器train_iterator = train_dataset.make_initializable_iterator()sess.run(train_iterator.initializer)train_batch = train_iterator.get_next()# 开始生成数据while True:try:x_batch, y_batch = sess.run(train_batch)yield (x_batch, y_batch)except:# 如果没有  train_dataset = train_dataset.repeat()# 数据遍历完就到end了,就会抛出异常train_iterator = train_dataset.make_initializable_iterator()sess.run(train_iterator.initializer)train_batch = train_iterator.get_next()x_batch, y_batch = sess.run(train_batch)yield (x_batch, y_batch)
import matplotlib.pyplot as plt
%matplotlib inline
#%config InlineBackend.figure_format='retina'# 显示图像
def view_samples(samples, nrows, ncols, figsize=(5,5)):fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize, sharey=True, sharex=True)for ax, img in zip(axes.flatten(), samples):ax.axis('off')ax.set_adjustable('box-forced')im = ax.imshow(img, aspect='equal')plt.subplots_adjust(wspace=0, hspace=0)plt.show()return fig, axes
# 测试一下我们的代码
train_gen = train_generator(16)iteration = 5
for i in range(iteration): noise_x, x = next(train_gen)_ = view_samples(noise_x, 4,4)_ = view_samples(x, 4, 4)

总结

TensorFlow提供了一整套图像预处理以及数据生成的机制,我们实现了一个简单的常用的数据处理框架,总结为三步
1. 获取所有图片的路径
2. 写好预处理的代码(parse_data)
3. 定义好数据生成器
基于以上的流程,稍微加以修改就能够应对大部分训练要求

Tensorflow图像处理以及数据读取相关推荐

  1. TensorFlow全新的数据读取方式:Dataset API入门教程

    Dataset API是TensorFlow 1.3版本中引入的一个新的模块,主要服务于数据读取,构建输入数据的pipeline. 此前,在TensorFlow中读取数据一般有两种方法: 1.使用pl ...

  2. tensorflow随笔-文件数据读取

    # -*- coding:utf-8 -*- import tensorflow as tf fn_queue=tf.train.string_input_producer(["winequ ...

  3. 『TensorFlow』数据读取类_data.Dataset

    一.资料 参考原文: TensorFlow全新的数据读取方式:Dataset API入门教程 API接口简介: TensorFlow的数据集 二.背景 注意,在TensorFlow 1.3中,Data ...

  4. linux 读取大量图片 内存,10 张图帮你搞定 TensorFlow 数据读取机制

    导读 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解 ...

  5. tensorflow 1.0 学习:十图详解tensorflow数据读取机制

    本文转自:https://zhuanlan.zhihu.com/p/27238630 在学习tensorflow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找 ...

  6. 十图详解TensorFlow数据读取机制(附代码)

    在学习TensorFlow的过程中,有很多小伙伴反映读取数据这一块很难理解.确实这一块官方的教程比较简略,网上也找不到什么合适的学习材料.今天这篇文章就以图片的形式,用最简单的语言,为大家详细解释一下 ...

  7. TF学习——TF数据读取:TensorFlow中数据读这三张图片的5个epoch +把读取的结果重新存到read 文件夹中

    TF学习--TF数据读取:TensorFlow中数据读这三张图片的5个epoch +把读取的结果重新存到read 文件夹中 目录 实验展示 代码实现 实验展示 代码实现 1.如果设置shuffle为T ...

  8. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  9. tensorflow数据读取机制

    原博客地址:https://zhuanlan.zhihu.com/p/27238630 代码地址:https://github.com/hzy46/Deep-Learning-21-Examples/ ...

最新文章

  1. Bresenham 生成直线
  2. python快速排序函数_python算法-快速排序
  3. 条款一:尽量使用const、inline而不是#define
  4. 如何通过adb命令将apk包导入到安卓设备上
  5. Android Webview 设置Cookie问题
  6. 人的声音是可以通过训练而改变的吗?
  7. excel转置怎么操作_EXCEL转置的方法介绍,这种函数80%的人没用过,教你如何转置996...
  8. 如何快速进行十进制二进制转换
  9. html5 答题源码脚本,自动答题脚本教程及源码分享(无视分辨率)
  10. 关于amd cpu超频 个人心得
  11. elementui实现横向时间轴_element ui step组件在另一侧加时间轴显示
  12. 新手网站制作教程:网站建设流程及步骤有哪些?
  13. vue 中父子组件传递通信(看图就会了,皮卡皮卡)
  14. APP代码打包成apk文件
  15. uo和o的区别和用法_韵母o和uo的区别
  16. 移动互联网时代电商如何突围?
  17. 16.火星文转换 C#
  18. 网络规划综合实验(思科模拟器)
  19. Flink SqlServer CDC 连接器的使用
  20. 2021年,定个小目标,排名1w以内

热门文章

  1. cocos creator粒子不变色_隐秘的物理粒子系统与渲染 !Cocos Creator LiquidFun !
  2. 计算机专业笔记本需要小键盘吗,笔记本电脑小键盘数字0不能用
  3. markdown 代码块背景色_markdown 绘图利器之graphviz
  4. 【论文阅读】基于未知传播模型的信息源检测问题 2017年AAAI国际先进人工智能协会
  5. 有什么办法可以判断页面是静态还是动态?_你知道seo到底是什么吗?该怎么优化?...
  6. 网页数据分页显示php,PHP网页设计例子:用PHP3完成MySQL数据的分页显示
  7. 动态二维数组外圈元素值的和_C语言 | 用指向元素的指针变量输出二维数组元素的值...
  8. redhat 6.4 mysql_redhat6.4 安装 MySQL 5.6.27
  9. python打开文件报错无效序列_解决Python 写文件报错TypeError的问题
  10. c语言结构体在内存中的存储,C语言结构体在内存中的存储情况探究------内存对齐...