本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流。
未经本人允许禁止转载
如下为个人的学习笔记,使用latex编写,再写成文本太麻烦,这里直接截图。

个人笔记内容















ID3代码Java解读

数据组织

本代码来自于:https://github.com/agangotia/ID3
在使用编程语言实现算法的过程中,第一步便是要对数据的输入进行组织。如下为所使用的数据格式。其中,训练数据的格式为:

在训练数据集中共有6个属性,最后1列为决策变量。根据数据的格式,读取数据,其核心代码如下:

 //读取数据集public void prepareMatrix(String FileNameToRead,int PercentageOfDataToLEarnFrom) {BufferedReader br = null;try {br = new BufferedReader(new FileReader(FileNameToRead));int NumberOfColoumns = 0;// Reading Header  读取第一行{String line = br.readLine();StringTokenizer st = new StringTokenizer(line);boolean notNum = true;while (st.hasMoreElements()) {if (notNum) { //为true,添加属性Headers.add((String) st.nextElement());notNum = false;} else {st.nextElement(); //跳过数组notNum = true;}}}//最后添加一个类别----此时头共有7个Headers.add("Class");/** System.out.println("Lets Check Headers"); for(String temp:head){* System.out.println("\t "+temp); }*/NumberOfColoumns = Headers.size();coloumns = NumberOfColoumns;{// lets read coloumnsString line = br.readLine();while (line != null) {// System.out.print(line);StringTokenizer st = new StringTokenizer(line);// System.out.println("---- Split by space ------");int[] tempCol = new int[NumberOfColoumns];int tempIndex = 0;while (st.hasMoreElements()) {tempCol[tempIndex++] = Integer.parseInt((String) st.nextElement());// System.out.println(tempCol[tempIndex-1]);}rows.add(tempCol);line = br.readLine();}}// Now the truncating part,int rowsAfterTrunc = (int) ((PercentageOfDataToLEarnFrom * (rows.size())) / 100);if (rowsAfterTrunc == rows.size()) {// do nothing} else {for (int i = rows.size() - 1; i > rowsAfterTrunc; i--) {rows.remove(i);}}Numrows = rows.size();} catch (FileNotFoundException e) {// TODO Auto-generated catch blocke.printStackTrace();} catch (IOException e) {// TODO Auto-generated catch blocke.printStackTrace();} finally {try {br.close();} catch (IOException e) {// TODO Auto-generated catch blocke.printStackTrace();}}//数据封装成数组for (int i = 0; i < rows.size(); i++) {int[] arr = rows.get(i);for (int j = 0; j < arr.length; j++) {System.out.print(arr[j]);System.out.print("\t");}System.out.println("");}}

从这段代码中,可以看出怎么将按数字分割的属性和类别添加到Headers集合中另外,ID3算法,需要计算的是每个属性对应的信息增益,因此需要按照列进行组织数据。这里使用了集合的方式:

ArrayList<int[]> rows;//rows in matrix

添加的数据既包含每一列属性也包含最后一列的决策变量

算法核心代码

/*** */
package com.qian.id3;import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/*** @author Anupam Gangotia* Profile::http://en.gravatar.com/gangotia* github::https://github.com/agangotia*/
/*** This class the learner class, which has the function*         to learn from the train values, and store it in a Decision tree* */
public class ID3Learner {String FileNameToRead;// filename to read the training data setint PercentageOfDataToLEarnFrom;// this is the percent , which shows how// many line to be read from input// Paramaterized constructorpublic ID3Learner(String FileName, int percs) {FileNameToRead = FileName;PercentageOfDataToLEarnFrom = percs;}/** Function :: startLearning, This function, reads the file values into a* MxN Matrix datatype, The data from Matrix is further split across a set* of Training Vectors, and a FinalClass Vector. As our ID3 algortihm takes* set of Training vectors & final class vector as inputs. This* function,internally calls the learnTree Function, which is an* implementation of ID3 algorithm*/public TreeNode startLearning() {if (FileNameToRead == null) {System.out.println("---- Error ------");System.out.println("---- Please Specify test data set ------");}if (PercentageOfDataToLEarnFrom < 0) {System.out.println("---- Error ------");System.out.println("---- Please Specify %correctly ------");}MatrixData matrix = new MatrixData();// Prepares a new matrix datatypematrix.prepareMatrix(FileNameToRead, PercentageOfDataToLEarnFrom);// reads// 训练数据,不包含决策变量HashMap<String, int[]> setTrainingVector = new HashMap<String, int[]>();// Now i need a set of R training vectors  矩阵的列数  这里只循环了自变量for (int i = 0; i < matrix.coloumns - 1; i++) {// Training Vectors being//按照列存数据int[] trainingVector = new int[matrix.Numrows];matrix.fillArray(trainingVector, i);setTrainingVector.put(matrix.Headers.get(i), trainingVector);}// 决策变量int[] FinalClass = new int[matrix.Numrows];matrix.fillArray(FinalClass, matrix.coloumns - 1);// final class vector//初始化TreeNodeTreeNode rootNode = new TreeNode();rootNode.setAtrvalue(-1);// since its a root node 种子节点// Calling the ID3 implementation algorithm 自变量 因变量 生成树  数据learnTree(setTrainingVector, FinalClass, rootNode, matrix);return rootNode;}/** Function :: startLearning, Recursive Function. AN exact copy of ID3* algorithm(http://en.wikipedia.org/wiki/ID3_algorithm) This function* generates a decision tree recursively. Parameters: 1.A Hashmap containing* Training Vectors :: HashMap<String,int[]> setTrainingVector 2.A VEctor of* Final class :: int[] FinalClass 3.THe decision tree NOde::TreeNode node* 4.The MAtrix datatype, which is used in constructing vectors of train* data::MatrixData matrix*/public void learnTree(HashMap<String, int[]> setTrainingVector,int[] FinalClass, TreeNode node, MatrixData matrix) {// 判断所有的数据对应的类别是否为单一类别if (checkFinalClass(FinalClass, 0)) {// If all examples are 0, Return// the single-node tree Root,// with label = 0.node.fClass = 0;return;} else if (checkFinalClass(FinalClass, 1)) {// If all examples are 1,node.fClass = 1;return;}// 如果只有一个属性的情况if (setTrainingVector.entrySet().size() == 1) {int cPos = getCountPositives(FinalClass);int cNeg = FinalClass.length - cPos;if (cPos >= cNeg) {node.fClass = 0;return;} else {node.fClass = 1;return;}} else {/*使用信息增益选择属性*/HashMap<String, Double> attributesGains = new HashMap<String, Double>(); // 存储每个特征对应的信息增益HashMap<String, ArrayList<Integer>> mapAttributesValuesInListUnique = new HashMap<String, ArrayList<Integer>>();// The//计算样本数据的熵double entropyS = getEntropy(FinalClass);// initial entropy//setTrainingVector为自变量,每个属性对应的一列值,通过一个一维数组存储for (Map.Entry entry : setTrainingVector.entrySet()) {//某属性不同值对应的正类HashMap<Integer, Integer> atrPositive = new HashMap<Integer, Integer>();//某属性不同值对应的负类HashMap<Integer, Integer> atrNegative = new HashMap<Integer, Integer>();ArrayList<Integer> atrUnique = new ArrayList<Integer>();//获取所有的样本的训练集int[] trainingClass = (int[]) entry.getValue();for (int i = 0; i < trainingClass.length; i++) {// NOw finding// individual// entropiesaddOnlyUnique(atrUnique, trainingClass[i]);if (FinalClass[i] == 0)// its a positive{  //将某一属性下,不同值对应的正类和负类的数量统计出来if (atrPositive.containsKey(trainingClass[i])) {atrPositive.put(trainingClass[i],atrPositive.get(trainingClass[i]) + 1);} else {atrPositive.put(trainingClass[i], 1);}} else {// FinalClass is negativeif (atrNegative.containsKey(trainingClass[i])) {atrNegative.put(trainingClass[i],atrNegative.get(trainingClass[i]) + 1);} else {atrNegative.put(trainingClass[i], 1);}}}mapAttributesValuesInListUnique.put((String) entry.getKey(),atrUnique);// 针对每个属性,计算熵{double gain = entropyS;//每个属性对应的不重复值for (int tempAttr : atrUnique) {double entropyTemp = 0.0;int positives = 0;int negatives = 0;//获取不重复值对应的正类if (atrPositive.get(tempAttr) != null)positives = atrPositive.get(tempAttr);//获取不重复值对应的负类if (atrNegative.get(tempAttr) != null)negatives = atrNegative.get(tempAttr);double val1 = (double) (positives)/ (positives + negatives);double val2 = (double) (negatives)/ (positives + negatives);//基于公式计算信息熵entropyTemp = -(val1 * log2(val1))- (val2 * log2(val2));//累计计算信息增益值gain = gain- ((((double) positives + negatives) / trainingClass.length) * entropyTemp);}//封装该属性的信息增益值attributesGains.put((String) entry.getKey(), gain);}}// loop ends/**使用Map排序算法,这里是作者写的,*也可以直接调用  Collections.sort()进行排序**找出使得信息增益最大的属性**/String attributeWithMAxGain = "";double maxGainValue = 0.0;int indexToChoose = 0;for (Map.Entry entry : setTrainingVector.entrySet()) {double tempGain = attributesGains.get((String) entry.getKey());if (indexToChoose == 0) {maxGainValue = tempGain;attributeWithMAxGain = (String) entry.getKey();indexToChoose++;}if (tempGain > maxGainValue) {maxGainValue = tempGain;attributeWithMAxGain = (String) entry.getKey();}}// loop ends//节点添加node.setAttributeName(attributeWithMAxGain);node.setfClass(-1);node.setGain(maxGainValue);//下面,开始采用递归的方式向下计算ArrayList<Integer> atrUniqueValuesForAttrMaxGain = mapAttributesValuesInListUnique.get(attributeWithMAxGain);for (int tempAtrUniqueValue : atrUniqueValuesForAttrMaxGain) {TreeNode NodeChild = new TreeNode();NodeChild.setAtrvalue(tempAtrUniqueValue);// since its a child// nodenode.getBranches().add(NodeChild);MatrixData matrixChild = matrix.splitMatrix(attributeWithMAxGain, tempAtrUniqueValue);// matrixChild.printMatrix();// calling the algorithmHashMap<String, int[]> setTrainingVectorChild = new HashMap<String, int[]>();// Now i need a set of R training vectorsfor (int i = 0; i < matrixChild.coloumns - 1; i++) {int[] trainingVectorChild = new int[matrixChild.Numrows];matrixChild.fillArray(trainingVectorChild, i);setTrainingVectorChild.put(matrixChild.Headers.get(i),trainingVectorChild);}// i need final class vectorint[] FinalClassChild = new int[matrixChild.Numrows];matrixChild.fillArray(FinalClassChild, matrixChild.coloumns - 1);learnTree(setTrainingVectorChild, FinalClassChild, NodeChild,matrixChild);}return;}}// Function:checkFinalClass// Returns True or False// If all the attributes in final class equals valueToChecked returns Truepublic boolean checkFinalClass(int[] FinalClass, int valueToChecked) {for (int i = 0; i < FinalClass.length; i++) {if (FinalClass[i] != valueToChecked)return false;}return true;}// Function:getCountPositives// Returns the count of positives in final classpublic int getCountPositives(int[] FinalClass) {int countPos = 0;for (int i = 0; i < FinalClass.length; i++) {if (FinalClass[i] == 0)countPos++;}return countPos;}// 计算样本数据对应的信息熵public double getEntropy(int[] vector) {double entropy = 0.0;int positives = 0;int negatives = 0;for (int i = 0; i < vector.length; i++) {if (vector[i] == 0)// its a positive{positives++;} else {// FinalClass is negativenegatives++;}}double val1 = (double) (positives) / (positives + negatives);double val2 = (double) (negatives) / (positives + negatives);entropy = -(val1 * log2(val1)) - (val2 * log2(val2));return entropy;}// Function:log2// Returns log base 2public static double log2(double num) {if (num <= 0)return 0.0;return (Math.log(num) / Math.log(2));}// Function:addOnlyUnique// Adds a value to the arraylist only if does not exists in the listpublic void addOnlyUnique(ArrayList<Integer> data, int val) {if (!data.contains(val))data.add(val);}}

在上面的代码中,给出了完整的注释。其中,读者需要注意的时,如何使用递归的方法选择树的属性。

决策树模型(ID3/C4.5/CART)原理和底层代码解读 学习笔记相关推荐

  1. 机器学习:决策树及ID3,C4.5,CART算法描述

    文章目录 概念理解 熵: 条件熵: 信息增益,互信息: 信息增益比 基尼指数 ID3算法描述 C4.5算法描述 CART (Classification and Regression Tree)算法描 ...

  2. 决策树数学原理(ID3,c4.5,cart算法)

    上面这个图就是一棵典型的决策树.我们在做决策树的时候,会经历两个阶段:构造和剪枝. 构造 简单来说,构造的过程就是选择什么属性作为节点的过程,那么在构造过程中,会存在三种节点: 根节点:就是树的最顶端 ...

  3. 决策树(ID3,C4.5和CART)介绍、说明、联系和区别

    决策树 决策树 1. 决策树介绍 2. 决策树构建过程 2.1 属性选择 熵 条件熵 信息增益 信息增益比 3. 决策树生成和修建 4. 决策树常见算法 ID3 C4.5 CART(基尼指数) 5.总 ...

  4. 机器学习爬大树之决策树(ID3,C4.5)

    自己自学机器学习的相关知识,过了一遍西瓜书后准备再刷一遍,后来在看别人打比赛的代码时多次用到XGBoost,lightGBM,遂痛下决心认真学习机器学习关于树的知识,自己学习的初步流程图为: 决策树( ...

  5. unity物理射线之原理分析(擅码网学习笔记)

    @[TOC]unity物理射线之原理分析(擅码网学习笔记) using System.Collections; using System.Collections.Generic; using Unit ...

  6. 编译原理(龙书)学习笔记 第一章

    编译原理(龙书)学习笔记 第一章 1.1语言处理器 解释器(interpreter) : 编译器(compiler): 一个语言处理系统 练习 1.1.1:编译器和解释器之间的区别 1.1.2:相对优 ...

  7. cart算法_【统计学】决策树模型大比拼!ID3/C4.5/CART算法哪个更好用?

    - 点击上方"中国统计网"订阅我吧!- 决策树模型在监督学习中非常常见,可用于分类(二分类.多分类)和回归.虽然将多棵弱决策树的Bagging.Random Forest.Boos ...

  8. 常用决策树模型ID3、C4.5、CART算法

    决策树概述 决策树(decision tree):是一种基本的分类与回归方法,下面提到的ID3.C4.5.CART主要讨论分类的决策树. 在分类问题中,表示基于特征对实例进行分类的过程,可以认为是if ...

  9. 决策树 基于python实现ID3,C4.5,CART算法

    实验目录 实验环境 简介 决策树(decision tree) 信息熵 信息增益(应用于ID3算法) 信息增益率(在C4.5算法中使用) 基尼指数(被用于CART算法) 实验准备 数据集 算法大体流程 ...

最新文章

  1. omnet++ : could not be resolved 或error: coreexception 的解决
  2. Eclipse 4.9 正式发布,支持 Java 11!
  3. easyUI 添加排序到datagrid
  4. 在阿里云服务器上安装docker
  5. Sublime Text 3 初试牛刀
  6. 2017.4.5 假期的宿舍 思考记录
  7. gradle 项目打包成多个jar包_自从用完 Gradle 后,有点嫌弃 Maven 了!
  8. C语言:存储类型,内存管理
  9. 在JavaScript中定义枚举的首选语法是什么? [关闭]
  10. 阶段3 2.Spring_05.基于XML的IOC的案例1_3 测试基于XML的IOC案例
  11. javaweb--请求转发
  12. 如鹏网.Net高级技术4.String特点及常用方法
  13. PHP 开发者应了解的 24 个库
  14. 计算机组装维修设置还原点,win10的还原点设定还原后安装的软件还在吗?
  15. 利用持续同调在基于深度学习的分割框架中引入显式的拓扑学约束
  16. python可视化窗口制作一个摇骰子游戏_python摇骰子游戏小案例
  17. MYSQL的开题报告题目,开题报告的选题意义.docx
  18. 树莓派4B安装Ubuntu Mate20.04
  19. 不用运动快速有效减肥——红光光浴#大健康#红光光浴#红光#种光光学
  20. GitLab CI/CD artifacts 属性的配置与使用

热门文章

  1. Oracle归档日志与非归档日志的切换及路径设置
  2. Kubernetes教程
  3. 如何在两个服务器之间迁移MySQL数据库
  4. VMware linux 虚拟机(ubuntu18.04) 安装TL-WDN5200H 2.0网卡驱动 完美使用(适用于vmware无法桥接网络使用此方法)
  5. handsontable pro 授权码 key 生成器(JS函数)(仅供学习交流)
  6. 【JS】最简单的域名防红方法,QQ/微信打开网站提示浏览器打开
  7. vConsole 让你在手机上也能轻松调试网页
  8. 【C语言】求s(n)=a+aa+aaa+...+aa...a的值
  9. 【C语言】在有序数组中插入一个数,保证它依然有序
  10. C#LeetCode刷题之#326-3的幂(Power of Three)