1. 感知机

模型
感知机是根据输入实例的特征向量xxx对其进行二类分类的线性分类模型:

f(x)=sign⁡(w⋅x+b)f(x)=\operatorname{sign}(w \cdot x+b) f(x)=sign(w⋅x+b)

感知机模型对应于输入空间(特征空间)中的分离超平面w⋅x+b=0w \cdot x+b=0w⋅x+b=0

策略
感知机学习的策略是极小化损失函数:

min⁡w,bL(w,b)=−∑xi∈Myi(w⋅xi+b)\min _{w, b} L(w, b)=-\sum_{x_{i} \in M} y_{i}\left(w \cdot x_{i}+b\right) w,bmin​L(w,b)=−xi​∈M∑​yi​(w⋅xi​+b)

损失函数对应于误分类点到分离超平面的总距离。

方法
  感知机学习算法是基于 随机梯度下降法 的对损失函数的最优化算法,有原始形式和对偶形式。当训练数据集线性可分时,感知机学习算法是收敛的。当训练数据集线性可分时,感知机学习算法存在无穷多个解,其解由于不同的初值或不同的迭代顺序而可能有所不同。


2. 二分类模型

模型
f(x)=sign(w⋅x+b)f(x) = sign(w\cdot x + b)f(x)=sign(w⋅x+b)

sign⁡(x)={+1,x⩾0−1,x<0\operatorname{sign}(x)=\left\{\begin{array}{ll}{+1,} & {x \geqslant 0} \\ {-1,} & {x<0}\end{array}\right.sign(x)={+1,−1,​x⩾0x<0​

策略
给定训练集:

T={(x1,y1),(x2,y2),⋯,(xN,yN)}T=\left\{\left(x_{1}, y_{1}\right),\left(x_{2}, y_{2}\right), \cdots,\left(x_{N}, y_{N}\right)\right\}T={(x1​,y1​),(x2​,y2​),⋯,(xN​,yN​)}

定义感知机的损失函数 L(w,b)L(w, b)L(w,b):

L(w,b)=−∑xi∈Myi(w⋅xi+b)L(w, b)=-\sum_{x_{i} \in M} y_{i}\left(w \cdot x_{i}+b\right)L(w,b)=−∑xi​∈M​yi​(w⋅xi​+b)

算法

随即梯度下降法 Stochastic Gradient Descent

随机抽取一个误分类点使其梯度下降:

{w=w+ηyixib=b+ηyi\left\{\begin{array}{ll}{w = w + \eta y_{i}x_{i}} \\ {b = b + \eta y_{i}}\end{array}\right.{w=w+ηyi​xi​b=b+ηyi​​

当实例点被误分类,即位于分离超平面的错误侧,则调整www, bbb的值,使分离超平面向该无分类点的一侧移动,直至误分类点被正确分类


3. 感知机的实现

首先把我们需要的库导入;

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

3.1 数据集预处理

可以使用 sklearn 库中的 iris 鸢尾花数据集,使用方法如下:

iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.target    # 添加一列
df.columns = ['sepal length','sepal width','petal length','petal width','label'
]
print(df)
print(df.shape) # (150, 5)
# 统计 label 的每种值的数量
print(df.label.value_counts())# 绘制散点图
plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')
plt.scatter(df[100:150]['sepal length'],df[100:150]['sepal width'], label='2')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
plt.show()

输出结果为:

     sepal length  sepal width  petal length  petal width  label
0             5.1          3.5           1.4          0.2      0
1             4.9          3.0           1.4          0.2      0
2             4.7          3.2           1.3          0.2      0
..            ...          ...           ...          ...    ...
147           6.5          3.0           5.2          2.0      2
148           6.2          3.4           5.4          2.3      2
149           5.9          3.0           5.1          1.8      2
[150 rows x 5 columns](150, 5)0    50
1    50
2    50
Name: label, dtype: int64

得到的结果如图:

利用这些数据我们可以得到训练数据的 X 和 y:

data = np.array(df.iloc[:100, [0, 1, -1]])
X, y = data[:, :-1], data[:, [-1]]
y = np.array([1 if i == 1 else -1 for i in y])

此时,鸢尾花样本有 100 个,每个样本有花瓣长度和花瓣宽度两个特征,分类的类别有 2 中,用 +1 和 -1 表示。

3.2 构建模型

根据上面的理论知识,可以利用随机梯度下降算法构建二类分类模型,其中 fit() 为拟合函数:

class Model:def __init__(self):self.W = np.zeros(X.shape[1], dtype=np.float32)self.b = 0self.lr = 0.1def sign(self, X, W, b):y = np.dot(X, W) + breturn 1 if y > 0 else -1def fit(self, X_train, y_train):success = Falsewhile not success:wrong_count = 0for i in range(len(X_train)):X = X_train[i]y = y_train[i]if y * self.sign(X, self.W, self.b) < 0:self.W += self.lr * np.dot(X, y)self.b += y * self.lrwrong_count += 1if wrong_count == 0:success = Truereturn 'Perception Model'def score(self):pass

3.3 预测分类

使用先前的模型先对数据训练,将分类结果绘制出来:

perception = Model()
res = perception.fit(X, y)
print(res)  # Perception Modelx_ = np.linspace(4, 7, 10)
y_ = -(perception.W[0] * x_ + perception.b) / perception.W[1]
plt.plot(x_, y_)plt.plot(data[:50, 0], data[:50, 1], 'bo', color='blue', label='0')
plt.plot(data[50:100, 0], data[50:100, 1], 'bo', color='orange', label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
plt.show()

结果如下:


4. sklearn

sklearn 的 Perception 模型(点击蓝色的字查看参数)为基于随机梯度下降 SGD 的分类模型:

下面使用 sklearn 的 Perceptron 模型:

from sklearn.linear_model import Perceptronclf = Perceptron(fit_intercept=True,max_iter=1000,shuffle=True,tol=None)
clf.fit(X, y)W = clf.coef_[0]
b = clf.intercept_x_ = np.linspace(4, 7, 10)
y_ = -(W[0] * x_ + b) / W[1]
plt.plot(x_, y_)plt.plot(data[:50, 0], data[:50, 1], 'bo', color='blue', label='0')
plt.plot(data[50:100, 0], data[50:100, 1], 'bo', color='orange', label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
plt.show()

其中构造函数的参数 fit_intercept 代表计算 b,max_iter 代表最大训练次数,shuffle 为 True 代表每次训练后清洗数据,表两次训练后误差下降的量小于 tol 时停止训练;

训练好的模型的参数存储在 coef_intercept_

结果如下:


参考内容:

  1. 李航机器学习
  2. lihang-code-master
  3. https://scikit-learn.org/

机器学习笔记 1 —— Perceptron相关推荐

  1. 李弘毅机器学习笔记:第五章—分类

    李弘毅机器学习笔记:第五章-分类 例子(神奇宝贝属性预测) 分类概念 神奇宝贝的属性(水.电.草)预测 回归模型 vs 概率模型 回归模型 其他模型(理想替代品) 概率模型实现原理 盒子抽球概率举例 ...

  2. 李弘毅机器学习笔记:第十六章—无监督学习

    李弘毅机器学习笔记:第十六章-无监督学习 1-of-N Encoding 词嵌入 基于计数的词嵌入 基于预测的词嵌入 具体步骤 共享参数 训练 Various Architectures 多语言嵌入 ...

  3. 李弘毅机器学习笔记:第七章—深度学习的发展趋势

    李弘毅机器学习笔记:第七章-深度学习的发展趋势 回顾一下deep learning的历史: 1958: Perceptron (linear model) 1969: Perceptron has l ...

  4. 迷人的神经网络——机器学习笔记1

    目录 迷人的神经网络--机器学习笔记1 第1章 神经元模型及网络结构 1.1 神经元模型 1.1.1 单输入神经元 1.1.2 激活函数 1.1.3 多输入神经元 1.2 网络结构 1.2.1 单层神 ...

  5. 机器学习笔记之前馈神经网络(一)基本介绍

    机器学习笔记之前馈神经网络--基本介绍 引言 从机器学习到深度学习 频率学派思想 贝叶斯学派思想 深度学习的发展过程 引言 从本节开始,将介绍前馈神经网络. 从机器学习到深度学习 在机器学习笔记开始- ...

  6. 一份520页的机器学习笔记!附下载链接

    点击上方"视学算法",选择"星标"公众号 第一时间获取价值内容 近日,来自SAP(全球第一大商业软件公司)的梁劲(Jim Liang)公开了自己所写的一份 52 ...

  7. 700 页的机器学习笔记火了!完整版开放下载

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 作者       梁劲(Jim Liang),来自SAP(全球第一大商业软件公司). 书籍特点       条理清晰 ...

  8. 机器学习笔记十四:随机森林

    在上一篇机器学习笔记十三:Ensemble思想(上)中,简要的提了一下集成学习的原理和两种主要的集成学习形式.  而在这部分要讲的随机森林,就算是其中属于bagging思路的一种学习方法.为了篇幅,b ...

  9. 机器学习笔记七:损失函数与风险函数

    一.损失函数 回顾之前讲过的线性回归模型,我们为了要学习参数使得得到的直线更好的拟合数据,我们使用了一个函数 这个函数就是比较模型得到的结果和"真实值"之间的"差距&qu ...

最新文章

  1. C#获取刚插入的数据的id
  2. redis 计数 java_redission计数器实现,redisTemplate计数器
  3. canvas基础之旅
  4. 美联储降息首日:资本市场反向操作 道指狂泻800点
  5. arduino 嗡鸣器 音乐_arduino蜂鸣器怎么输出指定的音乐
  6. [转载] Python字符串操作大全(一)
  7. 大型分布式网站术语分析
  8. Excel制作二维码、条形码?你肯定没见过
  9. Linux 下sha1加密
  10. C++如何开发验证码短信接口API
  11. java抽象类存在的意义
  12. PC版京东炸年兽活动一键做任务 全民自动炸年兽最新版1.1
  13. C盘各个文件的简单介绍
  14. win11升级后黑屏问题
  15. os.system和os.popen和commands
  16. 【Eternallyc】函数reverse的基本用法
  17. C++中怎么表示根号下的数字(用cmath中的sqrt()可以开根号)
  18. mysql中字段约束unique_什么是MySQL UNIQUE约束,我们如何将其应用于表的字段?
  19. 学习笔记-Matlab之多项式详解
  20. 在地址栏显示网站的图标

热门文章

  1. Delayed延时队列 来实现关闭已超时的任务或订单
  2. 百变星君Beta分布
  3. Ubuntu20.04 修改内核版本,降低版本, 锁定内核
  4. js函数Math.random()取某区间内的随机数公式推导
  5. 上海飞国内最远是哪里_嫦娥2号,距离地球最远?天问一号将打破记录
  6. 第13期:动态规划-dp题集
  7. 【Typescript】paths alias别名设置
  8. 【PCB布局布线】之蛇行等长布线(转)
  9. 小偷模拟器 Thief Simulator V20230207 最新中文学习版 单机游戏游戏下载免安装【3.27G】
  10. 杨焘鸣:沟通中合一架构的艺术