英文题目:Intelligible Models for Classification and Regression
中文题目:可理解的分类和回归模型
论文地址:https://www.doc88.com/p-41099846725043.html
领域:模型可解释性,广义加性模型,机器学习
发表时间:2012
作者:Yin Lou,Rich Caruana(模型可解释性大佬),康耐尔大学,微软
出处:KDD
被引量:256
代码和数据:https://github.com/interpretml/interpret
阅读时间:220819

读后感

加性模型的准确性优于线性模型,差于梯度决策树和深度学习模型.它在模型精度和可解释性间取平衡.其核心原理是针对单个特征建立模型(可以是非线性模型),然后把这些模型加在一起形成最终模型.本文描述了具体实现方法.

介绍

复杂模型虽然预测精度高,但可解释性较差,因为很难判断单个特征在复杂模型中的贡献度.

本文目标是建立尽量准确且可解释的模型,让用户可以理解每个特征的贡献度.使用广义加性模型(GAMs)方法,其基本算法如下:

将 g 称为 link 函数,f 称为 shape 函数,g 和 f 可以是任何函数,比如非线性函数,对于单个特征建模 f,可以有很高的复杂度,但特很之间组合比较简单,只是简单的叠加关系.

比如下式就是一个加性模型的示例:
y = x 1 + x 2 2 + x 3 + l o g ( x 4 ) + e x p ( x 5 ) + 2 s i n ( x 6 ) + ϵ y=x_1+x^2_2+\sqrt{x^3}+log(x_4)+exp(x_5)+2sin(x_6)+\epsilon y=x1​+x22​+x3 ​+log(x4​)+exp(x5​)+2sin(x6​)+ϵ
对应的每个特征影响的图-1所示,可以分别看到每个特征对y的影响.

每个shape函数都梲以是非线性的,这也是加性模型效果优于线性模型的原因.表-1展示了各种模型的基本公式:

方法

设数据集中有N个实例,每个实例有n个特征{xi1…xin},标签为yi.目标是构建函数F(x),最小化损失函数L(y,F(x)).

具体实现方法涉及两个维度,对于单特征训练的shape模型,一般使用样条函数或者树模型(图-4中的纵向);对于shape模型的组合训练方法(图-4中的横向),即如何训练整体模型,则可选用最小二乘法,boosting和backfitting方法.

shape函数

文中提到的shape函数有样条函数和集成树函数,所有shape函数只涉及单个特征作为输入.

样条函数

样条是一种特殊的函数,由多项式分段定义.比如三次样条中的每一段都由三次多项式表示,且整体是一条光滑的曲线,三次多项式形如:
y = a i + b i x + c i x 2 + d i x 3 y=a_i+b_ix+c_ix^2+d_ix^3 y=ai​+bi​x+ci​x2+di​x3
文中的样条插值是设置维度为d的回归样条函数:

树和集成树模型

使用二叉树和集成二叉树方法,用叶节点数描述树的复杂度.树模型的每个分叉是对同一特征的不同值范围进行切分.支持的树包括:Single Tree,Bagged Trees,Bossted Trees,Bossted Baaged Trees.后面的实验中将首字体作为其方法的缩写.

训练整体模型

以用以下方法训练整体模型,用最小二乘训练样条函数,用梯度提升和回修训练树模型.

最小二乘法

最小二乘法可以很好的训练线性模型,这里将bk(x)看成特征,训练拟合参数Bk.另外,还加入了平滑系数 λ.实验中将该方法称为惩罚最小二乘,记作P-LS.对于逻辑回归问题,使用样条被简化为用不同的基拟合逻辑回归,方法称为惩罚迭代重加权最小二乘,记作 P-IRLS。

Boosting梯度提升法

在每一个迭代中,循环地依次训练所有特征,具体方法如下:

  • line 1: 将三个shape函数初值设为 0
  • line 2: 一共10次迭代:M=10
  • line 3: 遍历所有特征:假设一共三个特征n=3
  • line 4: 这里构造了一个数据集合R,对于所有实例 i=1…N,它的自变量是实例中是第j个特征xij,因变量是将每个实例i代入当前所有 f 后计算预测值,然后计算预测与真值y的残差.
  • line 5:学习Shape函数S,它利用第j个特征x,训练S(x)用于拟合R,之所以是Boosting,是因为它拟合的不是y本身,而是拟合残差
  • line 6:利用拟合的残差函数S调整更新第 j 个特征拟合函数 fj

Backfitting回修法

回修法是之前拟合加性模型的主流方法,它与梯度提升方法非常类似,差别在于第4行和第6行,在第4行,回修法的fk不包含它自己本身对应的第fj;而第6行,它直接用S替换fj.对比可以看出梯度提升拟合的是残差,而回修法拟合的是fj本身,因此,可能随着数据不同,回修的波动相对比较大,最终可能验证以收敛.

实验

图-3对比了梯度提升和回修方法对回归(a,b,c)和分类(d,e,f)的建模效果,可以看到,叶节点过多,在训练集中效果好,但在测试集上效果差,回修法效果相对不稳定.

图-5对比了使用不同Shape函数的效果,样条函数由于追求拟合曲线的平滑,在数据较少的位置拟合效果较差,这可能是由于样条过于平滑,学不出细节.相对来说树模型效果更好.

表-5展示了主实验结果,这里使用了6个回归的数据集,从实验结果的均值可以看到,复杂模型效果最好,加性模型中,BST-bagTRX,即梯度提升的Bagging树,X表示随机设置叶节点数.

比较有意思的是,在BST-bagTR类中,叶节点2-4,效果最好,这可能是由于叶节点太多可能造成过拟合.

图-6展示了回归中各个模型的偏差和方差,偏差描述了模型预测结果和实际y之间的差异,方差用于评价子学习器学出结果的一致性,以评价稳定性(常用交叉验证的方法测量方差).可以看到对于所有数据集,位置于中间偏左的梯度提升+树模型效果都最好.

扩展阅读

GA2M

Accurate Intelligible Models with Pairwise Interactions
同一作者写的另一篇基于GAM的优化,将基于单个特征的加性模型扩展为基于特征组合的加性模型.核心公式如下:

其核心在于如何选择和优化特征组合,实验证明在有些情况下比lightgbm更好.

论文阅读_广义加性模型_GAMs相关推荐

  1. 广义典型相关分析_数学建模/机器学习:广义加性模型(GAM)及其Python实现

    笔者做过国赛也做过美赛,其中一类典型问题就是分析相关性,从而进行预测或者其他操作.这类问题通常情况下属于比较常规的问题,一般通过matlab或SPSS分析相关性,得到一个较好的数值即可. 然而有的时候 ...

  2. python 广义线性模型_数学建模/机器学习:广义加性模型(GAM)及其Python实现

    笔者做过国赛也做过美赛,其中一类典型问题就是分析相关性,从而进行预测或者其他操作.这类问题通常情况下属于比较常规的问题,一般通过matlab或SPSS分析相关性,得到一个较好的数值即可. 然而有的时候 ...

  3. 广义加性模型和树模型

    广义加性模型 传统线性模型所面临的问题: 在现实生活中,变量的作用通常不是线性的. 广义加性模型是一种自由灵活的统计模型,它可以用来探测到非线性回归的影响.模型如下: E(Y|X1,...,Xp)=α ...

  4. R语言广义加性模型(GAMs:Generalized Additive Model)建模:数据加载、划分数据、并分别构建线性回归模型和广义线性加性模型GAMs、并比较线性模型和GAMs模型的性能

    R语言广义加性模型(GAMs:Generalized Additive Model)建模:数据加载.划分数据.并分别构建线性回归模型和广义线性加性模型GAMs.并比较线性模型和GAMs模型的性能 目录

  5. R语言广义加性模型GAMs:可视化每个变量的样条函数、样条函数与变量与目标变量之间的平滑曲线比较、并进行多变量的归一化比较、测试广义线性加性模型GAMs在测试集上的表现(防止过拟合)

    R语言广义加性模型GAMs:可视化每个变量的样条函数.样条函数与变量与目标变量之间的平滑曲线比较.并进行多变量的归一化比较.测试广义线性加性模型GAMs在测试集上的表现(防止过拟合) 目录

  6. R语言广义加性模型(generalized additive models,GAMs):使用广义线性加性模型GAMs构建logistic回归

    R语言广义加性模型(generalized additive models,GAMs):使用广义线性加性模型GAMs构建logistic回归 目录

  7. R语言广义加性模型GAMs分析温度、臭氧环境数据绘制偏回归图与偏残差图

    最近我们被客户要求撰写关于广义加性模型的研究报告,包括一些图形和统计输出. 视频:R语言广义相加模型(GAM)在电力负荷预测中的应用 拓端tecdat:R语言广义相加模型(GAM)在电力负荷预测中的应 ...

  8. R语言惩罚逻辑回归、线性判别分析LDA、广义加性模型GAM、多元自适应回归样条MARS、KNN、二次判别分析QDA、决策树、随机森林、支持向量机SVM分类优质劣质葡萄酒十折交叉验证和ROC可视化

    最近我们被客户要求撰写关于葡萄酒的研究报告,包括一些图形和统计输出. 介绍 数据包含有关葡萄牙"Vinho Verde"葡萄酒的信息.该数据集有1599个观测值和12个变量,分别是 ...

  9. R语言使用mgcv包的gam函数拟合广义加性模型回归模型:使用predict函数和训练好的模型进行预测推理、使用ggplot2可视化预测值和实际值的曲线进行对比分析

    R语言使用mgcv包的gam函数拟合广义加性模型回归模型:使用predict函数和训练好的模型进行预测推理.使用ggplot2可视化预测值和实际值的曲线进行对比分析 目录

最新文章

  1. 利用python安装opencv_Linux下安装OpenCV+Python支持
  2. MySQL 慢查询日志分析及可视化结果
  3. 1024-程序员节快乐!给大家发福利啦!以及向大家讲述节日由来
  4. 在Ubuntu为Android硬件抽象层(HAL)模块编写JNI方法提供Java访问硬件服务接口 6...
  5. mysql多数据源事务_多数据源一致性事务解决方案
  6. 酷狗笔试题:补齐左括号(栈)
  7. [转]Windows 性能监视器工具-perfmon
  8. VMware 下安装centos7,无法进入图形化界面
  9. Java的API帮助文档
  10. 沃特玛采集均衡模块_采集均衡模块以及电池管理系统_2016212573884_说明书_专利查询_专利网_钻瓜专利网...
  11. 北京业内网友见面会,及其他
  12. 数字格式化 * 有一个小数,123.45678 要求保留两位
  13. 1314520用计算机怎么算,表白公式数学公式抖音 抖音1314520怎么计算,快用计算器表白?...
  14. CentOS下删除和安装JDK
  15. 无源物联网的定义、特点和优势
  16. 低度酒的诸神之战,能分出胜负吗?
  17. FS2222可调过压过流芯片IC,40V耐压过压保护可调OVP可调OCP
  18. AI写作机器人-ai文章生成器在线
  19. 计算机组成原理源码,计算机组成原理源码两位乘课程设计报告.docx
  20. VUE指令大全(详解)

热门文章

  1. JVisualVM、Visual GC
  2. Eclipse使用指南
  3. 计算机复制教程,教你如何使用电脑复制粘贴快捷键
  4. 2017年终总结,毕业和工作
  5. net程序员面试题,基本上是基础概念题
  6. 思岚科技定位导航技术凸显 成为服务机器人企业首选品牌
  7. Intouch学习笔记—新建工程
  8. 什么是嵌入式?嵌入式开发怎么学
  9. Wayland协议解析 一 什么是Wayland
  10. 前端基础知识点-每天一个基本知识点(100+个前端小知识,你是否都知道?)