构建一个决策树分类模型,实现对鸢尾花的分类

1.lris数据集介绍:

鸢尾花数据集是机器学习领域中非常经典的一个分类数据集。数据集全名为:Iris Data Set,总共包含150行数据。

每一行由4个特征值及一个目标值(类别变量)组成。

其中4个特征值分别是:萼片长度、萼片宽度、花瓣长度、花瓣宽度

目标值为3种不同类别的鸢尾花:山鸢尾、变色鸢尾、维吉尼亚鸢尾

2.读取数据

Iris数据集里是一个矩阵,每一列代表了萼片或花瓣的长宽,一共4列,每一列代表某个被测量的鸢尾植物,一共采样了150条记录。

from sklearn.datasets import load_iris  # 导入方法类iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据
print (iris.data)          #输出数据集
print (iris.target)        #输出真实标签
print (len(iris.target) )
print (iris.data.shape )   #150个样本 每个样本4个特征#输出结果如下:
[[5.1 3.5 1.4 0.2][4.9 3.  1.4 0.2][4.7 3.2 1.3 0.2][4.6 3.1 1.5 0.2][5.  3.6 1.4 0.2][5.4 3.9 1.7 0.4][4.6 3.4 1.4 0.3][5.  3.4 1.5 0.2][4.4 2.9 1.4 0.2][4.9 3.1 1.5 0.1][5.4 3.7 1.5 0.2][4.8 3.4 1.6 0.2][4.8 3.  1.4 0.1][4.3 3.  1.1 0.1][5.8 4.  1.2 0.2][5.7 4.4 1.5 0.4][5.4 3.9 1.3 0.4][5.1 3.5 1.4 0.3][5.7 3.8 1.7 0.3][5.1 3.8 1.5 0.3][5.4 3.4 1.7 0.2][5.1 3.7 1.5 0.4][4.6 3.6 1.  0.2][5.1 3.3 1.7 0.5][4.8 3.4 1.9 0.2][5.  3.  1.6 0.2][5.  3.4 1.6 0.4][5.2 3.5 1.5 0.2][5.2 3.4 1.4 0.2][4.7 3.2 1.6 0.2][4.8 3.1 1.6 0.2][5.4 3.4 1.5 0.4][5.2 4.1 1.5 0.1][5.5 4.2 1.4 0.2][4.9 3.1 1.5 0.2][5.  3.2 1.2 0.2][5.5 3.5 1.3 0.2][4.9 3.6 1.4 0.1][4.4 3.  1.3 0.2][5.1 3.4 1.5 0.2][5.  3.5 1.3 0.3][4.5 2.3 1.3 0.3][4.4 3.2 1.3 0.2][5.  3.5 1.6 0.6][5.1 3.8 1.9 0.4][4.8 3.  1.4 0.3][5.1 3.8 1.6 0.2][4.6 3.2 1.4 0.2][5.3 3.7 1.5 0.2][5.  3.3 1.4 0.2][7.  3.2 4.7 1.4][6.4 3.2 4.5 1.5][6.9 3.1 4.9 1.5][5.5 2.3 4.  1.3][6.5 2.8 4.6 1.5][5.7 2.8 4.5 1.3][6.3 3.3 4.7 1.6][4.9 2.4 3.3 1. ][6.6 2.9 4.6 1.3][5.2 2.7 3.9 1.4][5.  2.  3.5 1. ][5.9 3.  4.2 1.5][6.  2.2 4.  1. ][6.1 2.9 4.7 1.4][5.6 2.9 3.6 1.3][6.7 3.1 4.4 1.4][5.6 3.  4.5 1.5][5.8 2.7 4.1 1. ][6.2 2.2 4.5 1.5][5.6 2.5 3.9 1.1][5.9 3.2 4.8 1.8][6.1 2.8 4.  1.3][6.3 2.5 4.9 1.5][6.1 2.8 4.7 1.2][6.4 2.9 4.3 1.3][6.6 3.  4.4 1.4][6.8 2.8 4.8 1.4][6.7 3.  5.  1.7][6.  2.9 4.5 1.5][5.7 2.6 3.5 1. ][5.5 2.4 3.8 1.1][5.5 2.4 3.7 1. ][5.8 2.7 3.9 1.2][6.  2.7 5.1 1.6][5.4 3.  4.5 1.5][6.  3.4 4.5 1.6][6.7 3.1 4.7 1.5][6.3 2.3 4.4 1.3][5.6 3.  4.1 1.3][5.5 2.5 4.  1.3][5.5 2.6 4.4 1.2][6.1 3.  4.6 1.4][5.8 2.6 4.  1.2][5.  2.3 3.3 1. ][5.6 2.7 4.2 1.3][5.7 3.  4.2 1.2][5.7 2.9 4.2 1.3][6.2 2.9 4.3 1.3][5.1 2.5 3.  1.1][5.7 2.8 4.1 1.3][6.3 3.3 6.  2.5][5.8 2.7 5.1 1.9][7.1 3.  5.9 2.1][6.3 2.9 5.6 1.8][6.5 3.  5.8 2.2][7.6 3.  6.6 2.1][4.9 2.5 4.5 1.7][7.3 2.9 6.3 1.8][6.7 2.5 5.8 1.8][7.2 3.6 6.1 2.5][6.5 3.2 5.1 2. ][6.4 2.7 5.3 1.9][6.8 3.  5.5 2.1][5.7 2.5 5.  2. ][5.8 2.8 5.1 2.4][6.4 3.2 5.3 2.3][6.5 3.  5.5 1.8][7.7 3.8 6.7 2.2][7.7 2.6 6.9 2.3][6.  2.2 5.  1.5][6.9 3.2 5.7 2.3][5.6 2.8 4.9 2. ][7.7 2.8 6.7 2. ][6.3 2.7 4.9 1.8][6.7 3.3 5.7 2.1][7.2 3.2 6.  1.8][6.2 2.8 4.8 1.8][6.1 3.  4.9 1.8][6.4 2.8 5.6 2.1][7.2 3.  5.8 1.6][7.4 2.8 6.1 1.9][7.9 3.8 6.4 2. ][6.4 2.8 5.6 2.2][6.3 2.8 5.1 1.5][6.1 2.6 5.6 1.4][7.7 3.  6.1 2.3][6.3 3.4 5.6 2.4][6.4 3.1 5.5 1.8][6.  3.  4.8 1.8][6.9 3.1 5.4 2.1][6.7 3.1 5.6 2.4][6.9 3.1 5.1 2.3][5.8 2.7 5.1 1.9][6.8 3.2 5.9 2.3][6.7 3.3 5.7 2.5][6.7 3.  5.2 2.3][6.3 2.5 5.  1.9][6.5 3.  5.2 2. ][6.2 3.4 5.4 2.3][5.9 3.  5.1 1.8]]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 00 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 22 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 22 2]
150
(150, 4)
 

data是150*4的矩阵,对应着150条鸢尾花数据(每条4个数据:包括萼片和花瓣的长宽)

target是一个数组,存储了data中每条数据属于哪类鸢尾植物,所以数组长度是150

因为共有3类鸢尾花,所以0,1,2分别代表了山鸢尾花、杂色鸢尾花、维吉尼亚鸢尾花

3.数据可视化

调用pandas扩展包进行绘图。

首先绘制直方图,展现了花瓣、花萼的长和宽的特征数量,纵坐标表示汇总的数量,横坐标表示对应的长度

通过调用hist()函数实现

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris  # 导入方法类iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据
#print (iris.data)          #输出数据集
#print (iris.target)        #输出真实标签
#print (len(iris.target) )
#print (iris.data.shape )   #150个样本 每个样本4个特征import pandas
#导入数据集iris
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['sepal-length', 'sepal-width', 'petal-length', 'petal-width', 'class']
dataset = pandas.read_csv(url, names=names) #读取csv数据
print(dataset.describe())
#直方图 histograms
dataset.hist()
plt.show()#dataset.describe()输出如下:sepal-length  sepal-width  petal-length  petal-width
count    150.000000   150.000000    150.000000   150.000000
mean       5.843333     3.054000      3.758667     1.198667
std        0.828066     0.433594      1.764420     0.763161
min        4.300000     2.000000      1.000000     0.100000
25%        5.100000     2.800000      1.600000     0.300000
50%        5.800000     3.000000      4.350000     1.300000
75%        6.400000     3.300000      5.100000     1.800000
max        7.900000     4.400000      6.900000     2.500000

4.训练和分类

首先对从sklearn中导入决策树分类器,对数据集进行训练和分类

from sklearn import tree
from sklearn.tree import DecisionTreeClassifier      #导入决策树DTC包
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris  # 导入方法类iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据clf = DecisionTreeClassifier()      # 所以参数均置为默认状态
clf.fit(iris.data, iris.target)     # 使用训练集训练模型
print(clf)
predicted = clf.predict(iris.data)    #使用模型对测试集进行预测
print(predicted)
print("精度是:{:.3f}".format(clf.score(iris.data, iris.target)))#输出如下:
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,max_features=None, max_leaf_nodes=None,min_impurity_decrease=0.0, min_impurity_split=None,min_samples_leaf=1, min_samples_split=2,min_weight_fraction_leaf=0.0, presort=False,random_state=None, splitter='best')
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 00 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 11 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 22 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 22 2]
精度是:1.000

因为叶结点都是纯的,输的深的很大,足以完美的记住训练数据的所有标签。

之前有线性模型也做个鸢尾花分类,线性模型的精度约为95%线性模型——鸢尾花分类

如果我们不限制决策树的深度,他的深度和复杂度都会变得很大。

银次未剪枝的树容易过度拟合,对新数据的泛化能力不佳。

我们将预剪枝应用到决策树上,这可以在完美拟合训练数据之前阻止树的展开。

一种选择是,在树到达一定深度后停止树的展开。代码如下:

clf = DecisionTreeClassifier(max_depth=3,random_state=0)#输出精度:

这意味着只能连续问4个问题。限制树的深度可以减少过拟合。

这会降低训练集精度,但是可以提高测试集的精度

(也就是训练出来的模型精度低了,但是预测的时候精度高了,这肯定是好的么)

5.可视化决策树

当我们不限制树的深度时:

# 引入数据集
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier      #导入决策树DTC包
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris  # 导入方法类iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据clf = DecisionTreeClassifier()      # 所以参数均置为默认状态
clf.fit(iris.data, iris.target)     # 使用训练集训练模型
#print(clf)
predicted = clf.predict(iris.data)
#print(predicted)
print("精度是:{:.3f}".format(clf.score(iris.data, iris.target)))
# viz code 可视化 制作一个简单易读的PDF
from sklearn.externals.six import StringIO
import pydot
#需要安装pydot包,用Anaconda Prompt安装,需要先安装graphviz再安装pydot,命令如下:
# conda install graphviz
# conda install pydot
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,feature_names=iris.feature_names,class_names=iris.target_names,filled=True, rounded=True,special_characters=True)
graph = pydot.graph_from_dot_data(dot_data.getvalue())
# print(len(graph))  # 1
# print(graph)  # [<pydot.Dot object at 0x000001F7BD1A9630>]
# print(graph[0])  # <pydot.Dot object at 0x000001F7BD1A9630>
# graph.write_pdf("iris.pdf")
graph[0].write_pdf("iris.pdf")#输出如下:
精度是:1.000

我们可以利用export_graphviz()函数将树可视化,并输出成pdf,如下图:

当我们限制树的深度为3时:精度是:0.973

clf = DecisionTreeClassifier(max_depth=3,random_state=0)

6.数据集多类分类

决策树实现类是DecisionTreeClassifier,能够执行数据集的多类分类。

输入参数为两个数组x[n_samples,n_features]和X[n_samples],

x为训练数据,X为训练数据的标记数据

把分类好的数据集绘制散点图,使用Matplotlib模块

# 引入数据集
from sklearn import tree
from sklearn.tree import DecisionTreeClassifier      #导入决策树DTC包
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris  # 导入方法类iris = load_iris()  #导入数据集iris
iris_feature = iris.data    #特征数据
iris_target = iris.target   #分类数据clf = DecisionTreeClassifier()      # 所以参数均置为默认状态
clf.fit(iris.data, iris.target)     # 使用训练集训练模型
#print(clf)
predicted = clf.predict(iris.data)
#print(predicted)# 获取花卉两列数据集
X = iris.data
L1 = [x[0] for x in X]
#print(L1)
L2 = [x[1] for x in X]
#print (L2)#绘图
plt.scatter(X[:50, 0], X[:50, 1], color='red', marker='o', label='setosa')
plt.scatter(X[50:100, 0], X[50:100, 1], color='blue', marker='x', label='versicolor')
plt.scatter(X[100:, 0], X[100:, 1], color='green', marker='s', label='Virginica')
plt.title("DTC")
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.xticks(())
plt.yticks(())
plt.legend(loc=2)
plt.show()#输出如下:
[5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9, 5.4, 4.8, 4.8, 4.3, 5.8, 5.7, 5.4, 5.1, 5.7, 5.1, 5.4, 5.1, 4.6, 5.1, 4.8, 5.0, 5.0, 5.2, 5.2, 4.7, 4.8, 5.4, 5.2, 5.5, 4.9, 5.0, 5.5, 4.9, 4.4, 5.1, 5.0, 4.5, 4.4, 5.0, 5.1, 4.8, 5.1, 4.6, 5.3, 5.0, 7.0, 6.4, 6.9, 5.5, 6.5, 5.7, 6.3, 4.9, 6.6, 5.2, 5.0, 5.9, 6.0, 6.1, 5.6, 6.7, 5.6, 5.8, 6.2, 5.6, 5.9, 6.1, 6.3, 6.1, 6.4, 6.6, 6.8, 6.7, 6.0, 5.7, 5.5, 5.5, 5.8, 6.0, 5.4, 6.0, 6.7, 6.3, 5.6, 5.5, 5.5, 6.1, 5.8, 5.0, 5.6, 5.7, 5.7, 6.2, 5.1, 5.7, 6.3, 5.8, 7.1, 6.3, 6.5, 7.6, 4.9, 7.3, 6.7, 7.2, 6.5, 6.4, 6.8, 5.7, 5.8, 6.4, 6.5, 7.7, 7.7, 6.0, 6.9, 5.6, 7.7, 6.3, 6.7, 7.2, 6.2, 6.1, 6.4, 7.2, 7.4, 7.9, 6.4, 6.3, 6.1, 7.7, 6.3, 6.4, 6.0, 6.9, 6.7, 6.9, 5.8, 6.8, 6.7, 6.7, 6.3, 6.5, 6.2, 5.9]
[3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1, 3.7, 3.4, 3.0, 3.0, 4.0, 4.4, 3.9, 3.5, 3.8, 3.8, 3.4, 3.7, 3.6, 3.3, 3.4, 3.0, 3.4, 3.5, 3.4, 3.2, 3.1, 3.4, 4.1, 4.2, 3.1, 3.2, 3.5, 3.6, 3.0, 3.4, 3.5, 2.3, 3.2, 3.5, 3.8, 3.0, 3.8, 3.2, 3.7, 3.3, 3.2, 3.2, 3.1, 2.3, 2.8, 2.8, 3.3, 2.4, 2.9, 2.7, 2.0, 3.0, 2.2, 2.9, 2.9, 3.1, 3.0, 2.7, 2.2, 2.5, 3.2, 2.8, 2.5, 2.8, 2.9, 3.0, 2.8, 3.0, 2.9, 2.6, 2.4, 2.4, 2.7, 2.7, 3.0, 3.4, 3.1, 2.3, 3.0, 2.5, 2.6, 3.0, 2.6, 2.3, 2.7, 3.0, 2.9, 2.9, 2.5, 2.8, 3.3, 2.7, 3.0, 2.9, 3.0, 3.0, 2.5, 2.9, 2.5, 3.6, 3.2, 2.7, 3.0, 2.5, 2.8, 3.2, 3.0, 3.8, 2.6, 2.2, 3.2, 2.8, 2.8, 2.7, 3.3, 3.2, 2.8, 3.0, 2.8, 3.0, 2.8, 3.8, 2.8, 2.8, 2.6, 3.0, 3.4, 3.1, 3.0, 3.1, 3.1, 3.1, 2.7, 3.2, 3.3, 3.0, 2.5, 3.0, 3.4, 3.0]

不同颜色的点代表不同的种类。

决策树模型——鸢尾花分类相关推荐

  1. 《scikit-learn》决策树之鸢尾花分类

    有了上一博客的基础,我们来看看怎么操作鸢尾花的分裂问题.也是做一个简单的总结和回顾. 直接整代码了. from sklearn import tree from sklearn.datasets im ...

  2. 决策树实现鸢尾花分类

    介绍 在这篇博客中,我们使用以下几个库来实现决策树算法 scikit-learn机器学习库 scikit-learn最先是由David Cournapeau在2007年发起的一个Google Summ ...

  3. python决策树分类鸢尾花_基于决策树—鸢尾花分类

    决策树算法广泛应用于:语音识别.医疗诊断.客户关系管理.模式识别.专家系统等,在实际工作中,必须根据数据类型的特点及数据集的大小,选择合适的算法. 本文选择经典案例--<鸢尾花分类> 一. ...

  4. 机器学习(五)常用分类模型(K最近邻、朴素贝叶斯、决策树)和分类评价指标

    机器学习(五)常用分类模型(K最近邻.朴素贝叶斯.决策树)和分类评价指标 文章目录 机器学习(五)常用分类模型(K最近邻.朴素贝叶斯.决策树)和分类评价指标 综述 常用分类模型 K最近邻模型 朴素贝叶 ...

  5. 利用sklearn库决策树模型对iris数据多分类并进行评估

    1.导入所需要的库 from sklearn.tree import DecisionTreeClassifier from sklearn.datasets import load_iris 2.加 ...

  6. 鸢尾花分类与直方图、散点图的绘制及可视化决策树

    一.IRIS鸢尾花 鸢尾花有三个亚属,分别是山鸢尾(Iris-setosa).变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica) 数据集一共包含4个特征变量,1个 ...

  7. 基于sklearn的鸢尾花分类模型

    1.鸢尾花数据获取及查看 可以通过sklearn直接获取数据集: from sklearn.datasets import load_iris import matplotlib.pyplot as ...

  8. 【机器学习】决策树案例二:利用决策树进行鸢尾花数据集分类预测

    利用决策树进行鸢尾花数据集分类预测 2 利用决策树进行鸢尾花数据集分类预测 2.1 导入模块与加载数据 2.2 划分数据 2.3 模型创建与应用 2.4 模型可视化 手动反爬虫,禁止转载: 原博地址 ...

  9. 决策树实战项目-鸢尾花分类

    决策树实战项目-鸢尾花分类 一.实验介绍 1.1 实验内容 决策树是机器学习中一种简单而又经典的算法.本次实验将带领了解决策树的基本原理,并学习使用 scikit-learn 来构建一个决策树分类模型 ...

最新文章

  1. 如何手工展开函数栈来定位问题
  2. 分析 C# 2.0 新特性 -- 范型(Generics)
  3. Careercup - Microsoft面试题 - 5428361417457664
  4. C++中set和map的erase用法
  5. 服务器显示不明用户远程过,服务器显示不明用户远程过
  6. HTML学习笔记--HTML的语法【1】
  7. Python基础篇【第二篇】:运算符
  8. 5.Magento资源配置(Setup Resource)
  9. Python反编译apk,获取各类信息
  10. Java静态代理、动态代理
  11. 天才小毒妃 第917章 深藏不露大财主
  12. CANoe软件使用(一)——软件界面介绍
  13. SSM+酒店管理系统的设计和实现 毕业设计-附源码260839
  14. cesium 缩放_cesium 缩放中心点控制
  15. 机器学习 数据集划分 训练集 验证集 测试集
  16. mysql 触发器 模板_MySQL 触发器例子(两张表同步增加和删除)
  17. 磁盘分区MBR和GPT格式详解(Linux)
  18. 12.关于uniapp小程序设置页面背景色无效的问题及解决方案
  19. Item 5:Know what functions C++ silently writes and calls
  20. C语言中图形函数及其用法

热门文章

  1. 音视频技术开发周刊 | 263
  2. 计算机系统大作业 程序人生-Hello’s P2P
  3. 看完你就知道原因了,这3类人不适合做自媒体,看看是不是你自己
  4. 采购员的主要职责是什么?
  5. 基于Hyperlynx VX.2.5 的DDR3仿真之一:Verifying That the Software Recognizes Your Design Correctly
  6. python自动化运维脚本(仅供参考)
  7. K-Means聚类算法原理及其python和matlab实现
  8. echarts给柱状图某个柱子设置颜色
  9. 拼多多进军美国市场是为国内电商人铺路还是强走了最后的蛋糕?
  10. scala中sorted,sortby,sortwith的用法(转)