本文摘要 · 理论来源:【统计学习方法】第二章 感知机
· 技术支持:pandas(读csv)、matplotlib(画图)、numpy、sklearn.linear_model.Perceptron(感知机模型)、随机梯度下降思想
· 代码目的:利用手写、sklearn两种感知机模型,对鸢尾花数据集进行二分类
作者:CSDN 征途黯然.

  

一、鸢尾花(iris)数据集

  Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种。
下载地址:点击此处

二、代码描述

  1、首先,画出数据集中150个数据的前两个特征的散点分布图,我们观察到品种‘Iris-setosa’与‘Iris-versicolor’之间是线性可分的:


  2、然后,我们对以上两个品种共100条数据、2个维度进行二分类,利用我们自己定义的感知机模型,效果如下图:

  3、利用sklearn库提供的感知机模型,效果如下图,有一个点没有正确分类:

三、python代码(注释详细)

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import Perceptron"""自定义感知机模型"""
# 数据线性可分,二分类数据
# 此处为一元一次线性方程
class Model:def __init__(self):# 创建指定形状的数组,数组元素以 1 来填充self.w = np.ones(len(data[0]) - 1, dtype=np.float32)self.b = 0  # 初始w/b的值self.l_rate = 0.1# self.data = datadef sign(self, x, w, b):y = np.dot(x, w) + b  # 求w,b的值# Numpy中dot()函数主要功能有两个:向量点积和矩阵乘法。# 格式:x.dot(y) 等价于 np.dot(x,y) ———x是m*n 矩阵 ,y是n*m矩阵,则x.dot(y) 得到m*m矩阵return y# 随机梯度下降法# 随机梯度下降法(SGD),随机抽取一个误分类点使其梯度下降。根据损失函数的梯度,对w,b进行更新def fit(self, X_train, y_train):  # 将参数拟合 X_train数据集矩阵 y_train特征向量is_wrong = False# 误分类点的意思就是开始的时候,超平面并没有正确划分,做了错误分类的数据。while not is_wrong:wrong_count = 0  # 误分为0,就不用循环,得到w,bfor d in range(len(X_train)):X = X_train[d]y = y_train[d]if y * self.sign(X, self.w, self.b) <= 0:# 如果某个样本出现分类错误,即位于分离超平面的错误侧,则调整参数,使分离超平面开始移动,直至误分类点被正确分类。self.w = self.w + self.l_rate * np.dot(y, X)  # 调整w和bself.b = self.b + self.l_rate * ywrong_count += 1if wrong_count == 0:is_wrong = Truereturn 'Perceptron Model!'# 得分def score(self):pass# 导入数据集
df = pd.read_csv('./iris/Iris.csv', usecols=[1, 2, 3, 4, 5])# pandas打印表格信息
# print(df.info())# pandas查看数据集的头5条记录
# print(df.head())"""绘制训练集基本散点图,便于人工分析,观察数据集的线性可分性"""
# 表示绘制图形的画板尺寸为8*5
plt.figure(figsize=(8, 5))
# 散点图的x坐标、y坐标、标签
plt.scatter(df[:50]['SepalLengthCm'], df[:50]['SepalWidthCm'], label='Iris-setosa')
plt.scatter(df[50:100]['SepalLengthCm'], df[50:100]['SepalWidthCm'], label='Iris-versicolor')
plt.scatter(df[100:150]['SepalLengthCm'], df[100:150]['SepalWidthCm'], label='Iris-virginica')
plt.xlabel('SepalLengthCm')
plt.ylabel('SepalWidthCm')
# 添加标题 '鸢尾花萼片的长度与宽度的散点分布'
plt.title('Scattered distribution of length and width of iris sepals.')
# 显示标签
plt.legend()
plt.show()# 取前100条数据中的:前2个特征+标签,便于训练
data = np.array(df.iloc[:100, [0, 1, -1]])
# 数据类型转换,为了后面的数学计算
X, y = data[:, :-1], data[:, -1]
y = np.array([1 if i == 'Iris-setosa' else -1 for i in y])"""自定义感知机模型,开始训练"""
perceptron = Model()
perceptron.fit(X, y)
# 最终参数
print(perceptron.w, perceptron.b)
# 绘图
x_points = np.linspace(4, 7, 10)
y_ = -(perceptron.w[0] * x_points + perceptron.b) / perceptron.w[1]
plt.plot(x_points, y_)
plt.scatter(df[:50]['SepalLengthCm'], df[:50]['SepalWidthCm'], label='Iris-setosa')
plt.scatter(df[50:100]['SepalLengthCm'], df[50:100]['SepalWidthCm'], label='Iris-versicolor')
plt.xlabel('SepalLengthCm')
plt.ylabel('SepalWidthCm')
# 添加标题 '自定义感知机模型训练结果'
plt.title('Training results of Custom perceptron model.')
plt.legend()
plt.show()"""sklearn感知机模型,开始训练"""
# 使用训练数据进行训练
clf = Perceptron()
# 得到训练结果,权重矩阵
clf.fit(X, y)
# Weights assigned to the features.输出特征权重矩阵
# print(clf.coef_)
# 超平面的截距 Constants in decision function.
# print(clf.intercept_)
# 对测试集预测
# print(clf.predict([[6.0, 4.0]]))
# 对训练集评分
# print(clf.score(X, y))# 绘图
x_points = np.linspace(4, 7, 10)
y_ = -(clf.coef_[0][0] * x_points + clf.intercept_[0]) / clf.coef_[0][1]
plt.plot(x_points, y_)
plt.scatter(df[:50]['SepalLengthCm'], df[:50]['SepalWidthCm'], label='Iris-setosa')
plt.scatter(df[50:100]['SepalLengthCm'], df[50:100]['SepalWidthCm'], label='Iris-versicolor')
plt.xlabel('SepalLengthCm')
plt.ylabel('SepalWidthCm')
# 添加标题 'sklearn感知机模型训练结果'
plt.title('Training results of sklearn perceptron model.')
plt.legend()
plt.show()

【统计学习方法】感知机对鸢尾花(iris)数据集进行二分类相关推荐

  1. 【统计学习方法】线性可分支持向量机对鸢尾花(iris)数据集进行二分类

    本文摘要 · 理论来源:[统计学习方法]第七章 SVM · 技术支持:pandas(读csv).numpy.sklearn.svm.svm思想.matplotlib.pyplot(绘图) · 代码目的 ...

  2. 【统计学习方法】K近邻对鸢尾花(iris)数据集进行多分类

    本文摘要 · 理论来源:[统计学习方法]第三章 K近邻 · 技术支持:pandas(读csv).collections.Counter(统计).numpy.sklearn.neighbors.KNei ...

  3. sklearn基础篇(三)-- 鸢尾花(iris)数据集分析和分类

    后面对Sklearn的学习主要以<Python机器学习基础教程>和<机器学习实战基于scikit-learn和tensorflow>,两本互为补充进行学习,下面是开篇的学习内容 ...

  4. 统计学习方法-感知机概括和补充

    前言 <统计学习方法>第二版出了有段时间了,最近得空可以拜读一下.之前看第一版的时候还是一年多以前,那个时候看的懵懵懂懂的,很吃力.希望这一次能够有所收获,能够收获新的东西,这些文章只是用 ...

  5. 【统计学习方法】朴素贝叶斯对鸢尾花(iris)数据集进行训练预测

    本文摘要 · 理论来源:[统计学习方法]第三四章 朴素贝叶斯 · 技术支持:pandas(读csv).numpy.sklearn.naive_bayes.GaussianNB(高斯朴素贝叶斯模型).s ...

  6. 统计学习方法|感知机原理剖析及实现

    欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...

  7. [机器学习-sklearn]鸢尾花Iris数据集

    鸢尾花数据集 1. 鸢尾花Iris数据集介绍 2. Sklearn代码获取Iris 2. 描述性统计 3. 箱线图 4. 数据分布情况 1. 鸢尾花Iris数据集介绍 Iris flower数据集是1 ...

  8. ML之SVM:基于SVM(sklearn+subplot)的鸢尾花iris数据集的前两个特征(线性不可分的两个样本),判定鸢尾花是哪一种类型

    ML之SVM:基于SVM(sklearn+subplot)的鸢尾花iris数据集的前两个特征(线性不可分的两个样本),判定鸢尾花是哪一种类型 目录 输出结果 实现代码 输出结果 (1).黄色的点为支持 ...

  9. MAT之ELM:ELM实现鸢尾花(iris数据集)种类测试集预测识别正确率(better)结果对比

    MAT之ELM:ELM实现鸢尾花(iris数据集)种类测试集预测识别正确率(better)结果对比 目录 输出结果 实现代码 输出结果 实现代码 load iris_data.mat P_train ...

最新文章

  1. Flume Source 实例
  2. java web开发之 spring单元测试
  3. 河北大学计算机专业调剂,【计算机考研调剂】河北大学2021级硕士研究生预调剂信息统计的通知...
  4. 2020中国DevOps社区峰会(成都站),雄关漫道,砥砺前行
  5. shell脚本中常见的几个判断
  6. RocketMQ是怎么存储消息的?
  7. Android 8.0学习(8)---内核文件系统优化
  8. Python之网络编程(TCP套接字与UDP套接字)
  9. jdbc至sql server的两种常见方法
  10. 用python开发的运维管理系统_Python运维三十六式:用Python写一个简单的监控系统...
  11. oracle连接ORA-01017、ORA-12640
  12. OFD在线预览方案评测
  13. 1155 服务器芯片组,2014年主流主板芯片组分析
  14. Migration中的Collation Confliction
  15. 打造自己的HelloDrone 无人机APP过程《3》
  16. java集成极光推送实现Android的消息推送
  17. keras教程【2】编写CNN
  18. 猜数字游戏(小游戏编码)
  19. 远程连接MySQL, 10038问题
  20. Proteus仿真电路笔记

热门文章

  1. Anaconda | CentOS7 -解决 Python2和Python3共存
  2. AI提高药物发现效率 | ML,Supercomputers and Big Data
  3. 零基础入门学习Python(16)-函数1,Python的乐高积木
  4. 零基础入门学习Python(15)-序列
  5. 测序发展史,150年的风雨历程 (第二版)
  6. PBio-2018:如何设计可预测植物表型的微生物组
  7. R语言ggplot2可视化:ggplot2可视化直方图(histogram)并在直方图的顶部外侧(top upper)或者直方图内部添加数值标签
  8. R语言ggplot2可视化自定义图例标签间距实战:自定义图例标签间距、自定义图例与图像之间的间距
  9. R语言stringr包str_ends函数、str_starts函数起始、结束字符串判断实战
  10. R语言tidyr包pivot_longer函数、pivot_wider函数数据表变换实战(长表到宽表、宽表到长表)