文章目录

  • 1. 提升方法AdaBoost算法
  • 2. AdaBoost算法训练误差分析
  • 3. AdaBoost算法的解释
  • 4. 提升树
  • 5. sklearn 实例

提升(boosting)方法是一种常用的统计学习方法,应用广泛且有效。

在分类问题中,它通过改变训练样本的权重,学习多个分类器,并将这些分类器进行线性组合,提高分类的性能

1. 提升方法AdaBoost算法

  • 思路:多个算法的判断结果综合
  • 弱学习方法容易获得,通过组合一系列弱学习方法,提升出来强学习方法
  • 大多数提升方法:改变训练数据的概率分布(权值分布)
  • 如何改变权值或概率分布:AdaBoost 的做法是,提高被前一轮弱分类器错误分类样本的权值,没有得到正确分类的数据,由于其权值的加大而受到后一轮的弱分类器的更大关注
  • 如何将弱分类器组合:AdaBoost 采取加权多数表决的方法。
    加大分类误差率小的弱分类器的权值,使其在表决中起较大的作用;减小分类误差率大的弱分类器的权值,使其在表决中起较小的作用。

AdaBoost 模型是弱分类器的线性组合:

f ( x ) = ∑ m = 1 M α m G m ( x ) f(x)=\sum_{m=1}^{M} \alpha_{m} G_{m}(x) f(x)=m=1∑M​αm​Gm​(x)

  • AdaBoost 算法每次迭代中,提高前一轮分类器错误分类数据的权值,降低被正确分类的数据的权值。
  • AdaBoost 将基本分类器的线性组合作为强分类器,给分类误差率小的基本分类器大的权值,给分类误差率大的基本分类器小的权值。

算法步骤

1)给每个训练样本( x 1 , x 2 , … . , x N x_{1},x_{2},….,x_{N} x1​,x2​,….,xN​)分配权重,初始权重 w 1 i w_{1i} w1i​ 均为 1 N \frac{1}{N} N1​。

2)针对带有权值的样本进行训练,得到基本分类器 G m G_m Gm​(初始模型为 G 1 G1 G1)。

3)计算模型 G m G_m Gm​ 的误分率 e m = ∑ i = 1 N w m i I ( y i ≠ G m ( x i ) ) e_m=\sum\limits_{i=1}^N w_{mi} I(y_i\not= G_m(x_i)) em​=i=1∑N​wmi​I(yi​​=Gm​(xi​))

4)计算模型 G m G_m Gm​ 的系数 α m = 1 2 ln ⁡ 1 − e m e m \alpha_m=\frac{1}{2} \ln \frac{1-e_m}{e_m} αm​=21​lnem​1−em​​

5)根据误分率 e e e 和当前权重向量 w m w_m wm​ 更新权重向量 w m + 1 w_{m+1} wm+1​
w m + 1 , i = w m , i Z m exp ⁡ ( − α m y i G m ( x i ) ) w_{m+1,i} = \frac{w_{m,i}}{Z_m} \exp (-\alpha_my_iG_m(x_i)) wm+1,i​=Zm​wm,i​​exp(−αm​yi​Gm​(xi​))
Z m Z_m Zm​ 是规范化因子, Z m = ∑ i = 1 M w m , i exp ⁡ ( − α m y i G m ( x i ) ) Z_m = \sum\limits_{i=1}^M w_{m,i} \exp (-\alpha_my_iG_m(x_i)) Zm​=i=1∑M​wm,i​exp(−αm​yi​Gm​(xi​))

6)计算组合模型 f ( x ) = ∑ m = 1 M α m G m ( x i ) f(x)=\sum\limits_{m=1}^M \alpha_mG_m(x_i) f(x)=m=1∑M​αm​Gm​(xi​) 的误分率。

7)当组合模型的误分率或迭代次数低于一定阈值,停止迭代;否则,回到步骤 2)

2. AdaBoost算法训练误差分析

书上有定理证明,AdaBoost 算法能在学习的过程中,不断减少训练误差。

AdaBoost 具有适应性,即它能适应弱分类器各自的训练误差率。这也是它的名称(适应的提升)的由来,Ada是Adaptive的简写。

3. AdaBoost算法的解释

AdaBoost 算法的一个解释是该算法实际是前向分步算法的一个实现。

在这个方法里,模型是加法模型,损失函数是指数损失,算法是前向分步算法。

每一步中极小化损失函数

( β m , γ m ) = arg ⁡ min ⁡ β , γ ∑ i = 1 N L ( y i , f m − 1 ( x i ) + β b ( x i ; γ ) ) \left(\beta_{m}, \gamma_{m}\right)=\arg \min _{\beta, \gamma} \sum_{i=1}^{N} L\left(y_{i}, f_{m-1}\left(x_{i}\right)+\beta b\left(x_{i} ; \gamma\right)\right) (βm​,γm​)=argβ,γmin​i=1∑N​L(yi​,fm−1​(xi​)+βb(xi​;γ))

得到参数 β m , γ m \beta_{m}, \gamma_{m} βm​,γm​。

4. 提升树

提升树是以分类树或回归树为基本分类器的提升方法。提升树被认为是统计学习中最有效的方法之一。

提升方法实际采用加法模型(即基函数的线性组合)与前向分步算法。

决策树为基函数的提升方法称为提升树(boosting tree)。

5. sklearn 实例

sklearn.ensemble.AdaBoostClassifier

class sklearn.ensemble.AdaBoostClassifier(base_estimator=None,
n_estimators=50, learning_rate=1.0, algorithm='SAMME.R',random_state=None)
  • algorithm:这个参数只有AdaBoostClassifier有。有两种Adaboost分类算法,SAMME和SAMME.R。
    主要区别是弱学习器权重的度量,SAMME使用分类效果作为弱学习器权重,而SAMME.R使用预测概率大小来作为弱学习器权重。
    由于SAMME.R使用了概率度量的连续值,迭代一般比SAMME快,因此AdaBoostClassifier的默认算法algorithm的值也是SAMME.R。
    注意使用了SAMME.R, 则弱分类学习器参数base_estimator必须限制使用支持概率预测的分类器。SAMME算法则没有这个限制。

  • n_estimators: AdaBoostClassifier和AdaBoostRegressor都有,是弱学习器的最大迭代次数,或者说最大的弱学习器的个数。一般来说n_estimators太小,容易欠拟合,n_estimators太大,又容易过拟合,一般选择一个适中的数值。默认是50。

  • learning_rate: AdaBoostClassifier和AdaBoostRegressor都有,即每个弱学习器的权重缩减系数ν

  • base_estimator:AdaBoostClassifier和AdaBoostRegressor都有,即弱分类学习器或者弱回归学习器。常用的是CART决策树或者神经网络MLP。

参考:https://github.com/fengdu78/lihang-code

# -*- coding:utf-8 -*-
# @Python Version: 3.7
# @Time: 2020/3/24 23:13
# @Author: Michael Ming
# @Website: https://michael.blog.csdn.net/
# @File: AdaBoost.py
# @Reference: https://github.com/fengdu78/lihang-code
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import AdaBoostClassifier
from sklearn.datasets import load_irisdata = [[0, 1, 3, -1],[0, 3, 1, -1],[1, 2, 2, -1],[1, 1, 3, -1],[1, 2, 3, -1],[0, 1, 2, -1],[1, 1, 2, 1],[1, 1, 1, 1],[1, 3, 1, -1],[0, 2, 1, -1]]
data = pd.DataFrame(np.array(data))
X, y = data.iloc[:, 0:-1], data.iloc[:, -1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
clf = AdaBoostClassifier(n_estimators=50, learning_rate=0.5)
clf.fit(X_train, y_train)
print(clf.score(X_test, y_test))

提升方法(Boosting)相关推荐

  1. 提升方法(boosting)详解

    注:本文非笔者原创,原文转载自:http://www.sigvc.org/bbs/thread-727-1-1.html 提升方法是基于这样一种思想:对于一个复杂任务来说,将多个专家的判断进行适当的综 ...

  2. 提升方法boosting

    本文是<统计学习方法>李航著学习笔记.现在的数据科学比赛中用到的算法大杀器GBDT(gradient boosting decision tree)终于要出场了! 提升方法的基本思想是将学 ...

  3. 机器学习算法总结--提升方法

    参考自: <统计学习方法> 浅谈机器学习基础(上) Ensemble learning:Bagging,Random Forest,Boosting 简介 提升方法(boosting)是一 ...

  4. Boosting(提升方法)之GBDT

    一.GBDT的通俗理解 提升方法采用的是加法模型和前向分步算法来解决分类和回归问题,而以决策树作为基函数的提升方法称为提升树(boosting tree).GBDT(Gradient Boosting ...

  5. python 梯度提升树_梯度提升方法(Gradient Boosting)算法案例

    GradientBoost算法 python实现,该系列文章主要是对<统计学习方法>的实现. 完整的笔记和代码以上传到Github,地址为(觉得有用的话,欢迎Fork,请给作者个Star) ...

  6. 提升方法之AdaBoost算法

    提升方法之AdaBoost算法 作为非数学专业出身看到密密麻麻的数学公式刚开始真的是非常头疼.算法的物理逻辑的时候尚能理解,但是涉及到具体的数学公式实现就开始懵逼了:为什么要用这个公式,这个公式是怎么 ...

  7. 08_提升方法Boosting2_统计学习方法

    文章目录 三.GBDT算法 1.提升树算法 (1)算法三要素 (2)GBDT与AdaBoost区别 2.平方损失的提升树 3.梯度提升树(Gradient Boosting Decison Tree, ...

  8. 08_提升方法Boosting1_统计学习方法

    文章目录 一.Boosting和集成学习介绍 二.AdaBoost 1.AdaBoost算法 (1)AdaBoost算法的三要素 (2)AdaBoost模型定义 (3)AdaBoost损失函数定义 ( ...

  9. 机器学习之提升方法Adaboost算法

    文章目录 1.背景 2.基本原理 3.Adaboost算法 4.周志华老师Boosting25周年 5.Adaboost算法优缺点 6.Q&A 1.背景 集成学习(ensemble learn ...

最新文章

  1. 【iOS UIKit】UITableView属性及方法大全
  2. ceph 查看是否在恢复_Ceph的最佳实践
  3. [CF1082G]Petya and Graph
  4. java动态代理_Java 动态代理和依赖注入
  5. linux msleep 头文件,Linux延迟函数
  6. hdu5468 Puzzled Elena
  7. xsslabs靶机解题_web 攻击靶机解题过程
  8. 数据分析:复杂业务场景下,量化评估流程
  9. DBVisualizer 添加数据库JDBC驱动
  10. 调用 Windows 7 中英文混合朗读
  11. EditText属性设置
  12. python tkinter 实现图片格式批量转换小工具
  13. 无人机学习笔记 8 雷达工作波段划分
  14. 从键盘输入十个整数,统计非负数的个数,计算非负数的和
  15. AWS免费账号取消步骤
  16. FAT12模拟-C语言读取
  17. 浙大计算机城市学院联合培养,浙大城市学院 今日视点 城市学院2007届联合培养硕士研究生顺利毕业...
  18. 开到“十字路口”的共享汽车未来怎么走?
  19. WebIM 即时通信
  20. java一元二次方程程序设计实验报告_Java 组件及事件处理实训 实训2:编写一个窗体程序,用于计算一元二次方程...

热门文章

  1. javaSE<String和StringBuffer和StringBuider>day11
  2. threejs使用精灵图添加图片贴图
  3. 数据库—应用系统开发方法
  4. 荣联科技转型的一二三四五
  5. 什么是SSH 以及常见的ssh 功能
  6. Html Table 样式
  7. 【awk】awk 常用命令
  8. 安全技术与相关安全工具
  9. 智能财务报表OCR识别系统
  10. Python 中 concurrent.futures 模块使用说明