Python实现决策树(Decision Tree)分类
关于决策树的简介可以参考: http://blog.csdn.net/fengbingchun/article/details/78880934
在 https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/ 中给出了CART(Classification and Regression Trees,分类回归树算法,简称CART)算法的Python实现,采用的数据集为Banknote Dataset,关于此数据集的介绍可以参考:http://blog.csdn.net/fengbingchun/article/details/78624358 ,这里在原作者的基础上,进行了略微改动,使其可以直接执行,code如下:
# reference: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
# http://zhuanlan.51cto.com/art/201702/531945.htm
# using CART(Classification and Regression Trees,分类回归树算法,简称CART算法)) for classification# CART on the Bank Note dataset
from random import seed
from random import randrange
from csv import reader# Load a CSV file
def load_csv(filename):file = open(filename, "r")lines = reader(file)dataset = list(lines)return dataset# Convert string column to float
def str_column_to_float(dataset, column):for row in dataset:row[column] = float(row[column].strip())# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):dataset_split = list()dataset_copy = list(dataset)fold_size = int(len(dataset) / n_folds)for i in range(n_folds):fold = list()while len(fold) < fold_size:index = randrange(len(dataset_copy))fold.append(dataset_copy.pop(index))dataset_split.append(fold)return dataset_split# Calculate accuracy percentage
def accuracy_metric(actual, predicted):correct = 0for i in range(len(actual)):if actual[i] == predicted[i]:correct += 1return correct / float(len(actual)) * 100.0# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):folds = cross_validation_split(dataset, n_folds)scores = list()for fold in folds:train_set = list(folds)train_set.remove(fold)train_set = sum(train_set, [])test_set = list()for row in fold:row_copy = list(row)test_set.append(row_copy)row_copy[-1] = Nonepredicted = algorithm(train_set, test_set, *args)actual = [row[-1] for row in fold]accuracy = accuracy_metric(actual, predicted)scores.append(accuracy)return scores# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):left, right = list(), list()for row in dataset:if row[index] < value:left.append(row)else:right.append(row)return left, right# Calculate the Gini index for a split dataset
def gini_index(groups, classes):# count all samples at split pointn_instances = float(sum([len(group) for group in groups])) # 计算总的样本数# sum weighted Gini index for each groupgini = 0.0for group in groups:size = float(len(group))# avoid divide by zeroif size == 0:continuescore = 0.0# score the group based on the score for each classfor class_val in classes:p = [row[-1] for row in group].count(class_val) / size # row[-1]指每个样本(一行)中最后一列即类别score += p * p# weight the group score by its relative sizegini += (1.0 - score) * (size / n_instances)return gini# Select the best split point for a dataset
def get_split(dataset):class_values = list(set(row[-1] for row in dataset)) # class_values的值为: [0, 1]b_index, b_value, b_score, b_groups = 999, 999, 999, Nonefor index in range(len(dataset[0])-1): # index的值为: [0, 1, 2, 3]for row in dataset:groups = test_split(index, row[index], dataset)gini = gini_index(groups, class_values)if gini < b_score:b_index, b_value, b_score, b_groups = index, row[index], gini, groupsreturn {'index':b_index, 'value':b_value, 'groups':b_groups} # 返回字典数据类型# Create a terminal node value
def to_terminal(group):outcomes = [row[-1] for row in group]return max(set(outcomes), key=outcomes.count)# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):left, right = node['groups']del(node['groups'])# check for a no splitif not left or not right:node['left'] = node['right'] = to_terminal(left + right)return# check for max depthif depth >= max_depth:node['left'], node['right'] = to_terminal(left), to_terminal(right)return# process left childif len(left) <= min_size:node['left'] = to_terminal(left)else:node['left'] = get_split(left)split(node['left'], max_depth, min_size, depth+1)# process right childif len(right) <= min_size:node['right'] = to_terminal(right)else:node['right'] = get_split(right)split(node['right'], max_depth, min_size, depth+1)# Build a decision tree
def build_tree(train, max_depth, min_size):root = get_split(train)split(root, max_depth, min_size, 1)return root# Make a prediction with a decision tree
def predict(node, row):if row[node['index']] < node['value']:if isinstance(node['left'], dict):return predict(node['left'], row)else:return node['left']else:if isinstance(node['right'], dict):return predict(node['right'], row)else:return node['right']# Classification and Regression Tree Algorithm
def decision_tree(train, test, max_depth, min_size):tree = build_tree(train, max_depth, min_size)predictions = list()for row in test:prediction = predict(tree, row)predictions.append(prediction)return(predictions)# Test CART on Bank Note dataset
seed(1)
# load and prepare data
filename = '../../../data/database/BacknoteDataset/data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):str_column_to_float(dataset, i) # dataset为嵌套列表的列表,类型为float# evaluate algorithm
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))
执行结果如下:
GitHub: https://github.com/fengbingchun/NN_Test
Python实现决策树(Decision Tree)分类相关推荐
- 决策树分类python代码_分类算法-决策树 Decision Tree
决策树(Decision Tree)是一个非参数的监督式学习方法,决策树又称为判定树,是运用于分类的一种树结构,其中的每个内部节点代表对某一属性的一次测试,每条边代表一个测试结果,叶节点代表某个类或类 ...
- 决策树(Decision Tree)_海洋动物分类
决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法.由于 ...
- 算法杂货铺——分类算法之决策树(Decision tree)
算法杂货铺--分类算法之决策树(Decision tree) 2010-09-19 16:30 by T2噬菌体, 88978 阅读, 29 评论, 收藏, 编辑 3.1.摘要 在前面两篇文章中,分别 ...
- 分类Classification:决策树Decision Tree
目录 分类的定义 决策树Decision Tree 混乱衡量指标Gini index 决策树的特点 分类的定义 分类:建立一个学习函数(分类模型)将每个属性集合(x1,x2,...xn)对应到一组已定 ...
- Python数据挖掘入门与实践 第三章 用决策树预测获胜球队(一)pandas的数据预处理与决策树(Decision tree)
作为一个NBA球迷,看到这一章还是挺激动的. 不过内容有点难,研究了半天... 要是赌球的,用这章的预测+凯利公式,是不是就能提升赢钱概率了? 数据预处理 回归书本内容,既然要分析,首先需要有数据: ...
- 数据分类:决策树Decision Tree
背景 决策树(decision tree)是一种基本的分类和回归(后面补充一个回归的例子?)方法,它呈现的是一种树形结构,可以认为是if-then规则的集合.其其主要优点是模型具有很好的可读性,且分类 ...
- 决策树Decision Tree 及实现
本文基于python逐步实现Decision Tree(决策树),分为以下几个步骤: 加载数据集 熵的计算 根据最佳分割feature进行数据分割 根据最大信息增益选择最佳分割feature 递归构建 ...
- 机器学习算法实践:决策树 (Decision Tree)(转载)
前言 最近打算系统学习下机器学习的基础算法,避免眼高手低,决定把常用的机器学习基础算法都实现一遍以便加深印象.本文为这系列博客的第一篇,关于决策树(Decision Tree)的算法实现,文中我将对决 ...
- 决策树Decision Tree 和随机森林RandomForest基本概念(一)
文章目录 一.决策树介绍 1.1 什么是决策树 1.2 决策树种类 1.3 决策树学习过程 1.4 Entropy(熵) 1.5 information gain(信息增益) 1.6 信息论 1.8 ...
最新文章
- return,break,continue三者区别
- 假期第一次编程总结(改二)
- MAC使用homeBrew安装Redis
- 【Python】安装方法小结
- 小游戏一键跳转小程序任意页面
- Hinton胶囊网络代码正式开源,5天GitHub fork超1.4万
- mysql主从搭建_手把手教你搭建MySQL主从架构
- 《系统集成项目管理》第四章 项目管理一般知识
- IBAN和SWIFT代码有什么不同?
- emoji语言抽象话大全_当抽象话也成为一种暗语
- c语言二次方程的实根,C程序求二次方程的根
- Aligning Domain-Specific Distribution and Classifier for Cross-Domain Classification from Multiple
- 垂杨柳中学2021年高考成绩查询时间,2021年北京中考各学校分数线,历年北京中考分数线...
- 花花公子 243线SLOT
- 容器:forward_list用法及示例
- 席位分配问题——惯例Q值法和d'hondt法的MATLAB程序
- 谷歌浏览器无法使用谷歌翻译解决办法
- 让我受益终生的六个字:道、法、术、器、志、势
- 操作系统面试基础知识点
- 开发者分享在PC上制作iOS游戏的经验(下)
热门文章
- 【radar】毫米波雷达相关开源项目代码汇总(工具箱、仿真、2D毫米波检测、融合、4D毫米波检测、分割、SLAM、跟踪)(6)
- 度量学习:ArcFace算法和工程应用总结
- HDU - 5875 2016 ACM/ICPC 大连网络赛 H题 暴力
- 朴素贝叶斯预测是否为垃圾短信
- matlab实现移动平均
- Mat,Iplimage,vector,vector_vector_Point2f等类型之间的相互转换
- VS生成dll和lib库文件
- 【从零开始的ROS四轴机械臂控制】(一)- 实际模型制作、Solidworks文件转urdf与rviz仿真
- 在Ubuntu 14.04 64位上使用libpcap进行抓包和解包
- C# http 性能优化500毫秒到 60 毫秒