例:有没有心脏病?

-------1胸痛---------2血液循环良好-----------3.动脉阻塞-------------4心脏病(HD)

01.------No---------------No----------------------No------------------No
02.------Yes--------------Yes---------------------Yes-----------------Yes
03.------Yes--------------Yes---------------------No------------------No


每次选取一个可判断的条件,然后进行单一条件的决策树转化
1.Chest Pain:如是否有胸痛
------|--------|----
----HD-----HD-----
-----|--------|------
Yes|No—Yes|No------是否患有心脏病
105|39-----34|125–对应的人数
2.Good Blood Circulation
–|------|----
HD-----HD-----
|–|-----|–|------
Yes|No-Yes|No------
37 127 100 33

3.Blocked Arteries
–|------|----
HD-----HD-----
|–|-----|–|------
Yes|No-Yes|No------
92 31 45 129

我们把并非100%正确分类的要做"impure"的分类(着重关注点实在leaf node)





# 1.Chest Pain:如是否有胸痛
# ------|--------|----
# ----HD-----HD-----
# -----|--------|------
# Yes|No---Yes|No------是否患有心脏病
# 105|39-----34|125--对应的人数# 概述:有胸痛/有心脏病的有105个  有胸痛/无心脏病的有的39个
#  无胸痛/有心脏病的有34个   无胸痛无心脏病的有125个1-(105/(105+39))**2-(39/(105+39))**2

1-(34/(34+125))**2-(125/(34+125))**2

# 加权处理:
(144/(144+159))*0.395 +(159/(144+159))*0.336


以此类推, 1.Chest Pain = 0.364
2.Good Blood Circulation = 0.360
3.Blocked Arteries = 0.381

有上述得出Good Blood Circulation = 0.360 最小(最纯)

选定根节点之后
Good Blood Circulation
–|------|----
HD-----HD-----
|–|-----|–|------
Yes|No-Yes|No------
37 127 100 33

1.Chest Pain
–|------|----
HD-----HD-----
|–|-----|–|------
Yes|No-Yes|No------
13 98 24 29
这个时候可以算出在Chest Pain下面的Gini =0.3

2.Blocked Arteries
–|------|----
HD-----HD-----
|–|-----|–|------
Yes|No-Yes|No------
24 25 13 102
这个时候可以算出在Blocked Arteries下面的Gini =0.290

-----Good Blood Circulation-----
-------|-----------|--------------
------BA---------CP------------
----|-----|------|-----|----------
—CP-13/102–92/3–BA--------
-|–|-------------|–|------
17/3–7/22--------8/0—0/30—

思考: 1.如果你得到的是数值的数据如何计算Gini?(rank一遍,计算平均值,通过小于等于来分类,*没有必要将最大的一个数值包括,因为无法分类) 2.如果你得到的是程度数值(比如:按照喜欢程度1234)的数据如何计算Gini?(rank一遍,通过小于等于来分类,没有必要将最大的一个数值包括,因为无法分类) 3.如果你得到的是调查问卷的数据如何计算Gini?(通过排列组合来分类,*没有必要将包括所有的组合计算在内,因为无法分类)

基尼指数代码

def calculate_the_gini_index(groups, classes):# 计算有多少实例个数n_instances = float(sum([len(groups) for group in groups]))# 把每一个group里面的加权gini计算出来gini = 0.0for group in groups:size = float(len(groups))# *注意,这里不能除以0,所以我们要考虑到分母为0的情况if size == 0:continuescore = 0.0for class_val in classes:p = [row[-1] for row in group].count(class_val) / sizescore += p * p# 这里做了一个加权处理gini += (1 - score) * (size / n_instances)return gini
# 两个类别的最坏情况
worst_case_for_two_classes = [[[1, 1], [1, 0]], [[1, 1], [1, 0]]]print(calculate_the_gini_index(worst_case_for_two_classes, [0, 1]))
# 两个类别的最佳情况
best_case_for_two_classes = [[[1, 0], [1, 0]],[[1, 1], [1, 1]]]print(calculate_the_gini_index(best_case_for_two_classes, [0, 1]))

def test_split(index, value, dataset): # 左右切分函数left, right = list(), list()for row in dataset:if row[index] < value:left.append(row)else:right.append(row)return left, right
# index,value,groups数据较多,所以选用dict
def get_split(dataset):class_values = list(set(row[-1] for row in dataset))posi_index, posi_value, posi_score, posi_groups = 888, 888, 888, Nonefor index in range(len(dataset[0]) - 1):for row in dataset:groups = test_split(index, row[index], dataset)gini = calculate_the_gini_index(groups, class_values)if gini < posi_score:posi_index, posi_value, posi_score, posi_groups = index, row[index], gini, groupsreturn {'The Best Index is': posi_index, 'The Best Value is': posi_value, 'The Best Groups is': posi_groups}
# 测试
dataset = [[2.1, 1.1, 0],[3.4, 2.5, 0],[1.3, 5.8, 0],[1.9, 8.6, 0],[3.7, 6.2, 0],[8.8, 1.1, 1],[9.6, 3.4, 1],[10.2, 7.4, 1],[7.7, 8.8, 1],[9.7, 6.9, 1]]split = get_split(dataset)
print(split)

最好的分割

def test_split(index, value, dataset):left, right = list(), list()for row in dataset:if row[index] < value:left.append(row)else:right.append(row)return left, rightdef calculate_the_gini_index(groups, classes):# 计算有多少实例n_instances = float(sum([len(group) for group in groups]))# 把每一个group里面的加权gini计算出来gini = 0.0for group in groups:size = float(len(group))# *注意,这里不能除以0,所以我们要考虑到分母为0的情况if size == 0:continuescore = 0.0for class_val in classes:p = [row[-1] for row in group].count(class_val) / sizescore += p * p# 这个做了一个加权处理gini += (1 - score) * (size / n_instances)return ginidef get_split(dataset):class_values = list(set(row[-1] for row in dataset))posi_index, posi_value, posi_score, posi_groups = 888, 888, 888, Nonefor index in range(len(dataset[0]) - 1):for row in dataset:groups = test_split(index, row[index], dataset)gini = calculate_the_gini_index(groups, class_values)print("X%d < %.3f Gini=%.3f" % ((index + 1), row[index], gini))if gini < posi_score:posi_index, posi_value, posi_score, posi_groups = index, row[index], gini, groupsreturn {'index': posi_index, 'value': posi_value, 'groups': posi_groups}dataset = [[2.1, 1.1, 0],[3.4, 2.5, 0],[1.3, 5.8, 0],[1.9, 8.6, 0],[3.7, 6.2, 0],[8.8, 1.1, 1],[9.6, 3.4, 1],[10.2, 7.4, 1],[7.7, 8.8, 1],[9.7, 6.9, 1]]split = get_split(dataset)
print('Split:[X%d < %.3f]' % ((split['index'] + 1), split['value']))

建立回归树

# 1.root node
# 2.recursive split
# 3.terminal node (为了解决over-fitting的问题,减少整个tree的深度/高度,以及必须规定最小切分单位)
# 4.finish building the treedef test_split(index, value, dataset):left, right = list(), list()for row in dataset:if row[index] < value:left.append(row)else:right.append(row)return left, rightdef calculate_the_gini_index(groups, classes):# 计算有多少实例n_instances = float(sum([len(group) for group in groups]))# 把每一个group里面的加权gini计算出来gini = 0.0for group in groups:size = float(len(group))# *注意,这里不能除以0,所以我们要考虑到分母为0的情况if size == 0:continuescore = 0.0for class_val in classes:p = [row[-1] for row in group].count(class_val) / sizescore += p * p# 这个做了一个加权处理gini += (1 - score) * (size / n_instances)return ginidef get_split(dataset):class_values = list(set(row[-1] for row in dataset))posi_index, posi_value, posi_score, posi_groups = 888, 888, 888, Nonefor index in range(len(dataset[0]) - 1):for row in dataset:groups = test_split(index, row[index], dataset)gini = calculate_the_gini_index(groups, class_values)print("X%d < %.3f Gini=%.3f" % ((index + 1), row[index], gini))if gini < posi_score:posi_index, posi_value, posi_score, posi_groups = index, row[index], gini, groupsreturn {'index': posi_index, 'value': posi_value, 'groups': posi_groups}def determine_the_terminal(group):outcomes = [row[-1] for row in group]return max(set(outcomes), key=outcomes.count)# 1.把数据进行切分(分为左边与右边),原数据删除掉
# 2.检查非空以及满足我们的我们设置的条件(深度/最小切分单位/非空)
# 3.一直重复类似寻找root node的操作,一直到最末端def split(node, max_depth, min_size, depth):# 做切分,并删除掉原数据left, right = node['groups']del (node['groups'])# 查看非空if not left or not right:node['left'] = node['right'] = determine_the_terminal(left + right)return# 检查最大深度是否超过if depth >= max_depth:node['left'], node['right'] = determine_the_terminal(left), determine_the_terminal(right)return# 最小分类判断与左侧继续向下分类if len(left) <= min_size:node['left'] = determine_the_terminal(left)else:node['left'] = get_split(left)split(node['left'], max_depth, min_size, depth + 1)# 最小分类判断与右侧继续向下分类if len(right) <= min_size:node['right'] = determine_the_terminal(right)else:node['right'] = get_split(right)split(node['right'], max_depth, min_size, depth + 1)# 最终建立决策树def build_the_regression_tree(train, max_depth, min_size):root = get_split(train)split(root, max_depth, min_size, 1)return root# 通过CLI可视化的呈现类树状结构便于感性认知
def print_our_tree(node, depth=0):if isinstance(node, dict):print('%s[X%d < %.3f]' % ((depth * '-', (node['index'] + 1), node['value'])))print_our_tree(node['left'], depth + 1)print_our_tree(node['right'], depth + 1)else:print('%s[%s]' % ((depth * '-', node)))def make_prediction(node, row):if row[node['index']] < node['value']:if isinstance(node['left'], dict):return make_prediction(node['left'], row)else:return node['left']else:if isinstance(node['right'], dict):return make_prediction(node['right'], row)else:return node['right']dataset = [[2.1, 1.1, 0],[3.4, 2.5, 0],[1.3, 5.8, 0],[1.9, 8.6, 0],[3.7, 6.2, 0],[8.8, 1.1, 1],[9.6, 3.4, 1],[10.2, 7.4, 1],[7.7, 8.8, 1],[9.7, 6.9, 1]]tree = build_the_regression_tree(dataset, 3, 1)
print_our_tree(tree)decision_tree_stump = {'index': 0, 'right': 1, 'value': 9.3, 'left': 0}
for row in dataset:prediction = make_prediction(decision_tree_stump, row)print("What is expected data : %d , Your prediction is %d " % (row[-1], prediction))

从零开始数据科学与机器学习算法-分类与决策树-06相关推荐

  1. 从零开始数据科学与机器学习算法-KNN分类算法-07

    KNN概念 物以类聚 1.k--超参数(hyper-parameter) 2.k最好为奇数(no even number , better be odd) 3.k大小有学问: k太小:outliers ...

  2. 从零开始数据科学与机器学习算法-数据预处理与基准模型-01

    读取数据函数 from csv import reader # 导入库 def read_csv(the_name_of_file_to_be_read): # 定义数据读取函数file = open ...

  3. 从零开始数据科学与机器学习算法-学习向量量化(Learning_Vector_Quantization)-08

    LVQ概述 通常,我们使用LVQ方法用在分类问题上. codebook vector(是一系列数字,与你训练数据里的input与output相关的特征一样) 例: 1.class 0,1 2.widt ...

  4. 从零开始数据科学与机器学习算法-朴素贝叶斯-07

    朴素贝叶斯概念 例子:邮件分类问题: N = (12/17)*(5/11)*(3/11) S = (5/17)*(2/7)*(1/7)print(N) print(S) # N>S 我们可以判断 ...

  5. 从零开始数据科学与机器学习算法-人工神经网络与反向传播-09

    概述 rectifier其实就是一种模仿生物的激活机制的函数 (activation function) 常见的激活函数 https://en.wikipedia.org/wiki/Rectifier ...

  6. 从零开始数据科学与机器学习算法-简单感知器-05

    如下图给定的一组数据可以通过一条线分割成两个不同的类别称之为Linearly_Separable 如下图有明显特征但是不能通过线性进行切分称为线性不可分 我们可以在拿到数据后进行基本的判断,然后确定是 ...

  7. 从零开始数据科学与机器学习算法-线性回归-02

    简单线性回归 import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns ...

  8. 从零开始数据科学与机器学习算法-知识点补充-00

    知识拓展-python与统计学 1.Descriptive statistics 描述性统计 2.Inferential statistics 推断性统计:步骤如下: sample样本(sample ...

  9. 从零开始数据科学与机器学习算法-集成算法-10

    概述 把各种model综合起来--让预测更准确.更加稳定(做平均) 在随机森林里面的超参数(hyper-parameter): 1.对于每一棵树,要选取特性(features),假设总共有n个feat ...

最新文章

  1. MacBook Pro新版上市
  2. U2NET目标显著性检测,抠图去背景效果倍儿棒
  3. Linux 交换文件已存在解决办法
  4. ogre研究之第一个程序(二)
  5. POJ 1651 Multiplication Puzzle(类似矩阵连乘 区间dp)
  6. 两万字深度介绍分布式系统原理!【收藏版】
  7. 使用MySQL UDFs来调用gearman分布式任务分发系统
  8. 计算机二级c语言考试真题及答案详解,2021全国计算机二级C语言程序设计历年真题及答案节选...
  9. 小学计算机片段教学案例,小学信息技术教学案例分析(张擘)
  10. 基于SSM的论坛系统
  11. IPV6之DHCPV6
  12. MATLBA中最小二乘支持向量机原理+实例分析
  13. 汽车电工及电子技术基础【3】
  14. Flowers Sky Wallet First school/Primary school holiday Getting up early Reading Home c
  15. offer和面经分享(内含offer截图)
  16. 民宿运营经验分享:玩转自我营销,带动流量持续增长
  17. MLY翻译 -- 2.How to use this book to help your team?
  18. 工作中使用了一些触发器
  19. Linux记录-sysctl.conf优化方案
  20. Rust轻量级I/O库mio

热门文章

  1. mysql 英文占几个字符_MySQL 数据库 varchar 到底可以存多少个汉字,多少个英文呢?我们来搞搞清楚...
  2. error.html mp4,HTML Video error用法及代码示例
  3. android运动轨迹怎么画,Android 利用三阶贝塞尔曲线绘制运动轨迹的示例
  4. 怎么在android中定义泛型,android – 如何在GSON TypeToken类中使用自定义泛型?
  5. go uint64 转 字符_Go的基本数据类型入门看这一篇就差不多了
  6. 0xc000007b应用程序无法正常启动_应用程序无法正常启动0xc0000142
  7. nginx 反向代理web应用将https请求转成http请求时,必须注意事项
  8. java程序语句的理解,[每日学习笔记][2012.07.10]使用Java理解程序逻辑(六)
  9. springcloud实体类抽离
  10. mysql升级回退_Mysql 升级、用户与授权,