编辑:陈萍

损失函数是机器学习里最基础也是最为关键的一个要素,其用来评价模型的预测值和真实值不一样的程度。最为常见的损失函数包括平方损失、指数损失、log 对数损失等损失函数。这里回顾了一种新的损失函数,通过引入鲁棒性作为连续参数,该损失函数可以使围绕最小化损失的算法得以推广,其中损失的鲁棒性在训练过程中自动自我适应,从而提高了基于学习任务的性能。

这篇文章对 CVPR 2019 的一篇论文《A General and Adaptive Robust Loss Function》进行了回顾性综述,主要讲述了为机器学习问题开发鲁棒以及自适应的损失函数。论文作者为谷歌研究院的研究科学家 Jon Barron。

论文地址:https://arxiv.org/pdf/1701.03077.pdf

异常值(Outlier)与鲁棒损失

考虑到机器学习问题中最常用的误差之一——均方误差(Mean Squared Error, MSE),其形式为:(y-x)²。该损失函数的主要特征之一是:与小误差相比,对大误差的敏感性较高。并且,使用 MSE 训练出的模型将偏向于减少最大误差。例如,3 个单位的单一误差与 1 个单位的 9 个误差同等重要。

下图为使用 Scikit-Learn 创建的示例,演示了在有 / 无异常值影响的情况下,拟合是如何在一个简单数据集中变化的。

MSE 以及异常值的影响。

如上图所示,包含异常值的拟合线(fit line)受到异常值的较大影响,但是优化问题应要求模型受内点(inlier)的影响更大。在这一点上,你可能认为平均绝对误差(Mean Absolute Error, MAE)会优于 MSE,因为 MAE 对大误差的敏感性较低。也不尽然。目前有各种类型的鲁棒损失(如 MAE),对于特定问题,可能需要测试各种损失。

所以,这篇论文引入一个泛化的损失函数,其鲁棒性可以改变,并且可以在训练网络的同时训练这个超参数,以提升网络性能。与网格搜索(grid-search)交叉验证寻找最优损失函数相比,这种损失函数花费的时间更少。让我们从下面的几个定义开始讲解:

鲁棒性与自适应损失函数的一般形式:

公式 1:鲁棒性损失,其中α为超参数,用来控制鲁棒性。

α控制损失函数的鲁棒性。c 可以看作是一个尺度参数,在 x=0 邻域控制弯曲的尺度。由于α作为超参数,我们可以看到,对于不同的α值,损失函数有着相似的形式。

公式 2:不同α值对应不同的自适应性损失。

在α=0 和α=2 时,损失函数是未定义的,但利用极限可以实现近似。从α=2 到α=1,损失函数平稳地从 L2 损失过渡到 L1 损失。对于不同的α值,我们可以绘制不同的损失函数,如下图 2 所示。

导数对于优化损失函数非常重要。下面研究一下这个损失函数的一阶导数,我们知道,梯度优化涉及到导数。对于不同的α值,x 的导数如下所示。下图 2 还绘制了不同α的导数和损失函数。

公式 3:鲁棒损失(表达式 1)对于不同的α的值相对于 x 的导数

自适应损失及其导数

下图对于理解此损失函数及其导数非常重要。在下图 2 中,尺度参数 c 固定为 1.1。当 x = 6.6 时,可以将其视为 x = 6×c。可以得出以下有关损失及其导数的推论:

1. 当 x、α和 c>0 时,损失函数是光滑的,因此适合于基于梯度的优化;

2. 损失函数总是在原点为零,并且在 | x |>0 时单调增加。损失的单调性也可以与损失的对数进行比较;

3. 损失也随着α的增加而单调增加。此属性对于损失函数的鲁棒性很重要,因为可以从较高的α值开始,然后在优化过程中逐渐减小(平滑)以实现鲁棒的估计,从而避免局部最小值;

4. 当 | x |

5. 对于α= 2,导数始终与残差的大小成正比。通常,这是 MSE(L2)损失的特性;

6. 对于α=1(L1 损失),我们看到导数的幅度在 | x |>c 之外饱和至一个常数值(正好是 1/c)。这意味着残差的影响永远不会超过一个固定的量;

7. 对于αc 而减小。这意味着当残差增加时,它对梯度的影响较小,因此异常值在梯度下降过程中的影响较小。

图 2:损失函数及其导数与α的关系。

图 3:自适应损失函数(左)及其导数(右)的曲面图。

鲁棒损失的实现:Pytorch 和 Google Colab

关于鲁棒损失的理论掌握了,怎么实现呢?使用的代码在 Jon Barron 的 GitHub 项目「robust_loss_pytorch」中稍加修改。此外还创建了一个动画来描述随着迭代次数的增加,自适应损失如何找到最佳拟合线。

GitHub 地址:https://github.com/jonbarron/arom_loss_pytorch

不需要克隆存储库,我们可以使用 Colab 中的 pip 在本地安装它。

此外还创建了一个简单的线性数据集,包括正态分布的噪声和异常值。

首先,由于使用了 Pythorch 库,利用 torch 将 x, y 的 numpy 数组转换为张量。

其次,使用 pytorch 模块定义线性回归类,如下所示:

接下来,用线性回归模型拟合自创建的线性数据集,首先使用损失函数的一般形式。这里使用一个固定值α(α=2.0),它在整个优化过程中保持不变。正如在α=2.0 时看到的,损失函数等效 L2 损失,这对于包括异常值在内的问题不是最优的。对于优化,使用学习率为 0.01 的 Adam 优化器。

利用鲁棒损失函数的一般形式和固定α值,可以得到拟合线。原始数据、真直线(生成数据点时使用的具有相同斜率和偏差的线,排除异常值)和拟合线如下图 4 所示:

图 4:一般损失函数

损失函数的一般形式不允许α发生变化,因此必须手动微调α参数或执行网格搜索进行微调。

此外,正如上图所示,由于使用了 L2 损失,拟合受到异常值的影响。这是一般的情况,但如果使用损失函数的自适应版本,会发生什么呢?调用自适应损失模块,并初始化α,让α在每个迭代步骤中自适应。

此外,还有一些额外的代码使用 celluloid 模块,见下图 5。在这里,可以清楚地看到,随着迭代次数的增加,自适应损失如何找到最佳拟合线。这个结果接近真实的线,对于异常值的影响可以忽略不计。

图 5:自适应损失函数如何达到最佳拟合的动画。

未定义与 struct 类型的输入参数相对应的函数 fetch_引入鲁棒性作为连续参数,这种新损失函数实现了自适应、随时变换...相关推荐

  1. matlab中未定义与 ‘cell‘ 类型的输入参数相对应的运算符 ‘+‘ 的解决方案

    在函数文件中写入以下内容 function re=fun(a,b,varargin) if nargin == 2re=a+b; elseif nargin==3c=varargin(1);re = ...

  2. matlab 手工实现normalize函数 未定义与 ‘double‘ 类型的输入参数相对应的函数 ‘normalize‘

    matlab自带的normalize函数有时候总抽风不好使: 未定义与 'double' 类型的输入参数相对应的函数 'normalize' 不过考虑到这个东西本身也不难,无非就是这么个公式:Xi−μ ...

  3. matlab没有int函数,matlab 未定义与 'char' 类型的输入参数相对应的函数 'int'。

    最后你那zd句plot(int(k),double(s));里的int(k);不对.你如果想以k为横坐标,直接把int去掉,如回果非想要变成整数可以用floor(k);注:fix:向零取整 floor ...

  4. matlab参数数目不足lorenz,求指导,MATLAB程序,老是提示这个“未定义与 'double' 类型的输入参数相对应的函数 'genfisl'。”错误,...

    答:不要帖图,否则看不出是咋回事,要程序可调试.利用, >> help genfisl genfisl not found. Use the Help browser search fie ...

  5. 未定义与 double 类型的输入参数相对应的函数 eval_点评一下鸿蒙os的时钟计算函数...

    鸿蒙os liteos-m版,是面向嵌入式的分支,看代码 arch 目录下,有 cortex m4 架构的支持的代码. cortex m4相对于其他mcu芯片的优势,支持浮点.dsp等运算,适合某些需 ...

  6. 用MATLAB仿真SCARA机器人,报错:未定义与 ‘char‘ 类型的输入参数相对应的函数 ‘Link‘

    在使用MATLAB仿真SCARA机器人的时候报错如下 是因为我们没有安装Robotics Tools工具箱,解决办法: https://blog.csdn.net/u011831805/article ...

  7. matlab中的uint8函数,未定义与 'uint8' 类型的输入参数相对应的函数 'fitnessty'

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 程序如下: clear all clc tic popsize=15; lanti=10; maxgen=50; cross_rate=0.4; muta ...

  8. matlab int8 函数,未定义与 'uint8' 类型的输入参数相对应的函数 'fitnessty'

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 程序如下: clear all clc tic popsize=15; lanti=10; maxgen=50; cross_rate=0.4; muta ...

  9. matlab boundaries和fchcode函数无法执行的解决办法 未定义与 'double' 类型的输入参数相对应的函数 'boundaries'

    在测试代码时发现,自己的matlab无法执行Freeman链码函数: boundaries和fchcode函数都无法正常运行: 需要在自己的工作目录中添加如下函数: boundaries   fchc ...

最新文章

  1. PHP executable not found. Install PHP 7 and add it to your PATH or set the php.executablePath settin
  2. 如何将txt文档插入sql2000数据库
  3. 数组、链表、Hash(转)
  4. mybatis的快速入门
  5. Circular Dance
  6. Android笔记 apk的反编译 | 更新于2017/7/25
  7. linux怎么升级python2.7,linux升级python到2.7
  8. centos6.5 nginx1.8 php mysql,CentOS6.5 源码安装Nginx1.8 + PHP7.0.6 + MySQL5.7.12
  9. SEO内容为王之如何创造伪原创
  10. Redis+Nginx+设计模式+Spring全家桶+Dubbo+阿里P8技术精选文档
  11. 广播的接收与U盘广播
  12. 高级商务办公软件应用【10】
  13. 数据库中的日期相减_sql中两个日期相减
  14. mybatis小结(1)
  15. 探究腾讯云TCA和阿里acp的区别
  16. Mybatis-Plus用纯注解完成一对多多对多查询
  17. free_rtos系统基本配置
  18. c35是什么意思_混凝土c25、c30、c35分别代表什么意思?
  19. 【路径规划】基于matlab蚁群算法栅格地图路径规划及避障【含Matlab源码 2088期】
  20. 以太坊:事件日志 (Event Logs)

热门文章

  1. Selenium脚本编写技巧和窍门
  2. 服务器日志记录_5种改善服务器日志记录的技术
  3. cassandra使用心得_使用Cassandra和Nutch爬网
  4. 邪恶的Java技巧使JVM忘记检查异常
  5. hibernate批量查询_使用Hibernate批量获取
  6. Java命令行界面(第25部分):JCommando
  7. 如何使用Spring设置安全的REST API
  8. mapreduce文本排序_MapReduce:通过数据密集型文本处理
  9. ejb jsf jpa_完整的WebApplication JSF EJB JPA JAAS –第2部分
  10. gkz cloud sql_使用Cloud SQL的Google App Engine全文搜索