比赛杀器LightGBM常用操作总结!
Datawhale干货
作者:阿水,北京航空航天大学,Datawhale成员
LightGBM是基于XGBoost的一款可以快速并行的树模型框架,内部集成了多种集成学习思路,在代码实现上对XGBoost的节点划分进行了改进,内存占用更低训练速度更快。
LightGBM官网:https://lightgbm.readthedocs.io/en/latest/
参数介绍:https://lightgbm.readthedocs.io/en/latest/Parameters.html
本文内容如下,原始代码获取方式见文末。
1 安装方法
2 调用方法
2.1 定义数据集
2.2 模型训练
2.3 模型保存与加载
2.4 查看特征重要性
2.5 继续训练
2.6 动态调整模型超参数
2.7 自定义损失函数
3 调参方法
人工调参
网格搜索
贝叶斯优化
1 安装方法
LightGBM的安装非常简单,在Linux下很方便的就可以开启GPU训练。可以优先选用从pip安装,如果失败再从源码安装。
安装方法:从源码安装
git clone --recursive https://github.com/microsoft/LightGBM ;
cd LightGBM
mkdir build ; cd build
cmake ..# 开启MPI通信机制,训练更快
# cmake -DUSE_MPI=ON ..# GPU版本,训练更快
# cmake -DUSE_GPU=1 ..
make -j4
安装方法:pip安装
# 默认版本
pip install lightgbm# MPI版本
pip install lightgbm --install-option=--mpi# GPU版本
pip install lightgbm --install-option=--gpu
2 调用方法
在Python语言中LightGBM提供了两种调用方式,分为为原生的API和Scikit-learn API,两种方式都可以完成训练和验证。当然原生的API更加灵活,看个人习惯来进行选择。
2.1 定义数据集
df_train = pd.read_csv('https://cdn.coggle.club/LightGBM/examples/binary_classification/binary.train', header=None, sep='\t')
df_test = pd.read_csv('https://cdn.coggle.club/LightGBM/examples/binary_classification/binary.test', header=None, sep='\t')
W_train = pd.read_csv('https://cdn.coggle.club/LightGBM/examples/binary_classification/binary.train.weight', header=None)[0]
W_test = pd.read_csv('https://cdn.coggle.club/LightGBM/examples/binary_classification/binary.test.weight', header=None)[0]y_train = df_train[0]
y_test = df_test[0]
X_train = df_train.drop(0, axis=1)
X_test = df_test.drop(0, axis=1)
num_train, num_feature = X_train.shape# create dataset for lightgbm
# if you want to re-use data, remember to set free_raw_data=Falselgb_train = lgb.Dataset(X_train, y_train,weight=W_train, free_raw_data=False)lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train,weight=W_test, free_raw_data=False)
2.2 模型训练
params = {'boosting_type': 'gbdt','objective': 'binary','metric': 'binary_logloss','num_leaves': 31,'learning_rate': 0.05,'feature_fraction': 0.9,'bagging_fraction': 0.8,'bagging_freq': 5,'verbose': 0
}# generate feature names
feature_name = ['feature_' + str(col) for col in range(num_feature)]
gbm = lgb.train(params,lgb_train,num_boost_round=10,valid_sets=lgb_train, # eval training datafeature_name=feature_name,categorical_feature=[21])
2.3 模型保存与加载
# save model to file
gbm.save_model('model.txt')print('Dumping model to JSON...')
model_json = gbm.dump_model()with open('model.json', 'w+') as f:json.dump(model_json, f, indent=4)
2.4 查看特征重要性
# feature names
print('Feature names:', gbm.feature_name())# feature importances
print('Feature importances:', list(gbm.feature_importance()))
2.5 继续训练
# continue training
# init_model accepts:
# 1. model file name
# 2. Booster()
gbm = lgb.train(params,lgb_train,num_boost_round=10,init_model='model.txt',valid_sets=lgb_eval)
print('Finished 10 - 20 rounds with model file...')
2.6 动态调整模型超参数
# decay learning rates
# learning_rates accepts:
# 1. list/tuple with length = num_boost_round
# 2. function(curr_iter)
gbm = lgb.train(params,lgb_train,num_boost_round=10,init_model=gbm,learning_rates=lambda iter: 0.05 * (0.99 ** iter),valid_sets=lgb_eval)
print('Finished 20 - 30 rounds with decay learning rates...')# change other parameters during training
gbm = lgb.train(params,lgb_train,num_boost_round=10,init_model=gbm,valid_sets=lgb_eval,callbacks=[lgb.reset_parameter(bagging_fraction=[0.7] * 5 + [0.6] * 5)])
print('Finished 30 - 40 rounds with changing bagging_fraction...')
2.7 自定义损失函数
# self-defined objective function
# f(preds: array, train_data: Dataset) -> grad: array, hess: array
# log likelihood loss
def loglikelihood(preds, train_data):labels = train_data.get_label()preds = 1. / (1. + np.exp(-preds))grad = preds - labelshess = preds * (1. - preds)return grad, hess# self-defined eval metric
# f(preds: array, train_data: Dataset) -> name: string, eval_result: float, is_higher_better: bool
# binary error
# NOTE: when you do customized loss function, the default prediction value is margin
# This may make built-in evalution metric calculate wrong results
# For example, we are doing log likelihood loss, the prediction is score before logistic transformation
# Keep this in mind when you use the customization
def binary_error(preds, train_data):labels = train_data.get_label()preds = 1. / (1. + np.exp(-preds))return 'error', np.mean(labels != (preds > 0.5)), Falsegbm = lgb.train(params,lgb_train,num_boost_round=10,init_model=gbm,fobj=loglikelihood,feval=binary_error,valid_sets=lgb_eval)
print('Finished 40 - 50 rounds with self-defined objective function and eval metric...')
2.8 调参方法
人工调参
For Faster Speed
Use bagging by setting
bagging_fraction
andbagging_freq
Use feature sub-sampling by setting
feature_fraction
Use small
max_bin
Use
save_binary
to speed up data loading in future learningUse parallel learning, refer to
Parallel Learning Guide <./Parallel-Learning-Guide.rst>
__
For Better Accuracy
Use large
max_bin
(may be slower)Use small
learning_rate
with largenum_iterations
Use large
num_leaves
(may cause over-fitting)Use bigger training data
Try
dart
Deal with Over-fitting
Use small
max_bin
Use small
num_leaves
Use
min_data_in_leaf
andmin_sum_hessian_in_leaf
Use bagging by set
bagging_fraction
andbagging_freq
Use feature sub-sampling by set
feature_fraction
Use bigger training data
Try
lambda_l1
,lambda_l2
andmin_gain_to_split
for regularizationTry
max_depth
to avoid growing deep treeTry
extra_trees
Try increasing
path_smooth
网格搜索
lg = lgb.LGBMClassifier(silent=False)
param_dist = {"max_depth": [4,5, 7],"learning_rate" : [0.01,0.05,0.1],"num_leaves": [300,900,1200],"n_estimators": [50, 100, 150]}grid_search = GridSearchCV(lg, n_jobs=-1, param_grid=param_dist, cv = 5, scoring="roc_auc", verbose=5)
grid_search.fit(train,y_train)
grid_search.best_estimator_, grid_search.best_score_
贝叶斯优化
import warnings
import time
warnings.filterwarnings("ignore")
from bayes_opt import BayesianOptimization
def lgb_eval(max_depth, learning_rate, num_leaves, n_estimators):params = {"metric" : 'auc'}params['max_depth'] = int(max(max_depth, 1))params['learning_rate'] = np.clip(0, 1, learning_rate)params['num_leaves'] = int(max(num_leaves, 1))params['n_estimators'] = int(max(n_estimators, 1))cv_result = lgb.cv(params, d_train, nfold=5, seed=0, verbose_eval =200,stratified=False)return 1.0 * np.array(cv_result['auc-mean']).max()lgbBO = BayesianOptimization(lgb_eval, {'max_depth': (4, 8),'learning_rate': (0.05, 0.2),'num_leaves' : (20,1500),'n_estimators': (5, 200)}, random_state=0)lgbBO.maximize(init_points=5, n_iter=50,acq='ei')
print(lgbBO.max)
本文代码,可以在后台回复【lgb】,下载本文的代码Notebook!
“干货学习,点赞三连↓
比赛杀器LightGBM常用操作总结!相关推荐
- 通俗理解kaggle比赛大杀器xgboost
通俗理解kaggle比赛大杀器xgboost 说明:若出现部分图片无法正常显示而影响阅读,请以此处的文章为准:xgboost 题库版. 时间:二零一九年三月二十五日. 0 前言 xgboost一直在竞 ...
- 通俗理解kaggle比赛大杀器xgboost + XGBOOST手算内容 转
通俗理解kaggle比赛大杀器xgboost 转 https://blog.csdn.net/v_JULY_v/article/details/81410574 XGBOOST有手算内容 htt ...
- 16.3 shared_ptr常用操作、计数与自定义删除器等
一:shared_ptr引用计数的增加和减少 共享式引用计数,每一个shared_ptr的拷贝都指向相同的内容(对象),只有最后一个指向该对象的shared_ptr指针不需要. 在指向该对象的时候,这 ...
- 理解kaggle比赛大杀器xgboost
通俗理解kaggle比赛大杀器xgboost 查看全文 http://www.taodudu.cc/news/show-5416062.html 相关文章: 强推大杀器xgboost 成品计算机毕业论 ...
- 干货▍全网通俗易懂的数据竞赛大杀器XGBoost 算法详解
前言 xgboost一直在竞赛江湖里被传为神器,比如时不时某个kaggle/天池比赛中,某人用xgboost于千军万马中斩获冠军. 而我们的机器学习课程里也必讲xgboost,如寒所说:"R ...
- 异类框架BigDL,TensorFlow的潜在杀器!
作者 | Nandita Dwivedi 译者 | 风车云马 责编 | Jane 出品 | AI 科技大本营(id:rgznai100) [导读]你能利用现有的 Spark 集群构建深度学习模型吗?如 ...
- 大杀器!攻克目标检测难点秘籍四,数据增强大法
点击上方"AI算法修炼营",选择加星标或"置顶" 标题以下,全是干货 前面的话 在前面的秘籍一:模型加速之轻量化网络.秘籍二:非极大抑制算法和回归损失优化之路. ...
- 强推大杀器xgboost
强推大杀器xgboost 1 决策树 举个例子,集训营某一期有100多名学员,假定给你一个任务,要你统计男生女生各多少人,当一个一个学员依次上台站到你面前时,你会怎么区分谁是男谁是女呢? 很快,你考虑 ...
- 使用docker-compose 大杀器来部署服务 - 上
我们都听过或者用过 docker,然而使用方式却是仅仅用手动的方式,这样去操作 docker 还是很原始. 好吧,可能在小白的眼中噼里啪啦的对着 term 一顿操作会很拉风,但是高手很不屑!在高手眼里 ...
最新文章
- 如果可以,我想给这本书打十星!
- .NET开发Windows服务
- (字符串)字符串中首先出现两次的字符
- hibernate 之HQL查询实例
- lambda表达式for_each,find_if简介
- php 面向对象编程(class)之从入门到崩溃 高级篇
- java struts2 验证码,struts2中验证码的生成和使用
- 【博弈论】取棋子游戏
- 水星路由器怎么设置虚拟服务器,水星路由器怎么设置图解教程
- 程序上线【起飞检查清单】-让事故和教训转变为实用的清单
- 大数据下的数据分析平台架构
- 自制变压器,要注意几个内容
- 《深入理解Java内存模型》读书总结
- Python数据分析(五) —— 绘制直方图
- 炒股魅力:数据分析侠变身“赚钱机器”
- android numberpicker 自定义,Android的自定义数字Picker控件-NumberPicker使用方法
- neko虚拟机对象文档与扩充
- print中sep,end
- ucos系统使用delay函数死机原因
- 大一考初级会计早不早啊?
热门文章
- ubuntu 14.0 下github 配置
- oralce 增加表字段命令|oralce 增加表字段类型命令
- C++_volatile限定修饰符 Pair类型
- Datawhale组队学习周报(第010周)
- 【Python】打印魔方阵
- 「动手学深度学习」在B站火到没谁,加这个免费实操平台,妥妥天花板!
- 剖析Focal Loss损失函数: 消除类别不平衡+挖掘难分样本 | CSDN博文精选
- 如何保证世界杯直播不卡顿?腾讯云要用AI解决这个问题
- 这三个普通程序员,几个月就成功转型AI,他们的经验是...
- Spring MVC 五大组件