之前已经对MNIST使用过SVM和KNN的方法进行分类,效果看起来还不错。今天使用决策树来实验,看看结果如何。

使用的Decision Tree中,对MNIST中的灰度值进行了0/1处理,方便来进行分类和计算熵。

使用较少的测试数据测试了在对灰度值进行多分类的情况下,分类结果的正确率如何。实验结果如下。

#Test change pixel data into more categories than 0/1:
#int(pixel)/50: 37%
#int(pixel)/64: 45.9%
#int(pixel)/96: 52.3%
#int(pixel)/128: 62.48%
#int(pixel)/152: 59.1%
#int(pixel)/176: 57.6%
#int(pixel)/192: 54.0%

可见,在对灰度数据进行二分类,也就是0/1处理时,效果是最好的。

使用0/1处理,最终结果如下:

#Result:
#Train with 10k, test with 60k: 77.79%
#Train with 60k, test with 10k: 87.3%
#Time cost: 3 hours.

最终结果是87.3%的正确率。与SVM和KNN的超过95%相比,差距不小。而且消耗时间更长。

需要注意的是,此次Decision Tree算法中,并未对决策树进行剪枝。因此,还有可以提升的空间。

python代码见最下面。其中:

calcShannonEntropy(dataSet):是对矩阵的熵进行计算,根据各个数据点的分类情况,使用香农定理计算;

splitDataSet(dataSet, axis, value): 是获取第axis维度上的值为value的所有行所组成的矩阵。对于第axis维度上的数据,分别计算他们的splitDataSet的矩阵的熵,并与该维度上数据的出现概率相乘求和,可以得到使用第axis维度构建决策树后,整体的熵。

chooseBestFeatureToSplit(dataSet): 根据splitDataSet函数,对比得到整体的熵与原矩阵的熵相比,熵的增量最大的维度。根据此维度feature来构建决策树。

createDecisionTree(dataSet, features): 递归构建决策树。若在叶子节点处没法分类,则采用majorityCnt(classList)方法统计出现最多次的class作为分类。

代码如下:

#Decision tree for MNIST dataset by arthur503.
#Data format: 'class   label1:pixel    label2:pixel ...'
#Warning: without fix overfitting!
#
#Test change pixel data into more categories than 0/1:
#int(pixel)/50: 37%
#int(pixel)/64: 45.9%
#int(pixel)/96: 52.3%
#int(pixel)/128: 62.48%
#int(pixel)/152: 59.1%
#int(pixel)/176: 57.6%
#int(pixel)/192: 54.0%
#
#Result:
#Train with 10k, test with 60k: 77.79%
#Train with 60k, test with 10k: 87.3%
#Time cost: 3 hours.from numpy import *
import operatordef calcShannonEntropy(dataSet):numEntries = len(dataSet)labelCounts = {}for featureVec in dataSet:currentLabel = featureVec[0]if currentLabel not in labelCounts.keys():labelCounts[currentLabel] = 1else:labelCounts[currentLabel] += 1shannonEntropy = 0.0for key in labelCounts:prob = float(labelCounts[key])/numEntriesshannonEntropy -= prob  * log2(prob)return shannonEntropy#get all rows whose axis item equals value.
def splitDataSet(dataSet, axis, value):subDataSet = []for featureVec in dataSet:if featureVec[axis] == value:reducedFeatureVec = featureVec[:axis]reducedFeatureVec.extend(featureVec[axis+1:])    #if axis == -1, this will cause error!subDataSet.append(reducedFeatureVec)return subDataSetdef chooseBestFeatureToSplit(dataSet):#Notice: Actucally, index 0 of numFeatures is not feature(it is class label).numFeatures = len(dataSet[0])  baseEntropy = calcShannonEntropy(dataSet)bestInfoGain = 0.0bestFeature = numFeatures - 1     #DO NOT use -1! or splitDataSet(dataSet, -1, value) will cause error!#feature index start with 1(not 0)!for i in range(numFeatures)[1:]:featureList = [example[i] for example in dataSet]featureSet = set(featureList)newEntropy = 0.0for value in featureSet:subDataSet = splitDataSet(dataSet, i, value)prob = len(subDataSet)/float(len(dataSet))newEntropy += prob * calcShannonEntropy(subDataSet)infoGain = baseEntropy - newEntropyif infoGain > bestInfoGain:bestInfoGain = infoGainbestFeature = ireturn bestFeature#classify on leaf of decision tree.
def majorityCnt(classList):classCount = {}for vote in classList:if vote not in classCount:classCount[vote] = 0classCount[vote] += 1sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)return sortedClassCount[0][0]#Create Decision Tree.
def createDecisionTree(dataSet, features):print 'create decision tree... length of features is:'+str(len(features))classList = [example[0] for example in dataSet]if classList.count(classList[0]) == len(classList):return classList[0]if len(dataSet[0]) == 1:return majorityCnt(classList)bestFeatureIndex = chooseBestFeatureToSplit(dataSet) bestFeatureLabel = features[bestFeatureIndex]myTree = {bestFeatureLabel:{}}del(features[bestFeatureIndex])featureValues = [example[bestFeatureIndex] for example in dataSet]featureSet = set(featureValues)for value in featureSet:subFeatures = features[:]    myTree[bestFeatureLabel][value] = createDecisionTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures)return myTreedef line2Mat(line):mat = line.strip().split(' ')for i in range(len(mat)-1):   pixel = mat[i+1].split(':')[1]#change MNIST pixel data into 0/1 format.mat[i+1] = int(pixel)/128return mat#return matrix as a list(instead of a matrix).
#features is the 28*28 pixels in MNIST dataset.
def file2Mat(fileName):f = open(fileName)lines = f.readlines()matrix = []for line in lines:mat = line2Mat(line)matrix.append(mat)f.close()print 'Read file '+str(fileName) + ' to array done! Matrix shape:'+str(shape(matrix))return matrix#Classify test file.
def classify(inputTree, featureLabels, testVec):firstStr = inputTree.keys()[0]secondDict = inputTree[firstStr]featureIndex = featureLabels.index(firstStr)predictClass = '-1'for key in secondDict.keys():if testVec[featureIndex] == key:if type(secondDict[key]) == type({}):   predictClass = classify(secondDict[key], featureLabels, testVec)else:predictClass = secondDict[key]return predictClassdef classifyTestFile(inputTree, featureLabels, testDataSet):rightCnt = 0for i in range(len(testDataSet)):classLabel = testDataSet[i][0]predictClassLabel = classify(inputTree, featureLabels, testDataSet[i])if classLabel == predictClassLabel:rightCnt += 1 if i % 200 == 0:print 'num '+str(i)+'. ratio: ' + str(float(rightCnt)/(i+1))return float(rightCnt)/len(testDataSet)def getFeatureLabels(length):strs = []for i in range(length):strs.append('#'+str(i))return strs#Normal file
trainFile = 'train_60k.txt'
testFile = 'test_10k.txt'
#Scaled file
#trainFile = 'train_60k_scale.txt'
#testFile = 'test_10k_scale.txt'
#Test file
#trainFile = 'test_only_1.txt'
#testFile = 'test_only_2.txt'#train decision tree.
dataSet = file2Mat(trainFile)
#Actually, the 0 item is class, not feature labels.
featureLabels = getFeatureLabels(len(dataSet[0]))
print 'begin to create decision tree...'
myTree = createDecisionTree(dataSet, featureLabels)
print 'create decision tree done.'#predict with decision tree.
testDataSet = file2Mat(testFile)
featureLabels = getFeatureLabels(len(testDataSet[0]))
rightRatio = classifyTestFile(myTree, featureLabels, testDataSet)
print 'total right ratio: ' + str(rightRatio)


使用Decision Tree对MNIST数据集进行实验相关推荐

  1. 使用KNN对MNIST数据集进行实验

    之前使用SVM对MNIST数据集进行了分类实验,得到了98.46%的分类正确率(见:使用libsvm对MNIST数据集进行实验). 今天用python写了个小程序,来测试一下KNN的分类效果. 由于K ...

  2. 使用libsvm对MNIST数据集进行实验

    在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出话来. 1. libsvm介绍 虽然原理要求很高的数学知识等,但是libsvm中,完全就是一个工具包,拿来就能用.当时 ...

  3. OpenCV3.3中决策树(Decision Tree)接口简介及使用

    OpenCV 3.3中给出了决策树Decision Tres算法的实现,即cv::ml::DTrees类,此类的声明在include/opencv2/ml.hpp文件中,实现在modules/ml/s ...

  4. How To Implement The Decision Tree Algorithm From Scratch In Python (从零开始在Python中实现决策树算法)

    How To Implement The Decision Tree Algorithm From Scratch In Python 原文作者:Jason Brownlee 原文地址:https:/ ...

  5. Python实现决策树(Decision Tree)分类

    关于决策树的简介可以参考: http://blog.csdn.net/fengbingchun/article/details/78880934 在  https://machinelearningm ...

  6. Machine Learning | (7) Scikit-learn的分类器算法-决策树(Decision Tree)

    Machine Learning | 机器学习简介 Machine Learning | (1) Scikit-learn与特征工程 Machine Learning | (2) sklearn数据集 ...

  7. 深度学习:知识回收(Lecture3+4 PLA+Decision Tree)

    Lecture 3 PLA and Lecture 4 Decision Tree 序 Lecture 3 PLA Lecture 4 Decision Tree 特征划分 ID3 信息增益 C4.5 ...

  8. Decision Tree

    #①Aggregation Model 回顾上一篇文章讲到的聚合模型,三个臭皮匠顶一个诸葛亮.于是出现了blending,bagging,boost,stacking.blending有uniform ...

  9. 【机器学习实战】第3章 决策树(Decision Tree)

    第3章 决策树 <script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/ ...

最新文章

  1. 游标对于分页存储过程
  2. 2021年机器学习什么风向?谷歌大神Quoc Le:把注意力放在MLP上
  3. morlet包络检波matlab,布里渊光纤传感系统中的信号处理的研究
  4. DateUtils常用方法
  5. 互联网大佬马老师于昨日教师节正式卸任,让位现任CEO张勇;华为发布新一代CloudLink视讯解决方案,普惠4K+AI;联通……...
  6. java 解析/操作 xml 几种常用方式 xml的增加/删除/修改
  7. iphone怎么查看wifi密码_WiFi密码忘了怎么办?一秒找回密码
  8. GBDT+LR算法入门理解
  9. linux broadcast 作用,linux中inaddr_broadcast什么意思?
  10. pxe kickstart无人值守自动化装机
  11. python求一元二次方程解
  12. UVa 11909 - Soya Milk
  13. 腾讯云服务器搭建网站详细教程(新版)
  14. WeChat 聊天记录从一台手机转移到另外一台手机的注意事项
  15. 水果缺陷检测以草莓等级分类为例
  16. 倍福---CNC系统介绍
  17. xx-xx-xx-xx转换成x年x月x日星期x
  18. Long和Integer相互转换
  19. dropdownList获取值
  20. Docker技术研究

热门文章

  1. IDEA中Maven项目创建单元测试(JUnit4)
  2. Leetcode5634. 删除子字符串的最大得分[C++题解]:贪心
  3. c语言1至1000能被7或11 个数208 数下来没有208个,数学高手看过来:10000以内不能被3,5,7,9,11中任何一? 爱问知识人...
  4. php 不刷新提交,提交表单而不刷新页面ajax,php,javascript?
  5. mysql中主键外键的作用_数据库主键和外键的作用以及索引的作用,它的优缺点是什么?...
  6. c语言的编译环境出错怎么办,编译是通过,连接时出错,怎么办
  7. python背景颜色词云图_python中实现词云图
  8. Android开发基本概念
  9. 二级(多级)指针,二级指针和二维数组的避坑,指针面试考题
  10. android 获取其他布局,android listview onItemClick中获取其他item的布局