一、导入必要的工具包

# 导入必要的工具包

import xgboost as xgb

# 计算分类正确率

from sklearn.metrics import accuracy_score

二、数据读取

XGBoost可以加载libsvm格式的文本数据,libsvm的文件格式(稀疏特征)如下:

1  101:1.2 102:0.03

0  1:2.1 10001:300 10002:400

...

每一行表示一个样本,第一行的开头的“1”是样本的标签。“101”和“102”为特征索引,'1.2'和'0.03' 为特征的值。

在两类分类中,用“1”表示正样本,用“0” 表示负样本。也支持[0,1]表示概率用来做标签,表示为正样本的概率。

下面的示例数据需要我们通过一些蘑菇的若干属性判断这个品种是否有毒。

UCI数据描述:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/ ,

每个样本描述了蘑菇的22个属性,比如形状、气味等等(将22维原始特征用加工后变成了126维特征,

并存为libsvm格式),然后给出了这个蘑菇是否可食用。其中6513个样本做训练,1611个样本做测试。

注:libsvm格式文件说明如下 https://www.cnblogs.com/codingmengmeng/p/6254325.html

XGBoost加载的数据存储在对象DMatrix中

XGBoost自定义了一个数据矩阵类DMatrix,优化了存储和运算速度

DMatrix文档:http://xgboost.readthedocs.io/en/latest/python/python_api.html

数据下载地址:http://download.csdn.net/download/u011630575/10266113

# read in data,数据在xgboost安装的路径下的demo目录,现在我们将其copy到当前代码下的data目录

my_workpath = './data/'

dtrain = xgb.DMatrix(my_workpath + 'agaricus.txt.train')

dtest = xgb.DMatrix(my_workpath + 'agaricus.txt.test')

查看数据情况

dtrain.num_col()

dtrain.num_row()

dtest.num_row()

三、训练参数设置

max_depth: 树的最大深度。缺省值为6,取值范围为:[1,∞]

eta:为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重。

eta通过缩减特征的权重使提升计算过程更加保守。缺省值为0.3,取值范围为:[0,1]

silent:取0时表示打印出运行时信息,取1时表示以缄默方式运行,不打印运行时信息。缺省值为0

objective: 定义学习任务及相应的学习目标,“binary:logistic” 表示二分类的逻辑回归问题,输出为概率。

其他参数取默认值。

# specify parameters via map

param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }

print(param)

四、训练模型

# 设置boosting迭代计算次数

num_round = 2

import time

starttime = time.clock()

bst = xgb.train(param, dtrain, num_round) # dtrain是训练数据集

endtime = time.clock()

print (endtime - starttime)

XGBoost预测的输出是概率。这里蘑菇分类是一个二类分类问题,输出值是样本为第一类的概率。

我们需要将概率值转换为0或1。

train_preds = bst.predict(dtrain)

train_predictions = [round(value) for value in train_preds]

y_train = dtrain.get_label() #值为输入数据的第一行

train_accuracy = accuracy_score(y_train, train_predictions)

print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))

五、测试

模型训练好后,可以用训练好的模型对测试数据进行预测

# make prediction

preds = bst.predict(dtest)

检查模型在测试集上的正确率

XGBoost预测的输出是概率,输出值是样本为第一类的概率。我们需要将概率值转换为0或1。

predictions = [round(value) for value in preds]

y_test = dtest.get_label()

test_accuracy = accuracy_score(y_test, predictions)

print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))

六、模型可视化

调用XGBoost工具包中的plot_tree,在显示

要可视化模型需要安装graphviz软件包

plot_tree()的三个参数:

1. 模型

2. 树的索引,从0开始

3. 显示方向,缺省为竖直,‘LR'是水平方向

from matplotlib import pyplot

import graphviz

xgb.plot_tree(bst, num_trees=0, rankdir= 'LR' )

pyplot.show()

#xgb.plot_tree(bst,num_trees=1, rankdir= 'LR' )

#pyplot.show()

#xgb.to_graphviz(bst,num_trees=0)

#xgb.to_graphviz(bst,num_trees=1)

七、代码整理

# coding:utf-8

import xgboost as xgb

# 计算分类正确率

from sklearn.metrics import accuracy_score

# read in data,数据在xgboost安装的路径下的demo目录,现在我们将其copy到当前代码下的data目录

my_workpath = './data/'

dtrain = xgb.DMatrix(my_workpath + 'agaricus.txt.train')

dtest = xgb.DMatrix(my_workpath + 'agaricus.txt.test')

dtrain.num_col()

dtrain.num_row()

dtest.num_row()

# specify parameters via map

param = {'max_depth':2, 'eta':1, 'silent':0, 'objective':'binary:logistic' }

print(param)

# 设置boosting迭代计算次数

num_round = 2

import time

starttime = time.clock()

bst = xgb.train(param, dtrain, num_round) # dtrain是训练数据集

endtime = time.clock()

print (endtime - starttime)

train_preds = bst.predict(dtrain) #

print ("train_preds",train_preds)

train_predictions = [round(value) for value in train_preds]

print ("train_predictions",train_predictions)

y_train = dtrain.get_label()

print ("y_train",y_train)

train_accuracy = accuracy_score(y_train, train_predictions)

print ("Train Accuary: %.2f%%" % (train_accuracy * 100.0))

# make prediction

preds = bst.predict(dtest)

predictions = [round(value) for value in preds]

y_test = dtest.get_label()

test_accuracy = accuracy_score(y_test, predictions)

print("Test Accuracy: %.2f%%" % (test_accuracy * 100.0))

# from matplotlib import pyplot

# import graphviz

import graphviz

# xgb.plot_tree(bst, num_trees=0, rankdir='LR')

# pyplot.show()

# xgb.plot_tree(bst,num_trees=1, rankdir= 'LR' )

# pyplot.show()

# xgb.to_graphviz(bst,num_trees=0)

# xgb.to_graphviz(bst,num_trees=1)

python xgboost用法_XGBoost使用教程(纯xgboost方法)一相关推荐

  1. python xgboost用法_XGBoost类库使用小结

    在XGBoost算法原理小结中,我们讨论了XGBoost的算法原理,这一片我们讨论如何使用XGBoost的Python类库,以及一些重要参数的意义和调参思路. 1. XGBoost类库概述 XGBoo ...

  2. python end用法_python中end的使用方法

    python中end的使用方法 发布时间:2020-06-17 09:47:13 来源:亿速云 阅读:178 这篇文章给大家分享的是有关python中end的使用方法,小编觉得挺实用的,因此分享给大家 ...

  3. python 包用法_Python 基础教程之包和类的用法

    Python 基础教程之包和类的用法 这篇文章主要介绍了 Python 基础教程之包和类的用法的相关资料, 需要的朋友可以参考下 Python 是一种面向对象.解释型计算机程序设计语言,由 Guido ...

  4. python boxplot用法_Python pandas.DataFrame.boxplot函数方法的使用

    DataFrame.boxplot(column = None,by = None,ax = None,fontsize = None,rot = 0,grid = True,figsize = No ...

  5. python xgb模型 预测_如何使用XGBoost模型进行时间序列预测

    字幕组双语原文:如何使用XGBoost模型进行时间序列预测 英语原文:How to Use XGBoost for Time Series Forecasting 翻译:雷锋字幕组(Shangru) ...

  6. xgboost算法_XGBoost算法可能会长期占据你的视野!

    点击上方关注,All in AI中国 我仍然记得十五年前第一份工作的第一天,我刚刚完成了我的研究生课程,并以分析师的身份加入了一家全球投资银行.在上班的第一天,我还是很紧张的,我经常会有的小动作就是会 ...

  7. python中globals用法_Python基础教程之内置函数locals()和globals()用法分析

    本文实例讲述了Python基础教程之内置函数locals()和globals()用法.分享给大家供大家参考,具体如下: 1. 这两个函数主要提供,基于字典的访问局部变量和全局变量的方式. python ...

  8. xgboost分类_XGBoost(Extreme Gradient Boosting)

    一.XGBoost在Ensemble Learning中的位置 机器学习中,有一类算法叫集成学习(Ensemble Learning),所谓集成学习,指将多个分类器的预测结果集成起来,作为最终预测结果 ...

  9. python怎样设置全局变量_Python教程之全局变量用法

    本文实例讲述了Python全局变量用法.分享给大家供大家参考,具体如下: 全局变量不符合参数传递的精神,所以,平时我很少使用,除非定义常量.今天有同事问一个关于全局变量的问题,才发现其中原来还有门道. ...

最新文章

  1. wxruby框架例子1
  2. 我用Python玩小游戏“跳一跳”,瞬间称霸了朋友圈!
  3. android 开发 矩形截屏插件,Android 上如何实现矩形区域截屏
  4. python统计输入学生的总分和平均分_C输入函数和成绩显示函数并计算每位同学总分和平均分对成绩排名输出.doc...
  5. ConcurrentHashMap源码解读,java大厂面试攻略
  6. c# 串口发送接收数据
  7. Java EE组件技术
  8. k8s 配置dashboard
  9. springmvc执行原理(基于组件)
  10. 运营小技能:订阅号文章排版教程(添加图片超链接、推文采集、往期推荐)
  11. 【矩阵论】矩阵的相似标准型(4)(5)
  12. 【DNSPOD】利用DNSPod实现动态域名解析【DDNS】
  13. python lncrna_【云计算】LncRNA生信分析案例
  14. 一键屏蔽百度热搜,专注工作!
  15. 6年主导3个项目,我终于成了别人眼中的大神
  16. 如何实现同一个ip下同一个80端口部署多个网站?
  17. 数据可视化项目(一)
  18. Linux系统中systemctl命令的使用
  19. Wacom 数位板 和冠 手绘笔 Photoshop MacOS 延时卡顿丢笔解决办法
  20. sql数据库本地服务器不显示,sql数据库本地服务器不显示

热门文章

  1. 都已经十岁的ApacheDubbo,还能再乘风破浪吗?
  2. 信用算力基于 RocketMQ 实现金融级数据服务的实践
  3. 用PyTorch创建一个图像分类器?So easy!(Part 2)
  4. MaxCompute Tunnel上传典型问题场景
  5. 为普及再助一把力!《2021年中国低代码/无代码市场研究报告》正式发布
  6. IBM在中国发布Cloud Paks,牵手神州数码,助力企业云转型步入“第二篇章”
  7. php 什么时候传引用,什么时候在PHP中使用传递引用?
  8. java实现对文件加解密操作
  9. 解决windows下Error:node with name rabbit already running on “XXX” 和管理页面打不开问题
  10. Apache JMeter 测试webservice接口 中文乱码