关于决策树的简介可以参考: http://blog.csdn.net/fengbingchun/article/details/78880934

在  https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/ 中给出了CART(Classification and Regression Trees,分类回归树算法,简称CART)算法的Python实现,采用的数据集为Banknote Dataset,关于此数据集的介绍可以参考:http://blog.csdn.net/fengbingchun/article/details/78624358 ,这里在原作者的基础上,进行了略微改动,使其可以直接执行,code如下:

# reference: https://machinelearningmastery.com/implement-decision-tree-algorithm-scratch-python/
#            http://zhuanlan.51cto.com/art/201702/531945.htm
# using CART(Classification and Regression Trees,分类回归树算法,简称CART算法)) for classification# CART on the Bank Note dataset
from random import seed
from random import randrange
from csv import reader# Load a CSV file
def load_csv(filename):file = open(filename, "r")lines = reader(file)dataset = list(lines)return dataset# Convert string column to float
def str_column_to_float(dataset, column):for row in dataset:row[column] = float(row[column].strip())# Split a dataset into k folds
def cross_validation_split(dataset, n_folds):dataset_split = list()dataset_copy = list(dataset)fold_size = int(len(dataset) / n_folds)for i in range(n_folds):fold = list()while len(fold) < fold_size:index = randrange(len(dataset_copy))fold.append(dataset_copy.pop(index))dataset_split.append(fold)return dataset_split# Calculate accuracy percentage
def accuracy_metric(actual, predicted):correct = 0for i in range(len(actual)):if actual[i] == predicted[i]:correct += 1return correct / float(len(actual)) * 100.0# Evaluate an algorithm using a cross validation split
def evaluate_algorithm(dataset, algorithm, n_folds, *args):folds = cross_validation_split(dataset, n_folds)scores = list()for fold in folds:train_set = list(folds)train_set.remove(fold)train_set = sum(train_set, [])test_set = list()for row in fold:row_copy = list(row)test_set.append(row_copy)row_copy[-1] = Nonepredicted = algorithm(train_set, test_set, *args)actual = [row[-1] for row in fold]accuracy = accuracy_metric(actual, predicted)scores.append(accuracy)return scores# Split a dataset based on an attribute and an attribute value
def test_split(index, value, dataset):left, right = list(), list()for row in dataset:if row[index] < value:left.append(row)else:right.append(row)return left, right# Calculate the Gini index for a split dataset
def gini_index(groups, classes):# count all samples at split pointn_instances = float(sum([len(group) for group in groups])) # 计算总的样本数# sum weighted Gini index for each groupgini = 0.0for group in groups:size = float(len(group))# avoid divide by zeroif size == 0:continuescore = 0.0# score the group based on the score for each classfor class_val in classes:p = [row[-1] for row in group].count(class_val) / size # row[-1]指每个样本(一行)中最后一列即类别score += p * p# weight the group score by its relative sizegini += (1.0 - score) * (size / n_instances)return gini# Select the best split point for a dataset
def get_split(dataset):class_values = list(set(row[-1] for row in dataset)) # class_values的值为: [0, 1]b_index, b_value, b_score, b_groups = 999, 999, 999, Nonefor index in range(len(dataset[0])-1): # index的值为: [0, 1, 2, 3]for row in dataset:groups = test_split(index, row[index], dataset)gini = gini_index(groups, class_values)if gini < b_score:b_index, b_value, b_score, b_groups = index, row[index], gini, groupsreturn {'index':b_index, 'value':b_value, 'groups':b_groups} # 返回字典数据类型# Create a terminal node value
def to_terminal(group):outcomes = [row[-1] for row in group]return max(set(outcomes), key=outcomes.count)# Create child splits for a node or make terminal
def split(node, max_depth, min_size, depth):left, right = node['groups']del(node['groups'])# check for a no splitif not left or not right:node['left'] = node['right'] = to_terminal(left + right)return# check for max depthif depth >= max_depth:node['left'], node['right'] = to_terminal(left), to_terminal(right)return# process left childif len(left) <= min_size:node['left'] = to_terminal(left)else:node['left'] = get_split(left)split(node['left'], max_depth, min_size, depth+1)# process right childif len(right) <= min_size:node['right'] = to_terminal(right)else:node['right'] = get_split(right)split(node['right'], max_depth, min_size, depth+1)# Build a decision tree
def build_tree(train, max_depth, min_size):root = get_split(train)split(root, max_depth, min_size, 1)return root# Make a prediction with a decision tree
def predict(node, row):if row[node['index']] < node['value']:if isinstance(node['left'], dict):return predict(node['left'], row)else:return node['left']else:if isinstance(node['right'], dict):return predict(node['right'], row)else:return node['right']# Classification and Regression Tree Algorithm
def decision_tree(train, test, max_depth, min_size):tree = build_tree(train, max_depth, min_size)predictions = list()for row in test:prediction = predict(tree, row)predictions.append(prediction)return(predictions)# Test CART on Bank Note dataset
# load and prepare data
filename = '../../../data/database/BacknoteDataset/data_banknote_authentication.csv'
dataset = load_csv(filename)
# convert string attributes to integers
for i in range(len(dataset[0])):str_column_to_float(dataset, i) # dataset为嵌套列表的列表,类型为float# evaluate algorithm
n_folds = 5
max_depth = 5
min_size = 10
scores = evaluate_algorithm(dataset, decision_tree, n_folds, max_depth, min_size)
print('Scores: %s' % scores)
print('Mean Accuracy: %.3f%%' % (sum(scores)/float(len(scores))))


GitHub: https://github.com/fengbingchun/NN_Test

Python实现决策树(Decision Tree)分类相关推荐

  1. 决策树分类python代码_分类算法-决策树 Decision Tree

    决策树(Decision Tree)是一个非参数的监督式学习方法,决策树又称为判定树,是运用于分类的一种树结构,其中的每个内部节点代表对某一属性的一次测试,每条边代表一个测试结果,叶节点代表某个类或类 ...

  2. 决策树(Decision Tree)_海洋动物分类

    决策树(Decision Tree)是在已知各种情况发生概率的基础上,通过构成决策树来求取净现值的期望值大于等于零的概率,评价项目风险,判断其可行性的决策分析方法,是直观运用概率分析的一种图解法.由于 ...

  3. 算法杂货铺——分类算法之决策树(Decision tree)

    算法杂货铺--分类算法之决策树(Decision tree) 2010-09-19 16:30 by T2噬菌体, 88978 阅读, 29 评论, 收藏, 编辑 3.1.摘要 在前面两篇文章中,分别 ...

  4. 分类Classification:决策树Decision Tree

    目录 分类的定义 决策树Decision Tree 混乱衡量指标Gini index 决策树的特点 分类的定义 分类:建立一个学习函数(分类模型)将每个属性集合(x1,x2,...xn)对应到一组已定 ...

  5. Python数据挖掘入门与实践 第三章 用决策树预测获胜球队(一)pandas的数据预处理与决策树(Decision tree)

    作为一个NBA球迷,看到这一章还是挺激动的. 不过内容有点难,研究了半天... 要是赌球的,用这章的预测+凯利公式,是不是就能提升赢钱概率了? 数据预处理 回归书本内容,既然要分析,首先需要有数据: ...

  6. 数据分类:决策树Decision Tree

    背景 决策树(decision tree)是一种基本的分类和回归(后面补充一个回归的例子?)方法,它呈现的是一种树形结构,可以认为是if-then规则的集合.其其主要优点是模型具有很好的可读性,且分类 ...

  7. 决策树Decision Tree 及实现

    本文基于python逐步实现Decision Tree(决策树),分为以下几个步骤: 加载数据集 熵的计算 根据最佳分割feature进行数据分割 根据最大信息增益选择最佳分割feature 递归构建 ...

  8. 机器学习算法实践:决策树 (Decision Tree)(转载)

    前言 最近打算系统学习下机器学习的基础算法,避免眼高手低,决定把常用的机器学习基础算法都实现一遍以便加深印象.本文为这系列博客的第一篇,关于决策树(Decision Tree)的算法实现,文中我将对决 ...

  9. 决策树Decision Tree 和随机森林RandomForest基本概念(一)

    文章目录 一.决策树介绍 1.1 什么是决策树 1.2 决策树种类 1.3 决策树学习过程 1.4 Entropy(熵) 1.5 information gain(信息增益) 1.6 信息论 1.8 ...


