本文代码已上传CSDN,点我下载

文章目录

  • 备注
  • 推荐阅读
  • 简介
  • 安装
  • 初试
  • 可视化
  • 决策树
  • 特征重要性
  • 最优模型
  • 调用GPU
  • 参考文献

备注

该库貌似仍不稳定,我在继续训练的时候找到一个BUGCan Not Training continuation(2020.2.27 版本0.21),现在已修复了

推荐阅读

MNIST & CatBoost保存模型并预测
快速掌握CatBoost基本用法

简介

CatBoost是一款高性能机器学习开源库,基于GBDT,由俄罗斯搜索巨头Yandex在2017年开源。

那么CatBoost与其他Boosting算法如LightGBM和XGBoost相比如何呢?

在质量上,无论是fine-tuned后还是默认情况下,CatBoost的loss优于其他三个框架。

在速度上,CatBoost在Epsilon和Higgs数据集上与对手进行了比较,在GPU训练下完胜对手,在CPU训练下与LightGBM平分秋色。

Epsilon数据集(二分类2001个特征)

Higgs数据集(二分类29个特征)

CatBoost特点有:

  1. 免调参高质量
  2. 支持类别特征
  3. 快速和可用GPU
  4. 提高准确性
  5. 快速预测

更多对比参见Battle of the Boosting Algos: LGB, XGB, Catboost,建议自己运行一遍,本人运行与原文有出入——XGBoost、LightGBM、Catboost对比

安装

GPU开箱即用,不用额外安装其他

pip install catboost

Jupyter可视化配置

pip install ipywidgets
jupyter nbextension enable --py widgetsnbextension

初试

CatBoost内置数据集Titanic,该数据集为二分类任务。


导入必要的包

from catboost.datasets import titanic
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split

读取数据集

# 数据集
titanic_train, titanic_test = titanic()
titanic_train.head(10)


有数据为空NaN,例如乘客编号为6的年龄。
有数据是离散值,例如姓名和船票编号。
认为对模型训练作用性不大,去掉。

remove = ['PassengerId', 'Name', 'Ticket', 'Cabin']
X = titanic_train.drop(remove, axis=1)  # 去掉无关信息
X = X.dropna(how='any', axis='rows')  # 去掉空值
y = X.pop('Survived')  # 标签
X.head()


结果如上,其中船舱等级、性别和登船码头(下标为0,1,6)显然为类别特征,而恰好CatBoost支持类别特征训练。

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

创建Pool对象,这是CatBoost自带的类,便于CatBoost库进行处理。
当然,CatBoost实现了sklearn的接口,直接使用pd.DataFrame类型的X_train, X_test, y_train, y_test训练也行。

# 定义池(CatBoost最快的处理方式)
cat_features = [0, 1, 6]  # 分类特征
train_pool = Pool(X_train, y_train, cat_features=cat_features)
test_pool = Pool(X_test, y_test, cat_features=cat_features)

定义CatBoost分类模型

# 定义模型
model = CatBoostClassifier()

训练,参数含义分别是:train_pool训练数据,eval_set验证集,plot可视化,silent不输出训练过程,use_best_model使用最优模型

# 训练
model.fit(train_pool, eval_set=test_pool, plot=True, silent=True, use_best_model=True)  #可视化,不输出过程,最优模型


查看最优结果和准确率

model.get_best_score()  # 最优loss
{'learn': {'Logloss': 0.14129628504561498},'validation': {'Logloss': 0.471373085990394}}
model.score(test_pool) #准确率
0.8111888111888111

最后保存模型

model.save_model('titanic.model') # 保存模型

加载模型

del model
model = CatBoostClassifier()
model.load_model('titanic.model')

查看测试集数据

print(X_test[:10])
print(y_test[:10])
     Pclass     Sex   Age  SibSp  Parch      Fare Embarked
641       1  female  24.0      0      0   69.3000        C
496       1  female  54.0      1      0   78.2667        C
262       1    male  52.0      1      1   79.6500        S
311       1  female  18.0      2      2  262.3750        C
551       2    male  27.0      0      0   26.0000        S
550       1    male  17.0      0      2  110.8833        C
279       3  female  35.0      1      1   20.2500        S
268       1  female  58.0      0      1  153.4625        S
110       1    male  47.0      0      0   52.0000        S
554       3  female  22.0      0      0    7.7750        S
641    1
496    1
262    0
311    1
551    0
550    1
279    1
268    1
110    0
554    1
Name: Survived, dtype: int64

使用模型进行预测

model.predict(X_test[:10])  #预测

可以看到前5个都对了,后5个错得有点多

array([1, 1, 0, 1, 0, 0, 0, 1, 0, 0], dtype=int64)

使用模型进行概率预测

model.predict_proba(X_test[:10])  #预测概率
array([[0.02731782, 0.97268218],[0.03240048, 0.96759952],[0.63710499, 0.36289501],[0.03272136, 0.96727864],[0.80136214, 0.19863786],[0.64224485, 0.35775515],[0.64860225, 0.35139775],[0.06276485, 0.93723515],[0.64481127, 0.35518873],[0.58364375, 0.41635625]])

继续训练

new_model = CatBoostClassifier()
new_model.fit(test_pool, plot=True, silent=True, init_model='titanic.model') # 继续训练

可视化

fit()时加入参数plot=True

model.fit(X_train, y_train, plot=True)

决策树

调用plot_tree()tree_idx为树的索引

model.plot_tree(tree_idx=0, pool=test_pool)

特征重要性

调用模型属性model.feature_importances_

for i,j in zip(X.columns, model.feature_importances_):print('{}: {:.2f}%'.format(i,j))
Pclass: 18.62%
Sex: 46.79%
Age: 12.47%
SibSp: 4.68%
Parch: 2.16%
Fare: 10.65%
Embarked: 4.63%
%matplotlib inline
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def feature_importances(df, model):max_num_features=10feature_importances = pd.DataFrame(columns = ['feature', 'importance'])feature_importances['feature'] = df.columnsfeature_importances['importance'] = model.feature_importances_feature_importances.sort_values(by='importance', ascending=False, inplace=True)feature_importances = feature_importances[:max_num_features]plt.figure(figsize=(12, 6));sns.barplot(x="importance", y="feature", data=feature_importances);plt.title('CatBoost features importance');
feature_importances(X, model)


看来最决定生死的前三个因素是性别、船舱等级和年龄。

最优模型

fit()时加入参数use_best_model=True

model.fit(X_train, y_train, use_best_model=True)

调用GPU

定义模型时加入参数task_type="GPU"

model = CatBoostClassifier(task_type="GPU")
model.fit(X_train, y_train)

如果需要GPU支持,系统编译器必须与CUDA Toolkit兼容。
若报错请自行编译CatBoost Build from source on Windows

参考文献

  1. CatBoost - open-source gradient boosting library
  2. Quick start - CatBoost. Documentation
  3. CatBoost tutorials
  4. 机器学习算法之Catboost
  5. MNIST & Catboost保存模型并预测

CatBoost快速入门相关推荐

  1. Shiro第一个程序:官方快速入门程序Qucickstart详解教程

    目录 一.下载解压 二.第一个Shiro程序 1. 导入依赖 2. 配置shiro配置文件 3. Quickstart.java 4. 启动测试 三.shiro.ini分析 四.Quickstart. ...

  2. 计算机入门新人必学,异世修真人怎么玩?新手快速入门必备技巧

    异世修真人怎么快速入门?最近新出来的一款文字修仙游戏,很多萌新不知道怎么玩?进小编给大家带来了游戏新手快速入门技巧攻略,希望可以帮到大家. 新手快速入门攻略 1.开局出来往下找婆婆,交互给点钱,旁边有 ...

  3. Spring Boot 2 快速教程:WebFlux 快速入门(二)

    2019独角兽企业重金招聘Python工程师标准>>> 摘要: 原创出处 https://www.bysocket.com 「公众号:泥瓦匠BYSocket 」欢迎关注和转载,保留摘 ...

  4. Apache Hive 快速入门 (CentOS 7.3 + Hadoop-2.8 + Hive-2.1.1)

    2019独角兽企业重金招聘Python工程师标准>>> 本文节选自<Netkiller Database 手札> 第 63 章 Apache Hive 目录 63.1. ...

  5. 《iOS9开发快速入门》——导读

    本节书摘来自异步社区<iOS9开发快速入门>一书中的目录,作者 刘丽霞 , 邱晓华,更多章节内容可以访问云栖社区"异步社区"公众号查看 目 录 前 言 第1章 iOS ...

  6. BIML 101 - ETL数据清洗 系列 - BIML 快速入门教程 - 序

    BIML 101 - BIML 快速入门教程 做大数据的项目,最花时间的就是数据清洗. 没有一个相对可靠的数据,数据分析就是无木之舟,无水之源. 如果你已经进了ETL这个坑,而且预算有限,并且有大量的 ...

  7. python scrapy菜鸟教程_scrapy学习笔记(一)快速入门

    安装Scrapy Scrapy是一个高级的Python爬虫框架,它不仅包含了爬虫的特性,还可以方便的将爬虫数据保存到csv.json等文件中. 首先我们安装Scrapy. pip install sc ...

  8. OpenStack快速入门

    OpenStack云计算快速入门(1) 该教程基于Ubuntu12.04版,它将帮助读者建立起一份OpenStack最小化安装.我是五岳之巅,翻译中多采用意译法,所以个别词与原版有出入,请大家谅解.我 ...

  9. Expression Blend实例中文教程(2) - 界面快速入门

    上一篇主要介绍Expression系列产品,另外概述了Blend的强大功能,本篇将用Blend 3创建一个新Silverlight项目,通过创建的过程,对Blend进行快速入门学习. 在开始使用Ble ...

  10. 图文并茂!60页PPT《快速入门python数据分析路线》(附链接)

    一个月不走弯路快速入门学python和python数据分析路线,呕心沥血加班加点做了2天,一共63页,该课件讲的都是路线中的核心知识,今天把该PPT分享给大家,能根据该课件提到的知识有针对性的学,做到 ...

最新文章

  1. JavaScript封装一个注册函数解决兼容问题
  2. 显示计算机硬盘驱动器更改,笔记本硬盘驱动器的字母怎么修改?笔记本修改硬盘驱动器字母的方法...
  3. 使用 NOR Flash 中的supervivi 下载裸机程序到NandFlash
  4. linux7 rpmdb 修复,Linux[CentOS 7]rpmdb open failed错误修复
  5. 修改element默认样式_ggplot2作图:修改主题元素的外观样式(整体修改)
  6. 新版chrome调整开发者工具位置方式改变
  7. arcgis渔网分割提取栅格图_【操作】ArcGIS中字段的合并、分割、提取
  8. 很多人不知道的中国高校“V9联盟”,另一领域的顶尖牛校!
  9. 如何在Windows11和Windows10上获取驱动程序更新
  10. LNK2019 无法解析的外部符号 __imp_CommandLineToArgvW,该符号在函数 WinMain 中被引用
  11. 程序员的进阶课-架构师之路(7)-树的概念
  12. FPGA常用总线IIC 与SPI选择策略
  13. ★LeetCode(812)——最大三角形面积(JavaScript)
  14. php4.0中文手册,服务 — CodeIgniter 4.0.0 中文手册|用户手册|用户指南|中文文档
  15. win11扩展任务栏没东西怎么办 windows11扩展任务栏没东西的解决方法
  16. 编辑器笔记——sublime text3 编译sass
  17. matlab quiver 箭头颜色,matlab – quiver3箭头颜色对应大小
  18. linux下创造进程指令,Linux系统创建一个新进程(下)
  19. 文本分析python和r_中文文本挖掘R语言和Python哪个好?
  20. 在 Google 工作十年后的感悟

热门文章

  1. 软件常见的各种版本英文缩写
  2. 爬取汽车之家所有汽车参数配置
  3. 泰国之旅随感(r1笔记第70天)
  4. 属性动画与图片三级缓存
  5. webgl点光源的漫反射
  6. WhatsApp网页版(电脑版)使用教程
  7. 浅析浏览器书签的导入和导出
  8. 达尔豪斯大学 计算机专业排名,加拿大留学计算机专业排名
  9. 音频怎么转换成mp3格式
  10. 全能扫描王的实现(python版本)- 目标检测图像矫正