《统计学习方法》——感知机
引言
本文主要介绍了感知机引入,并且附有一些实例代码和恰当的图片,让大家更容易理解。
首先介绍下什么是感知机。
感知机
感知机是二类分类的线性模型,其输入为实例的特征向量,输出为实例的类别,类别取+1,-1二值。
感知机对应于输入空间(特征空间)中将实例划分为正负两类的分离超平面,属于判别模型。
感知机分为原始形式和对偶形式。
感知机模型
假设输入空间时X⊆RnX \subseteq R^nX⊆Rn,输出空间时Y={+1,−1}Y = \{+1,-1\}Y={+1,−1}。由输入空间到输出空间的如下函数:
f(x)=sign(w⋅x+b)f(x) = sign(w \cdot x + b) f(x)=sign(w⋅x+b)
其中w∈Rnw \in R^nw∈Rn(属于nnn维实数集)叫权值(weight,或叫权重,通常是个向量),b∈Rb \in Rb∈R,叫偏置(或偏差)。
w⋅xw \cdot xw⋅x表示这两个向量的内积。signsignsign是符号函数,即
{+1,x≥0−1,x<0\left\{ \begin{array}{c} +1, x \geq 0 \\ -1, x < 0 \end{array} \right. {+1,x≥0−1,x<0
感知机是一种线性分类模型,其假设空间时定义在特征空间中的所有线性分类模型(或线性分类器),即函数集合 {f∣f(x)=w⋅x+b}\{f|f(x) = w \cdot x + b\}{f∣f(x)=w⋅x+b}
它的几何解释是线性方程w⋅x+b=0w \cdot x + b = 0w⋅x+b=0,该方程对应于特征空间RnR^nRn的一个超平面SSS,www是超平面的法向量,bbb是超平面的截距。
超平面是n维欧氏空间中余维度等于一的线性子空间,也就是必须是(n-1)维度。这是平面中的直线、空间中的平面之推广(n大于3才被称为“超”平面)。
该超平面将特征空间划分为两个部分。位于两部分的点分贝被分为正、负两类。因此,超平面SSS称为分离超平面。
感知机学习的训练数据集(实例的特征向量及类别)是这样的:
T={(x1,y1),(x2,y2),⋯,(xN,yn)}T = \{(x_1,y_1),(x_2,y_2),\cdots,(x_N,y_n)\} T={(x1,y1),(x2,y2),⋯,(xN,yn)}
感知机学习策略
给定上面的数据集TTT,如果存在某个超平面S:w⋅x+b=0S : w \cdot x +b =0S:w⋅x+b=0能将数据集的正实例点和负实例点完全正确地划分到超平面的两侧,即对所有yi=+1y_i=+1yi=+1的实例iii,都有w⋅xi+b>0w \cdot x_i + b >0w⋅xi+b>0,对所有yi=−1y_i= -1yi=−1的实例iii有w⋅xi+b<0w \cdot x_i + b <0w⋅xi+b<0 ,则称数据集TTT为线性可分数据集;否则,线性不可分。
假设训练数据集是线性可分的,感知机学习的目标是求得一个能够将训练集正实例点和负实例点完全正确分开的超平面。为了找出这样的超平面,即确定感知机模型参数w,bw,bw,b,需要确定一个学习策略,即定义损失函数并最小化它。
损失函数一般是计算所有误分类点到超平面SSS的总距离。为此,首先写出输入空间中任一点x0x_0x0到超平面的距离:
1∣∣w∣∣∣w⋅x0+b∣\frac{1}{||w||}|w \cdot x_0 + b| ∣∣w∣∣1∣w⋅x0+b∣
这里∣∣w∣∣||w||∣∣w∣∣是www的L2L_2L2范数,或者说是模。
我们知道点(x,y)(x,y)(x,y)到直线Ax+by+C=0Ax+by+C =0Ax+by+C=0的距离为
Ax+by+CA2+B2\frac{Ax+by+C}{\sqrt{A^2+B^2}} A2+B2Ax+by+C
可以看到它的分母是A2+B2A^2 +B^2A2+B2的开根号,也就是说,分母和截距无关。
拓展到nnn维空间,w⋅x+b=0w \cdot x + b = 0w⋅x+b=0,得
∣w⋅x+b∣∣∣w∣∣\frac{|w\cdot x +b|}{||w||} ∣∣w∣∣∣w⋅x+b∣
∣∣w∣∣=w12+w22+⋯+wn2||w|| = \sqrt{w_1^2 + w_2^2 + \cdots + w_n^2}∣∣w∣∣=w12+w22+⋯+wn2因为距离是正的,所以取绝对值就得到了上面的式子。
对于误分类的数据(xi,yi)(x_i,y_i)(xi,yi)来说,有−yi(w⋅xi+b)>0-y_i(w \cdot x_i + b) > 0−yi(w⋅xi+b)>0。因为当w⋅xi+b>0w \cdot x_i + b >0w⋅xi+b>0时,yi=−1y_i = -1yi=−1(注意是误分类点);而当w⋅xi+b<0w \cdot x_i + b < 0w⋅xi+b<0时,yi=+1y_i = +1yi=+1。它们乘起来在乘以−1-1−1一定是大于000的。
因此,我们把上面得到的距离公式乘以−yi-y_i−yi就得到了误分类点到超平面的距离公式:
−1∣∣w∣∣yi∣w⋅xi+b∣-\frac{1}{||w||}y_i|w \cdot x_i + b| −∣∣w∣∣1yi∣w⋅xi+b∣
因为yyy的取值无非就是+1,−1+1,-1+1,−1,然后上面又说了这个式子是一定大于0的,因此这个公式是成立的。
这样,假设超平面SSS的误分类点集合为MMM,那么所有误分类点到超平面SSS的总距离为:
−1∣∣w∣∣∑xi∈Myi(w⋅xi+b)-\frac{1}{||w||} \sum_{x_i \in M}y_i(w \cdot x_i + b) −∣∣w∣∣1xi∈M∑yi(w⋅xi+b)
如果不考虑1∣∣w∣∣\frac{1}{||w||}∣∣w∣∣1的话,就得到了感知机sign(w⋅x+b)sign(w \cdot x +b)sign(w⋅x+b)学习的损失函数:
L(w,b)=−∑xi∈Myi(w⋅xi+b)L(w,b) = -\sum_{x_i \in M}y_i(w \cdot x_i + b) L(w,b)=−xi∈M∑yi(w⋅xi+b)
注意,MMM是误分类点的集合,如果是正确分类的点,其损失直接为0,就不需要参与计算。
感知机学习的策略是在假设空间中选取使损失函数最小的模型参数w,bw,bw,b
感知机学习算法
感知机学习算法是误分类驱动的,采用随机梯度下降法。首先,任取一个超平面w0,b0w_0,b_0w0,b0,然后用梯度下降法不断地极小化目标函数L(w,b)=−∑xi∈Myi(w⋅xi+b)L(w,b) = -\sum_{x_i \in M}y_i(w \cdot x_i + b)L(w,b)=−∑xi∈Myi(w⋅xi+b)。极小化过程一次随机选取一个误分类点使其梯度下降。
损失函数L(w,b)L(w,b)L(w,b)的梯度由:
∇wL(w,b)=−∑xi∈Myixi∇bL(w,b)=−∑xi∈Myi\nabla_w L(w,b) = - \sum_{x_i \in M} y_ix_i \\ \nabla_b L(w,b) = - \sum_{x_i \in M} y_i ∇wL(w,b)=−xi∈M∑yixi∇bL(w,b)=−xi∈M∑yi
简单的证明一下:
∂L(w,b)∂w=∂(−∑xi∈Myi(w⋅xi+b))∂w=∂(−∑xi∈Myi(w⋅xi))∂w+∂(−∑xi∈Myib)∂w=−∑xi∈Myixi\begin{aligned} \frac{\partial L(w,b)}{\partial w} &= \frac{ \partial ( -\sum_{x_i \in M}y_i(w \cdot x_i + b))}{\partial w} \\ &= \frac{ \partial ( -\sum_{x_i \in M}y_i(w \cdot x_i ))}{\partial w} + \frac{ \partial ( -\sum_{x_i \in M}y_ib)}{\partial w} \\ &=- \sum_{x_i \in M} y_ix_i \end{aligned} ∂w∂L(w,b)=∂w∂(−∑xi∈Myi(w⋅xi+b))=∂w∂(−∑xi∈Myi(w⋅xi))+∂w∂(−∑xi∈Myib)=−xi∈M∑yixi
因为yi,by_i,byi,b是与www无关的,所以微分就是000,对bbb求微分也同理。
随机选一个误分类点(xi,yi)(x_i,y_i)(xi,yi)对w,bw,bw,b进行更新:
w←w+ηyixib←b+ηyiw \leftarrow w + \eta y_ix_i \\ b \leftarrow b + \eta y_i \\ w←w+ηyixib←b+ηyi
其中η\etaη是学习率。
学习算法原始形式为:
- 选取初值w0,b0w_0,b_0w0,b0;
- 在训练数据集中选取数据(xi,yi)(x_i,y_i)(xi,yi);
- 如果yi(w⋅xi+b)≤0y_i(w \cdot x_i + b) \leq 0yi(w⋅xi+b)≤0 ,通过w←w+ηyixi,b←b+ηyiw \leftarrow w + \eta y_ix_i,b \leftarrow b + \eta y_iw←w+ηyixi,b←b+ηyi更新:
- 跳转置2,直到训练集中没有误分类点
直观上的解释是:当一个实例点被误分类,也就是位于分离超平面的错误一侧时,调整w,bw,bw,b的值,使分离超平面向该实例点的一侧移动,以减少误分类点与超平面的距离,直到超平面越过该误分类点使其被正确分类。
我们用iris数据集为例,用sepal length,sepal width作为特征。
# load data
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt# 加载数据
iris = load_iris()
df = pd.DataFrame(iris.data, columns=iris.feature_names)
df['label'] = iris.targetdf.columns = ['sepal length', 'sepal width', 'petal length', 'petal width', 'label']plt.scatter(df[:50]['sepal length'], df[:50]['sepal width'], label='0')
plt.scatter(df[50:100]['sepal length'], df[50:100]['sepal width'], label='1')plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
只取里面的类别0和1,将它们的散点图画出来:
class Perceptron:def __init__(self):self.w = Noneself.b = Nonedef predict(self,X):# X (m,n_x)# w (1,n_x)y = np.dot(X,self.w.T) + self.by[y >= 0] = 1y[y < 0] = -1return ydef fit(self,X_train,y_train,iters=10,learning_rate = 0.1):# X_train (number of,feature_numbers) (m,n_x)n_x = X_train.shape[1]self.w = np.zeros((1,n_x))self.b = 0m = X_train.shape[0]for k in range(iters):wrong_count = 0for i in range(m):# x (1,n_x)# y (1,1)x,y = X_train[[i]],y_train[[i]]if y * (np.dot(self.w,x.T) + self.b) <= 0:wrong_count += 1self.w = self.w + learning_rate * y * xself.b = self.b + learning_rate * np.squeeze(y)if wrong_count == 0:breakprint('finish iter %d, accuracy:%.2f%%' % (k, (m - wrong_count) * 100.0/m))return self.w,self.bdef score(self,X_test,y_test):predict = self.predict(X_test)return np.sum(predict == y_test) / X_test.shape[0]
模型写好了,接下来画出预测的结果:
data = np.array(df.iloc[:100, [0, 1, -1]])#将这100数据转成np数组,最后一列是类别
X, y = data[:,:-1], data[:,-1] # 抽出两个特征放入x,抽出类别放入y
y = np.array([1 if i == 1 else -1 for i in y]).reshape(-1,1) # 把类别 0 转换成 -1p = Perceptron()
w,b = p.fit(X,y,iters = 1000) #需要迭代700次左右才能得到结果x_points = np.linspace(4, 7,10) # 在4和7之间等距离生成10个点
y_ = -(w[0][0] *x_points + b)/w[0][1] # w₀x₀ + w₁y + b = 0 => y = -(w₀x₀ + b)/w₁ 这里把x₁用y来代替,因为要画到二维坐标系上,就是两个维度,另一个当成y
plt.plot(x_points, y_)plt.plot(data[:50, 0], data[:50, 1], 'bo', color='blue', label='-1')
plt.plot(data[50:100, 0], data[50:100, 1], 'bo', color='orange', label='1')
plt.xlabel('sepal length')
plt.ylabel('sepal width')
plt.legend()
plt.show()
算法的收敛性
对于线性可分数据集感知机学习算法原始形式收敛,就是可以得到正确划分的分离超平面。
下面来证明,先将偏置bbb并入权重向量www,记作w^=(wT,b)T\hat w = (w^T,b)^Tw^=(wT,b)T,同样输入向量加入常数111,记作x^=(xT,1)T\hat x = (x^T,1)^Tx^=(xT,1)T,此时有
w^⋅x^=w⋅x+b\hat w \cdot \hat x = w \cdot x + bw^⋅x^=w⋅x+b
(1)由于训练数据集是线性可分的,存在超平面可将训练数据集完全正确分开,取此超平面为w^opt⋅x^=wopt⋅x+bopt=0\hat w_{opt} \cdot \hat x = w_{opt} \cdot x + b_{opt}= 0w^opt⋅x^=wopt⋅x+bopt=0,为了限定这个超平面,使∣∣w^opt∣∣=1||\hat w_{opt}|| = 1∣∣w^opt∣∣=1。
由于对于所有的iii都分类正确,即有
yi(w^opt⋅x^i)=yi(wopt⋅xi+bopt)>0y_i (\hat w_{opt} \cdot \hat x_i) = y_i(w_{opt} \cdot x_i + b_{opt}) > 0yi(w^opt⋅x^i)=yi(wopt⋅xi+bopt)>0
定义最小值γ\gammaγ为
γ=mini{yi(wopt⋅xi+bopt)}\gamma = \min_{i}\{ y_i(w_{opt} \cdot x_i + b_{opt}) \}γ=imin{yi(wopt⋅xi+bopt)}
所以有yi(w^opt⋅x^i)=yi(wopt⋅xi+bopt)≥γ(2.8)y_i (\hat w_{opt} \cdot \hat x_i) = y_i(w_{opt} \cdot x_i + b_{opt}) \geq \gamma \tag{2.8}yi(w^opt⋅x^i)=yi(wopt⋅xi+bopt)≥γ(2.8)
书上定理2.1第一部分证明完毕
(2) 感知机算法从w^0=0\hat w_0 = 0w^0=0开始,如果实例被误分类,则更新权重。令w^k−1\hat w_{k-1}w^k−1是第kkk个误分类实例之前的扩充权重向量,即
w^k−1=(wk−1T,bk−1)T\hat w_{k-1} = (w^T_{k-1},b_{k-1})^Tw^k−1=(wk−1T,bk−1)T
那么第kkk个误分类实例使得yi(w^k−1⋅x^i)=yi(wk−1⋅xi+bk−1)≤0(2.10)y_i (\hat w_{k-1} \cdot \hat x_i) = y_i(w_{k-1} \cdot x_i + b_{k-1}) \leq 0 \tag{2.10}yi(w^k−1⋅x^i)=yi(wk−1⋅xi+bk−1)≤0(2.10)
若(xi,yi)(x_i,y_i)(xi,yi)是被w^k−1=(wk−1T,bk−1)T\hat w_{k-1} = (w^T_{k-1},b_{k-1})^Tw^k−1=(wk−1T,bk−1)T误分类的数据,则www和bbb会通过下式更新得到wk,bkw_k,b_kwk,bk:
wk←wk−1+ηyixibk←bk−1+ηyiw_k \leftarrow w_{k-1} + \eta y_ix_i \\ b_k \leftarrow b_{k-1} + \eta y_i wk←wk−1+ηyixibk←bk−1+ηyi
即扩充权重向量为w^k=w^k−1+ηyix^i(2.11)\hat w_k = \hat w_{k-1} + \eta y_i\hat x_i \tag{2.11}w^k=w^k−1+ηyix^i(2.11)
由式(2.11)(2.11)(2.11)及(2.8)(2.8)(2.8)得
w^k⋅w^opt=w^k−1⋅w^opt+ηyi(w^opt⋅x^i)≥w^k−1⋅w^opt+ηγ\hat w_k \cdot \hat w_{opt} = \hat w_{k-1}\cdot \hat w_{opt} + \eta y_i(\hat w_{opt} \cdot \hat x_i) \geq \hat w_{k-1}\cdot \hat w_{opt} + \eta \gamma w^k⋅w^opt=w^k−1⋅w^opt+ηyi(w^opt⋅x^i)≥w^k−1⋅w^opt+ηγ
继续应用式(2.11)(2.11)(2.11),得到wk−2,wk−3,⋯,w0w_{k-2},w_{k-3},\cdots ,w_0wk−2,wk−3,⋯,w0,比如w^k−1=w^k−2+ηyix^i\hat w_{k-1} = \hat w_{k-2} + \eta y_i\hat x_iw^k−1=w^k−2+ηyix^i
w^k⋅w^opt=w^k−1⋅w^opt+ηyi(w^opt⋅x^i)≥w^k−1⋅w^opt+ηγ≥w^k−2⋅w^opt+ηyi(w^opt⋅x^i)+ηγ=w^k−2⋅w^opt+2ηγ⋯≥w^0⋅w^opt+kηγ\hat w_k \cdot \hat w_{opt} = \hat w_{k-1}\cdot \hat w_{opt} + \eta y_i(\hat w_{opt} \cdot \hat x_i) \geq \hat w_{k-1}\cdot \hat w_{opt} + \eta \gamma \geq \hat w_{k-2}\cdot \hat w_{opt} +\eta y_i(\hat w_{opt} \cdot \hat x_i) + \eta \gamma = \hat w_{k-2}\cdot \hat w_{opt} + 2 \eta \gamma \cdots \geq \hat w_{0}\cdot \hat w_{opt} + k\eta \gamma w^k⋅w^opt=w^k−1⋅w^opt+ηyi(w^opt⋅x^i)≥w^k−1⋅w^opt+ηγ≥w^k−2⋅w^opt+ηyi(w^opt⋅x^i)+ηγ=w^k−2⋅w^opt+2ηγ⋯≥w^0⋅w^opt+kηγ
因为w^0=0\hat w_0 = 0w^0=0,所以上式化简为 w^k⋅w^opt≥kηγ(2.12)\hat w_k \cdot \hat w_{opt} \geq k\eta \gamma \tag{2.12}w^k⋅w^opt≥kηγ(2.12)
由式(2.11)(2.11)(2.11)和式(2.10)(2.10)(2.10)可得 ∣∣w^k∣∣2=∣∣w^k−1∣∣2+2ηyi(w^k−1⋅x^i)+η2∣∣x^i∣∣2≤∣∣w^k−1∣∣2+η2∣∣x^i∣∣2||\hat w_k||^2 = ||\hat w_{k-1}||^2 + 2 \eta y_i (\hat w_{k-1} \cdot \hat x_i) + \eta^2||\hat x_i||^2 \leq ||\hat w_{k-1}||^2 + \eta^2||\hat x_i||^2 ∣∣w^k∣∣2=∣∣w^k−1∣∣2+2ηyi(w^k−1⋅x^i)+η2∣∣x^i∣∣2≤∣∣w^k−1∣∣2+η2∣∣x^i∣∣2
令RRR为所有x^\hat xx^中长度最大的,即R=max1≤i≤N∣∣x^i∣∣R = \max_{1 \leq i \leq N} ||\hat x_i||R=1≤i≤Nmax∣∣x^i∣∣
上式令∣∣x^i∣∣||\hat x_i||∣∣x^i∣∣取最大值RRR即有
∣∣w^k∣∣2≤∣∣w^k−1∣∣2+η2R2≤∣∣w^k−2∣∣2+2η2R2≤⋯≤kη2R2(2.13)||\hat w_k||^2 \leq ||\hat w_{k-1}||^2 + \eta^2 R^2 \leq ||\hat w_{k-2}||^2 + 2 \eta^2 R^2 \leq \cdots \leq k\eta^2R^2 \tag{2.13} ∣∣w^k∣∣2≤∣∣w^k−1∣∣2+η2R2≤∣∣w^k−2∣∣2+2η2R2≤⋯≤kη2R2(2.13)
结合不等式(2.12)(2.12)(2.12)和(2.13)(2.13)(2.13)可得
kηγ≤w^k⋅w^optk\eta \gamma \leq \hat w_k \cdot \hat w_{opt}kηγ≤w^k⋅w^opt
由柯西不等式可知w^k⋅w^opt≤∣∣wk∣∣⋅∣∣w^opt∣∣=∣∣wk∣∣\hat w_k \cdot \hat w_{opt} \leq ||w_k|| \cdot ||\hat w_{opt} || = ||w_k||w^k⋅w^opt≤∣∣wk∣∣⋅∣∣w^opt∣∣=∣∣wk∣∣ 因为∣∣w^opt∣∣=1||\hat w_{opt}|| = 1∣∣w^opt∣∣=1
所以上式为kηγ≤∣∣w^k∣∣≤kηRk\eta \gamma \leq ||\hat w_k|| \leq \sqrt{k} \eta Rkηγ≤∣∣w^k∣∣≤kηR
kηγ≤kηRk≤Rγk≤(Rγ)2k \bcancel{\eta} \gamma \leq \sqrt{k} \bcancel{\eta} R \\ \sqrt{k} \leq \frac{R}{\gamma} \\ k \leq \left (\frac{R}{\gamma} \right)^2 kηγ≤kηRk≤γRk≤(γR)2
证明完毕。
定理表明,当训练数据集线性可分时,误分类的次数kkk是有上界的,经过有限次搜索可以找到将训练数据完全正确分开的分离超平面。
感知机学习算法的对偶形式
向量的内积就是向量对应元素乘积之和,内积的结果是个标量。
对偶形式的基本想法是,将www和bbb表示为实例xix_ixi和标记yiy_iyi的线性组合的形式,通过求解其系数而求得www和bbb。其实就是将当前的www用之前的累计更新来表示。
不失一般性,设初始值w0,b0w_0,b_0w0,b0均为000,对误分类点(xi,yi)(x_i,y_i)(xi,yi)通过
w←w+ηyixib←b+ηyiw \leftarrow w + \eta y_ix_i \\ b \leftarrow b + \eta y_i w←w+ηyixib←b+ηyi
如果某个误分类点(xi,yi)(x_i,y_i)(xi,yi)被多次选中,那么它对当前的参数就进行了nnn次更新,就会使w,bw,bw,b改变如下:
Δw=nηyixiΔb=nηyi\Delta w = n\eta y_ix_i \\ \Delta b = n\eta y_i Δw=nηyixiΔb=nηyi
对偶形式的参数更新就是更新某个误分类点参与更新的次数,就是更新这个nnn。更新过程中被选中的每个误分类点都可以如上表示,将它们全部加起来,就可以表示由初始的w0,b0w_0,b_0w0,b0到当前的w,bw,bw,b之间的总变化,则www的表示形式如下:
w=w0+n1ηy1x1+n2ηy2x2+⋯+niηyixiw = w_0 + n_1\eta y_1x_1 + n_2\eta y_2x_2 + \cdots + n_i\eta y_ix_i w=w0+n1ηy1x1+n2ηy2x2+⋯+niηyixi
这个式子包含了训练集所有的点,如果某个点未被选中,则n=0n=0n=0。我们把上式中的niηn_i\etaniη改写为αi\alpha_iαi,而上面说了,初始值w0,b0w_0,b_0w0,b0均为000,所以可以写成:
w=∑i=1Nαiyixiw = \sum_{i=1}^N \alpha_iy_ix_i w=i=1∑Nαiyixi
因此上面的感知机模型可以写成下面的式子:
f(x)=sign(∑j=1Nαjyjxj⋅x+b)f(x) = sign (\sum_{j=1}^N\alpha_jy_jx_j \cdot x + b) f(x)=sign(j=1∑Nαjyjxj⋅x+b)
对偶形式的算法如下:
- 设置初值α←0,b←0\alpha \leftarrow 0,b \leftarrow 0α←0,b←0;
- 在训练集中选取数据(xi,yi)(x_i,y_i)(xi,yi);
- 如果yi(∑j=1Nαjyjxj⋅xi+b)≤0y_i(\sum_{j=1}^N\alpha_jy_jx_j \cdot x_i + b) \leq 0yi(∑j=1Nαjyjxj⋅xi+b)≤0, 则αi←αi+η,b←b+ηyi\alpha _i \leftarrow \alpha_i + \eta,b \leftarrow b + \eta y_iαi←αi+η,b←b+ηyi;
- 跳转到2直到没有误分类数据
bbb的更新方式和原始形式一样,这没问题。我们重点来看下α\alphaα的更新方式。
α\alphaα表示的是niηn_i\etaniη,当我们用上面的方式进行更新时,其实就是使nin_ini增加了111,相当于
∑i=1Nαiyixi\sum_{i=1}^N \alpha_iy_ix_i∑i=1Nαiyixi 增加了ηyixi\eta y_ix_iηyixi,也就是:
∑i=1Nαiyixi+ηyixi\sum_{i=1}^N \alpha_iy_ix_i + \eta y_ix_i i=1∑Nαiyixi+ηyixi
再结合w=∑i=1Nαiyixiw = \sum_{i=1}^N \alpha_iy_ix_iw=∑i=1Nαiyixi
这里的更新就等价于w←w+ηyixiw \leftarrow w + \eta y_ix_iw←w+ηyixi
上面的文字如果看不懂的话,我们用例2.2来做个说明。
首先给出Gram矩阵的定义:
nnn维欧式空间中任意k(k≤n)k(k \leq n)k(k≤n)个向量α1,α2,⋯,αk\alpha_1,\alpha_2,\cdots,\alpha_kα1,α2,⋯,αk的内积所组成的矩阵
((α1,α1)(α1,α2)⋯(α1,αk)(α2,α1)(α2,α2)⋯(α2,αk)⋯⋯⋯⋯(αk,α1)(αk,α2)⋯(αk,αk))\left( \begin{matrix} (\alpha_1,\alpha_1)&(\alpha_1,\alpha_2)& \cdots (\alpha_1,\alpha_k)\\ (\alpha_2,\alpha_1)&(\alpha_2,\alpha_2)& \cdots (\alpha_2,\alpha_k)\\ \cdots & \cdots & \cdots \cdots \\ (\alpha_k,\alpha_1)&(\alpha_k,\alpha_2)& \cdots (\alpha_k,\alpha_k) \end{matrix} \right) ⎝⎜⎜⎛(α1,α1)(α2,α1)⋯(αk,α1)(α1,α2)(α2,α2)⋯(αk,α2)⋯(α1,αk)⋯(α2,αk)⋯⋯⋯(αk,αk)⎠⎟⎟⎞
就是这个kkk个向量的Gram矩阵。
以例2.2推导如下。从原始形式中我们可以知道。www的更新过程。
第一次更新是x1y1=((3,3)T,1)x_1y_1=((3,3)^T,1)x1y1=((3,3)T,1)点不能使yi(w⋅xi+b)y_i(w \cdot x_i + b)yi(w⋅xi+b) 大于零,所以w1=w0+x1y1w_1=w_0+x_1y_1w1=w0+x1y1
第二次更新是x3y3=((1,1)T,−1)x_3y_3=((1,1)^T,-1)x3y3=((1,1)T,−1)点, w2=w1+x3y3w_2=w_1+x_3y_3w2=w1+x3y3
第三次更新是x3y3=((1,1)T,−1)x_3y_3=((1,1)^T,-1)x3y3=((1,1)T,−1)点,所以 w3=w2+x3y3w_3=w_2+x_3y_3w3=w2+x3y3
第四次更新是x3y3=((1,1)T,−1)x_3y_3=((1,1)^T,-1)x3y3=((1,1)T,−1)点,所以 w4=w3+x3y3w_4=w_3+x_3y_3w4=w3+x3y3
第五次更新是x1y1=((3,3)T,1)x_1y_1=((3,3)^T,1)x1y1=((3,3)T,1)点,所以 w5=w4+x1y1w_5=w_4+x_1y_1w5=w4+x1y1
第六次更新是x3y3=((1,1)T,−1)x_3y_3=((1,1)^T,-1)x3y3=((1,1)T,−1)点,所以 w6=w5+x3y3w_6=w_5+x_3y_3w6=w5+x3y3
第七次更新是x3y3=((1,1)T,−1)x_3y_3=((1,1)^T,-1)x3y3=((1,1)T,−1)点,所以 w7=w6+x3y3w_7=w_6+x_3y_3w7=w6+x3y3
最终我们得到w7=(1,1)Tw_7=(1,1)^Tw7=(1,1)T,b7=−3b_7=-3b7=−3
w7⋅x+b7=x(1)+x(2)−3w_7 \cdot x + b_7 = x^{(1)} + x^{(2)} -3w7⋅x+b7=x(1)+x(2)−3
从上面可以总结得出
w7=w6+x3y3w_7=w_6+x_3y_3w7=w6+x3y3
w7=w5+x3y3+x3y3w_7=w_5+x_3y_3+x_3y_3w7=w5+x3y3+x3y3
w7=w4+x1y1+x3y3+x3y3w_7=w_4+x_1y_1+x_3y_3+x_3y_3w7=w4+x1y1+x3y3+x3y3
w7=w3+x3y3+x1y1+x3y3+x3y3w_7=w_3+x_3y_3+x_1y_1+x_3y_3+x_3y_3w7=w3+x3y3+x1y1+x3y3+x3y3
w7=w2+x3y3+x3y3+x1y1+x3y3+x3y3w_7=w_2+x_3y_3+x_3y_3+x_1y_1+x_3y_3+x_3y_3w7=w2+x3y3+x3y3+x1y1+x3y3+x3y3
w7=w1+x3y3+x3y3+x3y3+x1y1+x3y3+x3y3w_7=w_1+x_3y_3+x_3y_3+x_3y_3+x_1y_1+x_3y_3+x_3y_3w7=w1+x3y3+x3y3+x3y3+x1y1+x3y3+x3y3
w7=w0+x1y1+x3y3+x3y3+x3y3+x1y1+x3y3+x3y3w_7=w_0+x_1y_1+x_3y_3+x_3y_3+x_3y_3+x_1y_1+x_3y_3+x_3y_3w7=w0+x1y1+x3y3+x3y3+x3y3+x1y1+x3y3+x3y3
最终w7=2x1y1+5x3y3w_7 = 2x_1y_1 + 5x_3y_3w7=2x1y1+5x3y3
也就是对偶形式中的
w=∑i=1Nαiyixiw = \sum_{i=1}^N \alpha_iy_ix_i w=i=1∑Nαiyixi
同理可以得出bbb
b=∑i=1Nαiyib = \sum_{i=1}^N \alpha_iy_i b=i=1∑Nαiyi
例2.2的误分条件可以写成
yi(∑j=1Nαjyjxj⋅xi+b)=yi((α1y1x1+α2y2x2+α3y3x3)⋅xi+b)y_i \left( \sum_{j=1}^N \alpha_jy_jx_j \cdot x_i + b \right) = y_i\left((\alpha_1y_1x_1 + \alpha_2y_2x_2 + \alpha_3y_3x_3) \cdot x_i + b\right) yi(j=1∑Nαjyjxj⋅xi+b)=yi((α1y1x1+α2y2x2+α3y3x3)⋅xi+b)
因为这里N=3N=3N=3,有三个实例。
yi((α1y1x1+α2y2x2α3y3x3)⋅xi+b)=yi(α1y1x1⋅xi+α2y2x2⋅xi+α3y3x3⋅xi+b)y_i\left((\alpha_1y_1x_1 + \alpha_2y_2x_2 \alpha_3y_3x_3) \cdot x_i + b\right) = y_i\left(\alpha_1y_1x_1\cdot x_i + \alpha_2y_2x_2 \cdot x_i +\alpha_3y_3x_3\cdot x_i + b\right)yi((α1y1x1+α2y2x2α3y3x3)⋅xi+b)=yi(α1y1x1⋅xi+α2y2x2⋅xi+α3y3x3⋅xi+b)
因为该例中的y1=1,y2=1,y3=−1y_1 = 1,y_2 =1,y_3=-1y1=1,y2=1,y3=−1,所以可以写成
yi(α1(x1⋅xi)+α2(x2⋅xi)−α3(x3⋅xi)+b)y_i\left(\alpha_1(x_1\cdot x_i) + \alpha_2(x_2 \cdot x_i) - \alpha_3(x_3\cdot x_i) + b\right)yi(α1(x1⋅xi)+α2(x2⋅xi)−α3(x3⋅xi)+b)
公式中的x1⋅xi,x2⋅xi,x3⋅xix_1\cdot x_i,x_2 \cdot x_i,x_3\cdot x_ix1⋅xi,x2⋅xi,x3⋅xi即为向量的内积,将所有实例点的内积以矩阵的形式存储,该矩阵就被称为Gram矩阵。
下面编码实现例2.2的求解过程:
def compute_res(X,alpha,gram,b,i):res = 0for j in range(X.shape[0]):res += alpha[j] * y[j] * gram[j][i]return y[i] * (res + b) # b不在求和里面def fit(X,y,alpha,iters = 1000,eta=1,b=0):# 计算gram矩阵m = X.shape[0]gram = np.array([np.dot(xi,xj) for xi in X for xj in X]).reshape(m,-1)for _ in range(iters):updated = Falsefor i in range(m):res = compute_res(X,alpha,gram,b,i)if res <= 0:updated = Truealpha[i] = alpha[i] + etab = b + y[i]print('x_%d , alpha: %s , b:%d' %(i+1,alpha,b))if updated == False:breakw = np.zeros((1,X.shape[1]))b = 0for i in range(m):w = w + alpha[i] * y[i] * X[i]b = b + alpha[i] * y[i]return w,b
上面定义了两个简单的函数
X = np.array([[3,3],[4,3],[1,1]])
y = [1,1,-1]
m = X.shape[0]
alpha = [0] * m
fit(X,y,alpha)
最后以手写数字识别为例应用下上面写的原始感知机。
以手写数字识别为例
import pickle
import gzip# Third-party libraries
import numpy as npdef load_data():f = gzip.open('./datasets/mnist.pkl.gz', 'rb')training_data, validation_data, test_data = pickle.load(f, encoding='latin-1')f.close()return (training_data, validation_data, test_data)def load_data_wrapper():tr_d, va_d, te_d = load_data()X_train = np.array([np.reshape(x, (784, 1)) for x in tr_d[0]])y_train = np.array([two_labels(y) for y in tr_d[1]])#validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]#validation_data = zip(validation_inputs, va_d[1])X_test = np.array([np.reshape(x, (784, 1)) for x in te_d[0]])y_test = np.array([two_labels(y) for y in te_d[1]])return X_train, y_train, X_test, y_testdef two_labels(i):return 1 if i >= 5 else -1# one-hot vector
def vectorized_result(j):e = np.zeros((10, 1))e[j] = 1.0return e
上面是加载MNIST数据集的代码,将小于5个数字归为类别-1,不小于5的数字归为类别1,转换成二分类问题。
import numpy as np
import mnist_loaderX_train,y_train,X_test,y_test = mnist_loader.load_data_wrapper()
X_train = X_train.reshape(X_train.shape[0],-1)
y_train = y_train.reshape(y_train.shape[0],-1)p = Perceptron()
p.fit(X_train,y_train)
最后看下准确率有82%
参考
- 李航.统计学习方法第二版
- https://www.cnblogs.com/santian/p/4351756.html
《统计学习方法》——感知机相关推荐
- 统计学习方法|感知机原理剖析及实现
欢迎直接到我的博客查看最近文章:www.pkudodo.com.更新会比较快,评论回复我也能比较快看见,排版也会更好一点. 原始blog链接: http://www.pkudodo.com/2018/ ...
- 统计学习方法 --- 感知机模型原理及c++实现
参考博客 Liam Q博客 和李航的<统计学习方法> 感知机学习旨在求出将训练数据集进行线性划分的分类超平面,为此,导入了基于误分类的损失函数,然后利用梯度下降法对损失函数进行极小化,从而 ...
- 统计学习方法-感知机概括和补充
前言 <统计学习方法>第二版出了有段时间了,最近得空可以拜读一下.之前看第一版的时候还是一年多以前,那个时候看的懵懵懂懂的,很吃力.希望这一次能够有所收获,能够收获新的东西,这些文章只是用 ...
- 统计学习方法感知机(附简单模型代码)
1. 感知机模型 输入为实例的特征向量, 输出为实例的类别, 取+1和-1:感知机对应于输入空间中将实例划分为正负两类的分离超平面, 属于判别模型:导入基于误分类的损失函数:利用梯度下降法对损失函数进 ...
- 统计学习方法——感知机
1. 感知机原理 感知机是一个二类分类模型,输入为实例的特征向量,输出为实例的类别,取值为+1和-1. 定义1-1:数据的线性可分性 假设训练数据集是线性可分的,感知机学习的目标是求得一个能够将训练数 ...
- 李航统计学习方法----感知机章节学习笔记以及python代码
目录 1 感知机模型 2 感知机学习策略 2.1 数据集的线性可分性 2.2 感知机学习策略 3 感知机学习算法 3.1 感知机学习算法的原始形式 3.2 感知机算法的对偶形式 4 感知机算法pyth ...
- 复习02统计学习方法(感知机perceptron machine)---图片版
- 【李航统计学习方法】感知机模型
目录 一.感知机模型 二.感知机的学习策略 三.感知机学习算法 感知机算法的原始形式 感知机模型的对偶形式 参考文献 本章节根据统计学习方法,分为模型.策略.算法三个方面来介绍感知机模型. 首先介绍感 ...
- 统计学习方法02—感知机
目录 1. 简单了解感知机 2. 从简博士处学习整理的笔记(感知机) 2.1 模型介绍与学习策略 2.2 梯度下降算法 2.2.1 随机梯度下降代码 2.3 感知机的原始形式 2.4 感知机的对偶形 ...
- 手写实现李航《统计学习方法》书中全部算法
点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 来源:专知 [导读]Dod-o的手写实现李航<统计学习方法>书中全部算法, ...
最新文章
- 从入门到精通的Java进阶学习笔记整理,不愧是大佬
- 会签是什么意思_后宫为什么要争宠·六
- mysql 字段相同条数_用sql语句统计数据库某个字段中相同的数据有多少条?
- html5网页动画总结--jQuery旋转插件jqueryrotate
- 当我们扩张时——技术商业策略圆桌第一弹
- 前端利用CryptoJS进行AES对称加解密(16进制编码)
- 作者:鄂世嘉,男,同济大学博士生,CCF学生会员。
- 广州软件性能测试培训,Loadrunner企业级性能测试课程 广州八神软件性能测试实战教程 炼数性能测试视频...
- HTML中常见问题汇总贴
- Android开源项目推荐之「网络请求哪家强」
- 机器学习第二回总结——多变量线性回归
- C语言 - 判断素数
- 计算机桌面声音图标,win7桌面右下角的小喇叭音量图标不见了怎么办?
- 兄弟连 40 期 临行时刻
- 8G的U盘变成4M解决方法
- 2022年天津专升本报考专业对口限制目录,升本专业课如何备考~
- kotlinx.serialization处理Json解析
- 热门小程序拆盲盒3D特效版开发
- php 导出excel (html),php两种导出excel的方法
- Python matplotlib 中填充颜色