目录

1、ID3算法

2、使用sklearn API——模型保存和调用成功


1、ID3算法

以下实现了决策树的创建、可视化绘制、决策树的保存和调用

但是在利用决策树进行预测的时候出现错误

分类代码

#实用决策树进行分类
def classify(inputTree, featLabels, testVec):  firstStr = inputTree.keys()[0]  secondDict = inputTree[firstStr]  featIndex = featLabels.index(firstStr)  for key in secondDict.keys():  if testVec[featIndex] == key:  if type(secondDict[key]).__name__ == 'dict':  classLabel = classify(secondDict[key], featLabels, testVec)  else: classLabel = secondDict[key]  return classLabel 

UnboundLocalError: local variable 'classLabel' referenced before assignment

暂时未解决

 完整代码

from dataProcess import loaddatasets
from math import log
import operator
import json
import numpy as np
from paths import abs_path##数据集
def createDataSet(xlsPath):"""创建数据集"""datas, labels = loaddatasets(xlsPath)labels = labels.reshape(-1, 1)dataSet = np.hstack((datas, labels))dataSet = dataSet.tolist()featureName = ['能效设计', '含油舱底水污染控制', '污油污染控制', '餐饮污水控制', '生活污水控制', '发动机排气污染物控制','制冷剂', '灭火剂', '垃圾污染控制', '防止噪声污染', '应用比例', '振动', '噪声', '有害物质的禁用和限用']# 返回数据集和每个维度的名称return dataSet, featureName##分割数据集
def splitDataSet(dataSet, axis, value):"""按照给定特征划分数据集:param axis:划分数据集的特征的维度:param value:特征的值:return: 符合该特征的所有实例(并且自动移除掉这维特征)"""# 循环遍历dataSet中的每一行数据retDataSet = []for featVec in dataSet:if featVec[axis] == value:reduceFeatVec = featVec[:axis]  # 删除这一维特征reduceFeatVec.extend(featVec[axis + 1:])retDataSet.append(reduceFeatVec)return retDataSetdef majorityCnt(classList):classCount = {}for vote in classList:# 统计classList中每个元素出现的次数if vote not in classCount.keys():classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 根据字典的值降序排序return sortedClassCount[0][0]  # 返回classList中出现次数最多的元素##计算信息熵
# 计算的始终是类别标签的不确定度
def calcShannonEnt(dataSet):"""计算训练数据集中的Y随机变量的香农熵:param dataSet::return:"""numEntries = len(dataSet)  # 实例的个数labelCounts = {}for featVec in dataSet:  # 遍历每个实例,统计标签的频次currentLabel = featVec[-1]  # 表示最后一列# 当前标签不在labelCounts map中,就让labelCounts加入该标签if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 0labelCounts[currentLabel] += 1shannonEnt = 0.0for key in labelCounts:prob = float(labelCounts[key]) / numEntriesshannonEnt -= prob * log(prob, 2)  # log base 2return shannonEnt## 计算条件熵
def calcConditionalEntropy(dataSet, i, featList, uniqueVals):"""计算x_i给定的条件下,Y的条件熵:param dataSet: 数据集:param i: 维度i:param featList: 数据集特征列表:param unqiueVals: 数据集特征集合:return: 条件熵"""ce = 0.0for value in uniqueVals:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet) / float(len(dataSet))  # 极大似然估计概率ce += prob * calcShannonEnt(subDataSet)  # ∑pH(Y|X=xi) 条件熵的计算return ce##计算信息增益
def calcInformationGain(dataSet, baseEntropy, i):"""计算信息增益:param dataSet: 数据集:param baseEntropy: 数据集中Y的信息熵:param i: 特征维度i:return: 特征i对数据集的信息增益g(dataSet | X_i)"""featList = [example[i] for example in dataSet]  # 第i维特征列表uniqueVals = set(featList)  # 换成集合 - 集合中的每个元素不重复newEntropy = calcConditionalEntropy(dataSet, i, featList, uniqueVals)  # 计算条件熵,infoGain = baseEntropy - newEntropy  # 信息增益 = 信息熵 - 条件熵return infoGain## 算法框架
def chooseBestFeatureToSplitByID3(dataSet):"""选择最好的数据集划分:param dataSet::return:"""numFeatures = len(dataSet[0]) - 1  # 最后一列是分类baseEntropy = calcShannonEnt(dataSet)  # 返回整个数据集的信息熵bestInfoGain = 0.0bestFeature = -1for i in range(numFeatures):  # 遍历所有维度特征infoGain = calcInformationGain(dataSet, baseEntropy, i)  # 返回具体特征的信息增益if (infoGain > bestInfoGain):bestInfoGain = infoGainbestFeature = ireturn bestFeature  # 返回最佳特征对应的维度def createTree(dataSet, featureName, chooseBestFeatureToSplitFunc=chooseBestFeatureToSplitByID3):"""创建决策树:param dataSet: 数据集:param featureName: 数据集每一维的名称:return: 决策树"""classList = [example[-1] for example in dataSet]  # 类别列表if classList.count(classList[0]) == len(classList):  # 统计属于列别classList[0]的个数return classList[0]  # 当类别完全相同则停止继续划分if len(dataSet[0]) == 1:  # 当只有一个特征的时候,遍历所有实例返回出现次数最多的类别return majorityCnt(classList)  # 返回类别标签bestFeat = chooseBestFeatureToSplitFunc(dataSet)  # 最佳特征对应的索引bestFeatLabel = featureName[bestFeat]  # 最佳特征myTree = {bestFeatLabel: {}}  # map 结构,且key为featureLabeldel (featureName[bestFeat])# 找到需要分类的特征子集featValues = [example[bestFeat] for example in dataSet]uniqueVals = set(featValues)for value in uniqueVals:subLabels = featureName[:]  # 复制操作myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)return myTree# 预测
#####################官方代码(报错)###################################
def predict_(inputTree, featLabels, testVec):  # 传入参数:决策树,属性标签,待分类样本global classLabelfirstStr = list(inputTree.keys())[0]  # 树根代表的属性secondDict = inputTree[firstStr]# print(secondDict)featIndex = featLabels.index(firstStr)  # 树根代表的属性,所在属性标签中的位置,即第几个属性for key in list(secondDict.keys()):if testVec[featIndex] == key:if type(secondDict[key]).__name__ == 'dict':classLabel = predict(secondDict[key], featLabels, testVec)else:classLabel =  secondDict[key]return classLabel#######################自己代码(未跑通)#################################
def predict(tree,featureNames,testVec):''':param tree: 决策树:param featureNames: 属性名称:param testVec: 待测试数据向量:return:'''def predict(secondTree,featureNames,testVec):global predict_label# 当下一个不是字典时,说明已经到了决策树叶节点if type(secondTree).__name__ != "dict":predict_label = secondTreereturn# 当下一个为字典,继续遍历elif type(secondTree).__name__ == "dict":# 根节点属性名称rootName = list(secondTree.keys())[0]# 获取根节点属性值rootValue = testVec[featureNames.index(rootName)]# 根据根节点属性值选择分支secondTree = secondTree[rootName][str(rootValue)]predict(secondTree,featureNames,testVec)global predict_label# 根节点属性名称rootName = list(tree.keys())[0]# 获取根节点属性值rootValue = testVec[featureNames.index(rootName)]# 根据根节点属性值选择分支secondTree = tree[rootName][str(rootValue)]predict_label = predict(secondTree,featureNames,testVec)print(predict_label)return predict_label# 计算模型准确率
def evalute(testDataSets, featureName, tree):'''测试准确率:param testDataList:待测试数据集:param testLabelList: 待测试标签集:param tree: 训练集生成的树:return: 准确率'''testLabelList = []  # 类别列表testDataList = []for example in testDataSets:testDataList.append(example[:-1])testLabelList.append(example[-1])# 错误次数计数errorCnt = 0# 遍历测试集中每一个测试样本for i in range(len(testDataList)):# 判断预测与标签中结果是否一致if testLabelList[i] != predict(tree, featureName, testDataList[i]):errorCnt += 1# 返回准确率return 1 - errorCnt / len(testDataList)# 扩展json类方法,使得能够存储数组
class NpEncoder(json.JSONEncoder):def default(self, obj):if isinstance(obj, np.integer):return int(obj)elif isinstance(obj, np.floating):return float(obj)elif isinstance(obj, np.ndarray):return obj.tolist()  # 其实就是将数组转化成了列表进行保存else:return super(NpEncoder, self).default(obj)# 保存模型
def save_model(tree, model_save_path):json_str = json.dumps(tree, indent=4, cls=NpEncoder)with open(model_save_path, "w", encoding="utf-8") as f:f.write(json_str)# 加载模型
def load_model(model_path):# 加载模型with open(model_path, 'r', encoding="UTF-8") as f:tree = json.load(f)return tree###################################绘制################################################
import matplotlib.pyplot as plt# 定义文本框和箭头格式
decisionNode = dict(boxstyle="round4", color='#3366FF')  # 定义判断结点形态
leafNode = dict(boxstyle="circle", color='#FF6633')  # 定义叶结点形态
arrow_args = dict(arrowstyle="<-", color='g')  # 定义箭头# 绘制带箭头的注释
def plotNode(nodeTxt, centerPt, parentPt, nodeType):createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)# 计算叶结点数
def getNumLeafs(myTree):numLeafs = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])else:numLeafs += 1return numLeafs# 计算树的层数
def getTreeDepth(myTree):maxDepth = 0firstStr = list(myTree.keys())[0]secondDict = myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':thisDepth = 1 + getTreeDepth(secondDict[key])else:thisDepth = 1if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepth# 在父子结点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)def plotTree(myTree, parentPt, nodeTxt):numLeafs = getNumLeafs(myTree)depth = getTreeDepth(myTree)firstStr = list(myTree.keys())[0]cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)plotMidText(cntrPt, parentPt, nodeTxt)  # 在父子结点间填充文本信息plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 绘制带箭头的注释secondDict = myTree[firstStr]plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__ == 'dict':plotTree(secondDict[key], cntrPt, str(key))else:plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalWplotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalDdef createPlot(inTree):fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)plotTree.totalW = float(getNumLeafs(inTree))plotTree.totalD = float(getTreeDepth(inTree))plotTree.xOff = -0.5 / plotTree.totalWplotTree.yOff = 1.0plotTree(inTree, (0.5, 1.0), '')plt.show()###################################绘制################################################if __name__ == '__main__':# # 训练模型# print("加载数据集")dataSets, featureName = createDataSet(abs_path + "\\data\\min_datas.xlsx")# print("创建决策树")# mytree = createTree(dataSets, featureName)# print("决策树:", mytree)# print("保存决策树")# save_model(mytree, abs_path + "\\data\\decitionTree.json")# 评估模型print("评估决策树:")print("调用决策树模型")tree = load_model(abs_path + "\\data\\decitionTree.json")print(tree)acc = evalute(dataSets, featureName, tree)print("正确率为:%.2f" % acc)print("保存决策树")save_model(tree, abs_path + "\\data\\decitionTree_%.2f.json" % (acc))# 利用模型预测testdata = [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]print("测试数据:", testdata)print("真实标签:", 0)predict_label = (tree, featureName, testdata)print("预测标签:", predict_label)# ############################ 绘制决策树###############################
# from pylab import *
# mpl.rcParams['font.sans-serif'] = ['SimHei']  # 指定默认字体
# mpl.rcParams['axes.unicode_minus'] = False  # 解决保存图像时负号'-'显示为方块的问题
# # 绘制决策树
# createPlot(tree)
# ###################################################################

2、使用sklearn API——模型保存和调用成功

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2022/1/4 11:57
# @Author  : @linlianqin
# @Site    :
# @File    : test33.py
# @Software: PyCharm
# @description:
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_textdef storeTree(inputTree, filename):import picklefw = open(filename, 'wb')pickle.dump(inputTree, fw)fw.close()def grabTree(filename):import picklefr = open(filename,'rb')return pickle.load(fr)iris = load_iris()import numpy as np
decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
decision_tree = decision_tree.fit(iris.data, iris.target)
print(decision_tree)print("训练得到的模型")
for data in iris.data:data = data.reshape(1, -1)print(decision_tree.predict(data))breakstoreTree(decision_tree,'12.pkl')
tree = grabTree('12.pkl')
print("加载出来的模型")
for data in iris.data:data = data.reshape(1,-1)print(decision_tree.predict(data))break
print(tree)
r = export_text(decision_tree, feature_names=iris['feature_names'])
print(r)

运行结果:

DecisionTreeClassifier(max_depth=2, random_state=0)
训练得到的模型
[0]
加载出来的模型
[0]
DecisionTreeClassifier(max_depth=2, random_state=0)
|--- petal width (cm) <= 0.80
|   |--- class: 0
|--- petal width (cm) >  0.80
|   |--- petal width (cm) <= 1.75
|   |   |--- class: 1
|   |--- petal width (cm) >  1.75
|   |   |--- class: 2

注意:读写形式应该为wb,rb,不然会报错

TypeError: write() argument must be str, not bytes
def storeTree(inputTree, filename):import picklefw = open(filename, 'wb')pickle.dump(inputTree, fw)fw.close()def grabTree(filename):import picklefr = open(filename,'rb')return pickle.load(fr)

代码优化后:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2021/12/31 20:35
# @Author  : @linlianqin
# @Site    :
# @File    : decisionTree_.py
# @Software: PyCharm
# @description:# 实现决策树分类,值得注意的是对于数据的处理from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_text
from sklearn.metrics import precision_score, accuracy_score, recall_score, roc_curve
import picklefrom dataProcess import loaddatasets
from paths import abs_path
import numpy as np# 创建数据集
def createDataSet(xlsPath):datas, levels = loaddatasets(xlsPath)levels = levels.reshape(-1, 1)return datas,levels# 创建决策树
def createTree(datas,labels):decision_tree = DecisionTreeClassifier()decision_tree = decision_tree.fit(datas, labels)return decision_tree# 预测
def predict(tree,testVec):predict_label = tree.predict(testVec)return predict_label[0]# 保存模型
def storeTree(inputTree, filename):fw = open(filename, 'wb')pickle.dump(inputTree, fw)fw.close()# 加载模型
def grabTree(filename):fr = open(filename,'rb')return pickle.load(fr)# 模型的评价
def model_s(y_predict, y_test):# acc = accuracy_score(y_test, y_predict)# precision =  precision_score(y_test, y_predict,average='macro')# recall = recall_score(y_test, y_predict,average='macro')print("正确率为:", np.sum(y_predict == y_test) / len(y_test))print("准确率:", accuracy_score(y_test, y_predict))print("精确率:", precision_score(y_test, y_predict,average='macro'))print("查全率:", recall_score(y_test, y_predict,average='macro'))# return acc,precision,recall# 将树写入TXT中
def writeIntoTxt(tree,featureName,filename):r = export_text(tree, feature_names = featureName)with open(filename,'w') as f:f.writelines(r)if __name__ == '__main__':from paths import abs_pathxlsPath = abs_path + "\\data\\min_datas.xlsx"model_path = abs_path+"\\data\\decisionTree.pkl"txt_path = abs_path = abs_path + "\\data\\decisionTree.txt"testVec = [0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0]testVec = np.array(testVec).reshape(1, -1)featureNames = ['energy', 'oil-1', 'oil-2', "water-1", "water-2", "air", "cold", "fire","pm", "rubbish", "noise", "posi", "CEAR", "vibration", "noise1", "bad"]print("testVec:",testVec)print("featureNames:\n",featureNames)print("load datasets")datas,labels = createDataSet(xlsPath)print("building Tree")mytree = createTree(datas,labels)print("mytree:\n",mytree)print('save model......')storeTree(mytree,model_path)print('test.......')predict_label = predict(mytree,testVec)print("predict label:",predict_label)print('load model......')loadTree = grabTree(model_path)print('test load tree......')predict_label = predict(loadTree,testVec)print("predict label:",predict_label)

【机器学习——决策树】——两种方法实现,含模型的保存和调用相关推荐

  1. 利用多线程爬虫搭建代理ip池的两种方法(含源码)

    搭建爬虫代理ip池的两种方法(含源码) 前言 一.ip池是什么? 二.爬取原理 三.使用步骤 方法一 爬取网站https://www.kuaidaili.com/ 验证 存取到mysql 方法二 爬取 ...

  2. python入门小项目-判断闰年的两种方法(含代码示例)

    每当问到今年是不是闰年都先要思考一下再给出回答,一个年份的判断还好,多个问起来就得想一会了,虽然在网上也可以查,但在这里,还是和大家分享下用哪个Python怎么去帮我们快速做判断. 我们先搞清楚什么是 ...

  3. mac安装mysql的两种方法(含配置)

    参考 https://www.jianshu.com/p/fd3aae701db9 转载于:https://www.cnblogs.com/kaituorensheng/p/11193025.html

  4. vuex【解决方案】刷新页面数据丢失(两种方法)—— 含 vuex-along 教程

    问题描述 存入vuex中的数据,在用户刷新页面后会丢失 原因解析 js代码运行时所有变量.函数都保存在内存中.刷新页面后,以前申请的内存被释放,脚本代码重新加载,变量会重新赋值. 解决方案一 sess ...

  5. 机器学习的几种方法(knn,逻辑回归,SVM,决策树,随机森林,极限随机树,集成学习,Adaboost,GBDT)

     一.判别模式与生成模型基础知识 举例:要确定一个瓜是好瓜还是坏瓜,用判别模型的方法是从历史数据中学习到模型,然后通过提取这个瓜的特征来预测出这只瓜是好瓜的概率,是坏瓜的概率. 举例:利用生成模型是根 ...

  6. DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练、预测

    DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)利用MNIST数据集进行训练.预测 导读 利用python的numpy计算库,进行自定义搭建2层神经网络TwoLayerN ...

  7. 加载dict_PyTorch 7.保存和加载pytorch模型的两种方法

    众所周知,python的对象都可以通过torch.save和torch.load函数进行保存和加载(不知道?那你现在知道了(*^_^*)),比如: x1 = {"d":" ...

  8. ML:模型训练/模型评估中常用的两种方法代码实现(留一法一次性切分训练和K折交叉验证训练)

    ML:模型训练/模型评估中常用的两种方法代码实现(留一法一次性切分训练和K折交叉验证训练) 目录 模型训练评估中常用的两种方法代码实现 T1.留一法一次性切分训练 T2.K折交叉验证训 模型训练评估中 ...

  9. 水平集方法引入主动轮廓模型算法中的两种方法

    水平集方法引入主动轮廓模型算法中的两种方法 1.传统的基于主动轮廓模型和水平集理论的方法 2.变分水平集方法 在讲解水平集理论在主动轮廓模型中的应用前,我们先用流程图说明一下常见的处理主动轮廓模型的流 ...

最新文章

  1. c语言中字符串数组的地址存放以及%s输出单个字符导致程序崩溃的问题
  2. 从今天开始收集一些经典的算法。
  3. 病毒入侵:全靠分布式
  4. CSS:响应式下的折叠菜单(条纹式)
  5. 多线程中数据的并发访问与保护
  6. Jmeter 下载和安装
  7. c++ 设计模式_设计模式行为型:观察者模式(ObserverPattern)
  8. 基于Token实现开放API接口签名验证
  9. 计算机组成和体系结构-Flynn分类法
  10. mac卸载java1.7_Mac 下安装、卸载Java 7
  11. 计算机组装流程详解,笔记本电脑组装全过程图文详解
  12. 等保测评机构申请难吗?到哪里申请?
  13. channel的实现原理
  14. Filament渲染引擎剖析 之 通过图元构建几何体
  15. 一个简单的网页制作期末作业,学生个人html静态网页制作成品代码
  16. 用python解决数据结构与算法_python中各种数据结构与算法的解决技巧
  17. 创新理念 汉王手写式鼠标砚鼠新品使用体验
  18. 深入理解nvme协议之二:nvme 协议重点数据结构之间的关系
  19. 小猪o2o源码v14.17双系统版(生活通+营销系统)怎么配置微信支付和支付宝支付
  20. 关于UDK GameFramework的一点总结

热门文章

  1. openwrt lamp
  2. 跟我学Windows7的33个技巧(二)
  3. jQuery编写插件
  4. TCP/IP(六):HTTP 与 HTTPS 简介
  5. iOS设计模式 ——单例模式详解以及严格单例模式注意点
  6. ASP.Net中页面传值的几种方式
  7. dedecms手机站要同步pc站的图片
  8. 2016,请不要在公司混日子!
  9. 处理ajax的session超时
  10. DFS应用——遍历有向图+判断有向图是否有圈