Image Data Augmentation In Keras

讨论的内容包括

  • Data Augmentation
  • ImageDataGenerator 的使用方法
  • 在cifar-10数据集上使用Data Augmentation

完整代码在 这里 下载

Data Augmentation

Data Aumentation(数据扩充)指的是在使用以下或者其他方法增加数据输入量。这里,我们特指图像数据。

  • 旋转 | 反射变换(Rotation/reflection): 随机旋转图像一定角度; 改变图像内容的朝向;
  • 翻转变换(flip): 沿着水平或者垂直方向翻转图像;
  • 缩放变换(zoom): 按照一定的比例放大或者缩小图像;
  • 平移变换(shift): 在图像平面上对图像以一定方式进行平移;
  • 可以采用随机或人为定义的方式指定平移范围和平移步长, 沿水平或竖直方向进行平移. 改变图像内容的位置;
  • 尺度变换(scale): 对图像按照指定的尺度因子, 进行放大或缩小; 或者参照SIFT特征提取思想, 利用指定的尺度因子对图像滤波构造尺度空间. 改变图像内容的大小或模糊程度;
  • 对比度变换(contrast): 在图像的HSV颜色空间,改变饱和度S和V亮度分量,保持色调H不变. 对每个像素的S和V分量进行指数运算(指数因子在0.25到4之间), 增加光照变化;
  • 噪声扰动(noise): 对图像的每个像素RGB进行随机扰动, 常用的噪声模式是椒盐噪声和高斯噪声;

Data Aumentation 有很多好处,比如数据量太小了,我们用数据扩充来增加训练数据,或者通过Data Aumentation防止过拟合的问题。

在Keras中,ImageDataGenerator就是干这个事情的,特别方便。接下来,我们就聊聊ImageDataGenerator的使用方法

from keras.preprocessing.image import ImageDataGenerator
from keras.preprocessing import image
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
Using TensorFlow backend.

ImageDataGenerator for Single image

ImageDataGenerator 参数很多,详见这里或者在python环境下输入ImageDataGenerator?,我们先看一个例子,这个例子将对一张图片进行数据扩充

# 指定参数
# rotation_range 旋转
# width_shift_range 左右平移
# height_shift_range 上下平移
# zoom_range 随机放大或缩小
img_generator = ImageDataGenerator(rotation_range = 90,width_shift_range = 0.2,height_shift_range = 0.2,zoom_range = 0.3)
# 导入并显示图片
img_path = './imgs/dog.jpg'
img = image.load_img(img_path)
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7fd738246d30>

flow()将会返回一个生成器,这个生成器用来扩充数据,每次都会产生batch_size个样本。
因为目前我们只导入了一张图片,因此每次生成的图片都是基于这张图片而产生的,可以看到结果,旋转、位移、放大缩小,统统都有。

flow()可以将产生的图片进行保存,详见 深度学习中的Data Augmentation方法和代码实现

生成图片的过程大概是这样的,并且可以一直一直一直无限循环的生成

# 将图片转为数组
x = image.img_to_array(img)
# 扩充一个维度
x = np.expand_dims(x, axis=0)
# 生成图片
gen = img_generator.flow(x, batch_size=1)# 显示生成的图片
plt.figure()
for i in range(3):for j in range(3):x_batch = next(gen)idx = (3*i) + jplt.subplot(3, 3, idx+1)plt.imshow(x_batch[0]/256)
x_batch.shape
(1, 160, 240, 3)

ImageDataGenerator for Multiple image

单张图片的数据扩展我们已经演示完毕了,但是通常情况下,我们应该是有一个不太大的训练集需要Data Aumentation或者为了防止过拟合,总之,就是对一组数据进行Data Aumentation。这里我们以cifar-10数据库做一个演示。

我们将进行一组实验,比较训练之后的测试结果:

  • cifar-10 20%数据
  • cifar-10 20%数据 + Data Augmentation
from keras.datasets import cifar10
from keras.layers.core import Dense, Flatten, Activation, Dropout
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras.layers.normalization import BatchNormalization
from keras.models import Sequential
from keras.utils import np_utils
(x_train, y_train),(x_test, y_test) = cifar10.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
(50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)
def preprocess_data(x):x /= 255x -= 0.5x *= 2return x
# 预处理
x_train = x_train.astype(np.float32)
x_test = x_test.astype(np.float32)x_train = preprocess_data(x_train)
x_test = preprocess_data(x_test)# one-hot encoding
n_classes = 10
y_train = np_utils.to_categorical(y_train, n_classes)
y_test = np_utils.to_categorical(y_test, n_classes)
# 取 20% 的训练数据
x_train_part = x_train[:10000]
y_train_part = y_train[:10000]print(x_train_part.shape, y_train_part.shape)
(10000, 32, 32, 3) (10000, 10)
# 建立一个简单的卷积神经网络
def build_model():model = Sequential()model.add(Conv2D(64, (3,3), input_shape=(32,32,3)))model.add(Activation('relu'))model.add(BatchNormalization(scale=False, center=False))model.add(Conv2D(32, (3,3)))model.add(Activation('relu'))model.add(MaxPooling2D((2,2)))model.add(Dropout(0.2))model.add(BatchNormalization(scale=False, center=False))model.add(Flatten())model.add(Dense(256))model.add(Activation('relu'))model.add(Dropout(0.2))model.add(BatchNormalization())model.add(Dense(n_classes))model.add(Activation('softmax'))return model
# 训练参数
batch_size = 128
epochs = 20

cifar-10 20%数据

model = build_model()
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(x_train_part, y_train_part, epochs=epochs, batch_size=batch_size, verbose=1, validation_split=0.1)
Train on 9000 samples, validate on 1000 samples
Epoch 1/20
9000/9000 [==============================] - 8s 844us/step - loss: 1.8075 - acc: 0.4040 - val_loss: 2.9955 - val_acc: 0.1150
Epoch 2/20
9000/9000 [==============================] - 3s 343us/step - loss: 1.2029 - acc: 0.5742 - val_loss: 3.0341 - val_acc: 0.1910
Epoch 3/20
9000/9000 [==============================] - 3s 342us/step - loss: 0.9389 - acc: 0.6690 - val_loss: 2.8508 - val_acc: 0.1580
...........
...........
Epoch 18/20
9000/9000 [==============================] - 5s 597us/step - loss: 0.0668 - acc: 0.9824 - val_loss: 1.6110 - val_acc: 0.5840
Epoch 19/20
9000/9000 [==============================] - 6s 629us/step - loss: 0.0681 - acc: 0.9826 - val_loss: 1.5807 - val_acc: 0.5980
Epoch 20/20
9000/9000 [==============================] - 5s 607us/step - loss: 0.0597 - acc: 0.9847 - val_loss: 1.6222 - val_acc: 0.5930
loss, acc = model.evaluate(x_test, y_test, batch_size=32)
print('Loss: ', loss)
print('Accuracy: ', acc)
10000/10000 [==============================] - 4s 444us/step
Loss:  1.65560287151
Accuracy:  0.6058

经过20轮的训练之后,在训练集上已经有98%以上的准确率,但是在测试集上只有60%左右的准确率,可以说是过拟合了,主要原因就是训练集太小了,无法达到很好的效果。那么接下来我们试试经过Data Augmentation之后的准确率如何

cifar-10 20%数据 + Data Augmentation

在进行Data Augmentation时要注意的就是:生成的数据是有意义的。比如说对于某些医疗图像,如果进行了旋转,那么这个数据就属于采样错误,是没用的了。因此,在设置生成参数时要结合实际的情况。

# 设置生成参数
img_generator = ImageDataGenerator(rotation_range = 20,width_shift_range = 0.2,height_shift_range = 0.2,zoom_range = 0.2)

下面的代码是一种“手动”的训练方式,Progbar是进度条,用于显示训练进度。

另外一种“自动”的方法,请参考 官网给的例子 中 model.fit_generator的用法

from keras.utils import generic_utilsmodel_2 = build_model()
model_2.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])# Data Augmentation后,数据变多了,因此我们需要更的训练次数
for e in range(epochs*4):print('Epoch', e)print('Training...')progbar = generic_utils.Progbar(x_train_part.shape[0])batches = 0for x_batch, y_batch in img_generator.flow(x_train_part, y_train_part, batch_size=batch_size, shuffle=True):loss,train_acc = model_2.train_on_batch(x_batch, y_batch)batches += x_batch.shape[0]if batches > x_train_part.shape[0]:breakprogbar.add(x_batch.shape[0], values=[('train loss', loss),('train acc', train_acc)])
Epoch 0
Training...
10000/10000 [==============================] - 13s 1ms/step - train loss: 2.0455 - train acc: 0.3187
Epoch 1
Training...
10000/10000 [==============================] - 10s 1ms/step - train loss: 1.7304 - train acc: 0.3857
Epoch 2
Training...
10000/10000 [==============================] - 10s 1ms/step - train loss: 1.6195 - train acc: 0.4220
Epoch 3
Training...
10000/10000 [==============================] - 10s 1ms/step - train loss: 1.5595 - train acc: 0.4417
.........
.........
Epoch 76
Training...
10000/10000 [==============================] - 9s 874us/step - train loss: 0.8809 - train acc: 0.6890
Epoch 77
Training...
10000/10000 [==============================] - 9s 891us/step - train loss: 0.8776 - train acc: 0.6949
Epoch 78
Training...
10000/10000 [==============================] - 9s 892us/step - train loss: 0.8723 - train acc: 0.6916
Epoch 79
Training...
10000/10000 [==============================] - 9s 892us/step - train loss: 0.8737 - train acc: 0.6919
loss, acc = model_2.evaluate(x_test, y_test, batch_size=32)
print('Loss: ', loss)
print('Accuracy: ', acc)
10000/10000 [==============================] - 5s 455us/step
Loss:  0.842164948082
Accuracy:  0.7057

哇塞!经过Data Augmentation之后,测试的准去率已经有70%,提高了10%。并且我相信继续增加训练次数准确率将会继续上升。

好的,实验到此结束,这里只是给出一个简单的Data Augmentation实现方法,ImageDataGenerator这个类,还有其他有趣的功能我们还没有用到,有兴趣的同学可以在 这里 进行详细的阅读。网络上也有很多关于Data Augenmentation的讨论,希望可以帮助到大家。

  • 使用深度学习(CNN)算法进行图像识别工作时,有哪些data augmentation 的奇技淫巧?
  • The Effectiveness of Data Augmentation in Image Classification using Deep
    Learning
  • keras面向小数据集的图像分类(VGG-16基础上fine-tune)实现(附代码)

Keras-5 基于 ImageDataGenerator 的 Data Augmentation实现相关推荐

  1. Keras Data augmentation(数据扩充)

    在深度学习中,我们经常需要用到一些技巧(比如将图片进行旋转,翻转等)来进行data augmentation, 来减少过拟合. 在本文中,我们将主要介绍如何用深度学习框架keras来自动的进行data ...

  2. Keras Image Data Augmentation 各参数详解

    图像深度学习任务中,面对小数据集,我们往往需要利用Image Data Augmentation图像增广技术来扩充我们的数据集,而keras的内置ImageDataGenerator很好地帮我们实现图 ...

  3. tf torch keras 数据增强 data augmentation

    数据增强 data augmentation 2017年11月14日 22:19:27 阅读数:7964

  4. Data Augmentation

    转自:https://zhuanlan.zhihu.com/p/30197320 图像深度学习任务中,面对小数据集,我们往往需要利用Image Data Augmentation图像增广技术来扩充我们 ...

  5. 5, Data Augmentation

    Intro 这是深度学习第5课 在本课程结束时,您将能够使用数据增强. 这个技巧让你看起来拥有的数据远远超过实际拥有的数据,从而产生更好的模型. Lesson [1] from IPython.dis ...

  6. Keras学习| ImageDataGenerator的参数

    Keras ImageDataGenerator的参数 from keras.preprocessing.image import ImageDataGenerator keras.preproces ...

  7. 【方法】数据增强(Data Augmentation)

    在训练过程中,网络优化是一方面,数据集的优化又是另一方面.数据集会存在各类样本不均匀的情况,也就是各类样本的数量不一样,有的甚至差别很大.为了让模型具有更强的鲁棒性,采用Data Augmentati ...

  8. 【Keras】基于SegNet和U-Net的遥感图像语义分割

    from:[Keras]基于SegNet和U-Net的遥感图像语义分割 上两个月参加了个比赛,做的是对遥感高清图像做语义分割,美其名曰"天空之眼".这两周数据挖掘课期末projec ...

  9. Dataset之DA:数据增强(Data Augmentation)的简介、方法、案例应用之详细攻略

    Dataset之DA:数据增强(Data Augmentation)的简介.方法.案例应用之详细攻略 目录 DA的简介 DA的方法 DA的案例应用 DA的简介 数据集增强主要是为了减少网络的过拟合现象 ...

最新文章

  1. python编程语言是什么-python是什么编程语言
  2. 介绍几个好用的android自定义控件
  3. Flink从入门到精通100篇(三)-如何利用InfluxDB+Grafana搭建Flink on YARN作业监控大屏环境
  4. 李宏毅深度学习作业二
  5. 从JDK 12删除原始字符串文字
  6. dataframe修改列名_python dataframe操作大全数据预处理过程(dataframe、md5)
  7. kingbase自带的驱动在哪_为什么别人家的广告语都能自带BUG?
  8. 团队冲刺第二阶段-9
  9. python 仪表盘 ppt_Python强大的pyecharts绘画优美图形lt;三gt;
  10. 蓝桥杯2016年C/C++ 混搭
  11. java 读取excel wps_安装WPS引发的excel上传问题
  12. javaScript = == ===的区别
  13. 全面的SVM理论讲解
  14. THINKPHP获取路径
  15. sample_venc解析
  16. 快解析内网穿透,速度快 不限速 不限流
  17. linux操作系统是什么,操作系统概述
  18. 做好拼多多的几个小技巧-拼多多出评技巧
  19. 2017-2018 ACM-ICPC, Asia Daejeon Regional Contest
  20. 毕业工作五年的总结和感悟(上)

热门文章

  1. docker compose mysql_docker-compose部署MySQL
  2. webpack打包压缩混淆_细说webpack系列 3. webpack-cli 零配置打包
  3. JavaMail---简介
  4. 作为文本内容空格的HTML标签,HTML.fromHtml在文本末尾添加空格?
  5. mysql old key files_mysql出现“Incorrect key file for table”解决办法
  6. Linux中断线程化的优势,记一个实时Linux的中断线程化问题
  7. 福师计算机应用基础期末,福师2015计算机应用基础》期末试卷A123
  8. fpga烧写bin文件_Altera FPGA烧写步骤及注意事项_骏龙科技
  9. 创建此对象的程序是quation_MathType出现此对象创建于Equation中的问题怎么办
  10. php 全局变量能定义数组吗,php数组声明、遍历、数组全局变量使用小结