XGBoost算法
• XGBoost是陈天奇等人开发的一个开源机器学习项目,高效地实现了GBDT算法并进行了算法和工程上的许多改进,被广泛应用在Kaggle竞赛及其他许多机器学习竞赛中并取得了不错的成绩。

• XGBoost的基学习器除了可以是CART(这个时候就是GBDT)也可以是线性分类器,而GBDT只能是CART。

• XGBoost的目标函数的近似用了二阶泰勒展开,模型优化效果更好。

• XGBoost在代价函数中加入了正则项,用于控制模型的复杂度(正则项的方式不同,如果你仔细点话,GBDT是一种类似于缩减系数,而XGBoost类似于L2正则化项)。

• XGBoost借鉴了随机森林的做法,支持特征抽样,不仅防止过拟合,还能减少计算

• XGBoost工具支持并行化

• 综合来说Xgboost的运算速度和算法精度都会优于GBDT

具体的算法细节,且看之前学习的内容
我们得先安装xgboost库,pip install xgboost

我们直接来看下代码看看怎么玩的。
第一步,我们加载数据。

import pandas as pd
import numpy as  np
import matplotlib.pyplot as plt
from time import time
import datetimefrom xgboost import XGBRegressor as XGBR  # xgboost模块
from sklearn.ensemble import RandomForestRegressor as RFR  # 随机森林模块
from sklearn.linear_model import LinearRegression as LR  # 线性回归模块
from sklearn.datasets import load_boston  # 使用波士顿房价进行回归试验预测
from sklearn.model_selection import KFold, cross_val_score, train_test_split
from sklearn.metrics import mean_squared_error as MSE  # 评估指标是均方误差# 第一步,导入数据
data = load_boston()
x, y = data.data, data.target
print(x.shape)  # (506, 13)
print(y.shape)X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=420)import sklearn
print(sklearn.metrics.SCORERS.keys())  # 查看所有的模型评估指标

第二步骤:我们先实例化一个模型看看

# 第二步:建立一个xgboost模型,来看看是啥样子的。
xgbr = XGBR(n_estimators=100, silent=True).fit(X_train, y_train)  # 100个弱回归器,silent 会打印运行过程
print(xgbr.score(X_test, y_test))  # 查看分数,分数是R2误差
print(xgbr.feature_importances_)  # 查看各个属性的重要程度
# print(xgbr.predict(X_test))  # 进行预测
print(MSE(y_true=y_test, y_pred=xgbr.predict(X_test)))  # 查看下均方误差

第三步骤:和其他模型,随机森林和线性回归做个对比

# 第三步:看看在交叉验证的情况下,看看xgboost和random、线性回归看看区别
# 先看看xgboost
xgbr = XGBR(n_estimators=100, silent=True)  # 如果开启参数slient:在数据巨大,预料到算法运行会非常缓慢的时候可以使用这个参数来监控模型的训练进度
# 不严谨的交叉验证可以使用全部数据,避免把我的测试数据提起给暴露了
# 严谨的交叉验证就是使用部分数据,比如训练数据
res1 = cross_val_score(xgbr, X_train, y_train, cv=5, scoring='neg_mean_squared_error').mean()  # 必须得导入没有训练过的模型,切记啊!!!
print('XGBR cross val score is %f' % res1)# 来看看随机森林
rfr = RFR(n_estimators=100)
res2 = cross_val_score(rfr, X_train, y_train, cv=5, scoring='neg_mean_squared_error').mean()  # 必须得导入没有训练过的模型,切记啊!!!
print('RFR cross val score is %f' % res2)# 来看看线性回归模型
lr = LR()
res3 = cross_val_score(lr, X_train, y_train, cv=5, scoring='neg_mean_squared_error').mean()  # 必须得导入没有训练过的模型,切记啊!!!
print('LR cross val score is %f' % res3)

下面我们看看在不同规模的数据集下,训练分数和测试分数的大小的变化曲线。

# 第四步:调整参数,通过学习曲线来观察xgboost在boston数据上的表现
def plot_learing_curve(estimators, title, X, y,ax=None,ylim=None,cv=None,n_jobs=None):from sklearn.model_selection import learning_curve# 根据不同的数据集样本数目,来看看训练分数和测试分数train_sizes, train_scores, test_scores = learning_curve(estimators, X, y,shuffle=True, cv=cv, random_state = 23,n_jobs=n_jobs)if ax is None:ax = plt.gca()else:ax = plt.figure()ax.set_title(title)ax.set_xlabel('Training examples')ax.set_ylabel('Score')ax.plot(train_sizes, np.mean(train_scores, axis=1), 'o-', c='r', label='Training score')ax.plot(train_sizes, np.mean(test_scores, axis=1), 'o-', c='g', label='Test score')ax.grid()ax.legend()return axcv = KFold(n_splits=5, shuffle=True, random_state=42)  ## shuffle 表示把数据集打乱, cv就是把数据集分成几份
plot_learing_curve(XGBR(n_estimators=100, random_state=42, silent=True), "XGB", X_train, y_train, cv=cv)
plt.show()

上图中,横坐标是训练样本的数目,很明显,我们目前的模型在训练集上的准确度很高很高,但是在测试数据上表现的不是很好,说明模型处于过拟合的状态了,
那么我们接下来就是要调整参数了啊,第一个需要调整的就是n_estimators。

# 绘制n_estimators的学习曲线
axis_x = range(10, 1010, 40)
nmse = []  # 储存负均方误差
cv = KFold(n_splits=5, shuffle=True, random_state=42)  ## shuffle 表示把数据集打乱, cv就是把数据集分成几份
for i in axis_x:reg = XGBR(n_estimators=i, random_state=42, silent=True)nmse.append(cross_val_score(reg, X_train, y_train, cv=cv, scoring='neg_mean_squared_error').mean())
print(axis_x[nmse.index(max(nmse))], max(nmse))
plt.figure(figsize=(20, 8))
plt.plot(axis_x, nmse, c='r', label='XGB')
plt.legend()
plt.show()

此时呢,当n_estimators = 130的时候就差不多优比较大的值了

《scikit-learn》xgboost相关推荐

  1. python 高维数据_用Sci-kit learn和XGBoost进行多类分类:Brainwave数据案例研究

    在机器学习中,高维数据的分类问题非常具有挑战性.有时候,非常简单的问题会因为这个"维度诅咒"问题变得非常复杂.在本文中,我们将了解不同分类器的准确性和性能是如何变化的. 理解数据 ...

  2. 《english learn》

    中文 english 你好 hello 早上/下午/晚上好 good morning/afterning/evening/night 你是比尔吗 are you bill 是的,我是 yes ,i a ...

  3. 机器学习与Scikit Learn学习库

    摘要: 本文介绍机器学习相关的学习库Scikit Learn,包含其安装及具体识别手写体数字案例,适合机器学习初学者入门Scikit Learn. 在我科研的时候,机器学习(ML)是计算机科学领域中最 ...

  4. 【Book 118】《How We Learn》

    [Book 118]<How We Learn>

  5. 泰晤士报下载_《泰晤士报》和《星期日泰晤士报》新闻编辑室中具有指标的冒险活动-第1部分:问题

    泰晤士报下载 TLDR: Designing metrics that help you make better decisions is hard. In The Times and The Sun ...

  6. k近邻算法python解读_Python3《机器学习实战》学习笔记(一):k-近邻算法(史诗级干货长文)...

    运行平台: Windows IDE: Sublime text3 一.简单k-近邻算法 本文将从k-近邻 1.k-近邻法简介 k近邻法(k-nearest neighbor, k-NN)是1967年由 ...

  7. 《机器学习实战》——kNN(k近邻算法)

    原作者写的太好了,包括排版都特别整齐(其中有一个错误之处就是在约会网站配对效果判定的时候,列表顺序不对,导致结果有误,这里我已做出修改) 原作者和出处:http://blog.csdn.net/c40 ...

  8. 吴恩达《Machine Learning》精炼笔记 8:聚类 KMeans 及其 Python实现

    作者 | Peter 编辑 | AI有道 系列文章: 吴恩达<Machine Learning>精炼笔记 1:监督学习与非监督学习 吴恩达<Machine Learning>精 ...

  9. 吴恩达《Machine Learning》Jupyter Notebook 版笔记发布!图解、公式、习题都有了

    在我很早之前写过的文章<机器学习如何入门>中,就首推过吴恩达在 Coursera 上开设的<Machine Learning>课程.这门课最大的特点就是基本没有复杂的数学理论和 ...

  10. 关于《重启人工智能》11条建议的思考

    来源:人机与认知实验室 马库斯和欧内斯特·戴维斯在他们的新书<重启人工智能>(Rebooting AI)中主张开辟一条新的前进道路.他们相信,我们离获得这样的通用智能还差得很远,但他们也相 ...

最新文章

  1. 201403-4 无线网络
  2. 长期使用中型Access数据库的一点经验
  3. 李彦宏喊你来坐出租车,无人驾驶的那种;百度还要继续搞芯片,联手华为的那种...
  4. Android 混淆文件project.properties和proguard-project.txt
  5. 状态标志寄存器--EFLAGS
  6. python第三十课--异常(raise关键字)
  7. js判断浏览器\屏幕分辨率(转载)
  8. 【LDA学习系列】Latent Dirichlet Allocation主题模型理解
  9. UNIX Domain Socket(UDS)是什么?同一台主机间进程间通信
  10. 【干货】和你谈谈数据分析报告
  11. 为什么沿梯度方向,函数变化最快???
  12. 报错:此版本的SQL Server Data Tools与此计算机中安装的数据库运行时组件不兼容...
  13. 墨天轮2022年新春发布会暨年度数据库颁奖盛典即将开启!
  14. Android 系统优化(35)---Android 中如何计算 App 的启动时间?
  15. 流媒体服务器NTV Media Server G3性能测试
  16. 将一个js项目改造成vue项目
  17. Silverlight图片处理——(伸展,裁剪,蒙版)
  18. mysql怎么增加字数_数据库字段如何设置最大字数
  19. 滑窗口统计基因组GC含量的分布
  20. OpenHarmony短信验证码及倒计时实现

热门文章

  1. 《技术的潜能:商业颠覆、创新与执行》一一2.12决心、愿望和耐力
  2. 使用Advanced Installer将.exe程序重新封装为.msi程序
  3. 在Oracle DG Standby库上启用flashback database功能
  4. [PL/SQL]使用存储过程实现导出指定数据到文件(仿EXP)|转|
  5. linux文件三种时间及stat的用法
  6. Exchange 企业邮件与Windows安全应用 — Exchange 2007 收件人管理
  7. 业务逻辑数据层SqlDataSourcesql的输入参数
  8. 在 git hooks 中运行 npm script
  9. Golang 实现tcp转发代理
  10. CentOS 7添加开机启动服务/脚本