线性二分类

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report
from sklearn import tree#载入数据
data = np.genfromtxt('.csv',delimiter=',')
x_data = data[:,:-1]
y_data = data[:,-1]plt.scatter(x_data[:,0],x_data[:,-1],c=y_data)
plt.show()#创建决策树模型
model = tree.DecisionTreeClassifier()
model.fit(x_data,y_data)#导出决策树
import graphvizdot_data = tree.export_graphviz(model,out_file=None,#特征的名字,要设置feature_names = ['x','y'],class_names=['label0','label1'],filled=True,rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('cart_1')#获取数据值所在范围
x_min,x_max = x_data[:,0].min() - 1,x_data[:,0].max() + 1
y_min,y_max = x_data[:,1].min() - 1,x_data[:,1].max() + 1#生成网格矩阵
xx,yy = np.meshgrid(np.arange(x_min,x_max,0.02),np.arange(y_min,y_max,0.02))
z = model.predict(np.c_[xx.ravel(),yy.ravel()])
#扁平化,得到一个一个的点
#ravel和flatten类似,多维数据转一维,flatten不会改变原始数据,而ravel会
z = z.reshape(xx.shape)
#等高线图
#在这里,只有两个高度,0和1
cs = plt.contourf(xx,yy,z)
#样本散点图
plt.scatter(x_data[:,0],x_data[:,1],c=y_data)
plt.show()predictions = model.predict(x_data)
#查看分类的结果和正确率等等。
print(classification_report(predictions,y_data))

等高线图:

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report
from sklearn import tree
from sklearn.model_selection import train_test_split#载入数据
data = np.genfromtxt('test2.csv',delimiter=',')
x_data = data[:,:-1]
y_data = data[:,-1]plt.scatter(x_data[:,0],x_data[:,-1],c=y_data)
plt.show()#分割数据
x_train,x_test,y_train,y_test = train_test_split(x_data,y_data)#创建决策树模型
model = tree.DecisionTreeClassifier()
model.fit(x_data,y_data)#导出决策树
import graphvizdot_data = tree.export_graphviz(model,out_file=None,#特征的名字,要设置feature_names = ['x','y'],class_names=['label0','label1'],filled=True,rounded=True,special_characters=True)
graph = graphviz.Source(dot_data)
graph.render('cart_1')#获取数据值所在范围
x_min,x_max = x_data[:,0].min() - 1,x_data[:,0].max() + 1
y_min,y_max = x_data[:,1].min() - 1,x_data[:,1].max() + 1#生成网格矩阵
xx,yy = np.meshgrid(np.arange(x_min,x_max,0.02),np.arange(y_min,y_max,0.02))
z = model.predict(np.c_[xx.ravel(),yy.ravel()])
#扁平化,得到一个一个的点
#ravel和flatten类似,多维数据转一维,flatten不会改变原始数据,而ravel会
z = z.reshape(xx.shape)
#等高线图
#在这里,只有两个高度,0和1
cs = plt.contourf(xx,yy,z)
#样本散点图
plt.scatter(x_data[:,0],x_data[:,1],c=y_data)
plt.show()predictions = model.predict(x_train)
#查看分类的结果和正确率等等。
print(classification_report(predictions,y_train))predictions = model.predict(x_test)
print(classification_report(predictions,y_test))

画出的决策树非常的复杂,过拟合,最后所得结果正确率并不高。
所以我们对它进行剪枝。

剪枝代码:

#创建决策树模型
#max_depth,树的深度
#min_samples_split内部节点再划分所需的最小样本数,比如到了这个节点后发现样本数只有四,则不再往下分割了
model = tree.DecisionTreeClassifier(max_depth=4,min_samples_split=4)
model.fit(x_data,y_data)

决策树-线性二分类+非线性二分类相关推荐

  1. 决策树算法模型的归类与整理(ID3&C4.5&CART&线性二分类&非线性二分类)

    决策树算法模型的归类与整理(ID3&C4.5&CART&线性二分类&非线性二分类) 一. 总结摘要 决策树模型在监督学习中非常常见,可用于分类(二分类.多分类)和回归. ...

  2. ML:基于自定义数据集利用Logistic、梯度下降算法GD、LoR逻辑回归、Perceptron感知器、SVM支持向量机、LDA线性判别分析算法进行二分类预测(决策边界可视化)

    ML:基于自定义数据集利用Logistic.梯度下降算法GD.LoR逻辑回归.Perceptron感知器.支持向量机(SVM_Linear.SVM_Rbf).LDA线性判别分析算法进行二分类预测(决策 ...

  3. 数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC...

    全文链接:http://tecdat.cn/?p=27384 在本文中,数据包含有关葡萄牙"Vinho Verde"葡萄酒的信息(点击文末"阅读原文"获取完整代 ...

  4. 机器学习之支持向量机SVM之python实现ROC曲线绘制(二分类和多分类)

    目录 一.ROC曲线 二.TP.FP.TN.FN 三. python绘制ROC曲线(二分类) 1.思路 2.关键代码 3.完整代码 四. python绘制ROC曲线(多分类) 五.参考文献 一.ROC ...

  5. 机器学习之深度学习 二分类、多分类、多标签分类、多任务分类

    多任务学习可以运用到许多的场景. 首先,多任务学习可以学到多个任务的共享表示,这个共享表示具有较强的抽象能力,能够适应多个不同但相关的目标,通常可以使主任务获取更好的泛化能力. 此外,由于使用了共享表 ...

  6. R语言 | 二分类和多分类的逻辑回归实现

    目录 二分类逻辑回归 数据准备 模型构建 模型检验 多分类逻辑回归 二分类逻辑回归 首先,我先展示下我逻辑回归的总体代码,如果有基础的同志需要的话,可以直接修改数据和参数拿去用呀: library(l ...

  7. 六、(机器学习)-Adaboost提升树-二分类和多分类(最清晰最易懂)

    Adaboost提升树 一.bagging与boosting bagging即套袋法,通过对训练样本重新采样的方法得到不同的训练样本集,在这些新的训练样本集上分别训练学习器,最终合并每一个学习器的结果 ...

  8. 分类家族:二分类、多分类、多标签分类、多输出分类

    分类家族:二分类.多分类.多标签分类.多输出分类 目录 分类家族:二分类.多分类.多标签分类.多输出分类 二分类

  9. 机器学习 二分类分类阈值_分类指标和阈值介绍

    机器学习 二分类分类阈值_分类指标和阈值介绍_weixin_26752765的博客-CSDN博客 机器学习 二分类分类阈值_分类指标和阈值介绍_weixin_26752765的博客-CSDN博客

最新文章

  1. 迷你MVVM框架 avalonjs 学习教程14、事件绑定
  2. 迈克尔·乔丹,无可复制的篮球之神!
  3. 10行代码AC——1016 部分A+B (15分)
  4. opc服务器状态红叉,西门子S7-300与上位机通过OPC服务器的通讯设置分解.pdf
  5. python装饰器记录每一个函数的执行时间
  6. STM32F0xx_ADC采集电压配置详细过程
  7. vue 使用了浏览器的刷新之后报错_Electron-vue运行之后出现了文件浏览器
  8. 基于 HTML5 WebGL 的 3D 场景中的灯光效果
  9. 【体系结构】一条SQL语句经历了什么
  10. 6款主流PDF编辑器测试,快来看看哪一款最适合你吧
  11. 基于java的教师教学评价管理系统
  12. 约束的操作 - 增加 删除 禁止 启用
  13. 小程序跳转小程序,小程序跳转公众号,小程序跳转h5
  14. css中的vw/vh与%
  15. 电脑录屏怎么录视频?了解几个小技巧
  16. bilibili缓存文件在哪里_用这3招,彻底清除Windows10更新缓存,电脑高手必会
  17. 网站后台服务器进不去,网站进不去后台有什么原因啊?急
  18. jdbcTemplate打印sql
  19. NOI题库 scratch题解(部分)
  20. Unity3D如何开发最简单的VR游戏 vrPlus(神之眼)

热门文章

  1. watir-webdriver使用过程中异常
  2. 【递归】n个数的全排列
  3. 通过命令行编译器来编译运行程序
  4. T05 FX 试打报告
  5. Flash Video带宽估测
  6. Eeic Meyer on CSS 之 背景半透明效果
  7. ComponentArt控件分析之ComboBox(2)
  8. [转载] $CF290F$ 题解
  9. Java学习个人备忘录之接口
  10. 获取电脑系统当前时间