个人网站:红色石头的机器学习之路
CSDN博客:红色石头的专栏
知乎:红色石头
微博:RedstoneWill的微博
GitHub:RedstoneWill的GitHub
微信公众号:AI有道(ID:redstonewill)

本文所有的源代码均放在了我的GitHub上,需要的点击文末「阅读原文」获取。如果对你有用的话,别忘了Fork和Star哦!

什么是感知机「Perceptron」

PLA全称是Perceptron Linear Algorithm,即线性感知机算法,属于一种最简单的感知机(Perceptron)模型。

感知机模型是机器学习二分类问题中的一个非常简单的模型。它的基本结构如下图所示:

其中,xixix_i是输入,wiwiw_i表示权重系数,bbb表示偏移常数。感知机的线性输出为:

scores=∑iNwixi+b" role="presentation">scores=∑iNwixi+bscores=∑iNwixi+b

scores=\sum_i^Nw_ix_i+b

为了简化计算,通常我们将bbb作为权重系数的一个维度,即w0" role="presentation" style="position: relative;">w0w0w_0。同时,将输入xxx扩展一个维度,为1。这样,上式简化为:

scores=∑iN+1wixi" role="presentation">scores=∑iN+1wixiscores=∑iN+1wixi

scores=\sum_i^{N+1}w_ix_i

scoresscoresscores是感知机的输出,接下来就要对scoresscoresscores进行判断:

  • 若scores≥0scores≥0scores\geq0,则y^=1y^=1\hat y=1(正类)

  • 若scores<0scores<0scores,则y^=−1y^=−1\hat y=-1(负类)

以上就是线性感知机模型的基本概念,简单来说,它由线性得分计算阈值比较两个过程组成,最后根据比较结果判断样本属于正类还是负类。

PLA理论解释

对于二分类问题,可以使用感知机模型来解决。PLA的基本原理就是逐点修正,首先在超平面上随意取一条分类面,统计分类错误的点;然后随机对某个错误点就行修正,即变换直线的位置,使该错误点得以修正;接着再随机选择一个错误点进行纠正,分类面不断变化,直到所有的点都完全分类正确了,就得到了最佳的分类面。

利用二维平面例子来进行解释,第一种情况是错误地将正样本(y=1)分类为负样本(y=-1)。此时,wx<0wx<0wx,即www与x" role="presentation" style="position: relative;">xxx的夹角大于90度,分类线lll的两侧。修正的方法是让夹角变小,修正w" role="presentation" style="position: relative;">www值,使二者位于直线同侧:

w:=w+x=w+yxw:=w+x=w+yx

w:=w+x=w+yx

修正过程示意图如下所示:

第二种情况是错误地将负样本(y=-1)分类为正样本(y=1)。此时,wx>0wx>0wx>0,即www与x" role="presentation" style="position: relative;">xxx的夹角小于90度,分类线lll的同一侧。修正的方法是让夹角变大,修正w" role="presentation" style="position: relative;">www值,使二者位于直线两侧:

w:=w−x=w+yxw:=w−x=w+yx

w:=w-x=w+yx

修正过程示意图如下所示:

经过两种情况分析,我们发现PLA每次www的更新表达式都是一样的:w:=w+yx" role="presentation" style="position: relative;">w:=w+yxw:=w+yxw:=w+yx。掌握了每次www的优化表达式,那么PLA就能不断地将所有错误的分类样本纠正并分类正确。

数据准备

导入数据

数据集存放在’../data/’目录下,该数据集包含了100个样本,正负样本各50,特征维度为2。

import numpy as np
import pandas as pddata = pd.read_csv('./data/data1.csv', header=None)
# 样本输入,维度(100,2)
X = data.iloc[:,:2].values
# 样本输出,维度(100,)
y = data.iloc[:,2].values

数据分类与可视化

下面我们在二维平面上绘出正负样本的分布情况。

import matplotlib.pyplot as pltplt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.title('Original Data')
plt.show()

PLA算法

特征归一化

首先分别对两个特征进行归一化处理,即:

X=X−μσ" role="presentation">X=X−μσX=X−μσ

X=\frac{X-\mu}{\sigma}

其中,μμ\mu是特征均值,σσ\sigma是特征标准差。

# 均值
u = np.mean(X, axis=0)
# 方差
v = np.std(X, axis=0)X = (X - u) / v# 作图
plt.scatter(X[:50, 0], X[:50, 1], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 0], X[50:, 1], color='red', marker='x', label='Negative')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.title('Normalization data')
plt.show()

直线初始化

# X加上偏置项
X = np.hstack((np.ones((X.shape[0],1)), X))
# 权重初始化
w = np.random.randn(3,1)

显示初始化直线位置:

# 直线第一个坐标(x1,y1)
x1 = -2
y1 = -1 / w[2] * (w[0] * 1 + w[1] * x1)
# 直线第二个坐标(x2,y2)
x2 = 2
y2 = -1 / w[2] * (w[0] * 1 + w[1] * x2)
# 作图
plt.scatter(X[:50, 1], X[:50, 2], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 1], X[50:, 2], color='red', marker='x', label='Negative')
plt.plot([x1,x2], [y1,y2],'r')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.show()

由上图可见,一般随机生成的分类线,错误率很高。

计算scores,更新权重

接下来,计算scores,得分函数与阈值0做比较,大于零则y^=1y^=1\hat y=1,小于零则y^=−1y^=−1\hat y=-1

s = np.dot(X, w)
y_pred = np.ones_like(y)    # 预测输出初始化
loc_n = np.where(s < 0)[0]    # 大于零索引下标
y_pred[loc_n] = -1

接着,从分类错误的样本中选择一个,使用PLA更新权重系数www。

# 第一个分类错误的点
t = np.where(y != y_pred)[0][0]
# 更新权重w
w += y[t] * X[t, :].reshape((3,1))

迭代更新训练

更新权重w" role="presentation" style="position: relative;">www是个迭代过程,只要存在分类错误的样本,就不断进行更新,直至所有的样本都分类正确。(注意,前提是正负样本完全可分)

for i in range(100):s = np.dot(X, w)y_pred = np.ones_like(y)loc_n = np.where(s < 0)[0]y_pred[loc_n] = -1num_fault = len(np.where(y != y_pred)[0])print('第%2d次更新,分类错误的点个数:%2d' % (i, num_fault))if num_fault == 0:breakelse:t = np.where(y != y_pred)[0][0]w += y[t] * X[t, :].reshape((3,1))

迭代完毕后,得到更新后的权重系数ww<script type="math/tex" id="MathJax-Element-37">w</script>,绘制此时的分类直线是什么样子。

# 直线第一个坐标(x1,y1)
x1 = -2
y1 = -1 / w[2] * (w[0] * 1 + w[1] * x1)
# 直线第二个坐标(x2,y2)
x2 = 2
y2 = -1 / w[2] * (w[0] * 1 + w[1] * x2)
# 作图
plt.scatter(X[:50, 1], X[:50, 2], color='blue', marker='o', label='Positive')
plt.scatter(X[50:, 1], X[50:, 2], color='red', marker='x', label='Negative')
plt.plot([x1,x2], [y1,y2],'r')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.legend(loc = 'upper left')
plt.show()

其实,PLA算法的效率还算不错,只需要数次更新就能找到一条能将所有样本完全分类正确的分类线。所以得出结论,对于正负样本线性可分的情况,PLA能够在有限次迭代后得到正确的分类直线。

总结与疑问

本文导入的数据本身就是线性可分的,可以使用PCA来得到分类直线。但是,如果数据不是线性可分,即找不到一条直线能够将所有的正负样本完全分类正确,这种情况下,似乎PCA会永远更新迭代下去,却找不到正确的分类线。

对于线性不可分的情况,该如何使用PLA算法呢?我们下次将对PLA进行改进和优化。

阅读原文

更多干货文章请关注公众号:AI有道(ID:redstonewill)

一看就懂的感知机算法PLA相关推荐

  1. 【附源码】一看就懂的感知机算法PLA

    AI有道 一个有情怀的公众号 本文所有的源代码均放在了我的GitHub上,需要的点击文末「阅读原文」获取.如果对你有用的话,别忘了Fork和Star哦! 什么是感知机「Perceptron」 PLA全 ...

  2. 机器学习之路(四)之感知机算法PLA

    基本介绍: 如果训练集数据是两个互不相交的凸集的子集,那么可以找到一个支撑超平面将两个子集分开.那么,如果这个超平面是一个n维的线性方程,就称之为线性分类器.线性分类器是最简单和最基本的分类器,也是最 ...

  3. 字符串:你看的懂的KMP算法(带验证)

    前言 KMP算法可以说说许多学习算法的同学的第一道坎,要么是领会不到KMP算法的思想,要么是知道思想写不出代码,网上各种查找.关于算法的书籍上也都有KMP算法的实现,可为啥自己写不出来呢?博主看得大话 ...

  4. 你也能看得懂的python算法书pdf_你也能看得懂的Python算法书最新章节_王硕著_掌阅小说网...

    1.2 三大结构 Python语言中有三大结构:循序.分支和循环.这三种结构分别适用于不同的情况,一个复杂的程序中常常同时包含这三种结构. 1.2.1 循序结构 说到"循序"这个词 ...

  5. 《你也能看得懂的Python算法书》学习笔记(四)

    在学习完哈希算法之后,我们开始学习深度优先遍历算法.深度优先遍历算法是经典的图论算法,从某个节点v出发开始搜索,不断搜索直到该节点的所有边都被遍历完.当节点v的所有边都被遍历以后,深度优先遍历算法需要 ...

  6. 【一看就懂的图解算法】简单选择排序

    简单选择排序 冒泡排序是将最大的元素往后面排,简单选择排序是将小的元素往前面排 算法思想: 1.将第一个元素和其余元素进行对比,如果第一个元素和其他元素相比,第一个元素大,则交换,一轮下来,最小的元素 ...

  7. 通俗解释优化的线性感知机算法:Pocket PLA

    个人网站:红色石头的机器学习之路 CSDN博客:红色石头的专栏 知乎:红色石头 微博:RedstoneWill的微博 GitHub:RedstoneWill的GitHub 微信公众号:AI有道(ID: ...

  8. 一层循环时间复杂度_算法的时间与空间复杂度(一看就懂)

    算法(Algorithm)是指用来操作数据.解决程序问题的一组方法.对于同一个问题,使用不同的算法,也许最终得到的结果是一样的,但在过程中消耗的资源和时间却会有很大的区别. 那么我们应该如何去衡量不同 ...

  9. 感知机算法(一)PLA

    文章目录 1.什么是感知机「Perceptron」 2.PLA理论解释 3.数据准备 导入数据 数据分类与可视化 直线初始化 计算scores,更新权重 迭代更新训练 4.缺点分析 5.全部代码: 6 ...

最新文章

  1. 独家 | 如何解决深度学习泛化理论
  2. 学界 | 量化深度强化学习算法的泛化能力
  3. 本地实现ES6转ES5代码——gulpfile配置文件
  4. QT + OpenCV + MinGW 在windows下配置开发环境
  5. arcgis oracle trace,ArcGIS应用Oracle Spatial特征分析
  6. Nacos(一)之简介
  7. 2019阿里巴巴技术面试题集锦(含答案)
  8. 基础计算机b卷,计算机应用基础B卷.doc
  9. 安装中文VS2008 SP1 和.NETFRAMEWORK 3.5SP1后智能提示是英文的解决办法
  10. MATLAB中如何生成指定范围的随机整数向量
  11. 三大无线技术 —— WiFi、蓝牙、HomeRF(无线网卡、WPAN)
  12. Linux文本三剑客超详细教程---grep、sed、awk
  13. 10Gb以太网——数据中心的未来
  14. 2019最新论文阅读-BlazeFace:面向移动设备的实时人脸检测
  15. 网络运维经验分享01
  16. 解析新浪微博表情包的一套js代码
  17. ArcMap教程:合并ShapeFile中多个要素
  18. 共享硬盘没有权限访问计算机,win7系统访问磁盘共享没有权限的解决方法
  19. Gmail,OutLook邮箱基于Oauth2.0协议授权登录邮箱客户端
  20. IE和Firefox浏览器下javascript、CSS兼容性研究

热门文章

  1. 12.10课堂学习----实例化、构造方法案例
  2. ANDROID L——Material Design综合应用(Demo)
  3. Linux基本网路配置及软件包的安装
  4. Ogre读取中文路径名的文件失败的解决办法
  5. VC调用matlab中定义的.m文件中的函数的实例
  6. hdu 3046(最小割最大流)
  7. poj 1151(线段树求面积并)
  8. hdu 1227(二维dp)
  9. NYOJ 1085 数单词 (AC自动机模板题)
  10. oracle学习-PL SQL 存储过程中循环