背景

数据和特征决定了机器学习的上限,模型和算法只是不断逼近这个上限。

无论是做比赛还是做项目,都会遇到一个问题:类别不平衡。这与 数据分布不一致所带来的影响不太一样,前者会导致你的模型在训练过程中无法拟合所有类别的数据,也就是会弄混,后者则更倾向于导致模型泛华能力减弱。

举个例子,让你从一千张狗的图中找到放进去的一只猫,你看了一遍,由于狗的特征你观察的太多了,所以很难会及时分辨出哪只是猫(请忽略人的先验知识)。

下面给出两种解决办法:

1. 数据扩充

数据不平衡,某个类别的数据量太少,那就新增一些呗,简单直接。

但是,怎么增加?如果是实际项目且能够与数据源直接或方便接触的时候,就可以直接去采集新数据。如果是比赛,那就行不通了,最好的办法就是对数据做有效增强后进行扩充。

数据增强的手段:

  • 水平 / 竖直翻转

  • 90°,180°,270° 旋转

  • 翻转 + 旋转

  • 亮度,饱和度,对比度的随机变化

  • 随机裁剪(Random Crop)

  • 随机缩放(Random Resize)

  • 加模糊(Blurring)

  • 加高斯噪声(Gaussian Noise)

以上是我在实际过程中常用一些增强手段,但是除了前三种以外,其他的要慎重考虑。因为不同的任务场景下数据特征依赖不同,比如高斯噪声,在天池铝材缺陷检测竞赛中,如果高斯噪声增加不当,有些图片原本在采集的时候相机就对焦不准,导致工件难以看清,倘若再增加高斯模糊属性,基本就废了。

以前在做处理的时候,也是瞎凑一块,暴力堆数据,但是这样很容易导致噪声过大,从而影响模型效果。后来从 刘思聪大佬的竞赛分享中得到了启发(原文链接:Kaggle 求生:亚马逊热带雨林篇),以下是一些转移理解:

以下图为例

我们做数据增强一定要保证有效性,即不能跟原始数据特征差别太大也不能直接复制,旋转和翻转其实是保证了数据特征的旋转不变性能被模型学习到。就下面一张图而言,结合旋转和翻转,做了八次增强,效果如下:

即使我做了这么多次的旋转工作,模型能从第一张图中识别出雨林和河流,那理所当然从其他角度也能识别出。

在做旋转的时候,也有一个疑问,不做 90° 倍数的旋转不行吗?做 30° 倍数的旋转,最后得到的数据岂不是更多?

个人理解是这样的:一方面考虑存储和模型训练周期的影响,增益比太小,划不来;另一方面,我让模型从这八个角度去看一张图片理论来说已经把图片的旋转特征看了一遍了,这对深度学习模型而言已经足够了。

附上做旋转的代码:

from PIL import ImageEnhance

from PIL import Image

#原图

raw_image = Image.open("./raw_images/amazon.jpg")

#旋转90°倍数

rotate_90 = raw_image.rotate(90)

rotate_180 = raw_image.rotate(180)

rotate_270 = raw_image.rotate(270)

#旋转结合翻转

flip_vertical_raw = raw_image.transpose(Image.FLIP_TOP_BOTTOM)

flip_vertical_90 = rotate_90.transpose(Image.FLIP_TOP_BOTTOM)

flip_vertical_180 = rotate_180.transpose(Image.FLIP_TOP_BOTTOM)

flip_vertical_270 = rotate_270.transpose(Image.FLIP_TOP_BOTTOM)

#存储

flip_vertical_raw.save("./processed_images/flip_vertical_raw.jpg")

flip_vertical_90.save("./processed_images/flip_vertical_90.jpg")

flip_vertical_180.save("./processed_images/flip_vertical_180.jpg")

flip_vertical_270.save("./processed_images/flip_vertical_270.jpg")

raw_image.save("./processed_images/amazon.jpg")

rotate_90.save("./processed_images/rotate_90.jpg")

rotate_180.save("./processed_images/rotate_180.jpg")

rotate_270.save("./processed_images/rotate_270.jpg")

2. sampler

2.1 采样

如果说类别之间的差距过大,有效的数据增强方式肯定不能弥补这种严重的不平衡,这个时候就需要在模型训练过程中对采样过程进行处理了。常见的采样方式分为两种:过采样和欠采样,效果图如下 (图片来源见参考文献 2):

原理就是 “删图片” 和 “增加图片”,从而保证在训练过程中类别之间的数据量大致相同。所带来的影响如下

过采样:重复正比例数据,实际上没有为模型引入更多数据,过分强调正比例数据,会放大正比例噪音对模型的影响。

欠采样:丢弃大量数据,和过采样一样会存在过拟合的问题。

但总的来肯定是利大于弊

2.2 pytorch 权重采样

pytorch 在 DataLoader () 的时候可以传入 sampler ,这里只说一下加权采样

torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)

源码:

class WeightedRandomSampler(Sampler):

r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

Arguments:

weights (sequence) : a sequence of weights, not necessary summing up to one

num_samples (int): number of samples to draw

replacement (bool): if ``True``, samples are drawn with replacement.

If not, they are drawn without replacement, which means that when a

sample index is drawn for a row, it cannot be drawn again for that row.

"""

def __init__(self, weights, num_samples, replacement=True):

if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \

num_samples <= 0:

raise ValueError("num_samples should be a positive integeral "

"value, but got num_samples={}".format(num_samples))

if not isinstance(replacement, bool):

raise ValueError("replacement should be a boolean value, but got "

"replacement={}".format(replacement))

self.weights = torch.tensor(weights, dtype=torch.double)

self.num_samples = num_samples

self.replacement = replacement

def __iter__(self):

return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

def __len__(self):

return self.num_samples

使用方法:

import torch

from torch.utils.data import DataLoader,WeightedRandomSampler

from dataset import train_dataset

weights = torch.FloatTensor([1,2,2,4,4,1])

train_sampler = WeightedRandomSampler(weights,len(train_dataset),replacement=True)

train_sampler = DataLoader(train_dataset,sampler=sampler)

解释:

  • weights:指每一个类别在采样过程中得到权重大小(不要求综合为 1),权重越大的样本被选中的概率越大;

  • num_samples: 共选取的样本总数,待选取的样本数目一般小于全部的样本数目;

  • replacement :指定是否可以重复选取某一个样本,默认为 True,即允许在一个 epoch 中重复采样某一个数据。如果设为 False,则当某一类的样本被全部选取完,但其样本数目仍未达到 num_samples 时,sampler 将不会再从该类中选择数据,此时可能导致 weights 参数失效。

3. 损失函数加权

还有一种方法是在计算损失函数过程中,对每个类别的损失做加权,具体的方式如下

weights = torch.FloatTensor([1,1,8,8,4])

criterion = nn.BCEWithLogitsLoss(pos_weight=weights).cuda()

4. 其他方法

暂时没用到,如果有大佬有更好的办法,欢迎评论或联系我。

参考文献

[1] Kaggle 求生:亚马逊热带雨林篇

https://zhuanlan.zhihu.com/p/28084438

[2] Resampling strategies for imbalanced datasets

https://www.kaggle.com/rafjaa/resampling-strategies-for-imbalanced-datasets

[3] pytorch sampler 对数据进行采样

https://blog.csdn.net/TH_NUM/article/details/80877772

推荐阅读

超100亿中文数据,要造出中国自己的BERT!首个专为中文NLP打造的语言理解基准CLUE升级

nodejs 图片处理模块 rotate_如何针对数据不平衡做处理?相关推荐

  1. nodejs 图片处理模块 rotate_学会Pillow再也不用PS啦——Python图像处理库Pillow入门!...

    你在用什么软件进行图像处理呢?厌倦了鼠标和手指的拖拖点点,想不想用程序和代码进行图像的高效处理,Python作为简单高效又很强大的一门编程语言,对于图像的处理自然也是轻松拿下,听起来是不是很酷很极客, ...

  2. lightgbm 数据不平衡_不平衡数据下的机器学习(下)

    本文从不平衡学习的基础概念和问题定义出发,介绍了几类常见的不平衡学习算法和部分研究成果.总体来说,不平衡学习是一个很广阔的研究领域,但受笔者能力和篇幅的限制,本文仅对其中部分内容做了简单概述,有兴趣深 ...

  3. 三招提升数据不平衡模型的性能(附python代码)

    摘要: 本文的主要目标是处理数据不平衡问题.文中描述了用来克服数据不平衡问题的三种技术,分别是集成交叉验证.类别权重以及过大预测 . 对于深度学习而言,数据集非常重要,但在实际项目中,或多或少会碰见数 ...

  4. nodejs操作sqlserver数据_实例分析nodejs基于mssql模块连接sqlserver数据库的简单封装操作...

    本文主要介绍了nodejs基于mssql模块连接sqlserver数据库的简单封装操作,结合实例形式分析了nodejs中mssql模块的安装与操作sqlserver数据库相关使用技巧,需要的朋友可以参 ...

  5. nodejs图片总结

    nodejs图片总结 今天终于把朴灵老师写的<深入浅出Node.js>给学习完了, 这本书不是一本简单的Node入门书籍,它没有停留在Node介绍或者框架.库的使用层面上,而是从不同的视角 ...

  6. nodejs没有net模块_Node.js实战16:用http模块创建web服务器

    Nodejs的http模块,是基于net.server,经过c++二次封装,也是nodejs的核心模块. 功能比net.server更强,可解析和操作更多细节内容,如值.content-length. ...

  7. Nodejs的http模块

    一.http服务器 我们知道传统的HTTP服务器是由Aphche.Nginx.IIS之类的软件来搭建的,但是Nodejs并不需要,Nodejs提供了http模块,自身就可以用来构建服务器.例如,下面的 ...

  8. c++ 获取64位进程模块地址_针对银行木马BokBot核心模块的深入分析

    一.概述 BokBot恶意软件由LUNAR SPIDER恶意组织开发和运营,在2017年首次出现,CrowdStrike的Falcon Overwatch和Falcon Intelligenc团队对被 ...

  9. nodejs 向mongodB获取指定数目的数据

    nodejs 向mongodB获取指定数目的数据 原理:通过向nodejs服务器端发送请求,nodejs 收到请求向mongodB读取五条数据,在控制台中打印出来: nodejs段代码(新建一个文件s ...

最新文章

  1. GPT-4参数将达10兆!此表格预测全新语言模型参数将是GPT-3的57倍
  2. 让AI有道德!用AI的方式去发展AI
  3. mysql generator 命令_MyBatis Generator速查手册
  4. python json.loads()中文问题-python处理json数据中的中文
  5. 【剑指Offer】16重建二叉树
  6. java集合框架栈_自己实现集合框架(九):栈接口
  7. 在ABAP里模拟实现Java Spring的依赖注入
  8. aliyun折腾记录
  9. 攻略:需求评审怎样才能高效易懂?
  10. TFS多地办公时的处理
  11. AI大时代下,零基础进入人工智能领域该如何学习?
  12. idea修改回默认字体,设置 IntelliJ Idea 的中英文字体
  13. 根据列值删除Pandas中的DataFrame行
  14. OSError: Initializing from file failed
  15. dubbo内核简介(附部分源码解读)
  16. matlab自适应高斯滤波,[matlab] 自适应高斯滤波器在二维图像上的应用
  17. 2015年度APP分类
  18. 高盛发布区块链报告:从理论到实践(中文版)二
  19. 春秋航空航班查询API
  20. C语言中getchar()函数的用法

热门文章

  1. 清华大学2016年软件学院攻读工程硕士专业学位研究生培养方案
  2. 编写函数实现随机产生指定范围的整数的功能
  3. unity3d用鼠标拖动物体的一段代码
  4. java类和对象的基础(笔记)
  5. 开启Windows文件共享必须开启的两个服务
  6. GDAL源码剖析(五)之Python命令行程序
  7. linux安装了vnc服务器,Linux安装VNC服务及配置
  8. java matching_LeetCode第[44]题(Java):Wildcard Matching
  9. 配置跳转指定_http自动跳转https的配置方法
  10. Javascript - 面向对象