1、数据生成

利用随机数生成30个数据,包括坐标点和标签,并将生成的随机数据进行划分来确定标签

def create_random_data(w0,w1):samples = []lables = []for i in range(30):x0 = round(random.uniform(-1,1),2)x1 = round(random.uniform(-1,1),2)samples.append([x0,x1])if w0*x0 + w1*x1 > 3:lables.append(int(1))else:lables.append(int(-1))return np.array(samples),lables

2、PLA算法的python实现

# 新建一个感知机类
# shape()函数表示矩阵空间的维数
class Perceptron:def __init__(self,samples,lables):self.X = samplesself.Y = lablesself.W = np.zeros(2)self.a = 1 #learing rateself.b = 0self.train_count = 0self.max_train_count = 2000self.number_samples = self.X.shape[0]def sign(self,w,b,x):y = np.dot(x,w)+breturn int(y)def update(self,lable,data):tmp = self.a*lable*datatmp = tmp.reshape(self.W.shape)# update W and bself.W = tmp + self.Wself.b = self.b + self.a*lable# 训练def train(self):find_error_point = Falsetrain_count = 0while not find_error_point:count = 0train_count += 1if train_count > self.max_train_count:breakfor i in range(0,self.number_samples):temp = self.sign(self.W,self.b,self.X[i,:])if temp * self.Y[i] <=0 :print('误分类点为:',self.X[i,:],'此时的w和b为:',self.W,self.b)count += 1self.update(self.Y[i],self.X[i,:])if count == 0 and i == self.number_samples - 1:self.train_count = train_countprint("训练的次数为:",train_count)print('最终训练得到的w和b为:',self.W,self.b)find_error_point = Truereturn self.W,self.b# display

通过update函数实现更新权重W,train函数控制训练进程,找到一个错误点后,就利用错误点对权重W进行更新,在达到最大训练次数后退出训练。

3、Pocket算法的python实现

class perceptron_pocket():def __init__(self,sample,labels):self.X = sampleself.Y = labelsself.W = np.empty(2)    #init Wself.b = 0self.a = 1  #learning rateself.best_W = np.empty(2)self.best_b = 0self.train_count = 0self.max_count = 2000self.number_samples = self.X.shape[0]self.number_feature = self.X.shape[1]def sign(self,w,b,x):y = np.dot(x,w)+breturn int(y)def classify(self,w1,b1):mistake = []for i in range(self.number_samples):tempY = self.sign(w1,b1,self.X[i,:])if tempY * self.Y[i] <= 0:mistake.append(i)return mistakedef update(self,lable,data):tmp = self.a*lable*datatmp = tmp.reshape(self.W.shape)# update W and btemp_W = tmp + self.Wtemp_b = self.b + self.a*lableif len(self.classify(self.best_W,self.best_b))>=len(self.classify(temp_W,temp_b)):self.best_W = temp_Wself.best_b = temp_bself.W = temp_Wself.b = temp_bdef train(self):train_count = 0find_error_point = Falsewhile not find_error_point:mistake = self.classify(self.W,self.b)if len(mistake) == 0:self.train_count = train_countprint("训练的次数为:",train_count)print('最终训练得到的w和b为:',self.best_W,self.best_b)break# find a random mistaken = mistake[random.randint(0,len(mistake)-1)]# try to use (x,y) update Wself.update(self.Y[n],self.X[n,:])train_count += 1print('第',train_count,'次迭代误分类点为:', self.X[n, :], '此时的w和b为:', self.W, self.b)if train_count == self.max_count:print("训练的次数为:",train_count)print('最终训练得到的w和b为:',self.best_W,self.best_b)find_error_point = Truereturn self.best_W,self.best_b

classify函数判断该权重值W下分类的错误样本数量,update函数在样本分类错误数量更少的时候更新最优权重best_Wt+1,train函数存储所有的错误分类点,然后随机选取一个错误点更新当前权重Wt,然后进入update函数跟最优权重best_Wt比较,达到最大迭代次数后退出。

4、结果

左边是PLA算法分类结果图像,右边是Pocket算法分类结果图像

PLA算法最终训练得到的w为[ 12.9,-11.7],b为-3

Pocket算法最终训练得到的w为[ 349.11,-255.35],b为-73

还可以将数据维度增加到更大,例如:300

或者改变学习率(learning rate),例如:a = 0.1

PLA train_count =  545
Pocket train_count =  17354

可以看到当降低学习率的时候训练时间会变得更长,并且PLA算法的速度比Pocket算法更快。

完整代码:

# -*- encoding:utf-8 -*-import imp
from cProfile import label
from dis import dis
from turtle import color
from xmlrpc.client import TRANSPORT_ERROR
import numpy as np
import matplotlib.pyplot as plt
from numpy import *
import random#1、创建数据集
def createdata():samples0 = np.array([[0.10 ,-0.10],[0.30 , 0.60],[0.50 ,-0.20],[0.60 ,-0.25],[-0.10,-0.25],[-0.42,-0.30],[-0.50,-0.15],[-0.55,-0.12],[-0.70,-0.28],[-0.51, 0.22],[-0.48, 0.48],[-0.52, 0.47],[ 0.15, 0.63],[ 0.09, 0.81],[-0.68, 0.58]])labels0 = [1,1,1,1,1,1,-1,-1,-1,-1,-1,-1,-1,-1,-1]samples1 = np.array([[ 0.10,-0.10],[ 0.00, 0.75],[ 0.50,-0.20],[ 0.60,-0.25],[-0.10,-0.25],[-0.55, 0.30],[-0.50,-0.15],[-0.55,-0.12],[ 0.53,-0.28],[-0.51, 0.22],[-0.48, 0.48],[-0.52, 0.47],[ 0.15, 0.63],[ 0.09, 0.81],[-0.68, 0.58]])labels1 = [1,1,1,1,1,1,-1,-1,-1,-1,-1,-1,-1,-1,-1]return samples0,labels0def create_random_data(w0,w1):samples = []lables = []for i in range(30):x0 = round(random.uniform(-1,1),2)x1 = round(random.uniform(-1,1),2)samples.append([x0,x1])if w0*x0 + w1*x1 > 3:lables.append(int(1))else:lables.append(int(-1))return np.array(samples),lables# 新建一个感知机类
# shape()函数表示矩阵空间的维数
class Perceptron:def __init__(self,samples,lables):self.X = samplesself.Y = lablesself.W = np.zeros(2)self.a = 0.1 #learing rateself.b = 0self.train_count = 0self.max_train_count = 2000self.number_samples = self.X.shape[0]def sign(self,w,b,x):y = np.dot(x,w)+breturn int(y)def update(self,lable,data):tmp = self.a*lable*datatmp = tmp.reshape(self.W.shape)# update W and bself.W = tmp + self.Wself.b = self.b + self.a*lable# 训练def train(self):find_error_point = Falsetrain_count = 0while not find_error_point:count = 0train_count += 1if train_count > self.max_train_count:self.train_count = self.max_train_countbreakfor i in range(0,self.number_samples):temp = self.sign(self.W,self.b,self.X[i,:])if temp * self.Y[i] <=0 :print('误分类点为:',self.X[i,:],'此时的w和b为:',self.W,self.b)count += 1self.update(self.Y[i],self.X[i,:])if count == 0 and i == self.number_samples - 1:self.train_count = train_countprint("训练的次数为:",train_count)print('最终训练得到的w和b为:',self.W,self.b)find_error_point = Truereturn self.W,self.b# displayclass perceptron_pocket():def __init__(self,sample,labels):self.X = sampleself.Y = labelsself.W = np.empty(2)    #init Wself.b = 0self.a = 0.1  #learning rateself.best_W = np.empty(2)self.best_b = 0self.train_count = 0self.max_count = 20000self.number_samples = self.X.shape[0]self.number_feature = self.X.shape[1]def sign(self,w,b,x):y = np.dot(x,w)+breturn int(y)def classify(self,w1,b1):mistake = []for i in range(self.number_samples):tempY = self.sign(w1,b1,self.X[i,:])if tempY * self.Y[i] <= 0:mistake.append(i)return mistakedef update(self,lable,data):tmp = self.a*lable*datatmp = tmp.reshape(self.W.shape)# update W and btemp_W = tmp + self.Wtemp_b = self.b + self.a*lableif len(self.classify(self.best_W,self.best_b))>=len(self.classify(temp_W,temp_b)):self.best_W = temp_Wself.best_b = temp_bself.W = temp_Wself.b = temp_bdef train(self):train_count = 0find_error_point = Falsewhile not find_error_point:mistake = self.classify(self.W,self.b)if len(mistake) == 0:self.train_count = train_countprint("训练的次数为:",train_count)print('最终训练得到的w和b为:',self.best_W,self.best_b)break# find a random mistaken = mistake[random.randint(0,len(mistake)-1)]# try to use (x,y) update Wself.update(self.Y[n],self.X[n,:])train_count += 1print('第',train_count,'次迭代误分类点为:', self.X[n, :], '此时的w和b为:', self.W, self.b)if train_count == self.max_count:self.train_count = self.max_countprint("训练的次数为:",train_count)print('最终训练得到的w和b为:',self.best_W,self.best_b)find_error_point = Truereturn self.best_W,self.best_bdef display(dataset,W,b):#结果图len_approve = 0len_deny = 0for i in range(len(dataset)):if 15.1*dataset[i,:][0] - 12.67*dataset[i,:][1] > 4:len_approve += 1if len_approve == 1:plt.scatter(dataset[i,:][0],dataset[i,:][1],color = 'b',marker='o', label='Positive')else:plt.scatter(dataset[i,:][0],dataset[i,:][1],color = 'b',marker='o')else:len_deny +=1if len_deny == 1:plt.scatter(dataset[i,:][0],dataset[i,:][1],color = 'r',marker='x', label='Negative')else:plt.scatter(dataset[i,:][0],dataset[i,:][1],color = 'r',marker='x')# print("approve_points = ",approve_points)# print("deny_points = ",deny_points)# plt.scatter(approve_points[0], approve_points[1], color='blue', marker='o', label='Positive')# plt.scatter(deny_points[0], deny_points[1], color='red', marker='x', label='Negative')plt.xlabel('x')plt.ylabel('y')plt.legend(loc='upper left')plt.title('Scatter')plt.plot([-1,1],[(W[0]-b)/W[1],-1*(W[0]+b)/W[1]],'g')#画直线plt.show()passif __name__ == "__main__":print("start work!")# samples,lables = createdata()samples,lables = create_random_data(13,-11.3)print("samples = {},lables = {}".format(samples,lables))my_perceptron = Perceptron(samples=samples,lables=lables)my_perceptron.train()display(samples,my_perceptron.W,my_perceptron.b)my_pocket_perceptron = perceptron_pocket(sample=samples,labels=lables)my_pocket_perceptron.train()display(samples,my_pocket_perceptron.best_W,my_pocket_perceptron.best_b)print("PLA train_count = ",my_perceptron.train_count)print("Pocket train_count = ",my_pocket_perceptron.train_count)print("end work!")

参考:

https://www.jb51.net/article/131047.htm

用python实现PLA算法和Pocket算法相关推荐

  1. python使用Canny算法和HoughCiecle算法实现圆的检测与定位

    目录 一.实现原理 步骤1:使用Canny 算法提取图像边缘 高斯滤波 计算梯度 非极大值抑制 步骤2:在边缘图上利用Hough变换计算圆心与半径 二.具体代码 代码1:直接调用opencv库 代码2 ...

  2. 《OpenCv视觉之眼》Python图像处理十四 :Opencv图像轮廓提取之Scharr算法和Canny算法

    本专栏主要介绍如果通过OpenCv-Python进行图像处理,通过原理理解OpenCv-Python的函数处理原型,在具体情况中,针对不同的图像进行不同等级的.不同方法的处理,以达到对图像进行去噪.锐 ...

  3. 用Spark学习FP Tree算法和PrefixSpan算法

    在FP Tree算法原理总结和PrefixSpan算法原理总结中,我们对FP Tree和PrefixSpan这两种关联算法的原理做了总结,这里就从实践的角度介绍如何使用这两个算法.由于scikit-l ...

  4. 使用Apriori算法和FP-growth算法进行关联分析

    目录 1. 关联分析 2. Apriori原理 3. 使用Apriori算法来发现频繁集 4. 使用FP-growth算法来高效发现频繁项集 5. 示例:从新闻网站点击流中挖掘新闻报道 扩展阅读 系列 ...

  5. 数据结构与算法之美笔记——基础篇(下):图、字符串匹配算法(BF 算法和 RK 算法、BM 算法和 KMP 算法 、Trie 树和 AC 自动机)

    图 如何存储微博.微信等社交网络中的好友关系?图.实际上,涉及图的算法有很多,也非常复杂,比如图的搜索.最短路径.最小生成树.二分图等等.我们今天聚焦在图存储这一方面,后面会分好几节来依次讲解图相关的 ...

  6. dijkstra算法和A*算法

    转自: https://www.cnblogs.com/21207-iHome/p/6048969.html#undefined Dijkstra算法 迪杰斯特拉(Dijkstra)算法是典型的最短路 ...

  7. 关联规则挖掘算法: Aprior算法和Fpgrowth算法

      关联规则挖掘的目的是挖掘不同物品(item)之前的相关性,啤酒和尿布的故事就是一个典型的成功例子.关联规则挖掘思想简单高效,在广告推荐领域也有较多的应用,主要用于推荐模型落地前的流量探索以及构建规 ...

  8. BF算法和KMP算法

    给定两个字符串S和T,在主串S中查找子串T的过程称为串匹配(string matching,也称模式匹配),T称为模式.这里将介绍处理串匹配问题的两种算法,BF算法和KMP算法. BF算法 (暴力匹配 ...

  9. Algorithm:C++语言实现之字符串相关算法(字符串的循环左移、字符串的全排列、带有同个字符的全排列、串匹配问题的BF算法和KMP算法)

    Algorithm:C++语言实现之字符串相关算法(字符串的循环左移.字符串的全排列.带有同个字符的全排列.串匹配问题的BF算法和KMP算法) 目录 一.字符串的算法 1.字符串的循环左移 2.字符串 ...

最新文章

  1. 这本《Python+TensorFlow机器学习实战》给你送到家!
  2. php 显示数据库操作错误,php操作mysql数据库编码错误
  3. Java里的 for (;;) 与 while (true),哪个更快?
  4. 大数据预测实战-随机森林预测实战(三)-数据量对结果影响分析
  5. 云服务器网站301重定向跳转有什么作用?
  6. Gradle之全局配置
  7. OSPF的多域配置-要点总结
  8. [境内法规]中国人民银行关于分支行反洗钱工作的指导意见—银发[2005]56号
  9. 使用HTMLcss创建二级导航栏
  10. 武装突袭3多人服务器文件地图,武装突袭3地图文件夹 | 手游网游页游攻略大全...
  11. 鸡兔同笼——算法详解
  12. 时光里,我一个人的碎碎念。
  13. 使用yocs_velocity_smoother对机器人速度进行限制
  14. DANet Daul Attention位置和通道注意力(PAM&CAM)keras实现
  15. UWB高精度室内定位系统
  16. 惊呆了!我用 Python 可视化分析和预测了 2022 年 FIFA世界杯
  17. MJiOS底层笔记--OC对象本质
  18. 利用 Itchat 实现微信群发和关键词自动回复
  19. html居中小圆点点怎么打出来,目录的点怎么打 WORD目录里的点点怎么打
  20. su组件在什么窗口_草图大师Sketchup全窗口显示快捷键是什么呢?

热门文章

  1. Autodesk推出最新SketchBook Pro 7
  2. 鸿蒙系统win7,鸿蒙系统怎么样
  3. LRN 局部归一化处理
  4. python3.6.8安装LAC报错
  5. docker三种网络模式
  6. python中def的用法详解_Python3中def的用法
  7. windows系统“删库跑路“脚本bat
  8. 软件测试就只能挑Bug?绝对远远不止
  9. 忆——2017 International Genetically Engineered Machine Competition
  10. 对于iMazing电脑版需要输入许可证激活编号