本文主要提供源码的一些思路.具体源码可查看知乎:知乎
简单来说随机森林就是生成N颗CART树,通过bootstrap的方式,有放回可重复的从原始数据集M里选出一部分数据m,总共生成N份这样的数据给N颗CART树去做训练,同时设定每棵树选用数据集中的最大的特征数,也是可重复的选取.最后的结果通过投票表决决定最终结果.
代码如下(没有运行,只是看了下思路):

# -*- coding: utf-8 -*-import numpy as np
from decision_tree_model import ClassificationTree
from sklearn import datasets
from sklearn.model_selection import train_test_splitclass RandomForest():"""Random Forest classifier. Uses a collection of classification trees thattrains on random subsets of the data using a random subsets of the features.Parameters:-----------n_estimators: int树的数量The number of classification trees that are used.max_features: int每棵树选用数据集中的最大的特征数The maximum number of features that the classification trees are allowed touse.min_samples_split: int每棵树中最小的分割数,比如 min_samples_split = 2表示树切到还剩下两个数据集时就停止The minimum number of samples needed to make a split when building a tree.min_gain: float每棵树切到小于min_gain后停止The minimum impurity required to split the tree further.max_depth: int每棵树的最大层数The maximum depth of a tree."""def __init__(self, n_estimators=100, min_samples_split=2, min_gain=0,max_depth=float("inf"), max_features=None):self.n_estimators = n_estimators #树的数量self.min_samples_split = min_samples_split #每棵树中最小的分割数,比如 min_samples_split = 2表示树切到还剩下两个数据集时就停止self.min_gain = min_gain   #每棵树切到小于min_gain后停止self.max_depth = max_depth  #每棵树的最大层数self.max_features = max_features #每棵树选用数据集中的最大的特征数self.trees = []# 建立森林(bulid forest)for _ in range(self.n_estimators):tree = ClassificationTree(min_samples_split=self.min_samples_split, min_impurity=self.min_gain,max_depth=self.max_depth)self.trees.append(tree)def fit(self, X, Y):# 训练,每棵树使用随机的数据集(bootstrap)和随机的特征# every tree use random data set(bootstrap) and random featuresub_sets = self.get_bootstrap_data(X, Y)n_features = X.shape[1]if self.max_features == None:self.max_features = int(np.sqrt(n_features))for i in range(self.n_estimators):# 生成随机的特征# get random featuresub_X, sub_Y = sub_sets[i]idx = np.random.choice(n_features, self.max_features, replace=True)sub_X = sub_X[:, idx]self.trees[i].fit(sub_X, sub_Y)self.trees[i].feature_indices= idxprint("tree", i, "fit complete")def predict(self, X):y_preds = []for i in range(self.n_estimators):idx = self.trees[i].feature_indicessub_X = X[:, idx]y_pre = self.trees[i].predict(sub_X)y_preds.append(y_pre)y_preds = np.array(y_preds).Ty_pred = []for y_p in y_preds:# np.bincount()可以统计每个索引出现的次数# np.argmax()可以返回数组中最大值的索引# cheak np.bincount() and np.argmax() in numpy Docsy_pred.append(np.bincount(y_p.astype('int')).argmax())return y_preddef get_bootstrap_data(self, X, Y):# 通过bootstrap的方式获得n_estimators组数据# get int(n_estimators) datas by bootstrapm = X.shape[0] #行数Y = Y.reshape(m, 1)# 合并X和Y,方便bootstrap (conbine X and Y)X_Y = np.hstack((X, Y)) #np.vstack():在竖直方向上堆叠/np.hstack():在水平方向上平铺np.random.shuffle(X_Y) #随机打乱data_sets = []for _ in range(self.n_estimators):idm = np.random.choice(m, m, replace=True) #在range(m)中,有重复的选取 m个数字bootstrap_X_Y = X_Y[idm, :]bootstrap_X = bootstrap_X_Y[:, :-1]bootstrap_Y = bootstrap_X_Y[:, -1:]data_sets.append([bootstrap_X, bootstrap_Y])return data_setsif __name__ == '__main__':data = datasets.load_digits()X = data.datay = data.targetX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=2)print("X_train.shape:", X_train.shape)print("Y_train.shape:", y_train.shape)clf = RandomForest(n_estimators=100)clf.fit(X_train, y_train)y_pred = clf.predict(X_test)accuracy = accuracy_score(y_test, y_pred)print('y_test:{}\ty_pred:{}'.format(y_test, y_pred))print("Accuracy:", accuracy)

[机器学习]随机森林源码(python)相关推荐

  1. spark ml 随机森林源码笔记二

    书接上回,该分析run方法了,有1000多行,该方法主要是根据数据和参数,训练生成一组树,就是决策森林 开始先干了一件事 val metadata =       DecisionTreeMetada ...

  2. python 营销应用_随机森林算法入门(python),,它可以用于市场营销对客户

    随机森林算法入门(python),,它可以用于市场营销对客户 目录 1 什么是随机森林 1.1 集成学习 1.2 随机决策树 1.3 随机森林 1.4 投票 2 为什么要用它 3 使用方法 3.1 变 ...

  3. [源码]python Scapy Ftp密码嗅探

    [源码]python Scapy Ftp密码嗅探 原理很简单,FTP密码明文传输的 截取tcp 21端口User和Pass数据即可 Scapy框架编译程序较大(一个空程序都25M),所以就不提供exe ...

  4. 网站随机背景音乐源码

    介绍: 随机背景音乐源码 复制下载好的txt文件里的代码放在网页里就可以了,用户每次新访问背景音乐都会自动变换. 网盘下载地址: http://kekewangLuo.cc/VUv07tPWmRy0 ...

  5. [附源码]Python计算机毕业设计SSM绩效考核管理系统(程序+LW)

    [附源码]Python计算机毕业设计SSM绩效考核管理系统(程序+LW) 项目运行 环境配置: Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行)+ ...

  6. [附源码]Python计算机毕业设计SSM即刻实时预约排队系统(程序+LW)

    [附源码]Python计算机毕业设计SSM即刻实时预约排队系统(程序+LW) 项目运行 环境配置: Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行) ...

  7. 扩增子16S/ITS/18S微生物多样性课程更新-机器学习随机森林分析

    扩增子16S/ITS/18S微生物多样性课程更新-机器学习随机森林分析 机器学习或者人工智能(AI)是当前计算机领域研究的热点.然而,最近越来越多的研究者开始尝试将 AI 应用于另一个热门领域--微生 ...

  8. python美女源代码_随机美女写真网页源码+python源程序

    释放双眼,带上耳机,听听看~! 源码介绍 美图网站千千万,美图自己说了算!本源码由@香谢枫林 开发,首页图片做了浏览器窗口自适应,最大化占满PC浏览器和移动浏览器的窗口,并且防止出现滚动条. 功能介绍 ...

  9. [python] 机器学习 随机森林算法RandomForestRegressor

    前言 随机森林Python版本有很可以调用的库,使用随机森林非常方便,主要用到以下的库 sklearn Scikit learn 也简称 sklearn, 是机器学习领域当中最知名的 python 模 ...

最新文章

  1. 2022-2028年中国农用塑料薄膜行业市场研究及前瞻分析报告
  2. python opencv 彩色图非局部平均去噪
  3. 一次给女朋友转账引发我对分布式事务的思考
  4. Matlab学习笔记:画图多重设置
  5. linux检查swap配置,Linux环境下swap配置方法
  6. niosii spi 外部_NIOS II SPI详解 如何使用SPI方式传输
  7. 从Memcache转战Redis,聊聊缓存使用填过的“坑”
  8. http://127.0.0.1/thinkphp5/public/index/teacher/delete/id/1.html 这样的URL下,页面收不到get参数...
  9. 投影html连接电脑,电脑怎么连接投影仪?投影仪的详细安装使用教程
  10. 关于对DataTable进行操作的几个例子总结
  11. [译]C语言实现一个简易的Hash table(5)
  12. Spring Data JPA
  13. 不光荣的“革命”——“甘露之变”后的晚唐政治
  14. STM32F429+W25Q256+TouchFGX
  15. php日志在哪,php日志在哪
  16. “爆炸图!“ArcGIS中制作一张好看的爆炸分析图(附练习数据)
  17. React新手入门学习
  18. 华为云人脸识别服务 FRS 之初体验
  19. Linux服务器挂掉,使之自动重启脚本
  20. 前端(以Vue为例)webpack打包后dist文件包如何部署到django后台中

热门文章

  1. 极验滑块识别-通用滑块识别
  2. 处理mysql启动报错Table 'mysql.plugin' doesn't exis
  3. windows下同一个显卡配置多个CUDA工具包以及它们之间的切换
  4. python 报错'tuple' object does not support item assignment
  5. SuperMap GIS 9D 产品白皮书v1.0
  6. CorelDRAW破解版是如何一步一步坑人的
  7. BS EN 438-5装饰用薄板材层压板材的分类和规范
  8. Glyphs 2 for Mac(字体设计编辑软件)
  9. win7如何安装无线网卡驱动程序?具体安装步骤
  10. OpenGL ES之GLSL实现多种“马赛克滤镜”效果