在《提高模型性能,你可以尝试这几招...》一文中,我们给出了几种提高模型性能的方法,但这篇文章是在训练数据集不变的前提下提出的优化方案。其实对于深度学习而言,数据量的多寡通常对模型性能的影响更大,所以扩充数据规模一般情况是一个非常有效的方法。

对于Google、Facebook来说,收集几百万张图片,训练超大规模的深度学习模型,自然不在话下。但是对于个人或者小型企业而言,收集现实世界的数据,特别是带标签的数据,将是一件非常费时费力的事。本文探讨一种技术,在现有数据集的基础上,进行数据增强(data augmentation),增加参与模型训练的数据量,从而提升模型的性能。

什么是数据增强

所谓数据增强,就是采用在原有数据上随机增加抖动和扰动,从而生成新的训练样本,新样本的标签和原始数据相同。这个也很好理解,对于一张标签为“狗”的图片,做一定的模糊、裁剪、变形等处理,并不会改变这张图片的类别。数据增强也不仅局限于图片分类应用,比如有如下图所示的数据,数据满足正态分布:

我们在数据集的基础上,增加一些扰动处理,数据分布如下:

数据就在原来的基础上增加了几倍,但整体上仍然满足正态分布。有人可能会说,这样的出来的模型不是没有原来精确了吗?考虑到现实世界的复杂性,我们采集到的数据很难完全满足正态分布,所以这样增加数据扰动,不仅不会降低模型的精确度,然而增强了泛化能力。

对于图片数据而言,能够做的数据增强的方法有很多,通常的方法是:

  • 平移
  • 旋转
  • 缩放
  • 裁剪
  • 切变(shearing)
  • 水平/垂直翻转
  • ...

上面几种方法,可能切变(shearing)比较难以理解,看一张图就明白了:

我们要亲自编写这些数据增强算法吗?通常不需要,比如keras就提供了批量处理图片变形的方法。

keras中的数据增强方法

keras中提供了ImageDataGenerator类,其构造方法如下:

ImageDataGenerator(featurewise_center=False,samplewise_center=False,featurewise_std_normalization = False,samplewise_std_normalization = False,zca_whitening = False,rotation_range = 0.,width_shift_range = 0.,height_shift_range = 0.,shear_range = 0.,zoom_range = 0.,channel_shift_range = 0.,fill_mode = 'nearest',cval = 0.0,horizontal_flip = False,vertical_flip = False,rescale = None,preprocessing_function = None,data_format = K.image_data_format(),
)
复制代码

参数很多,常用的参数有:

  • rotation_range: 控制随机的度数范围旋转。
  • width_shift_range和height_shift_range: 分别用于水平和垂直移位。
  • zoom_range: 根据[1 - zoom_range,1 + zoom_range]范围均匀将图像“放大”或“缩小”。
  • horizontal_flip:控制是否水平翻转。

完整的参数说明请参考keras文档。

下面一段代码将1张给定的图片扩充为10张,当然你还可以扩充更多:

image = load_img(args["image"])
image = img_to_array(image)
image = np.expand_dims(image, axis=0)aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1, height_shift_range=0.1,shear_range=0.2, zoom_range=0.2, horizontal_flip=True, fill_mode="nearest")aug.fit(image)imageGen = aug.flow(image, batch_size=1, save_to_dir=args["output"], save_prefix=args["prefix"],save_format="jpeg")total = 0
for image in imageGen:# increment out countertotal += 1if total == 10:break
复制代码

需要指出的是,上述代码的最后一个迭代是必须的,否在不会在output目录下生成图片,另外output目录必须存在,否则会出现一下错误:

Traceback (most recent call last):File "augmentation_demo.py", line 35, in <module>for image in imageGen:File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1526, in __next__return self.next(*args, **kwargs)File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1704, in nextreturn self._get_batches_of_transformed_samples(index_array)File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/keras_preprocessing/image.py", line 1681, in _get_batches_of_transformed_samplesimg.save(os.path.join(self.save_to_dir, fname))File "/data/ai/anaconda3/envs/keras/lib/python3.6/site-packages/PIL/Image.py", line 1947, in savefp = builtins.open(filename, "w+b")
FileNotFoundError: [Errno 2] No such file or directory: 'output/image_0_1091.jpeg'
复制代码

如下一张狗狗的图片:

经过数据增强技术处理之后,可以得到如下10张形态稍微不同的狗狗的图片,这相当于在原有数据集上增加了10倍的数据,其实我们还可以扩充得最多:

数据增强之后的比较

我们以MiniVGGNet模型为例,说明在其在17flowers数据集上进行训练的效果。17flowers是一个非常小的数据集,包含17中品类的花卉图案,每个品类包含80张图片,这对于深度学习而言,数据量实在是太小了。一般而言,要让深度学习模型有一定的精确度,每个类别的图片至少需要1000~5000张。这样的数据集可以很好的说明数据增强技术的必要性。

从网站上下载的17flowers数据,所有的图片都放在一个目录下,而我们通常训练时的目录结构为:

{类别名}/{图片文件}
复制代码

为此我写了一个organize_flowers17.py脚本。

在没有使用数据增强的情况下,在训练数据集和验证数据集上精度、损失随着训练轮次的变化曲线图:

可以看到,大约经过十几轮的训练,在训练数据集上的准确率很快就达到了接近100%,然而在验证数据集上的准确率却无法再上升,只能达到60%左右。这个图可以明显的看出模型出现了非常严重的过拟合。

如果采用数据增强技术呢?曲线图如下:

从图中可以看到,虽然在训练数据集上的准确率有所下降,但在验证数据集上的准确率有比较明显的提升,说明模型的泛化能力有所增强。

也许在我们看来,准确率从60%多增加到70%,只有10%的提升,并不是什么了不得的成绩。但要考虑到我们采用的数据集样本数量实在是太少,能够达到这样的提升已经是非常难得,在实际项目中,有时为了提升1%的准确率,都会花费不少的功夫。

总结

数据增强技术在一定程度上能够提高模型的泛化能力,减少过拟合,但在实际中,我们如果能够收集到更多真实的数据,还是要尽量使用真实数据。另外,数据增强只需应用于训练数据集,验证集上则不需要,毕竟我们希望在验证集上测试真实数据的准确。

以上实例均有完整的代码,点击阅读原文,跳转到我在github上建的示例代码。

另外,我在阅读《Deep Learning for Computer Vision with Python》这本书,在微信公众号后台回复“计算机视觉”关键字,可以免费下载这本书的电子版。

参考阅读

提高模型性能,你可以尝试这几招...

计算机视觉与深度学习,看这本书就够了

使用数据增强技术提升模型泛化能力相关推荐

  1. 【深度学习基础知识 - 25】提升模型泛化能力的方法

    提升模型泛化能力的方法 从数据角度上来说.可以通过数据增强.扩充训练集等方法提高泛化能力. 在训练策略上,可以增加每个batch size的大小,进而让模型每次迭代时见到更多数据,防止过拟合. 调整数 ...

  2. 深度学习——提升模型泛化能力的方法

    泛化能力指对同类型独立分布的新数据的预测结果是否符合我们的预期.我们常常用泛化能力来反应一个模型的好坏,将不同程度的泛化状态分为:欠拟合.拟合和过拟合. 提升模型的泛化能力可以从两个方面着手:数据集和 ...

  3. BatchFormer:有效提升数据稀缺场景的模型泛化能力|CVPR2022

    文 | 侯志@知乎(已授权) 源 | 极市平台 摘要 当前的深度神经网络尽管已经取得了巨大的成功,但仍然面临着来自于数据稀缺的各种挑战,比如数据不平衡,零样本分布,域适应等等. 当前已经有各种方法通过 ...

  4. 如何提升模型泛化能力

    2.正则化模型 3.增加模型深度 4.使用更多的数据.数据增强 5.提早结束训练 6.Droupout 7.Batch Normalize 希望 评论区增加

  5. 机器学习中模型泛化能力和过拟合现象(overfitting)的矛盾、以及其主要缓解方法正则化技术原理初探...

    1. 偏差与方差 - 机器学习算法泛化性能分析 在一个项目中,我们通过设计和训练得到了一个model,该model的泛化可能很好,也可能不尽如人意,其背后的决定因素是什么呢?或者说我们可以从哪些方面去 ...

  6. AutoAugment: Learning Augmentation Policies from Data(一种自动数据增强技术)

    谷歌大脑提出自动数据增强方法AutoAugment:可迁移至不同数据集 近日,来自谷歌大脑的研究者在 arXiv 上发表论文,提出一种自动搜索合适数据增强策略的方法 AutoAugment,该方法创建 ...

  7. gan 总结 数据增强_[NLP]聊一聊,预处理和数据增强技术

    在基于margin-loss的句子相似度这个项目中,为了验证想法,找不到开放数据集,因此自己从新浪爱问爬取了数据.自己爬的数据和学界开放的数据对比,数据显得非常脏.这里有三个含义:第一:数据不规范,比 ...

  8. 从难以普及的数据增强技术,看AI的性价比时代

    数据是AI训练的核心,这一点已经被确认再确认了.虽然数据驱动不是AI算法训练的唯一途径,但在产业中已经出现了很明显的趋势,那些数据丰富廉价的领域,就是会更容易孕育出AI技术.像是汉英之间的机器翻译能力 ...

  9. NLP中的数据增强技术综述

    NLP数据增强技术 1. 词汇替换 Lexical Substitution 基于词典的替换 Thesaurus-based substitution 基于词向量的替换 Word-Embeddings ...

最新文章

  1. JZOJ 5483. 【清华集训2017模拟11.26】简单路径
  2. jdk1.8 idea 项目报错spring验证不通过
  3. ASP.NET Core中HTTP管道和中间件的二三事
  4. equals null报错吗_轻轻松松教你搞定Java中的==和equals
  5. 局部遮荫光伏matlab,一种基于随机蛙跳全局搜索算法的局部阴影光伏阵列MPPT控制的制作方法...
  6. php oracle count,请教分析函数count
  7. jq post 表单提交文件_Power Query 中使用POST方法进行网络抓取的尝试
  8. ADO.NET调用存储过程
  9. openwrt刷回原厂固件_小米路由器4刷breed, pandavan,openwrt
  10. IE-LAB网络实验室:HCNP培训机构 HCIE培训中心 HCIE认证培训 HCNA培训 华为面试考试时需要注意什么
  11. 「 数学模型 」“灰色模型的研究步骤及五步建模思想”讲解
  12. 计算机网络培训心得PPT,ppt培训心得体会(精选3篇)
  13. iOS开发面试只需知道这些,技术基本通关!(网络篇)
  14. Englis - 英文字母和音标
  15. Stream流中常用的方法
  16. 阿童木录机固态硬盘MOV视频损坏修复
  17. pgpool-ii的安装与使用
  18. php对参数校验(名称、地址、掩码、日期、时间、端口)
  19. 获阿里云领投的数千万A轮融资,剑指混合云的ZStack还有更大的野心
  20. 【一起学UniGUI】--UniGUI的窗体和模块(7)

热门文章

  1. xxx.jar 中没有主清单属性
  2. C#用Zlib压缩或解压缩字节数组
  3. The mook jong 计数DP
  4. poj 2696 A Mysterious Function
  5. 承载辉煌历史 畅想无线未来
  6. 根文件系统构建(Buildroot 方式)
  7. 计算机专业运动会口号,运动会口号押韵有气势 计算机系霸气口号
  8. swig封装 c语言函数到python库,python swig 调用C/C++接口
  9. Integer对象范围(-128-127)之间(Integer. valueOf()方法)
  10. C语言二月天数计算,关于计算两个日期间天数的代码,大家来看看