有几种方法可以将 Focal Loss 合并到多类分类器中。这是其中之一。

动机

许多现实世界的分类问题都有不平衡的类分布。当数据严重不平衡时,分类算法将开始做出有利于多数类的预测。有几种方法可以解决类别不平衡问题。

一种方法是分配与类频率成反比的样本权重,以增加较少频率类在损失函数中的贡献。

另一种方法是使用过采样/欠采样技术。为少数类生成人工样本的流行技术是合成少数类过采样技术 (SMOTE) 和自适应合成 (ADASYN),两者都包含在 imblearn Python 库中。

建议Ë ntly,使用焦距损失的目标函数的提出。该技术被Tsung-Yi Lin 等人用于二元分类问题。[1]。

在这篇文章中,我将演示如何将 Focal Loss 合并到 LightGBM 分类器中以进行多类分类。代码可在GitHub上找到。

二元分类

对于二元分类问题(标签 0/1),Focal Loss 函数定义如下:

Eq.1 焦点损失函数

其中_pₜ_是真实标签的函数。对于二元分类,该函数定义为:

Eq.2 类别概率

其中 pₜ 是通过将 sigmoid 函数应用于原始边距_z 获得的:_

Eq.3 用于将原始边距 z 转换为类概率 p 的 Sigmoid 函数

Focal Loss 可以解释为一个二元交叉熵函数乘以一个调制因子 (1- pₜ )^ γ,这减少了易于分类的样本的贡献。加权因子_aₜ_平衡了调制因子。引用作者的话:“当 γ = 2 时,与 CE 相比,分类为 pt = 0.9 的示例的损失将降低 100 倍,而当 pt ≈ 0.968 时,其损失将降低 1000 倍”。减少易于分类的示例的丢失,可以让训练更多地关注难以分类的示例”。

在 Max Halford 的博客 [ 2 ] 中可以找到一篇关于将 Focal Loss 纳入二元 LigthGBM 分类器的优秀文章。

多类分类

有几种方法可以将 Focal Loss 合并到多类分类器中。形式上,调制和加权因子应该应用于分类交叉熵。这种方法需要提供关于原始边距_z_的多类损失的一阶和二阶导数。

另一种方法是使用 One-vs-the-rest (OvR),其中为每个类_C_训练一个二元分类器。来自_C_类的数据被视为正数,所有其他数据都被视为负数。在这篇文章中,我使用了 OvR 方法,重用了哈尔福德开发的二元分类器,没有做任何修改。

下面显示的类 OneVsRestLightGBMWithCustomizedLoss 封装了该方法:

import numpy as np
from joblib import Parallel, delayed
from sklearn.multiclass import _ConstantPredictor
from sklearn.preprocessing import LabelBinarizer
from scipy import special
import lightgbm as lgbclass OneVsRestLightGBMWithCustomizedLoss:def __init__(self, loss, n_jobs=3):self.loss = lossself.n_jobs = n_jobsdef fit(self, X, y, **fit_params):self.label_binarizer_ = LabelBinarizer(sparse_output=True)Y = self.label_binarizer_.fit_transform(y)Y = Y.tocsc()self.classes_ = self.label_binarizer_.classes_columns = (col.toarray().ravel() for col in Y.T)if 'eval_set' in fit_params:# use eval_set for early stoppingX_val, y_val = fit_params['eval_set'][0]Y_val = self.label_binarizer_.transform(y_val)Y_val = Y_val.tocsc()columns_val = (col.toarray().ravel() for col in Y_val.T)self.results_ = Parallel(n_jobs=self.n_jobs)(delayed(self._fit_binary)(X, column, X_val, column_val, **fit_params) fori, (column, column_val) inenumerate(zip(columns, columns_val)))else:# eval set not availableself.results_ = Parallel(n_jobs=self.n_jobs)(delayed(self._fit_binary)(X, column, None, None, **fit_params) for i, columnin enumerate(columns))return selfdef _fit_binary(self, X, y, X_val, y_val, **fit_params):unique_y = np.unique(y)init_score_value = self.loss.init_score(y)if len(unique_y) == 1:estimator = _ConstantPredictor().fit(X, unique_y)else:fit = lgb.Dataset(X, y, init_score=np.full_like(y, init_score_value, dtype=float))if 'eval_set' in fit_params:val = lgb.Dataset(X_val, y_val, init_score=np.full_like(y_val, init_score_value, dtype=float),reference=fit)estimator = lgb.train(params=fit_params,train_set=fit,valid_sets=(fit, val),valid_names=('fit', 'val'),early_stopping_rounds=10,fobj=self.loss.lgb_obj,feval=self.loss.lgb_eval,verbose_eval=10)else:estimator = lgb.train(params=fit_params,train_set=fit,fobj=self.loss.lgb_obj,feval=self.loss.lgb_eval,verbose_eval=10)return estimator, init_score_valuedef predict(self, X):n_samples = X.shape[0]maxima = np.empty(n_samples, dtype=float)maxima.fill(-np.inf)argmaxima = np.zeros(n_samples, dtype=int)for i, (e, init_score) in enumerate(self.results_):margins = e.predict(X, raw_score=True)prob = special.expit(margins + init_score)np.maximum(maxima, prob, out=maxima)argmaxima[maxima == prob] = ireturn argmaximadef predict_proba(self, X):y = np.zeros((X.shape[0], len(self.results_)))for i, (e, init_score) in enumerate(self.results_):margins = e.predict(X, raw_score=True)y[:, i] = special.expit(margins + init_score)y /= np.sum(y, axis=1)[:, np.newaxis]return y

该类重新实现了 sklearn.multiclass 命名空间的 OneVsRestClassifier 类。重新实现原始 OneVsRestClassifier 类的主要原因是能够将附加参数转发给 fit 方法。当用户定义的评估函数没有改进时,这可用于传递评估集 (eval_set) 以停止训练,从而减少计算时间并避免过度拟合。

此外,该类使用通用 LightGBM 训练 API,这是在处理原始边距_z_和自定义损失函数时获得有意义的结果所必需的(有关更多详细信息,请参阅 [ 2 ])。如果没有这些限制,就可以更通用地实现该类,不仅可以接受任何损失函数,还可以接受任何实现 Scikit Learn 模型接口的模型。

该类的其他方法是 Scikit Learn 模型接口的一部分:fit、predict 和 predict_proba。在 predict 和 predict_proba 方法中,基本估计器返回原始边距_z_。请注意,当使用自定义损失函数时,LightGBM 返回原始边距_z_。类概率是使用 sigmoid 函数从边缘计算的,如等式所示。3.

一个例子

让我们首先创建一个包含 3 个类的人工不平衡数据集,其中 1% 的样本属于第一类,1% 属于第二类,98% 属于第三类。像往常一样,数据集被分为训练集和测试集。

X, y = make_classification(n_classes=3,n_samples=2000, n_features=2,n_informative=2,n_redundant =0,n_clusters_per_class=1,weights=[.01, .01, .98], flip_y=.01, random_state=42)le = preprocessing.LabelEncoder()
y_label = le.fit_transform(y)
X_train, X_test, y_train, y_test = train_test_split(X, y_label, test_size=0.30, random_state=42)classes =[]
labeles=np.unique(y_label)
for v in labeles:classes.append('Class '+ str(v))
print(classes)

然后将模型拟合到列车数据上。为了保持实验简单,没有使用提前停止。得到的混淆矩阵如下所示:

图 1 使用 LGBMClassifier 在测试集上的混淆矩阵

对于第一个实验,在测试集上获得了 0.990 的准确度和 0.676 的召回值。使用 OneVsRestLightGBMWithCustomizedLoss 分类器和 Focal Loss 重复相同的实验。

from OneVsRestLightGBMWithCustomizedLoss import *
from FocalLoss import FocalLoss #get the FocalLoss implementation from Halford's blog# Instantiate Focal loss
loss = FocalLoss(alpha=0.75, gamma=2.0)# Not using early stopping
clf = OneVsRestLightGBMWithCustomizedLoss(loss=loss)
clf.fit(X_train, y_train)# Using early stopping
#fit_params = {'eval_set': [(X_test, y_test)]}
#clf.fit(X_train, y_train, **fit_params)y_test_pred = clf.predict(X_test)
pred_accuracy_score = accuracy_score(y_test, y_test_pred)
pred_recall_score = recall_score(y_test, y_test_pred, average='macro')
print('prediction accuracy', pred_accuracy_score,' recall ', pred_recall_score)cnf_matrix = confusion_matrix(y_test, y_test_pred, labels=labeles)
plot_confusion_matrix(cnf_matrix, classes=classes,normalize=True,  title='Confusion matrix')

从上面的代码可以看出,损失函数完全可以在分类器之外配置,可以注入到类构造函数中。可以通过向 fit 方法提供包含 eval_set 的字典来打开提前停止,如上面的注释行所示。对于第二个实验,产生的混淆矩阵如下所示:

图 2 使用 OneVsRestLightGBMWithCustomizedLoss 分类器和 Focal Loss 在测试集上的混淆矩阵

在这种情况下,获得了 0.995 的准确度和 0.838 的召回值,比使用默认对数损失的第一个实验有所改进。这个结果从混淆矩阵中也很明显,其中 0 类的假阳性和 1 类的假阴性显着减少。

结论

在这篇文章中,我展示了一种通过使用 One-vs-the-rest (OvR) 方法将 Focal Loss 纳入多类分类器的方法。

通过使用 Focal Loss,不需要样本权重平衡或人工添加新样本来减少不平衡。在人工生成的多类不平衡数据集上,使用 Focal loss 增加了召回值并消除了少数类中的一些误报和漏报。

该方法的有效性必须通过探索现实世界的数据集来确认,其中噪声和非信息特征预计会影响分类结果。

QQ学习群:1026993837 领资料
Focal Loss 和 LightGBM 多分类就为大家介绍到这里,欢迎学习《Python数据分析与机器学习项目实战》bye!

Focal Loss 和 LightGBM 多分类应用-python实现相关推荐

  1. 前景背景样本不均衡解决方案:Focal Loss,GHM与PISA(附python实现代码)

    参考文献:Imbalance Problems in Object Detection: A Review 1 定义 在前景-背景类别不平衡中,背景占有很大比例,而前景的比例过小,这类问题是不可避免的 ...

  2. 无痛涨点!大白话讲解 Generalized Focal Loss

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨李翔 来源丨https://zhuanlan.zhihu.c ...

  3. Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估

    Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估 前言 二分类 focal loss 多分类 focal loss 测试结果 二分类focal_loss结果 ...

  4. 目标检测 | RetinaNet:Focal Loss for Dense Object Detection

    论文分析了 one-stage 网络训练存在的类别不平衡问题,提出能根据 loss 大小自动调节权重的 focal loss,使得模型的训练更专注于困难样本.同时,基于 FPN 设计了 RetinaN ...

  5. 目标检测分类损失函数——Cross entropy、Focal loss

    一.Cross Entropy Loss 交叉熵是用来判定实际的输出与期望的输出的接近程度,刻画的是实际输出与期望输出的距离,也就是交叉熵的值越小,两个概率分布就越接近. 1. CE Loss的定义 ...

  6. 焦点损失函数 Focal Loss 与 GHM

    文章来自公众号[机器学习炼丹术] 1 focal loss的概述 焦点损失函数 Focal Loss(2017年何凯明大佬的论文)被提出用于密集物体检测任务. 当然,在目标检测中,可能待检测物体有10 ...

  7. 论文翻译-SAFL A Self-Attention Scene Text Recognizer with Focal Loss

    论文翻译-SAFL A Self-Attention Scene Text Recognizer with Focal Loss 原文地址:https://ieeexplore.ieee.org/do ...

  8. 一、Focal Loss理论及代码实现

    文章目录 前言 一.基本理论 二.实现 1.公式 2.代码实现 1.基于二分类交叉熵实现. 2.知乎大佬的实现 前言 本文参考:几时见得清梦博主文章 参考原文:https://www.jianshu. ...

  9. 剖析Focal Loss损失函数: 消除类别不平衡+挖掘难分样本 | CSDN博文精选

    作者 | 图像所浩南哥 来源 | CSDN博客 论文名称:< Focal Loss for Dense Object Detection > 论文下载:https://arxiv.org/ ...

  10. NeurIPS 2020 | Focal Loss改进版来了!GFocal Loss:良心技术,无Cost涨点!

    本文作者:李翔 https://zhuanlan.zhihu.com/p/147691786 本文仅供学习参考,如有侵权,请联系删除! 论文地址:https://arxiv.org/abs/2006. ...

最新文章

  1. jfinal 动态切换orm 映射
  2. jquery easyUI分页dataGrid-Json
  3. Horizon View 6-客户端连接虚拟桌面⑹
  4. java16下载_java lombok下载
  5. 2022年Python数据分析的宝藏地带
  6. 10个高效Linux技巧及Vim命令对比
  7. 64位系统下,一个32位的程序究竟可以申请到多少内存,4GB还是更多
  8. rem是如何实现自适应布局的?
  9. 女子15000元网购兰基博尼跑车,上路就被查了,这操作真没见过!
  10. 练习题︱ python 协同过滤ALS模型实现:商品推荐 + 用户人群放大
  11. SAP MM模块之批次管理
  12. nextjs的发布,pm2发布nextjs项目
  13. SPFA算法模板(刘汝佳版)--Wormholes POJ - 3259
  14. jQuery的五种初始化加载写法
  15. 高德地图纠偏 php,驾车轨迹纠偏-轨迹纠偏-示例中心-JS API 示例 | 高德地图API
  16. 金山词霸2009牛津版下载地址
  17. DC学院数据分析师(入门)学习笔记----高级爬虫技巧
  18. 全球与中国pH控制剂市场现状及未来发展趋势(2022)
  19. WPS word表格中的神秘的底色
  20. 5G基带芯片之战现状:一二三分别对应联发科华为高通

热门文章

  1. javaScript原型链继承
  2. java 7 学习笔记_Java学习笔记7
  3. Nodejs ---- 升级到指定版本
  4. python两个函数中传递变量_如何在Python中向函数传递大量变量?
  5. 微信小程序 - 贝塞尔曲线(购物车效果)
  6. 总结命令----tar
  7. 如何在论坛里提高自己的从业水平
  8. apache的虚拟目录配置
  9. layui和js实现二级联动
  10. 局域网邮件服务器搭建地址薄更新,搭建局域网邮件服务器