效果图:

原始数据文件:

lense.txt
young    myope   no  reduced no lenses
young   myope   no  normal  soft
young   myope   yes reduced no lenses
young   myope   yes normal  hard
young   hyper   no  reduced no lenses
young   hyper   no  normal  soft
young   hyper   yes reduced no lenses
young   hyper   yes normal  hard
pre myope   no  reduced no lenses
pre myope   no  normal  soft
pre myope   yes reduced no lenses
pre myope   yes normal  hard
pre hyper   no  reduced no lenses
pre hyper   no  normal  soft
pre hyper   yes reduced no lenses
pre hyper   yes normal  no lenses
presbyopic  myope   no  reduced no lenses
presbyopic  myope   no  normal  no lenses
presbyopic  myope   yes reduced no lenses
presbyopic  myope   yes normal  hard
presbyopic  hyper   no  reduced no lenses
presbyopic  hyper   no  normal  soft
presbyopic  hyper   yes reduced no lenses
presbyopic  hyper   yes normal  no lenses

treestore.py代码

存储恢复tree

#-*-coding:utf-8-*-def  storeTree(inputTree, filename):"""Function:   存储决策树Args:       inputTree:树信息filename:文件名称Returns:    无"""#导入模块import pickle#新建文件,一定要加b属性,否则可能报错:#TypeError: write() argument must be str, not bytesfw = open(filename, 'wb')#写入数据pickle.dump(inputTree, fw)#关闭文件fw.close()def grabTree(filename):"""Function:   读取决策树Args:       filename:文件名称Returns:    pickle.load(fr):树信息"""#导入模块import pickle#打开文件,写入属性一致,否则可能报错:#UnicodeDecodeError: 'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequencefr = open(filename, 'rb')#导出数据return pickle.load(fr)

treeplot1.py绘制tree图像代码

# _*_ coding: UTF-8 _*_import matplotlib.pyplot as plt"""绘决策树的函数"""
decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 定义分支点的样式
leafNode = dict(boxstyle="round4", fc="0.8")  # 定义叶节点的样式
arrow_args = dict(arrowstyle="<-")  # 定义箭头标识样式# 计算树的叶子节点数量
def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafs# 计算树的最大深度
def getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth# 画出节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \xytext=centerPt, textcoords='axes fraction', va="center", ha="center", \bbox=nodeType, arrowprops=arrow_args)# 标箭头上的文字
def plotMidText(cntrPt, parentPt, txtString):lens = len(txtString)xMid = (parentPt[0] + cntrPt[0]) / 2.0 - lens * 0.002yMid = (parentPt[1] + cntrPt[1]) / 2.0createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.x0ff + \(1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.y0ff)plotMidText(cntrPt, parentPt, nodeTxt)plotNode(firstStr, cntrPt, parentPt, decisionNode)secondDict = myTree[firstStr]plotTree.y0ff = plotTree.y0ff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.x0ff = plotTree.x0ff + 1.0 / plotTree.totalWplotNode(secondDict[key], \(plotTree.x0ff, plotTree.y0ff), cntrPt, leafNode)plotMidText((plotTree.x0ff, plotTree.y0ff) \, cntrPt, str(key))plotTree.y0ff = plotTree.y0ff + 1.0 / plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.x0ff = -0.5 / plotTree.totalWplotTree.y0ff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()if __name__=='__main__':

id3的决策树代码

from math import log
import operator
import numpy as np
import pandas as pd
from pandas import DataFrame, Series
import treeplot1
import treestore# 计算数据的熵(entropy)-原始熵
def dataentropy(data, feat):lendata = len(data)  # 数据条数labelCounts = {}  # 数据中不同类别的条数for featVec in data:category = featVec[-1]  # 每行数据的最后一个字(叶子节点)if category not in labelCounts.keys():labelCounts[category] = 0labelCounts[category] += 1  # 统计有多少个类以及每个类的数量entropy = 0for key in labelCounts:prob = float(labelCounts[key]) / lendata  # 计算单个类的熵值entropy -= prob * log(prob, 2)  # 累加每个类的熵值return entropy# 处理后导入数据数据
def Importdata(datafile):dataa = pd.read_excel(datafile)  # datafile是excel文件,所以用read_excel,如果是csv文件则用read_csv# 将文本中不可直接使用的文本变量替换成数字productDict = {'高': 1, '一般': 2, '低': 3, '帅': 1, '丑': 3, '胖': 3, '瘦': 1, '是': 1, '否': 0}dataa['income'] = dataa['收入'].map(productDict)  # 将每一列中的数据按照字典规定的转化成数字dataa['hight'] = dataa['身高'].map(productDict)dataa['look'] = dataa['长相'].map(productDict)dataa['shape'] = dataa['体型'].map(productDict)dataa['is_meet'] = dataa['是否见面'].map(productDict)data = dataa.iloc[:, 5:].values.tolist()  # 取量化后的几列,去掉文本列b = dataa.iloc[0:0, 5:-1]labels = b.columns.values.tolist()  # 将标题中的值存入列表中return data, labels# 按某个特征value分类后的数据
def splitData(data, i, value):splitData = []for featVec in data:if featVec[i] == value:rfv = featVec[:i]rfv.extend(featVec[i + 1:])splitData.append(rfv)return splitData# 选择最优的分类特征
def BestSplit(data):numFea = len(data[0]) - 1  # 计算一共有多少个特征,因为最后一列一般是分类结果,所以需要-1baseEnt = dataentropy(data, -1)  # 定义初始的熵,用于对比分类后信息增益的变化bestInfo = 0bestFeat = -1for i in range(numFea):featList = [rowdata[i] for rowdata in data]uniqueVals = set(featList)newEnt = 0for value in uniqueVals:subData = splitData(data, i, value)  # 获取按照特征value分类后的数据prob = len(subData) / float(len(data))newEnt += prob * dataentropy(subData, i)  # 按特征分类后计算得到的熵info = baseEnt - newEnt  # 原始熵与按特征分类后的熵的差值,即信息增益if (info > bestInfo):  # 若按某特征划分后,若infoGain大于bestInf,则infoGain对应的特征分类区分样本的能力更强,更具有代表性。bestInfo = info  # 将infoGain赋值给bestInf,如果出现比infoGain更大的信息增益,说明还有更好地特征分类bestFeat = i  # 将最大的信息增益对应的特征下标赋给bestFea,返回最佳分类特征return bestFeat# 按分类后类别数量排序,取数量较大的
def majorityCnt(classList):c_count = {}for i in classList:if i not in c_count.keys():c_count[i] = 0c_count[i] += 1ClassCount = sorted(c_count.items(), key=operator.itemgetter(1), reverse=True)  # 按照统计量降序排序return ClassCount[0][0]  # reverse=True表示降序,因此取[0][0],即最大值# 建树
def createTree(data, labels):classList = [rowdata[-1] for rowdata in data]  # 取每一行的最后一列,分类结果(1/0)if classList.count(classList[0]) == len(classList):return classList[0]if len(data[0]) == 1:return majorityCnt(classList)bestFeat = BestSplit(data)  # 根据信息增益选择最优特征bestLab = labels[bestFeat]myTree = {bestLab: {}}  # 分类结果以字典形式保存del (labels[bestFeat])featValues = [rowdata[bestFeat] for rowdata in data]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestLab][value] = createTree(splitData(data, bestFeat, value), subLabels)return myTree# 选择最优的分类特征C.45算法
def BestSplitc45(data):numFea = len(data[0])-1#计算一共有多少个特征,因为最后一列一般是分类结果,所以需要-1baseEnt = dataentropy(data,-1)   # 定义初始的熵,用于对比分类后信息增益的变化bestGainRate = 0bestFeat = -1for i in range(numFea):featList = [rowdata[i] for rowdata in data]uniqueVals = set(featList)newEnt = 0for value in uniqueVals:subData = splitData(data,i,value)#获取按照特征value分类后的数据prob =len(subData)/float(len(data))newEnt +=prob*dataentropy(subData,i)  # 按特征分类后计算得到的熵info = baseEnt - newEnt  # 原始熵与按特征分类后的熵的差值,即信息增益splitonfo = dataentropy(subData,i) #分裂信息if splitonfo == 0:#若特征值相同(eg:长相这一特征的值都是帅),即splitonfo和info均为0,则跳过该特征continueGainRate = info/splitonfo #计算信息增益率if (GainRate>bestGainRate):   # 若按某特征划分后,若infoGain大于bestInf,则infoGain对应的特征分类区分样本的能力更强,更具有代表性。bestGainRate=GainRate #将infoGain赋值给bestInf,如果出现比infoGain更大的信息增益,说明还有更好地特征分类bestFeat = i #将最大的信息增益对应的特征下标赋给bestFea,返回最佳分类特征return bestFeatdef classify(inputTree, featLabels, testVec):"""Function:   使用决策树的分类函数Args:       inputTree:树信息featLabels:标签列表testVec:测试数据Returns:    classLabel:分类标签"""#第一个关键字为第一次划分数据集的类别标签,附带的取值表示子节点的取值firstStr = list(inputTree.keys())[0]#新的树,相当于脱了一层皮secondDict = inputTree[firstStr]#将标签字符串转为索引featIndex = featLabels.index(firstStr)#遍历整棵树for key in secondDict.keys():#比较testVec变量中的值与树节点的值if testVec[featIndex] == key:#判断子节点是否为字典类型,进而得知是否到达叶子结点if type(secondDict[key]).__name__=='dict':#没到达叶子结点,则递归调用classify()classLabel = classify(secondDict[key], featLabels, testVec)else:#到达叶子结点,则分类结果为当前节点的分类标签classLabel = secondDict[key]#返回分类标签return classLabelif __name__ == '__main__':# datafile = u'E:\\pythondata\\lense.txt'  # 文件所在位置,u为防止路径中有中文名称# datafile='lense.txt'fr = open("lense.txt")lenses = [inst.strip().split("\t") for inst in fr.readlines()]lensesLabels = ["age", "prescript", "astigmatic", "tearRte"]# data, labels = Importdata(datafile)  # 导入数据lensesTree=createTree(lenses, lensesLabels)print(lensesTree)  # 输出决策树模型结果#tree的文件存储测试恢复测试treestore.storeTree(lensesTree, 'classifierStorage.txt')lensesTree1=treestore.grabTree('classifierStorage.txt')print(lensesTree1)treeplot1.createPlot(lensesTree)#隐形眼镜类型决策分析lensesLabels = ["age", "prescript", "astigmatic", "tearRte"]  #这个必须再次给值,createtree会改变label,再次校验提示缺少最后一个元素glasstype=classify(lensesTree, lensesLabels, ['presbyopic', 'hyper', 'no', 'normal'])print(glasstype)

python:ai第五课:决策树的python实现调试,tree存储恢复,tree图绘制相关推荐

  1. 视频教程-Python创意编程视频课CSDN版-Python

    Python创意编程视频课CSDN版 20多年的编程磨砺,对编程技术有着执着的追求,开发有大量Python课程资源,如<哪吒学Python启蒙篇>.<哪吒学Python初级篇> ...

  2. 【Python】第五课 函数

    5.1 什么是函数 函数其实就是将一些需要经常执行和运用的代码进行整合打包起来,当需要用的时候直接调用即可,无需再花时间进行重新编写,这样可以加快开发项目的进度,缩短项目的开发周期.python也给我 ...

  3. python培训第五课

    定时器: import threadingdef run():print("定时器启动了")t2 = threading.Timer(5, run) # run函数运行五秒后再次运 ...

  4. 十五、Python第十五课——测试代码

    (请先看这篇文章:https://blog.csdn.net/GenuineMonster/article/details/104495419) 也许你听过软件测试?编写函数或类时,可以为其编写对应的 ...

  5. Python爬虫第五课 购票项目爬虫实战

    12306购票抓包分析以及任务分解 学习目标: 了解 12306抓包过程 1.1 抓包分析 使用谷歌浏览器或fiddler等抓包工具完成登陆以及购票操作,进行抓包,根据 具有业务作用 或 被set-c ...

  6. Python爬虫第五课:存储数据

    一.存储数据的方式 同样是存储数据的文件,"csv"格式文件和Excel文件有什么区别呢? 1. CSV文件 我们知道json是特殊的字符串.其实,csv也是一种字符串文件的格式, ...

  7. python基础第五课--从字符串中打包和解包大整数(小白piao分享)

    #4.从字符串打包和解包大整数 #将一个字符串解包成一个大整数,将一个大整数打包成一个字符串 #4.1 解决方案: #假设程序需要处理一个有16个元素的字节串,其中保存着一个128位的大整数 data ...

  8. python数据分析实战五_简单的python数据分析实战——黑五销售数据分析

    黑色星期五(通过消费者行为进行销售研究) 背景描述:关于零售商店中黑色星期五的55万个观测数据集.它包含不同类型的数字或分类变量,包含缺失值. 1.理解数据 数据包含538K行,12列.各列含义如下: ...

  9. python数据分析方法五种_加速Python数据分析的10个简单技巧(上)

    总有一些小贴士和技巧在编程领域是非常有用的.有时,一个小技巧可以节省时间甚至可以挽救生命.一个小的快捷方式或附加组件有时会被证明是天赐之物,并能真正提高生产力.因此,我总结了一些我最喜欢的一些贴士和技 ...

最新文章

  1. JSP 自定义标签介绍
  2. matlab显示YCrCb的图像,【Matlab系列】读取并显示YUV视频文件
  3. 路由器mstp多域配置举例
  4. Struts2 关于返回type=chain的用法.
  5. 深入Redis内部-Redis 源码讲解
  6. Mysql升级过程的问题
  7. 2019牛客暑期多校训练营(第八场)
  8. xp系统开机自检很久_电脑开机不能自检的五大原因
  9. 巴斯卡三角形 and 數字位移
  10. python socket编程实现的简单tcp迭代server
  11. mysql数据库的体系结构包括什么组件_MySQL数据库的体系结构
  12. [Git GitHub] Windows下安装git,从0开始搭建git环境(配置环境变量+设置git-ssh key...配置)(超全版)
  13. SQL取分组中前、后几条数据
  14. linux创建django项目,Ubuntu 16.04下配置Django项目
  15. APP适配安卓手机刘海屏
  16. vue组件样式scoped
  17. 【数据结构笔记20】图的定义,图的表示:邻接矩阵与邻接表
  18. 7.2 重入锁(ReentrantLock)
  19. 一个定时器的普通实现,多进程实现和多线程实现的对比
  20. 字符串低位优先排序真的只能排序字符串相同的字符么?

热门文章

  1. 将kali Linux系统的语言切换为中文
  2. 真爱空间网络办公OA
  3. vue抽出组件并传值
  4. 前端angular与服务器端nodejs实现从mysql数据库读取数据实现前后端交互实例
  5. PyeCharts绘制K线图(续)
  6. 时光倒流童年_使用Microsoft Excel时光倒流
  7. Python标准库之Turtle
  8. JSP简介以及常见动态网站开发技术(Asp.net、Php、Jsp)
  9. STM32 温度采集及WIFI电路设计
  10. 从c语言到Python (4)循环语句