目录

一 主要特点:

1 类别变量编码-Order Target Statistics方法

2 文本型变量编码处理

3 类别型特征交叉-FM

4 无偏提升-Ordered Boosting

5 使用对称树作为基模型,加速运行

6 不一样的缺失值处理

二 原理详解

1 类别变量处理-Order Target Statistics方法

2 文本型变量处理

3 特征交叉-FM

4 无偏提升-Ordered Boosting

5 对称树

6 缺失值处理

三 官网示例

1 CatBoostRegressor:

2 CatBoostClassifier 含类别型变量

3 使用Pools数据类型

四 catboost参数详解及实战


对于类别型变量而言,xgb需要先自行编码、才能输入模型;lgb极大地简化了一步,只需要将相应的变量列转化为category、或指定类别型变量名即可输入模型;catboost进一步处理,不仅嵌入了对类别型变量的处理,并附带类别型特征交叉功能、还加入了部分文本数据的处理。本文深入浅出地详解catboost,全篇通俗易懂帮助大家掌握原理。

官网文档:https://catboost.ai/en/docs/

一 主要特点:

1 类别变量编码-Order Target Statistics方法

2 文本型变量编码处理

3 类别型特征交叉-FM

4 无偏提升-Ordered Boosting

5 使用对称树作为基模型,加速运行

6 不一样的缺失值处理

二 原理详解

1 类别变量处理-Order Target Statistics方法

Target Statistics 是很常用的一种单调性编码方法,以音乐流派特征(包含rock、indie、pop等值)为例,rock的编码结果即为rock样本对应的平均bad_rate(y标签均值;y=1样本比例等),这样对每个特征值bad_rate越高、编码结果越大。

Target Statistics编码有一个小问题,类别型特征的部分特征值的样本量很少时,极端情况下假设pop的样本量只有1个、标签为1,则pop特征值的编码结果为1,但是当我们划分训练测试集时、训练集pop特征值的编码结果为1,而测试集中存在多个pop样本、且对应的y标签bad_rate远小于1时,这样就存在了以点概面的问题;针对这个问题可以在上式中添加整体数据集的bad_rate让编码结果扁平化、更加稳定

Order Target Statistics方法,以例子介绍更加清晰直观

(1)fn特征为音乐流派,包含rock、indie、pop等值,把该列类别型变量转化为数值型变量

(2)对数据集随机排列,生成多个随机序列的数据集(默认为4个)

(3)假设此时所有样本中rock、indie、pop对应的标签均值(y=1的比例)均为0.05,类别型变量按照以下公式编码为数值型变量

countInClass:当前序列的数据集中,x样本之前所有样本中,同类别、且同y标签的样本频数;以上图为例,第4条样本(Object=4)fn=rock,之前三条样本中rock有2条、同为rock且y标签一致的有1条,所以此处countInClass=1

prior:先验值,二分类下等同于当前特征值对应标签均值(风控中为bad_rate),此处即为0.05

totalCount:当前序列的数据集中,x样本之前所有样本中,同类别样本频数;以上图为例,第4条样本(Object=4)fn=rock,之前三条样本中rock有2条,则此处totalCount=2

综上,以上图第4条样本(Object=4)为例,avg_target=(1+0.05)/(2+1)=0.35

在上面例子中可以看到,同一特征值、同样的y标签下,前后的编码结果差异可能会很大,如第1、7条样本,此时排序越靠前的样本、对应的编码越不合理,因为该条样本之前没有多少样本作为编码的基础。

针对这个问题,catboost对数据随机排序多次(在第2步中提到过)、生成多个随机序列的数据集,在每个子模型中,会随机选择一个序列的数据集进行编码,这样可以缓解排序在前面样本的类别变量编码不合理的问题

2 文本型变量处理

(1)分词-将原始文本作为字符串,按照文本空格分词(以英文空格分词为例)

(2)生成词典:

token type包含两种字母级Letter、词级Word

对于文本“abra cadabra”,Letter分词结果{a, b, c, d, r};Word分词结果{'abra', 'cadabra'}

对于整列的文本数据,分词之后所有唯一词组成词典、并用索引对词典的词进行编码

(3)字符串转数字编码

对于文本数据列:

根据空格进行分词

汇总生成词典

数值化编码

(4)数值特征

数值特征依赖于分词结果,支持以下几种形式

a、词袋Bag of words:反映是否包含某个词(词编码)的布尔特征,生成的特征数量等同于词典大小

b、top_tokens_count:指定词频top的n个词,生成相应的布尔特征,生成的特征数量等同于n

c、朴素贝叶斯:多项式朴素贝叶斯模型,生成的特征数量等同于y标签的类别数;为了避免y标签信息泄露(穿越),这个模型会在多个数据集上在线计算生成(类似于CTR估计)

d、BM25:搜索引擎用于排名目的的一种函数,用于估计文档的相关性。为了避免目标泄漏,该模型在多个数据集排列上在线计算(类似于CTR估计)

3 特征交叉-FM

顾名思义,Catboost的主要贡献点就在类别型变量上,不仅嵌入了支持类别型变量的编码转换,还支持类别型变量的特征交叉、生成高阶交叉特征(支持任意阶数交叉,可通过参数限制最高阶数)

catboost子树分裂时,根结点的分裂特征会从原始的数值型变量和类别型编码过后的变量中选择,之后的分裂过程中,分裂特征会从包含交叉特征的全部变量中选择

例子:musical style和musical genre的二阶交叉特征

4 无偏提升-Ordered Boosting

引子,blending模型融合的思想大致是这样的,第一层先使用train1数据集构建n个模型,然后使用这些模型对另一批样本train2预测打分、得到n组模型分特征;第二层再使用train2数据集(特征集中包含n组模型分特征)训练模型。其中在第二层的模型训练中更换数据集是为了避免两层模型对同一个训练集数据过度学习、而造成过拟合。

在串行训练下,xgb、lgb在训练后面的树时,需要计算模型在所有样本上的一阶梯度和二阶梯度、用于训练后面的树;其中问题也就是子树对于同一批样本训练的结果、再次用于下一步的残差训练,存在对同一个训练集数据过度学习、而造成过拟合的风险。

catboost训练的每一颗子树,都使用的是其中一个随机排序过后的数据集,对于单个样本、只使用序号在它前面的样本训练子树,然后用模型来计算样本上的一阶梯度和二阶梯度、构建后面的树。在这个思路下,可以减少梯度的估计误差。对应参数:boosting_type,取值Ordered(排序梯度提升)、Plain(经典梯度提升)

5 对称树

catboost使用对称决策树作为子树,以y轴对称,分裂生成的每一层左右两侧节点(分裂特征、分裂节点)相同,一方面可以在一定程度上避免过拟合,另一方面可以加速预测

6 缺失值处理

xgb训练时会把缺失值分别放在左侧子节点和右侧子节点结算信息增益、保留增益较大的方向(详情可参考历史文章);与xgb不同,catboost对于缺失值的处理三个模式:

(1)Forbidden:禁用缺失值,当包含缺失值的数据集用于训练catboost模型时会报错

(2)Min:缺失值被处理为该列的最小值,这样可以让树模型将缺失值和其他值分列开

(3)Max:缺失值被处理为该列的最大值,这样可以让树模型将缺失值和其他值分列开

三 官网示例

1 CatBoostRegressor:

from catboost import CatBoostRegressor
# Initialize datatrain_data = [[1, 4, 5, 6],[4, 5, 6, 7],[30, 40, 50, 60]]eval_data = [[2, 4, 6, 8],[1, 4, 50, 60]]train_labels = [10, 20, 30]
# Initialize CatBoostRegressor
model = CatBoostRegressor(iterations=2,learning_rate=1,depth=2)
# Fit model
model.fit(train_data, train_labels)
# Get predictions
preds = model.predict(eval_data)

2 CatBoostClassifier 含类别型变量

from catboost import CatBoostClassifier
# Initialize data
cat_features = [0, 1]
train_data = [["a", "b", 1, 4, 5, 6],["a", "b", 4, 5, 6, 7],["c", "d", 30, 40, 50, 60]]
train_labels = [1, 1, -1]
eval_data = [["a", "b", 2, 4, 6, 8],["a", "d", 1, 4, 50, 60]]# Initialize CatBoostClassifier
model = CatBoostClassifier(iterations=2,learning_rate=1,depth=2)
# Fit model
model.fit(train_data, train_labels, cat_features)
# Get predicted classes
preds_class = model.predict(eval_data)
# Get predicted probabilities for each class
preds_proba = model.predict_proba(eval_data)
# Get predicted RawFormulaVal
preds_raw = model.predict(eval_data, prediction_type='RawFormulaVal')

3 使用Pools数据类型

from catboost import CatBoostClassifier, Pooltrain_data = Pool([[[0.1, 0.12, 0.33], [1.0, 0.7], 2, "male"],[[0.0, 0.8, 0.2], [1.1, 0.2], 1, "female"],[[0.2, 0.31, 0.1], [0.3, 0.11], 2, "female"],[[0.01, 0.2, 0.9], [0.62, 0.12], 1, "male"]],label = [1, 0, 0, 1],cat_features=[3],embedding_features=[0, 1]
)eval_data = Pool([[[0.2, 0.1, 0.3], [1.2, 0.3], 1, "female"],[[0.33, 0.22, 0.4], [0.98, 0.5], 2, "female"],[[0.78, 0.29, 0.67], [0.76, 0.34], 2, "male"],],label = [0, 1, 1],cat_features=[3],embedding_features=[0, 1]
)model = CatBoostClassifier(iterations=10)model.fit(train_data, eval_set=eval_data)
preds_class = model.predict(eval_data)

四 catboost参数详解及实战

catboost参数详解及实战(强推)

获取更多理论知识与代码分享,欢迎关注公众号:Python风控模型与数据分析

Catboost原理详解相关推荐

  1. CRF(条件随机场)与Viterbi(维特比)算法原理详解

    摘自:https://mp.weixin.qq.com/s/GXbFxlExDtjtQe-OPwfokA https://www.cnblogs.com/zhibei/p/9391014.html C ...

  2. LVS原理详解(3种工作方式8种调度算法)--老男孩

    一.LVS原理详解(4种工作方式8种调度算法) 集群简介 集群就是一组独立的计算机,协同工作,对外提供服务.对客户端来说像是一台服务器提供服务. LVS在企业架构中的位置: 以上的架构只是众多企业里面 ...

  3. jQuery中getJSON跨域原理详解

    详见:http://blog.yemou.net/article/query/info/tytfjhfascvhzxcytp28 jQuery中getJSON跨域原理详解 前几天我再开发一个叫 河蟹工 ...

  4. nginx配置文件及工作原理详解

    nginx配置文件及工作原理详解 1 nginx配置文件的结构 2 nginx工作原理 1 nginx配置文件的结构 1)以下是nginx配置文件默认的主要内容: #user nobody; #配置用 ...

  5. EMD算法之Hilbert-Huang Transform原理详解和案例分析

    目录 Hilbert-Huang Transform 希尔伯特-黄变换 Section I 人物简介 Section II Hilbert-Huang的应用领域 Section III Hilbert ...

  6. 图像质量损失函数SSIM Loss的原理详解和代码具体实现

    本文转自微信公众号SIGAI 文章PDF见: http://www.tensorinfinity.com/paper_164.html http://www.360doc.com/content/19 ...

  7. 深入剖析Redis系列(三) - Redis集群模式搭建与原理详解

    前言 在 Redis 3.0 之前,使用 哨兵(sentinel)机制来监控各个节点之间的状态.Redis Cluster 是 Redis 的 分布式解决方案,在 3.0 版本正式推出,有效地解决了 ...

  8. 【Android架构师java原理详解】二;反射原理及动态代理模式

    前言: 本篇为Android架构师java原理专题二:反射原理及动态代理模式 大公司面试都要求我们有扎实的Java语言基础.而很多Android开发朋友这一块并不是很熟练,甚至半路初级底子很薄,这给我 ...

  9. SVM分类器原理详解

    SVM分类器原理详解 标签: svm文本分类java 2015-08-21 11:51 2399人阅读 评论(0) 收藏 举报  分类: 数据挖掘 文本处理(16)  机器学习 分类算法(10)  目 ...

最新文章

  1. 推荐GitHub上几个比较热门的开源项目,记得收藏下!!!
  2. lunix 命令积累
  3. Java黑皮书课后题第4章:*4.15(电话键盘)电话上的国际标准字母/数字映射如下所示。编写程序,提示用户输入一个小写或大写字母,然后显示对应数字。对于非字母输入,提示非法输入
  4. 成员函数指针与高性能的C++委托(上篇)
  5. 拉斯维加斯算法结合八皇后问题
  6. git cherry pick
  7. java 二分搜索获得大于目标数的第一位_Java后端架构师技术图谱,你都了解多少?...
  8. 小记安装ElasticSearch遇到的小坑
  9. 打造自己的Android源码学习环境之三:在虚拟机中安装Ubuntu(下)
  10. 20170403_Windows网络编程视频学习1
  11. MacOS Ventura 13.0 Beta6 (22A5331f) 带 OC 0.8.4 三分区原版黑苹果镜像
  12. flask中jinjia2模板引擎的使用详解3
  13. Python的简单介绍
  14. 看了阿里找数据分析师的新规则,真让人头皮发麻!
  15. 利用Python打造短链接服务
  16. ubantu与CentOS虚拟机之间搭建GRE隧道
  17. python中文朗读_python语音朗读
  18. 【RC延迟电路 RC充电电路】 multisim 14.0仿真 参数计算
  19. ctfshow刷题日记sql注入篇
  20. MySQL数据库-设置数据完整性

热门文章

  1. mysql工具都有什么作用是什么_Navicat for MySQL是什么
  2. python指定圆心画圆
  3. HTB- Armageddon
  4. 国庆最佳旅游城市必属它了
  5. GaN图腾柱无桥 Boost PFC(单相)七-PFC占空比前馈
  6. centos6.5环境openldap实战之ldap配置详解及web管理工具lam(ldap-account-manager)使用详解...
  7. 对豆瓣进行爬虫来获取相关数据(分别保存到Excel表格和sqlite中)
  8. 解决微信昵称含有表情符号等插入到数据库报错问题
  9. HTML5 canvas 画布
  10. 【备忘】关于ssh为什么会失败的原因总结?下次记得来找。