python 多分类算法_深入理解GBDT多分类算法
我的个人微信公众号:Microstrong
微信公众号ID:MicrostrongAI
微信公众号介绍:Microstrong(小强)同学主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容,分享在学习过程中的读书笔记!期待您的关注,欢迎一起学习交流进步!目录:
1. GBDT多分类算法
1.1 Softmax回归的对数损失函数
1.2 GBDT多分类原理
2. GBDT多分类算法实例
3. 手撕GBDT多分类算法
3.1 用Python3实现GBDT多分类算法
3.2 用sklearn实现GBDT多分类算法
4. 总结
5. Reference
本文的主要内容概览:
1. GBDT多分类算法
1.1 Softmax回归的对数损失函数
当使用逻辑回归处理多标签的分类问题时,如果一个样本只对应于一个标签,我们可以假设每个样本属于不同标签的概率服从于几何分布,使用多项逻辑回归(Softmax Regression)来进行分类:
其中,
为模型的参数,而
可以看作是对概率的归一化。一般来说,多项逻辑回归具有参数冗余的特点,即将
同时加减一个向量后预测结果不变,因为
,所以
。
假设从参数向量
中减去向量
,这时每一个
都变成了
。此时假设函数变成了以下公式:
从上式可以看出,从
中减去
完全不影响假设函数的预测结果,这表明前面的Softmax回归模型中存在冗余的参数。特别地,当类别数为2时,
利用参数冗余的特点,我们将所有的参数减去
,上式变为:
其中
。而整理后的式子与逻辑回归一致。因此,多项逻辑回归实际上是二分类逻辑回归在多标签分类下的一种拓展。
当存在样本可能属于多个标签的情况时,我们可以训练
个二分类的逻辑回归分类器。第
个分类器用以区分每个样本是否可以归为第
类,训练该分类器时,需要把标签重新整理为“第
类标签”与“非第
类标签”两类。通过这样的办法,我们就解决了每个样本可能拥有多个标签的情况。
在二分类的逻辑回归中,对输入样本
分类结果为类别1和0的概率可以写成下列形式:
其中,
是模型预测的概率值,
是样本对应的类标签。
将问题泛化为更一般的多分类情况:
由于连乘可能导致最终结果接近0的问题,一般对似然函数取对数的负数,变成最小化对数似然函数。
补充:交叉熵
假设
和
是关于样本集的两个分布,其中
是样本集的真实分布,
是样本集的估计分布,那么按照真实分布
来衡量识别一个样本所需要编码长度的期望(即,平均编码长度):
如果用估计分布
来表示真实分布
的平均编码长度,应为:
这是因为用
来编码的样本来自于真实分布
,所以期望值
中的概率是
。而
就是交叉熵。
可以看出,在多分类问题中,通过最大似然估计得到的对数似然损失函数与通过交叉熵得到的交叉熵损失函数在形式上相同。
1.2 GBDT多分类原理
将GBDT应用于二分类问题需要考虑逻辑回归模型,同理,对于GBDT多分类问题则需要考虑以下Softmax模型:
其中
是
个不同的CART回归树集成。每一轮的训练实际上是训练了
棵树去拟合softmax的每一个分支模型的负梯度。softmax模型的单样本损失函数为:
这里的
是样本label在k个类别上作one-hot编码之后的取值,只有一维为1,其余都是0。由以上表达式不难推导:
可见,这
棵树同样是拟合了样本的真实标签与预测概率之差,与GBDT二分类的过程非常类似。下图是Friedman在论文中对GBDT多分类给出的伪代码:
根据上面的伪代码具体到多分类这个任务上面来,我们假设总体样本共有
类。来了一个样本
,我们需要使用GBDT来判断
属于样本的哪一类。
第一步我们在训练的时候,是针对样本
每个可能的类都训练一个分类回归树。举例说明,目前样本有三类,也就是
,样本
属于第二类。那么针对该样本的分类标签,其实可以用一个三维向量
来表示。
表示样本不属于该类,
表示样本属于该类。由于样本已经属于第二类了,所以第二类对应的向量维度为
,其它位置为
。
针对样本有三类的情况,我们实质上在每轮训练的时候是同时训练三颗树。第一颗树针对样本
的第一类,输入为
。第二颗树输入针对样本
的第二类,输入为
。第三颗树针对样本
的第三类,输入为
。这里每颗树的训练过程其实就CART树的生成过程。在此我们参照CART生成树的步骤即可解出三颗树,以及三颗树对
类别的预测值
, 那么在此类训练中,我们仿照多分类的逻辑回归 ,使用Softmax 来产生概率,则属于类别
的概率为:
并且我们可以针对类别
求出残差
;类别
求出残差
;类别
求出残差
。
然后开始第二轮训练,针对第一类输入为
, 针对第二类输入为
,针对第三类输入为
。继续训练出三颗树。一直迭代M轮。每轮构建3颗树。
当
时,我们其实应该有三个式子:
当训练完以后,新来一个样本
,我们要预测该样本类别的时候,便可以有这三个式子产生三个值
。样本属于某个类别的概率为:
2. GBDT多分类算法实例
(1)数据集
(2)模型训练阶段
首先,由于我们需要转化3个二分类的问题,所以需要先做一步one-hot:参数设置:
学习率:learning_rate = 1
树的深度:max_depth = 2
迭代次数:n_trees = 5
首先对所有的样本,进行初始化
,就是各类别在总样本集中的占比,结果如下表。
注意:在Friedman论文里全部初始化为0,但在sklearn里是初始化先验概率(就是各类别的占比),这里我们用sklearn中的方法进行初始化。
1)对第一个类别
拟合第一颗树
。
首先,利用公式
计算概率。
其次,计算负梯度值,以
为例
:
同样地,计算其它样本可以有下表:
接着,寻找回归树的最佳划分节点。在GBDT的建树中,可以采用如MSE、MAE等作为分裂准则来确定分裂点。本文采用的分裂准则是MSE,具体计算过程如下。遍历所有特征的取值,将每个特征值依次作为分裂点,然后计算左子结点与右子结点上的MSE,寻找两者加和最小的一个。
比如,选择
作为分裂点时
。
左子结点上的集合的MSE为:
右子节点上的集合的MSE为:
比如选择
作为分裂点时
。
对所有特征计算完后可以发现,当选择
做为分裂点时,可以得到最小的MSE,
。
下图展示以
为分裂点的
拟合一颗回归树的示意图:
然后,我们的树满足了设置,还需要做一件事情,给这棵树的每个叶子节点分别赋一个参数
(也就是我们文章提到的
),来拟合残差。
最后,更新
可得下表:
至此第一个类别(类别0)的第一颗树拟合完毕,下面开始拟合第二个类别(类别1)的第一颗树。
2)对第二个类别
拟合第一颗树
。
首先,利用
计算概率。
其次,计算负梯度值,以
为例
:
同样地,计算其它样本可以有下表:
然后,以
为分裂点的
拟合一颗回归树,可计算得到叶子节点:
,
最后,更新
可得下表:
至此第二个类别(类别1)的第一颗树拟合完毕。然后再拟合第三个类别(类别2)的第一颗树,过程也是重复上述步骤,所以这里就不再重复了。在拟合完所有类别的第一颗树后就开始拟合第二颗树。反复进行,直到训练了M轮。
3. 手撕GBDT多分类算法
3.1 用Python3实现GBDT多分类算法
需要的Python库:
pandas、PIL、pydotplus、matplotlib
其中pydotplus库会自动调用Graphviz,所以需要去Graphviz官网下载graphviz-2.38.msi安装,再将安装目录下的bin添加到系统环境变量,最后重启计算机。
3.2 用sklearn实现GBDT多分类算法
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
'''调参:loss:损失函数。有deviance和exponential两种。deviance是采用对数似然,exponential是指数损失,后者相当于AdaBoost。n_estimators:最大弱学习器个数,默认是100,调参时要注意过拟合或欠拟合,一般和learning_rate一起考虑。learning_rate:步长,即每个弱学习器的权重缩减系数,默认为0.1,取值范围0-1,当取值为1时,相当于权重不缩减。较小的learning_rate相当于更多的迭代次数。subsample:子采样,默认为1,取值范围(0,1],当取值为1时,相当于没有采样。小于1时,即进行采样,按比例采样得到的样本去构建弱学习器。这样做可以防止过拟合,但是值不能太低,会造成高方差。init:初始化弱学习器。不使用的话就是第一轮迭代构建的弱学习器.如果没有先验的话就可以不用管由于GBDT使用CART回归决策树。以下参数用于调优弱学习器,主要都是为了防止过拟合max_feature:树分裂时考虑的最大特征数,默认为None,也就是考虑所有特征。可以取值有:log2,auto,sqrtmax_depth:CART最大深度,默认为Nonemin_sample_split:划分节点时需要保留的样本数。当某节点的样本数小于某个值时,就当做叶子节点,不允许再分裂。默认是2min_sample_leaf:叶子节点最少样本数。如果某个叶子节点数量少于某个值,会同它的兄弟节点一起被剪枝。默认是1min_weight_fraction_leaf:叶子节点最小的样本权重和。如果小于某个值,会同它的兄弟节点一起被剪枝。一般用于权重变化的样本。默认是0min_leaf_nodes:最大叶子节点数'''
gbdt = GradientBoostingClassifier(loss='deviance', learning_rate=1, n_estimators=5, subsample=1
, min_samples_split=2, min_samples_leaf=1, max_depth=2
, init=None, random_state=None, max_features=None
, verbose=0, max_leaf_nodes=None, warm_start=False
)
train_feat = np.array([[6],
[12],
[14],
[18],
[20],
[65],
[31],
[40],
[1],
[2],
[100],
[101],
[65],
[54],
])
train_label = np.array([[0], [0], [0], [0], [0], [1], [1], [1], [1], [1], [2], [2], [2], [2]]).ravel()
test_feat = np.array([[25]])
test_label = np.array([[0]])
print(train_feat.shape, train_label.shape, test_feat.shape, test_label.shape)
gbdt.fit(train_feat, train_label)
pred = gbdt.predict(test_feat)
print(pred, test_label)
4. 总结
在本文中,我们首先从Softmax回归引出GBDT的多分类算法原理;其次用实例来讲解GBDT的多分类算法;然后不仅用Python3实现GBDT多分类算法,还用sklearn实现GBDT多分类算法;最后简单的对本文做了一个总结。至此,GBDT用于解决回归任务、二分类任务和多分类任务就完整的深入理解了一遍。
5. Reference
【1】Friedman J H. Greedy function approximation: a gradient boosting machine[J]. Annals of statistics, 2001: 1189-1232.
【2】《推荐系统算法实践》,黄美灵著。
【3】《百面机器学习》,诸葛越主编、葫芦娃著。
【9】GBDT算法用于分类问题 - hunter7z的文章 - 知乎,地址:https://zhuanlan.zhihu.com/p/46445201
python 多分类算法_深入理解GBDT多分类算法相关推荐
- python 二分类的实例_深入理解GBDT二分类算法
我的个人微信公众号:Microstrong 微信公众号ID:MicrostrongAI 微信公众号介绍:Microstrong(小强)同学主要研究机器学习.深度学习.计算机视觉.智能对话系统相关内容, ...
- 深入理解GBDT二分类算法
我的个人微信公众号: Microstrong 微信公众号ID: MicrostrongAI 微信公众号介绍: Microstrong(小强)同学主要研究机器学习.深度学习.计算机视觉.智能对话系统相关 ...
- 深入理解GBDT多分类算法
目录: GBDT多分类算法 1.1 Softmax回归的对数损失函数 1.2 GBDT多分类原理 GBDT多分类算法实例 手撕GBDT多分类算法 3.1 用Python3实现GBDT多分类算法 3.2 ...
- Python_机器学习_算法_第1章_K-近邻算法
Python_机器学习_算法_第1章_K-近邻算法 文章目录 Python_机器学习_算法_第1章_K-近邻算法 K-近邻算法 学习目标 1.1 K-近邻算法简介 学习目标 1 什么是K-近邻算法 1 ...
- java jvm垃圾回收算法_深入理解JVM虚拟机2:JVM垃圾回收基本原理和算法
本文转自互联网,侵删 本系列文章将整理到我在GitHub上的<Java面试指南>仓库,更多精彩内容请到我的仓库里查看 喜欢的话麻烦点下Star哈 文章将同步到我的个人博客: www.how ...
- weka java 分类算法_使用Weka快速实践机器学习算法
[译者注]在当下人工智能火爆发展的局面,每时每刻都有新的技术在诞生,但如果你是一个新手,Weka或许能帮助你直观.快速的感受机器学习带来的解决问题的新思路. Weka使机器学习的应用变得简单.高效并且 ...
- xgboost算法_工业大数据:分析算法
一. 应用背景 大数据分析模型的研究可以分为3个层次,即描述分析(探索历史数据并描述发生了什么).预测分析(未来的概率和趋势)和规范分析(对未来的决策给出建议).工业大数据分析的理论和技术研究仍处于起 ...
- 错误录入 算法_如何使用验证错误率确定算法输出之间的关系
错误录入 算法 Monument (www.monument.ai) enables you to quickly apply algorithms to data in a no-code inte ...
- smoteenn算法_基于EasyEnsemble算法和SMOTE算法的不均衡数据分类方法与流程
本发明涉及不均衡数据二分类技术领域,尤其涉及一种基于EasyEnsemble算法和SMOTE算法的不均衡数据二分类方法. 背景技术: 数据不均衡指的是在一个样本数据集中,某一类的样本数远少于其他类的样 ...
最新文章
- 98页PPT,看懂阿里、小米、京东、美团的组织架构和战略变迁!
- 大年初一,今年的春晚你看了吗?
- linux xampp常见问题
- C# 判断两张图片是否一致,极快速
- SDUT_2119 数据结构实验之链表四:有序链表的归并
- rgb fusion检测不到显卡_买不到RX 6800XT就装不了机解不了馋?我看未必
- 腾讯极客挑战赛邀你“码上种树”
- JDK1.8源码下载及获取、导入IDEA阅读、配置JDK源码
- mysql 备份如何使用_如何使用命令来备份和还原MySQL数据库
- 销售的基本功(倾听、提问、聊天)
- 第二十九篇、UICollectionView瀑布流
- mysql之jdbc连接数据库和sql注入的问题
- KMP(看毛片)NEXT数组模板
- 机械硬盘和固态硬盘之间的区别
- Request method XXX not supported
- Kruise Rollout: 让所有应用负载都能使用渐进式交付
- 在Android手机上将Minecraft国际版地图存档导入中国版(亲测有效)
- html5如何获取音频资源6,【已解决】如何从喜马拉雅的页面中获取到mp3音频文件...
- 基于JAVA Frame的太阳系行星运转系统
- 【转】中国与华尔街不同的投行人生
热门文章
- 将xls表格文件转为xml文件以及json格式的文件,支持2010xls解析
- Excel——快速批量统一边框和字体格式
- 魅族Flyme 8终于来了,1000+优化升级体验,适配27款机型
- web网页设计实例作业HTML+CSS+JavaScript蔬菜水果商城购物设计毕业论文源码
- python将txt转为字符串_Python 玩转生僻字
- QC协议+华为FCP+三星AFC快充取电5V9V芯片FS2601应用
- python进阶——自动驾驶寻找车道
- 用js 中数学对象(Math对象)随机点击明星名字
- 新手入门应该懂的Linux 细节知识
- 360计算机报名支付不了,Win7电脑装不上360安全卫士怎么办?完美解决方法看这里...