最近也在做这个,核心代码给你,如下:

import matplotlib.pyplot as plt

import numpy as np

from sklearn.datasets import make_classification

from sklearn.linear_model import LogisticRegression

from sklearn.ensemble import (RandomTreesEmbedding, RandomForestClassifier, GradientBoostingClassifier)

from sklearn.preprocessing import OneHotEncoder

from sklearn.model_selection import train_test_split

from sklearn.metrics import roc_curve,auc

from sklearn.externals import joblib

from sklearn.externals.six import StringIO

from sklearn import tree

import pydotplus

n_estimator = 4 # the number of base trees

X, y = make_classification(n_samples=100,n_features=4)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,random_state=0)

X_train, X_train_lr, y_train, y_train_lr = train_test_split(X_train,y_train,test_size=0.3)

grd = GradientBoostingClassifier(n_estimators=n_estimator,max_depth=3)

grd.fit(X_train, y_train)

score_feature = grd.feature_importances_

print(score_feature)

joblib.dump(grd,'gbdt_model/grd_model.m',compress = 3)

m1 = grd.apply(X_train)[:, :, 0]

# print (m1)

grd_enc = OneHotEncoder()

grd_lm = LogisticRegression()

grd_enc.fit(m1)

joblib.dump(grd_enc,'gbdt_model/grd_enc_model.m',compress = 3)

grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr)

joblib.dump(grd_lm,'gbdt_model/grd_lm_model.m',compress = 3)

# save the training model

m2 = grd_enc.transform(grd.apply(X_train_lr)[:, :, 0])

print (m2.toarray())

# grd = joblib.load('gbdt_model/grd_model.m')

# grd_enc = joblib.load('gbdt_model/grd_enc_model.m')

# grd_lm = joblib.load('gbdt_model/grd_lm_model.m')

y_pred_grd_lm = grd_lm.predict_proba(grd_enc.transform(grd.apply(X_test)[:, :, 0]))

print grd_lm.coef_

#print y_pred_grd_lm

acc = grd_lm.score(grd_enc.transform(grd.apply(X_test)[:, :, 0]),y_test)

print "acc is ",acc

fpr_grd_lm, tpr_grd_lm, _ = roc_curve(y_test, y_pred_grd_lm[:,1])

roc_auc = auc(fpr_grd_lm,tpr_grd_lm)

print "auc is ",roc_auc

dot_data = StringIO()

tree.export_graphviz(grd.estimators_[0,0],out_file = dot_data,node_ids=True,filled=True,rounded=True,special_characters=True)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())

graph.write_pdf("yang.pdf")

print('Visible tree plot saved as pdf.')

python处理grd格式文件_python sklearn中,GBDT模型训练之后,可以查看模型中树的分裂路径吗?...相关推荐

  1. python如何打开npy文件_python实现npy格式文件转换为txt文件操作

    如下代码会将npy的格式数据读出,并且输出来到控制台: import numpy as np ##设置全部数据,不输出省略号 import sys np.set_printoptions(thresh ...

  2. Python:pmml格式文件的简介、安装、使用方法(利用python将机器学习模型转为Java常用的pmml格式文件)之详细攻略

    Python:pmml格式文件的简介.安装.使用方法(利用python将机器学习模型转为Java常用的pmml格式文件)之详细攻略 目录 pmml格式文件的简介 1.PMML结构 pmml安装 pmm ...

  3. python 读取csv文件转成字符串,python实现csv格式文件转为asc格式文件的方法

    一.背景描述 csv格式文件是一种类似于excel的文件格式 asc格式文件是一种可以用text打开的文本文件 csv转asc本来可以用arcgis顺利完成,但由于csv数据量太大(744万行),ar ...

  4. csv加header python_用python处理csv格式文件

    用python处理csv格式文件 在各种平台上获取数据时,我们常常获得的是csv格式的文件.csv格式是一种逗号分隔值的文件格式,它并不是非常reader-friendly.所幸,python标准库中 ...

  5. python操作excel格式文件

    python操作excel格式文件 1. 读数据 2. 写excel 3.操作整合 1. 读数据 安装包 pip install openpyxl 导包 from openpyxl import lo ...

  6. python操作XML格式文件

    python操作XML格式文件 python操作XML格式文件 1. 读取文件和内容 2.读取节点数据 3.修改和删除节点 4.构建文档 python操作XML格式文件 可扩展标记语言,是一种简单的数 ...

  7. AI加速信息和知识获取速度,使用Python对MD格式文件和HTML网页进行内容摘要,2023年4月AI网页内容摘要工具大全

    在信息时代,获取知识变得至关重要.然而,有时候信息的数量是如此之大,以至于人类无法有效处理.这就是人工智能(AI)能够做出贡献的地方.通过AI,我们可以快速地找到并理解文章的核心观点和重要信息.下面将 ...

  8. python用os.system打开wav文件_使用python读取wav格式文件

    ** 使用python读取wav格式文件 ** - 基本概念 [采样频率] 即取样频率, 指每秒钟取得声音样本的次数.采样频率越高,声音的质量也就越好,声音的还原也就越真实,但同时它占的资源比较多.由 ...

  9. python获取的html转换为json,python读取XML格式文件并转为json格式

    XML文件如下: 红楼梦书名> 曹雪芹作者> 描述贾宝玉和林黛玉的爱情故事主要内容> 人民文学出版社出版社> 图书> 一.python读取XML格式文件代码: impor ...

  10. python操作xlsx格式文件

    python操作xlsx格式文件 一.准备工作 二 .xlrd库读取 三.pandas库读取 1.安装pandas: pip install pandas 2.代码如下 3.操作行列 一.准备工作 二 ...

最新文章

  1. android学习从模仿开始 —— 模仿UI 导航帖
  2. 如何在迭代时从列表中删除项目?
  3. hdu1261 字串数(排列组合、大整数)
  4. 进入公司前与Boss的会谈话
  5. 高考特长计算机2017,2017年北京理工大学计算机学院申请竞赛获奖与特长生推荐.PDF...
  6. Bootstrp--一个导航面板切换的实用例子
  7. jsp中jquery传值给Java_jsp中利用jquery+ajax在前后台之间传递json格式参数
  8. jdk8永久代从方法区移除的验证
  9. java正则表达式提取字符串中的数字
  10. 第十三期:消灭 Java 代码的“坏味道”
  11. php中冒号是什么,在PHP中:(双冒号)和-(箭头)有什么区别?
  12. C#格式化字符串净化代码的方法
  13. 厦门信息集团与EMC战略合作共建智慧厦门
  14. 比特币价格跌破3万美元登上彭博社头版
  15. 在继续之前,如何暂停我的Shell脚本一秒钟?
  16. 图像扩充边界_使用机器学习来索引数十亿图像中的文本
  17. ojdbc14:11.2.0.1.0出错
  18. 设计模式优秀文章集合
  19. Ultra Compare 8 文本比较乱码问题 解决
  20. 【服务器数据恢复】IBM某型号服务器VMware虚拟机误删除的数据恢复案例

热门文章

  1. 从 Codable 到 Swift 元编程
  2. 二进制拆弹phase0
  3. 导出excel换行问题,一个单元格多张图片问题,数组对象去重处理,计算属性传参
  4. sklearn基于轮廓系数来选择n_clusters
  5. win10无限蓝屏_WIN10无限重启怎么解决,现在开不了机
  6. 创建不带参数的存储过程
  7. 西游记中孙悟空大闹天宫时期被孙悟空打败的
  8. 前端小白入门之css
  9. 360校招之圈地运动
  10. ABAP 新特性 - CORRESPONDING