利用keras进行图像增强
在深度学习中,数据短缺是我们经常面临的一个问题,虽然现在有不少公开数据集,但跟大公司掌握的海量数据集相比,数量上仍然偏少,而某些特定领域的数据采集更是非常困难。根据之前的学习可知,数据量少带来的最直接影响就是过拟合。那有没有办法在现有少量数据基础上,降低或解决过拟合问题呢?
答案是有的,就是数据增强技术。我们可以对现有的数据,如图片数据进行平移、翻转、旋转、缩放、亮度增强等操作,以生成新的图片来参与训练或测试。这种操作可以将图片数量提升数倍,由此大大降低了过拟合的可能。本文将详解图像增强技术在Keras中的原理和应用。
https://blog.csdn.net/jacke121/article/details/79245732
相关参数描述:http://keras-cn.readthedocs.io/en/latest/preprocessing/image/
其中validation_split参数(官方上使用方法未描述):设置训练集与验证集的比例。
要与flow_from_directory或flow函数配合。在函数中subset参数中设置为'training' 或者 'validation',生成对应的数据集。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
|
一、Keras中的ImageDataGenerator类
图像增强的官网地址是:https://keras.io/preprocessing/image/ ,API使用相对简单,功能也很强大。
先介绍的是ImageDataGenerator类,这个类定义了图片该如何进行增强操作,其API及参数定义如下:
keras.preprocessing.image.ImageDataGenerator(featurewise_center=False, #输入值按照均值为0进行处理samplewise_center=False, #每个样本的均值按0处理featurewise_std_normalization=False, #输入值按照标准正态化处理samplewise_std_normalization=False, #每个样本按照标准正态化处理 zca_whitening=False, # 是否开启增白zca_epsilon=1e-06, rotation_range=0, #图像随机旋转一定角度,最大旋转角度为设定值width_shift_range=0.0, #图像随机水平平移,最大平移值为设定值。若值为小于1的float值,则可认为是按比例平移,若大于1,则平移的是像素;若值为整型,平移的也是像素;假设像素为2.0,则移动范围为[-1,1]之间height_shift_range=0.0, #图像随机垂直平移,同上brightness_range=None, # 图像随机亮度增强,给定一个含两个float值的list,亮度值取自上下限值间shear_range=0.0, # 图像随机修剪zoom_range=0.0, # 图像随机变焦 channel_shift_range=0.0, fill_mode='nearest', #填充模式,默认为最近原则,比如一张图片向右平移,那么最左侧部分会被临近的图案覆盖cval=0.0, horizontal_flip=False, #图像随机水平翻转vertical_flip=False, #图像随机垂直翻转rescale=None, #缩放尺寸preprocessing_function=None, data_format=None, validation_split=0.0, dtype=None)
下文将以mnist和花类的数据集进行图片操作,其中花类(17种花,共1360张图片)数据集可见我的百度网盘: https://pan.baidu.com/s/1YDA_VOBlJSQEijcCoGC60w 。让我们以直观地方式看看各参数能带来什么样的图片变化。
随机旋转
我们可用mnist数据集对图片进行随机旋转,旋转的最大角度由参数定义。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as KK.set_image_dim_ordering('th')(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32')# 创建图像生成器,指定对图像操作的内容
datagen = ImageDataGenerator(rotation_range=90)
# 图像生成器要训练的数据
datagen.fit(train_data)# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):for i in range(0, 9):# 创建一个 3*3的九宫格,以显示图片pyplot.subplot(330 + 1 + i)pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))pyplot.show()break
生成结果为:
随机平移
我们可用花类数据集对图片进行随机平移,可以在垂直和水平方向上平移,平移最大值由参数定义。
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_imgIMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']# 创建图像生成器,指定对图像操作的内容,平移的最大比例为50%
train_datagen = ImageDataGenerator(width_shift_range=0.5, height_shift_range=0.5)# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):for i in range(0, 9):pyplot.subplot(330 + 1 + i)pyplot.imshow(array_to_img(X_batch[i]))pyplot.show()break
生成结果为:
可以观察到,图片除了实现平移外,其原来的位置都被最近的图案给填充,因为默认给的填充方式是nearest。
随机亮度调整
我们可用花类数据集对图片进行随机亮度调整,亮度范围由参数定义。
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_imgIMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']# 创建图像生成器,指定对图像操作的内容,亮度范围在0.1~10之间随机选择
train_datagen = ImageDataGenerator(brightness_range=[0.1, 10])# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):for i in range(0, 9):pyplot.subplot(330 + 1 + i)pyplot.imshow(array_to_img(X_batch[i]))pyplot.show()break
生成结果为:
随机焦距调整
我们可用mnist数据集对图片进行随机焦距调整,焦距调整值由参数定义。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as KK.set_image_dim_ordering('th')(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32')# 创建图像生成器,指定对图像操作的内容,焦距值在0.1~1之间
datagen = ImageDataGenerator(zoom_range=[0.1, 1])
# 图像生成器要训练的数据
datagen.fit(train_data)# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):for i in range(0, 9):# 创建一个 3*3的九宫格,以显示图片pyplot.subplot(330 + 1 + i)pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))pyplot.show()break
生成结果为:
可以看出这跟相机调焦一样,可以放大或缩小焦距。
随机翻转
我们可用花类数据集对图片进行随机翻转。
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras.preprocessing.image import array_to_imgIMAGE_SIZE = 224
NUM_CLASSES = 17
TRAIN_PATH = '/home/hutao/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/hutao/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']# 创建图像生成器,指定对图像操作的内容,图片随机翻转
train_datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True)# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for X_batch, y_batch in train_datagen.flow_from_directory(directory=TRAIN_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),batch_size=9, classes=FLOWER_CLASSES):for i in range(0, 9):pyplot.subplot(330 + 1 + i)pyplot.imshow(array_to_img(X_batch[i]))pyplot.show()break
生成结果为:
从上图可看出,有些图片水平翻转了,有些是垂直翻转了。
ZCA图像增白
说实在我不太清楚该技术有何用,用花类图片实验结果显示zca不支持,可以用mnist数据集来看看效果。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as KK.set_image_dim_ordering('th')(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32')# 创建图像生成器,指定对图像操作的内容,增白图片
datagen = ImageDataGenerator(zca_whitening=True)
# 图像生成器要训练的数据
datagen.fit(train_data)# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):for i in range(0, 9):# 创建一个 3*3的九宫格,以显示图片pyplot.subplot(330 + 1 + i)pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))pyplot.show()break
生成结果为:
特征标准化
特征标准化的含义是使图片的像素均值为0,标准差为1,不过我试了多次,直观效果不明显。
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import backend as KK.set_image_dim_ordering('th')(train_data, train_label), (test_data, test_label) = mnist.load_data()
train_data = train_data.reshape(train_data.shape[0], 1, 28, 28)
train_data = train_data.astype('float32')# 创建图像生成器,指定对图像操作的内容,允许图片标准化处理
datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)
# 图像生成器要训练的数据
datagen.fit(train_data)# 这是个图像生成迭代器,是可以无限生成各种新图片,我们指定每轮迭代只生成9张图片
for batch_data, batch_label in datagen.flow(train_data, train_label, batch_size=9):for i in range(0, 9):# 创建一个 3*3的九宫格,以显示图片pyplot.subplot(330 + 1 + i)pyplot.imshow(batch_data[i].reshape(28, 28), cmap=pyplot.get_cmap('gray'))pyplot.show()break
生成结果为:
就个人而言,我倾向于在图像增强中使用旋转、亮度调整、翻转和平移操作。
二、Keras如何进行图像增强数据训练
在之前的文章中我已经展现过数据增强的使用。在Keras中,增强图片有三种来源:
- 图片来源于已知数据集,如mnist、cifar,数据格式为numpy格式;
- 图片来源于我们自己搜集的图片,如本文引入的花类数据集,其图片为jpg、png等格式;
- 图片来源于panda数据集;
其中数据来源已知数据集,其操作方法如下:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)datagen = ImageDataGenerator(featurewise_center=True,featurewise_std_normalization=True,rotation_range=20,width_shift_range=0.2,height_shift_range=0.2,horizontal_flip=True)#生成器绑定训练集
datagen.fit(x_train)# 模型绑定生成器,并不停地迭代产生数据,可指定迭代次数,假设图片总数为1000张,batch默认为32,则每次迭代需要产生1000/32=32个步骤
history = model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),steps_per_epoch=len(x_train) / 32, epochs=epochs)
数据来源图片集,其操作方法如下:
batch_size = 32
# 迭代50次
epochs = 50
# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224
TRAIN_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/train'
TEST_PATH = '/home/yourname/Documents/tensorflow/images/17flowerclasses/test'
FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']# 使用数据增强
train_datagen = ImageDataGenerator(rotation_range=90)
# 可指定输出图片大小,因为深度学习要求训练图片大小保持一致
train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,target_size=(IMAGE_SIZE, IMAGE_SIZE),classes=FLOWER_CLASSES)
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(directory=TEST_PATH, target_size=(IMAGE_SIZE, IMAGE_SIZE),classes=FLOWER_CLASSES)# 运行模型
history = model.fit_generator(train_generator, epochs=epochs, validation_data=test_generator)
需要说明的是,这些增强图片都是在内存中实时批量迭代生成的,不是一次性被读入内存,这样可以极大地节约内存空间,加快处理速度。若想保留中间过程生成的增强图片,可以在上述方法中添加保存路径等参数,此处不再赘述。
三、结论
本文介绍了如何在Keras中使用图像增强技术,对图片可以进行各种操作,以生成数倍于原图片的增强图片集。这些数据集可帮助我们有效地对抗过拟合问题,更好地生成理想的模型。
https://www.cnblogs.com/hutao722/p/10075150.html
https://keras-cn.readthedocs.io/en/latest/preprocessing/image/#imagedatagenerator
https://www.colabug.com/3426054.html
https://blog.csdn.net/weixin_43790591/article/details/84455226
https://blog.csdn.net/qq_29133371/article/details/54927266
深度学习图片分类增强数据集的方法汇总
1.随机切割,图片翻转,旋转,等等很多手段都可以增加训练集,提高泛化能力.
2. Resampling 或者增加噪声等等,人工合成更多的样本.
3.对小样本数据进行仿射变换、切割、旋转、加噪等各种处理,可以生成更多样本.
4.用GAN生成数据提供给数据集.
5.找个Imagenet数据集上训练好的的模型,冻结最后一层或者最后几层,然后迁移学习+fine tuning,图片数量少,做一些翻转,变化,剪切,白化等等.
6.
第一种思路是数据增强,也就是用随机应对随机。既然狗子的位置在照片中不固定,那就将原始的图片随机的裁剪一下,旋转一下,将图像的颜色做一些微调,总之就是想象一个熊孩子打开ps修改了每张狗子的照片,给你留下了一堆看起来和原始的训练数据差不多的照片作为新的训练集
7.
水平翻转Flip
随机裁剪、平移变换Crops/Scales
颜色、光照变换
最为常用的是:像素颜色抖动、旋转、剪切、随机裁剪、水平翻转、镜头拉伸和镜头矫正等。
利用keras进行图像增强相关推荐
- Keras .ImageDataGenerator图像增强用法大全以及如何和模型结合起来(有代码)
图像增强 在做图像任务时,我们常常需要图像增强.今天来讲解下keras中的图像增强 ImageDataGenerator 官网 https://keras.io/api/preprocessing/i ...
- python训练好的图片验证_利用keras加载训练好的.H5文件,并实现预测图片
我就废话不多说了,直接上代码吧! import matplotlib matplotlib.use('Agg') import os from keras.models import load_mod ...
- 利用Keras构建自动编码器
利用Keras构建自动编码器 我们在这份学习指南中将回答有关自动编码器的一些常见问题,除此之外,我们也会给出下述模型的代码示例: 基于全连接层的简单自动编码器 稀疏自动编码器 深度全连接自动编码器 深 ...
- python 多分类模型优化_【Python与机器学习】:利用Keras进行多类分类
多类分类问题本质上可以分解为多个二分类问题,而解决二分类问题的方法有很多.这里我们利用Keras机器学习框架中的ANN(artificial neural network)来解决多分类问题.这里我们采 ...
- win10下基于anaconda利用keras开展16系显卡GTX1650的GPU神经网络计算
win10下基于anaconda利用keras开展16系显卡GTX1650的GPU神经网络计算 虽然安装了双系统,但ubantu的确是不太常用,所以还是尝试一下win10下的GPU神经网络计算.从实践 ...
- 利用keras搭建CNN完成图片分类
文章目录 一.简介 二.流程 1.数据处理 2.神经网络搭建 3.训练 4.预测 三.参考 一.简介 本文旨在通过一些简单的案例,学习如何通过keras搭建CNN.从数据读取,数据处理,神经网络搭建, ...
- 利用keras搭建神经网络拟合非线性函数
神经网络有着一个非常奇妙的结构,它的数学原理虽然相对简单,但是能做的事情却不少,数学家已经证明,具有2层(输入层除外)和非线性激活函数的神经网络,只要在这些层中有足够多的神经元,就可以近似任何函数(严 ...
- 利用keras破解captcha验证码
本文参考了知乎上的一篇文章,只做了少许改动,感觉挺好玩的,自己实现了一下,准确率比原作者的要高一些.如果想要了解原创文章的话,请移步知乎:使用深度学习来破解captcha验证码 本文通过keras深度 ...
- 基于Keras的生成对抗网络(1)——利用Keras搭建简单GAN生成手写体数字
目录 0.前言 一.GAN结构 二.函数代码 2.1 生成器Generator 2.2 判别器Discriminator 2.3 train函数 三.结果演示 四.完整代码 五.常见问题汇总 0.前言 ...
最新文章
- TLSNotary中心化预言机(1) TLS1.1协议
- docker image 实践之容器化 ganglia
- php phpanalysis2.0,使用phpAnalysis打造PHP应用非侵入式性能分析器
- python opencv输出mp4_10分钟学会使用YOLO及Opencv实现目标检测
- c语言查找字符串au,几个C语言词汇不懂,望老鸟们相助(俺是新手哦)
- 基于法律罪行知识图谱的智能预判与客服问答
- python 姓名用*替换_学会用python截取你的姓名
- Qt|设计模式工作笔记-对单例模式进一步的理解(静态加单例实现专门收发UDP对象)
- iphone查看html源码的app,使用扩展App在Safari上查看源代码
- Java IO _打印流
- 川大计算机学梡分数线,2017四川大学历年录取分数线
- JavaScript简单的数据总计怎么做?
- JZOJ2020年8月11日提高组T3 页
- Cookie的路径设置(很重要)
- 学计算机买电脑看什么,学长学姐很后悔,当初买电脑时就该看看这篇攻略!
- 渗透测试业务逻辑测试汇总—专项篇
- 【nginx http flv 】ATC追踪:播放器拉流的调用堆栈及时间戳打印1
- 教育界杂志教育界杂志社教育界编辑部2022年第10期目录
- C#之FIFO算法实现页面置换算法
- Cache;高速缓冲存储器