实现RandomForest 随机森林

基于python的sklearn机器学习 类实现

平台
python3.7 Anaconda sklearn库及配套库
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.ensemble import BaggingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.metrics import cohen_kappa_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix    # 生成混淆矩阵函数
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn.preprocessing import StandardScaler
from sklearn.externals import joblib#保存模型
import itertools
class Ctrain_forest:'''调用sklearn 实现Random Forest功能:画混淆矩阵输入数据实现训练保存模型到指定位置调用模型实现预测'''def plot_confusion_matrix(self,cm, classes,normalize=False,title='Confusion matrix',cmap=plt.cm.Blues,path="maxtix"):"""画混淆矩阵This function prints and plots the confusion matrix.Normalization can be applied by setting `normalize=True`.画图函数 输入:cm 矩阵 classes 输入str类型title 名字cmap [图的颜色设置](https://matplotlib.org/examples/color/colormaps_reference.html)"""if normalize:cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]print("Normalized confusion matrix")else:print('Confusion matrix, without normalization')plt.figure(figsize=(11,8))plt.imshow(cm, interpolation='nearest', cmap=cmap)plt.title(title)plt.colorbar()tick_marks = np.arange(len(classes))plt.xticks(tick_marks, classes, rotation=45)plt.yticks(tick_marks, classes)fmt = '.2f' if normalize else 'd'thresh = cm.max() / 2.for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):plt.text(j, i, format(cm[i, j], fmt),horizontalalignment="center",color="white" if cm[i, j] > thresh else "black")# plt.gca().set_xticks(tick_marks, minor=True)# plt.gca().set_yticks(tick_marks, minor=True)# plt.gca().xaxis.set_ticks_position('none')# plt.gca().yaxis.set_ticks_position('none')#plt.grid()# plt.gcf().subplots_adjust(bottom=0.1)# plt.tight_layout()plt.ylabel('True label')plt.xlabel('Predicted label')#解决中文显示plt.rcParams['font.sans-serif']=['SimHei']plt.rcParams['axes.unicode_minus'] = False    plt.savefig(path,dpi=500)  # plt.show()def train_forest(self,x,y,path):"""Random Foeset类输入:x、y以实现训练,path是保存训练过程的路径输出:clf 模型matrix 混淆矩阵dd classifi_reportkappa kappa系数acc_1 模型精度"""X_train,data1x,y_train,data1y = train_test_split(x,y,test_size=0.9,random_state=0)#寻找最优参数depth = np.arange(1,25,4)acc_list = []for d in depth:clf =RandomForestClassifier(bootstrap=True, class_weight="balanced", criterion='gini',max_depth=d*10+1, max_features='auto', max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=3, min_samples_split=3,min_weight_fraction_leaf=0.0, n_estimators=140*2+1, n_jobs=-1,oob_score=False, verbose=0, warm_start=False)clf.fit(X_train, y_train)y_pred_rf = clf.predict(data1x)acc=accuracy_score(data1y, y_pred_rf)acc_list.append(acc)print(accuracy_score(data1y, y_pred_rf))  #整体精度print(cohen_kappa_score(data1y, y_pred_rf))  #Kappa系数#画图mpl.rcParams['font.sans-serif'] = ['SimHei']plt.figure(facecolor='w')plt.plot(depth, acc_list, 'ro-', lw=1)plt.xlabel('随机森林决策树数量', fontsize=15)plt.ylabel('预测精度', fontsize=15)plt.title('随机森林决策树数量和过拟合', fontsize=18)plt.grid(True)plt.savefig(path,dpi=300)#plt.show()y_pred_rf = clf.predict(data1x)print(accuracy_score(data1y, y_pred_rf))  #整体精度#dist=data1y-y_pred_rfprint(cohen_kappa_score(data1y, y_pred_rf))  #Kappa系数matrix=confusion_matrix(data1y, y_pred_rf)kappa=cohen_kappa_score(data1y, y_pred_rf)dd=classification_report(data1y, y_pred_rf)acc_1=accuracy_score(data1y, y_pred_rf)"""# 特征重要性评定rnd_clf = RandomForestClassifier(n_estimators=500, n_jobs=-1)rnd_clf.fit(x, y)for name, score in zip(x, rnd_clf.feature_importances_):print(name, score)""" return clf,matrix,dd,kappa,acc_1def save_model(self,clf,src):"""保存模型到某处clf 模型src 路径"""joblib.dump(clf, src)def get_model_predit(self,data,src):"""调用模型实现预测输入原始数据src 模型路径返回预测值"""getsavemodel=joblib.load(src)predity=getsavemodel.predict(pd.DataFrame(data))return predity

运行结果:


基于python sklearn的 RandomForest随机森林 类实现相关推荐

  1. Spark 和 Python.sklearn:使用随机森林计算 feature_importance 特征重要性

    前言 在使用GBDT.RF.Xgboost等树类模型建模时,往往可以通过feature_importance 来返回特征重要性,本文以随机森林为例介绍其原理与实现.[ 链接:机器学习的特征重要性究竟是 ...

  2. matlab 随机森林算法_(六)如何利用Python从头开始实现随机森林算法

    博客地址:https://blog.csdn.net/CoderPai/article/details/96499505 点击阅读原文,更好的阅读体验 CoderPai 是一个专注于人工智能在量化交易 ...

  3. sklearn实战之随机森林

    sklearn实战系列: (1) sklearn实战之决策树 (2) sklearn实战之随机森林 (3) sklearn实战之数据预处理与特征工程 (4) sklearn实战之降维算法PCA与SVD ...

  4. 【详细代码注释】基于CNN卷积神经网络实现随机森林算法

    随机森林算法简介: 随机森林(Random Forest)是一种灵活性很高的机器学习算法. 它的底层是利用多棵树对样本进行训练并预测的一种分类器.在机器学习的许多领域都有广泛地应用. 例如构建医学疾病 ...

  5. 《菜菜的机器学习sklearn课堂》随机森林应用泛化误差调参实例

    随机森林 随机森林 - 概述 集成算法概述 sklearn中的集成算法 随机森林分类器 RandomForestClassifier 重要参数 控制基评估器的参数 n_estimators:基评估器的 ...

  6. Python进行决策树和随机森林

    Python进行决策树和随机森林 一.决策树 第一步,导入库: 第二步,导入数据: 第三步,数据预处理: 第四步,决策树: 第五步,决策树评价: 第六步,生成决策树图. 二.随机森林 第一步,随机森林 ...

  7. 基于蜣螂算法改进的随机森林回归算法 - 附代码

    基于蜣螂算法改进的随机森林回归算法 - 附代码 文章目录 基于蜣螂算法改进的随机森林回归算法 - 附代码 1.数据集 2.RF模型 3.基于蜣螂算法优化的RF 4.测试结果 5.Matlab代码 6. ...

  8. Python 利用SVM,KNN,随机森林进行预测

    Python 利用SVM,KNN,随机森林进行预测 工具:Pycharm,Win10,Python3.6.4 上图是我们的数据文件,最后一列是附近有无超市的标签,1代表有,-1代表没有.可以发现数据维 ...

  9. python椭圆形骨料_一种基于python再生混凝土三维随机球形骨料模型的构建方法与流程...

    本发明涉及建筑技术领域,尤其涉一种基于python再生混凝土三维随机球形骨料模型的构建方法. 背景技术: 再生混凝土是指利用再生粗骨料部分或者全部代替天然骨料配置而成的混凝土,再生混凝土技术的开发和利 ...

  10. 基于 Python 的横版 2D 动作类小游戏

    基于 Python 的横版 2D 动作类小游戏 游戏代码 游戏代码 游戏整体代码(基于 pygame 模块开发) // An highlighted block import pygame impor ...

最新文章

  1. 【控制】影响系统响应的因素
  2. 排序算法大集锦_合并排序_1(分治思想)
  3. mysql导出单表数据
  4. 女朋友来大姨妈怎么办?
  5. Ubuntu ssh 登陆问题
  6. php选择nginx还是apache,浅谈apache和nginx的rewrite的区别
  7. SQL优化的一些总结
  8. css中的伪类与伪元素的区别
  9. 第二次作业-Steam软件分析
  10. mongodb 副本集搭建
  11. html5怎么播放3gp,写了个html5播放视频的video控件,只支持mp4和3gp(android和ios默认支持的格式就写了这个)...
  12. 华三模拟器之OSPF实验
  13. 【汇正财经】股票上市交易的费用都有哪些?
  14. error: %preun(mysql-community-server-5.7.36-1.el6.x86_64) scriptlet failed
  15. DAX 第八篇:【翻译】数据沿袭(Data Lineage )
  16. Git零基础教程①:如何加速开源社区github的打开(2022版)
  17. 物联网学习笔记(一)
  18. USGS Landsat 8 Collection 2 Level 1数据正确姿势下载
  19. 程序员买房指南——LZ的三次买房和一次卖房经历
  20. springboot预约挂号小程序毕业设计毕设作品开题报告开题答辩PPT

热门文章

  1. 数据库JDBCUtil 工具类 增加连接池操作
  2. gg 修改器游戏被保护_GFX画质修改器120帧下载
  3. pandas提取某两列的值_Pandas进阶修炼120题第五期
  4. 在windows中使用scp命令将文件上传到远端服务器
  5. 从linux服务器上取文件,简介从Linux服务器上远程获取文件的几种方法
  6. python复制函数_Python numpy.copy函数方法的使用
  7. Vue:echarts异步加载数据显示
  8. php字符串转openssl格式,将OpenSSL生成的RSA公钥转换为OpenSSH格式(PHP)
  9. Web研发模式演变史
  10. C++_你真的知道++i 和 i++的区别吗?_左值/右值/右值引用