python实现ID3

具体的决策树原理在此就不再赘述,可自行百度或者看我之前写的:https://blog.csdn.net/weixin_38273255/article/details/88752468
在这里主要列出我使用的代码,和一些学习时候的心得。

心得

代码在网上其实都有一大片一大片的,但是自己觉得不太符合自己想要的要求,所以就找了网上的代码,并做了一些修改。很多地方肯定会有毛病,还望见谅。
我想要的是能够对连续型数据做处理,但是ID3是对离散型数据做处理,网上的示例代码也是如此,之前在学习原理的时候看到说可以对连续数据做离散处理:

  • 对连续数据排序
  • 将排好的序列等分(让每个区间落入的数据量基本一致)
  • 重新对区间赋值
  • 离散化完成

可能这个也存在问题,但是就这么实现了。

代码

  1. 导入数据
  2. 对数据离散化
  3. 创建决策树
  4. 绘制决策树图像
  5. 输入测试集测试

下下面就列出我的代码,代码主要是对iris数据集做出了处理,其他数据集可自行修改。
main.py

# -*- coding: utf-8 -*-
import creat as ct
import draw as dr
import yanzheng as yz
from sklearn import datasets
import numpy as np
import lisanhua as lsh
import random#加载iris数据集
iris = datasets.load_iris()
all_data = iris.data[:,:]
all_target = iris.target[:]
labels = iris.feature_names[:]#常量定义
n = 150#数据集总数
m = int(n*2/3)#创建用的数据量
q = 4#数据维度
l = 7#离散化个数#对数据离散化
a = []
all_data,a = lsh.lsh(all_data,l)#将target和数据合并
all_data = all_data.tolist()
all_target = all_target.tolist()
for i in range(len(all_target)):all_data[i].append(all_target[i])#将数据打乱
random.shuffle(all_data)#创建决策树数据集
cj_data = all_data[:m]#创建决策树
myTree=ct.createTree(cj_data,labels)#创建验证数据集
all_data = np.array(all_data)#转化为numpy
yz_target = np.array(all_data[m:n,q:q+1])
yz_data = np.array(all_data[m:n,:q])
yz_labels = np.array(iris.feature_names[:])#验证决策树正确率
yz_shu = yz.yanzheng(myTree,yz_data,yz_labels,yz_target)
yz_bfb = float(yz_shu)/(n-m)#结果反馈
print(myTree)
print(yz_shu)
print(yz_bfb)
dr.createPlot(myTree)

lisanhua.py

import numpy as npdef lsh(data,num):a = []for i in range(len(data[0])):b = []data1 = data[:,i]l = len(data1)data1.sort()for k in range(num):b.append(data1[int(k*l/num)])for j in range(len(data)):if data[j,i] >= b[-1]:data[j,i] = b[-1]continuefor q in range(1,num):if data[j,i] < b[q] and data[j,i] >= b[q-1]:data[j,i] = b[q-1]a.append(b)return data,a

creat.py

 import math
import operatordef calcShannonEnt(dataset):numEntries = len(dataset)labelCounts = {}for featVec in dataset:currentLabel = featVec[-1]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] +=1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEnt -= prob*math.log(prob, 2)return shannonEntdef CreateDataSet():dataset = [[1, 1, 'yes' ],[1, 1, 'yes' ],[1, 0, 'no'],[0, 1, 'no'],[0, 1, 'no']]labels = ['no surfacing', 'flippers']return dataset, labelsdef splitDataSet(dataSet, axis, value):retDataSet = []for featVec in dataSet:if featVec[axis] == value:reducedFeatVec = featVec[:axis]reducedFeatVec.extend(featVec[axis+1:])retDataSet.append(reducedFeatVec)return retDataSetdef chooseBestFeatureToSplit(dataSet):numberFeatures = len(dataSet[0])-1baseEntropy = calcShannonEnt(dataSet)bestInfoGain = 0.0;bestFeature = -1;for i in range(numberFeatures):featList = [example[i] for example in dataSet]uniqueVals = set(featList)newEntropy =0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEnt(subDataSet)infoGain = baseEntropy - newEntropyif(infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeaturedef majorityCnt(classList):classCount ={}for vote in classList:if vote not in classCount.keys():classCount[vote]=0classCount[vote]=1sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]def createTree(dataSet, labels):classList = [example[-1] for example in dataSet]if classList.count(classList[0])==len(classList):return classList[0]if len(dataSet[0])==1:return majorityCnt(classList)bestFeat = chooseBestFeatureToSplit(dataSet)bestFeatLabel = labels[bestFeat]myTree = {bestFeatLabel:{}}del(labels[bestFeat])featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = labels[:]myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree

draw.py


# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
decisionNode=dict(boxstyle="sawtooth",fc="0.8")
leafNode=dict(boxstyle="round4",fc="0.8")
arrow_args=dict(arrowstyle="<-")#计算树的叶子节点数量
def getNumLeafs(myTree):numLeafs=0firstStr=myTree.keys()[0]secondDict=myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':numLeafs+=getNumLeafs(secondDict[key])else: numLeafs+=1return numLeafs#计算树的最大深度
def getTreeDepth(myTree):maxDepth=0firstStr=myTree.keys()[0]secondDict=myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':thisDepth=1+getTreeDepth(secondDict[key])else: thisDepth=1if thisDepth>maxDepth:maxDepth=thisDepthreturn maxDepth#画节点
def plotNode(nodeTxt,centerPt,parentPt,nodeType):createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',\xytext=centerPt,textcoords='axes fraction',va="center", ha="center",\bbox=nodeType,arrowprops=arrow_args)#画箭头上的文字
def plotMidText(cntrPt,parentPt,txtString):lens=len(txtString)xMid=(parentPt[0]+cntrPt[0])/2.0-lens*0.002yMid=(parentPt[1]+cntrPt[1])/2.0createPlot.ax1.text(xMid,yMid,txtString)def plotTree(myTree,parentPt,nodeTxt):numLeafs=getNumLeafs(myTree)depth=getTreeDepth(myTree)firstStr=myTree.keys()[0]cntrPt=(plotTree.x0ff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.y0ff)plotMidText(cntrPt,parentPt,nodeTxt)plotNode(firstStr,cntrPt,parentPt,decisionNode)secondDict=myTree[firstStr]plotTree.y0ff=plotTree.y0ff-1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':plotTree(secondDict[key],cntrPt,str(key))else:plotTree.x0ff=plotTree.x0ff+1.0/plotTree.totalWplotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrPt,leafNode)plotMidText((plotTree.x0ff,plotTree.y0ff),cntrPt,str(key))plotTree.y0ff=plotTree.y0ff+1.0/plotTree.totalDdef createPlot(inTree):fig=plt.figure(1,facecolor='white')fig.clf()axprops=dict(xticks=[],yticks=[])createPlot.ax1=plt.subplot(111,frameon=False,**axprops)plotTree.totalW=float(getNumLeafs(inTree))plotTree.totalD=float(getTreeDepth(inTree))plotTree.x0ff=-0.5/plotTree.totalWplotTree.y0ff=1.0plotTree(inTree,(0.5,1.0),'')plt.show()

yanzheng.py

import numpy as npdef  one(myTree,data,labels):if type(myTree) == int:return myTreekey = myTree.keys()[0]keys = key.split('<=')if data[labels == keys[0]] <= float(keys[1]):return one(myTree[myTree.keys()[0]][1],data,labels)else:return one(myTree[myTree.keys()[0]][0],data,labels)def getResult(myTree,data,labels):result = []for elem in data:result.append(one(myTree,elem,labels))return resultdef yanzheng(myTree,data,labels,target):count = 0result = getResult(myTree,data,labels)for i in range(len(result)):if(result[i] == target[i]):count += 1return count

结果

结果见:
https://blog.csdn.net/weixin_38273255/article/details/88981203

python实现ID3相关推荐

  1. 在西瓜数据集上用Python实现ID3决策树算法完整代码

    在西瓜数据集上用Python实现ID3决策树算法完整代码 文章目录 1.决策树算法代码ID3.py 2.可视化决策树代码visual_decision_tree.py 3.贴几张运行结果图 1.生成的 ...

  2. python决策树id3算法_python实现决策树ID3算法

    一.决策树概论 决策树是根据训练数据集,按属性跟类型,构建一棵树形结构.可以按照这棵树的结构,对测试数据进行分类.同时决策树也可以用来处理预测问题(回归). 二.决策树ID3的原理 有多种类型的决策树 ...

  3. 决策树 基于python实现ID3,C4.5,CART算法

    实验目录 实验环境 简介 决策树(decision tree) 信息熵 信息增益(应用于ID3算法) 信息增益率(在C4.5算法中使用) 基尼指数(被用于CART算法) 实验准备 数据集 算法大体流程 ...

  4. python实现ID3决策树及随机森林

    前言 数据集及完整代码:https://download.csdn.net/download/qq_49101550/20464229 一.对数据集进行训练集和测试集划分   从数据集中无放回地随机抽 ...

  5. python决策树id3算法_决策树ID3算法预测隐形眼睛类型--python实现

    标签: 本节讲解如何预测患者需要佩戴的隐形眼镜类型. 1.使用决策树预测隐形眼镜类型的一般流程 (1)收集数据:提供的文本文件(数据来源于UCI数据库) (2)准备数据:解析tab键分隔的数据行 (3 ...

  6. python中id3决策树算法_ID3决策树算法实现(Python版)

    1 #-*- coding:utf-8 -*- 2 3 from numpy import * 4 importnumpy as np5 importpandas as pd6 from math i ...

  7. 决策树算法python实现_决策树之python实现ID3算法(例子)

    1 #-*- coding: utf-8 -*- 2 from math importlog3 importoperator4 importpickle5 ''' 6 输入:原始数据集.子数据集(最后 ...

  8. python决策树id3算法_Python3 决策树ID3算法实现

    1 #-*- coding: utf-8 -*- 2 """ 3 Created on Wed Jan 24 19:01:40 20184 5 @author: mark ...

  9. ID3决策树 Python实现 + sklearn库决策树模型的应用

    本文介绍机器学习中决策树算法的python实现过程 共介绍两类方法: (1)亲手实习Python ID3决策树经典算法 (2)利用sklearn库实现决策树算法 关于决策树的原理,指路:机器学习 第四 ...

最新文章

  1. ES6变量常量字符串数值
  2. 中秋祝福网页制作_中秋节祝福语不知怎么写?这3个小程序让你的祝福更精美...
  3. 如何绕过浏览器的弹窗拦截机制
  4. 【MySQL】触发器:让指定某一字段的值等于id
  5. 只有一个显示器但是显示两个显示器_小米34寸曲面显示器深度体验 办公体验极佳 但是还有个大弱点...
  6. asp.net自定义控件的使用
  7. windows服务与其他进程使用MemoryMappedFile
  8. 数据科学入门与实战:玩转pandas之四
  9. python实现matlab_用matlab和python实现符号运算
  10. Java对象创建的过程及对象的内存布局与访问定位
  11. oppoJava面试!mysql客户端安装包
  12. 如何用VB编程实现关闭WINDOWS窗口?
  13. Linux 进程之如何查看进程详情?
  14. 独家揭秘:暴利的黑帽SEO行业
  15. 如何处理微信域名防封
  16. java语言有很多的优点和特点_Java语言具有许多优点和特点,下列选项中()反映了Java程序并行机制的特点 (5.0分)_学小易找答案...
  17. Circular RNA的产生机制、功能及RNA-seq数据鉴定方法
  18. VM虚拟机安装orcle数据库
  19. 用计算机如何计算投资回收期,动态投资回收期怎么算的?
  20. 路由器怎么用自己的笔记本电脑进行配置

热门文章

  1. ftp服务器备份手机文件,ftp服务器文件自动备份
  2. html网页开发入门
  3. node php聊天室,利用socket.io实现多人聊天室(基于Nodejs)
  4. 关于“#define REG_MEM_BASE (*(volatile unsigend long *)(PA_BAES + 0x00000050))”语句的解析
  5. 解析Token工具类
  6. 2023年5种最适合网络安全工程师学习运用的编程语言
  7. 【SSL/TLS】准备工作:HTTPS服务器部署:Nginx部署
  8. 【HLS教程】HLS入门与精通
  9. 带你了解跨站请求伪造(CSRF),具体代码实现
  10. 【JavaWeb开发】Referer防盗链的详解