《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)
前言
在前两篇博客里面,我们分别介绍了感知机的原始形式和感知机的对偶形式。在这篇博客里面,我们将用python3对上述两种感知机算法进行实现。
注意:本文参考了@akirameiao的博客内容。数据放在本文最后,直接复制进文本,保存为.txt格式,各位大佬自取。
- 导入第三方库。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
- 导入数据,将数据保存为 data。
# 载入数据
def load_data(file):# 指定数据类型data_types = {'data1': np.float32, 'data2': np.float32, 'data3': np.float32, 'label': np.int16}# 数据读取,注意,这里的sep值,一定要为三个空格' 'data = pd.read_csv(file, sep=' ', header=None, names=['data1', 'data2', 'label'], dtype=data_types)# w*x+b = (w, b)*(x, 1),所以我们将特征向量x增加一维,为(x, 1)data.insert(2, 'data3', 1)return data# 这里的文件路径替换成各位大佬自己文件所在的绝对路径
data = load_data('../input/ganzhiji.txt')
- 数据的可视化。
data_plot = data.groupby('label')
for name, group in data_plot:plt.scatter(data=group, x='data1', y='data2', label=name)
plt.legend()
plt.show()
8. 感知机原始形式
现将算法复述如下:
- 输入:数据集 T={(x1,y1),(x2,y2),...,(xN,yN)}T=\{(x_1, y_1), (x_2, y_2), ..., (x_N, y_N)\}T={(x1,y1),(x2,y2),...,(xN,yN)};学习步长 η\etaη
- 输出:www 和 bbb;感知机模型 f(x)=sign(w⋅x+b)f(x)=sign(w\cdot x+b)f(x)=sign(w⋅x+b)
(1) 给定初值 (w0,b0)=(0,0)(w_0, b_0)=(0, 0)(w0,b0)=(0,0)
(2) 遍历数据集 TTT,找到第一个误分类点 (xi,yi)(x_i, y_i)(xi,yi),满足 yi(w⋅xi+b)<0y_i(w\cdot x_i+b)<0yi(w⋅xi+b)<0
(3) 更新 www 和 bbb,w←w+ηyixiw\leftarrow w+\eta y_ix_iw←w+ηyixi b←b+ηyib\leftarrow b + \eta y_ib←b+ηyi
(4) 回到步骤 (2),如果找不到误分类点,则终止算法
根据上述算法,可以写出
# 训练感知机模型
def perception(data, w_b, eta=1, wrongPoints_num=[]):wrong_nums = 1while True:if not wrong_nums:break# 当前w和b下,计算yi(w*xi+b)=yi(w, b) * (xi, 1)的值# 首先计算(w, b) * (xi, 1)data['wrong_point'] = data[['data1', 'data2', 'data3']].dot(w_b)# 再依次乘以yi,并保存进datadata['wrong_point'] = data['label'].mul(data['wrong_point'])# 所有的误分类点temp = data[data['wrong_point']<=0]# 计算yi(w*xi+b)<=0,也就是误分类点的数量wrong_nums = temp['wrong_point'].count()wrongPoints_num.append(wrong_nums)# 找出第一个误分类点,并更新w, bif wrong_nums:# 计算 eta*yi*xichange = eta * temp['label'].iloc[0] * temp.iloc[0, 0:3].valuesw_b = w_b + change#print('更新后的w和b为', w_b[:2], w_b[2])#print(w_b)return w_b[:2], w_b[2], wrongPoints_num
给定初始值,并运行程序
# 初始值
w_b = np.array([0, 0, 0])
eta = 0.1
wrongPoints_num = [] # 记录每次迭代的误分类点个数# 运行程序
w, b, wrongPoints_num = perception(data, w_b, eta, wrongPoints_num)
print('最终权重w为', w)
print('最终偏置b为', b)
可以作图看下结果
# 可视化
data_plot = data.groupby('label')
for name, group in data_plot:plt.scatter(data=group, x='data1', y='data2', label=name)# 直线
x = np.arange(4, 5.5, 0.1)
y = - w[0] / w[1] * x - b / w[1]
plt.plot(x, y)plt.legend()
plt.show()
算法迭代过程中,误分类点数量的变化曲线
plt.plot(np.arange(len(wrongPoints_num)), wrongPoints_num)
plt.xlabel('num of recursions')
plt.ylabel('num of wrong points')
plt.show()
9. 感知机对偶形式
现将算法复述如下:
- 输入:数据集 T={(x1,y1),(x2,y2),...,(xN,yN)}T=\{(x_1, y_1), (x_2, y_2), ..., (x_N, y_N)\}T={(x1,y1),(x2,y2),...,(xN,yN)};学习步长 η\etaη
- 输出:(n1,n2,...,nN)(n_1, n_2, ..., n_N)(n1,n2,...,nN);感知机模型 f(x)=sign(w⋅x+b)f(x)=sign(w\cdot x+b)f(x)=sign(w⋅x+b),其中,w=∑i=1Nniηyixiw=\sum_{i=1}^Nn_i\eta y_ix_iw=i=1∑Nniηyixi b=∑i=1Nniηyib=\sum_{i=1}^Nn_i\eta y_ib=i=1∑Nniηyi
(1) 给定初始值 (n1,n2,...,nN)=(0,0,...,0)(n_1, n_2, ..., n_N)=(0, 0, ..., 0)(n1,n2,...,nN)=(0,0,...,0)
(2) 遍历数据集 TTT,找出第一个误分类点 (xi,yi)(x_i, y_i)(xi,yi),满足
yi(∑j=1Nnjηyjxj⋅xi+∑j=1Nnjηyj)=yi∑j=1Nnjηyj(xj⋅xi+1)<0\begin{array}{lll} &&y_i(\sum_{j=1}^Nn_j\eta y_jx_j\cdot x_i+\sum_{j=1}^Nn_j\eta y_j)\\ &=& y_i\sum_{j=1}^Nn_j\eta y_j(x_j\cdot x_i+1)\\ &<&0 \end{array} =<yi(∑j=1Nnjηyjxj⋅xi+∑j=1Nnjηyj)yi∑j=1Nnjηyj(xj⋅xi+1)0
(3) 更新 nin_ini,ni←ni+1n_i\leftarrow n_i+1ni←ni+1
(4) 返回步骤(2),如果没有误分类点,则终止算法
由于在判断误分类点的时候,我们仅需要 xj⋅xix_j\cdot x_ixj⋅xi 的值,所以,我们可以提前计算内积,也就是提前算出Gram矩阵
G=[xi⋅xj]N×N\mathbf{G}=\left[x_i\cdot x_j \right]_{N\times N}G=[xi⋅xj]N×N
# 计算Gram矩阵
# 实际上,我们需要计算的是 [xi * xj + 1]
# 预处理 Gram矩阵
G = data.loc[:, ['data1', 'data2', 'data3']].values.dot(data.loc[:, ['data1', 'data2', 'data3']].values.T)# 再计算 向量[yj] 与 Gram矩阵的第i行[xi * xj + 1] 按照元素做乘法
G_hat = G * data['label'].values
下面,我们可以写出如下程序
def perception_dual(data, eta, G_hat, n, wrongPoints_num):wrong_num = 1while True:if not wrong_num:break# 遍历数据集,找到误分类点temp = eta * pd.DataFrame(G_hat * n).apply(sum, axis=1)data['wrong_points'] = data['label'].mul(temp)# 所有的误分类点wrong = data[data['wrong_points']<=0]# 误分类点个数wrong_num = wrong['wrong_points'].count()wrongPoints_num.append(wrong_num)# 找出第一个误分类点(xi, yi),更新 n_iif wrong_num:first_index = list(wrong.index)[0]n[first_index] += 1#print('更新第', first_index, '个数据点')#print('该数据点n_i=', n[first_index])return n, wrongPoints_num
给初值,运行程序
# 给初值
n = np.zeros(len(data))
eta = 1
wrongPoints_num = []# 运行程序
n, wrongPoints_num = perception_dual(data, eta, G_hat, n, wrongPoints_num)
# 根据 n_i,计算w和b
def w_b(data, eta, n):w = eta * data.loc[:, ['data1', 'data2']].mul(data['label'], axis=0).mul(n, axis=0).apply(sum, axis=0)b = eta * sum(data['label'] * n)return w.values, bw, b = w_b(data, eta, n)
作图,看看结果对不对
# 可视化
data_plot = data.groupby('label')
for name, group in data_plot:plt.scatter(data=group, x='data1', y='data2', label=name)# 直线
x = np.arange(4.4, 5.5, 0.1)
y = - w[0] / w[1] * x - b / w[1]
plt.plot(x, y)plt.legend()
plt.show()
再看一下每次迭代后的误分类点情况
plt.plot(np.arange(len(wrongPoints_num)), wrongPoints_num)
plt.xlabel('num of recursions')
plt.ylabel('num of wrong points')
plt.show()
至此,我们将感知机的原始形式、对偶形式的数学推导以及python3实现全部完成。
下一篇博客中,我们将继续介绍 k近邻方法。
数据:100个数据,直接复制保存为.txt文件
3.542485 1.977398 -1
3.018896 2.556416 -1
7.551510 -1.580030 1
2.114999 -0.004466 -1
8.127113 1.274372 1
7.108772 -0.986906 1
8.610639 2.046708 1
2.326297 0.265213 -1
3.634009 1.730537 -1
0.341367 -0.894998 -1
3.125951 0.293251 -1
2.123252 -0.783563 -1
0.887835 -2.797792 -1
7.139979 -2.329896 1
1.696414 -1.212496 -1
8.117032 0.623493 1
8.497162 -0.266649 1
4.658191 3.507396 -1
8.197181 1.545132 1
1.208047 0.213100 -1
1.928486 -0.321870 -1
2.175808 -0.014527 -1
7.886608 0.461755 1
3.223038 -0.552392 -1
3.628502 2.190585 -1
7.407860 -0.121961 1
7.286357 0.251077 1
2.301095 -0.533988 -1
-0.232542 -0.547690 -1
3.457096 -0.082216 -1
3.023938 -0.057392 -1
8.015003 0.885325 1
8.991748 0.923154 1
7.916831 -1.781735 1
7.616862 -0.217958 1
2.450939 0.744967 -1
7.270337 -2.507834 1
1.749721 -0.961902 -1
1.803111 -0.176349 -1
8.804461 3.044301 1
1.231257 -0.568573 -1
2.074915 1.410550 -1
-0.743036 -1.736103 -1
3.536555 3.964960 -1
8.410143 0.025606 1
7.382988 -0.478764 1
6.960661 -0.245353 1
8.234460 0.701868 1
8.168618 -0.903835 1
1.534187 -0.622492 -1
9.229518 2.066088 1
7.886242 0.191813 1
2.893743 -1.643468 -1
1.870457 -1.040420 -1
5.286862 -2.358286 1
6.080573 0.418886 1
2.544314 1.714165 -1
6.016004 -3.753712 1
0.926310 -0.564359 -1
0.870296 -0.109952 -1
2.369345 1.375695 -1
1.363782 -0.254082 -1
7.279460 -0.189572 1
1.896005 0.515080 -1
8.102154 -0.603875 1
2.529893 0.662657 -1
1.963874 -0.365233 -1
8.132048 0.785914 1
8.245938 0.372366 1
6.543888 0.433164 1
-0.236713 -5.766721 -1
8.112593 0.295839 1
9.803425 1.495167 1
1.497407 -0.552916 -1
1.336267 -1.632889 -1
9.205805 -0.586480 1
1.966279 -1.840439 -1
8.398012 1.584918 1
7.239953 -1.764292 1
7.556201 0.241185 1
9.015509 0.345019 1
8.266085 -0.230977 1
8.545620 2.788799 1
9.295969 1.346332 1
2.404234 0.570278 -1
2.037772 0.021919 -1
1.727631 -0.453143 -1
1.979395 -0.050773 -1
8.092288 -1.372433 1
1.667645 0.239204 -1
9.854303 1.365116 1
7.921057 -1.327587 1
8.500757 1.492372 1
1.339746 -0.291183 -1
3.107511 0.758367 -1
2.609525 0.902979 -1
3.263585 1.367898 -1
2.912122 -0.202359 -1
1.731786 0.589096 -1
2.387003 1.573131 -1
《统计学习方法》—— 感知机原始形式、感知机对偶形式的python3代码实现(三)相关推荐
- 《统计学习方法》—— 感知机对偶算法、推导以及python3代码实现(二)
前言 在前一篇博客 <统计学习方法>-- 感知机原理.推导以及python3代码实现(一) 里面,我们介绍了感知机原始形式以及具体推导.在这篇博客里面,我们将继续介绍感知机对偶形式以及py ...
- 统计学习方法笔记第二章-感知机
统计学习方法笔记第二章-感知机 2.1 感知机模型 2.2感知机学习策略 2.2.1数据集的线性可分型 2.2.2感知机学习策略 2.3感知机学习算法 2.3.1感知机算法的原始形式 2.3.2算法的 ...
- 李航 统计学习方法 第2章 感知机
第2章 感知机 介绍感知机模型, 叙述感知机的学习策略, 特别是损失函数; 最后介绍感知机学习算法,包括原始形式和对偶形式, 证明算法的收敛性. 感知机模型 f ( x ) = s i g n ( w ...
- 《统计学习方法》读书笔记——感知机(原理+代码实现)
传送门 <统计学习方法>读书笔记--机器学习常用评价指标 <统计学习方法>读书笔记--感知机(原理+代码实现) <统计学习方法>读书笔记--K近邻法(原理+代码实现 ...
- 统计学习方法第二章作业:感知机模型原始形式与对偶形式代码实现
原始形式实现 import numpy as np import matplotlib.pyplot as pltclass Perceptron_orginal:def __init__(self, ...
- 【李航统计学习方法】感知机模型
目录 一.感知机模型 二.感知机的学习策略 三.感知机学习算法 感知机算法的原始形式 感知机模型的对偶形式 参考文献 本章节根据统计学习方法,分为模型.策略.算法三个方面来介绍感知机模型. 首先介绍感 ...
- 复现经典:《统计学习方法》第 2 章 感知机
本文是李航老师的<统计学习方法>[1]一书的代码复现. 作者:黄海广[2] 备注:代码都可以在github[3]中下载. 我将陆续将代码发布在公众号"机器学习初学者", ...
- 机器学习理论《统计学习方法》学习笔记:第二章 感知机
<统计学习方法>学习笔记:第二章 感知机 2 感知机 2.1 感知机模型 2.2 感知机学习策略 2.2.1 数据的线性可分性 2.2.2 感知机学习策略 2.3 感知机学习算法 2.3. ...
- 统计学习方法 --- 感知机模型原理及c++实现
参考博客 Liam Q博客 和李航的<统计学习方法> 感知机学习旨在求出将训练数据集进行线性划分的分类超平面,为此,导入了基于误分类的损失函数,然后利用梯度下降法对损失函数进行极小化,从而 ...
最新文章
- android 去除启动广告_APP启动页广告去除
- Shell脚本的调试技术
- JDK各版本内容和新特性
- JavaScript Bitwise NOT Operator
- python中xrange和range的区别
- linux 如何在命令行下改系统时间
- Android之解决viewpage加载第3个fragment的时候,第一个fragment又重新构建问题
- REVERSE-PRACTICE-BUUCTF-25
- jenkins 插件目录_10 个 Jenkins 实战经验,助你轻松上手持续集成
- 使用 Python 实现鼠标键盘自动化
- mysql 函数修改无效_MySQL:无效使用组函数
- 控制反转_.NET Core ASP.NET Core Basic 12 控制反转与依赖注入
- 作业三——求左部分中的最大值减去右部分最大值的绝对值,最大是多少...
- pythonturtle魔法阵_python turtle 库绘制简单魔法阵
- mips中的li_MIPS学习笔记(一)
- postman批量调用接口操作步骤
- CALL入门篇一:CALL的本质
- PCA主成分分析法浅理解
- Sort sort =new Sort(Sort.Direction.ASC,“id“)
- 神奇的口袋【北京大学】
热门文章
- xdf文件改word_真正Txt 文本文件和Doc Word文件批量互转工具
- PyTorch 学习笔记(七):PyTorch的十个优化器
- Python之 while循环
- PAT 1012 数字分类 (20 分)(C语言)
- ubuntu 16.04 R 安装,卸载以及Rsudio
- 获取北京时间授时api stm32 esp8266获取北京时间、年月日、星期api GMT格林威时间转换北京时间
- maven的setting文件简单配置
- cookie、session和token
- 《ggplot2:数据分析与图形艺术》,读书笔记
- Linux命令格式及帮助命令详解