Fisher线性判别

  • 1.Fisher线性判别步骤
  • 2.Fisher判别实现代码
  • 3.Fisher分类器

1.Fisher线性判别步骤

Fisher线性判别分析的基本思想:选择一个投影方向(线性变换,线性组合),将高维问题降低到一维问题来解决,同时变换后的一维数据满足每一类内部的样本尽可能聚集在一起,不同类的样本相隔尽可能地远。
Fisher线性判别分析,就是通过给定的训练数据,确定投影方向W和阈值w0, 即确定线性判别函数,然后根据这个线性判别函数,对测试数据进行测试,得到测试数据的类别。
Fisher判别分析是要实现有最大的类间距离,以及最小的类内距离。
线性判别函数的一般形式可表示成
g(X)=WTX+w0g(X)=W^TX+w_{0} g(X)=WTX+w0​

Fisher选择投影方向W的原则,即使原样本向量在该方向上的投影能兼顾类间分布尽可能分开,类内样本投影尽可能密集的要求。
(1)W的确定

各类样本均值向量mi

样本类内离散度矩阵和总类内离散度矩阵Sw


样本类间离散度矩阵Sb

在投影后的一维空间中,各类样本均值

样本类内离散度和总类内离散度

样本类间离散度

Fisher准则函数为max

(2)阈值的确定
是个常数,称为阈值权,对于两类问题的线性分类器可以采用下属决策规则:

如果g(x)>0,则决策x属于W1;如果g(x)<0,则决策x属于W2;如果g(x)=0,则可将x任意分到某一类,或拒绝。

(3)Fisher线性判别的决策规则
①投影后,各类样本内部尽可能密集,即总类内离散度越小越好。
②投影后,各类样本尽可能离得远,即样本类间离散度越大越好。
根据这两个性质,可求出

这就是Fisher判别准则下的最优投影方向。最后得到决策规则,如果



对于某一个未知类别的样本向量x,如果y=WT·x>y0,则x∈w1;否则x∈w2。
(4)“群内离散度”与“群间离散度”
“群内离散度”要求的是距离越远越好;而“群间离散度”的距离越近越好
由上可知:“群内离散度”(样本类内离散矩阵)的计算公式为

因为每一个样本有多维数据,因此需要将每一维数据代入公式计算后最后在求和即可得到样本类内离散矩阵。存在多个样本,重复该计算公式即可算出每一个样本的类内离散矩阵
“群间离散度”(总体类离散度矩阵)的计算公式为

例如鸢尾花数据集,会将其分为三个样本,因此就会得到三个总体类离散度矩阵,三个总体类离散度矩阵根据上述公式计算即可。
例题

2.Fisher判别实现代码

#导入库
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
#构建训练集
path=r'iris.data'
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]
#构建样本类内离散度矩阵
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):a=Iris1[i,:]-m1a=np.array([a])b=a.Ts1=s1+np.dot(b,a)
for i in range(0,30,1):c=Iris2[i,:]-m2c=np.array([c])d=c.Ts2=s2+np.dot(d,c)
for i in range(0,30,1):a=Iris3[i,:]-m3a=np.array([a])b=a.Ts3=s3+np.dot(b,a)
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判别函数以及T
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
#通过判别函数进行判别,求解正确率
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):x=Iris1[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)kind1=kind1+1elif g12<0 and g23>0:newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)
for i in range(30,49):x=Iris2[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:newiris2.extend(x)kind2=kind2+1elif g13<0 and g23<0 :newiris3.extend(x)
for i in range(30,49):x=Iris3[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:     newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('判断出来的综合正确率:',correct*100,'%')


学习下来我觉得Fisher判别有点繁琐,但是Fisher判别是最基础的线性判别方法。还有其他的线性判别方法,例如贝叶斯、BP神经网络、K-means、决策树等线性判别,这几种方法都较简单些。

3.Fisher分类器

from sklearn import model_selection
from sklearn import datasets
from sklearn import discriminant_analysis
#用莺尾花数据集
def load_data():iris=datasets.load_iris()return model_selection.train_test_split(iris.data,iris.target,test_size=0.25,random_state=0,stratify=iris.target)  #返回为: 一个元组,依次为:训练样本集、测试样本集、训练样本的标记、测试样本的标记
def test_LinearDiscriminantAnalysis(*data):x_train,x_test,y_train,y_test=datalda=discriminant_analysis.LinearDiscriminantAnalysis()lda.fit(x_train,y_train)print('Coefficients:%s, intercept %s'%(lda.coef_,lda.intercept_))#输出权重向量和bprint('Score: %.2f' % lda.score(x_test, y_test))#测试集print('Score: %.2f' % lda.score(x_train, y_train))#训练集
x_train,x_test,y_train,y_test=load_data()
test_LinearDiscriminantAnalysis(x_train,x_test,y_train,y_test)

结果:

在测试集上预测准确率为100%,而在训练集上预测准确率为97%,所以说即使训练过后,误差也还是存在的。
监督降维技术
该数据集是原始的数据集经过Fisher的投影

from sklearn import model_selection
from sklearn import datasets
from sklearn import discriminant_analysis
#用莺尾花数据集
def load_data():iris=datasets.load_iris()return model_selection.train_test_split(iris.data,iris.target,test_size=0.25,random_state=0,stratify=iris.target)  #返回为: 一个元组,依次为:训练样本集、测试样本集、训练样本的标记、测试样本的标记
def plot_LDA(converted_X,y):'''绘制经过 LDA 转换后的数据:param converted_X: 经过 LDA转换后的样本集:param y: 样本集的标记:return:  None'''from mpl_toolkits.mplot3d import Axes3Dimport matplotlib.pyplot as plt fig=plt.figure()ax=Axes3D(fig)colors='rgb'markers='o*s'for target,color,marker in zip([0,1,2],colors,markers):pos=(y==target).ravel()X=converted_X[pos,:]ax.scatter(X[:,0], X[:,1], X[:,2],color=color,marker=marker,label="Label %d"%target)ax.legend(loc="best")fig.suptitle("Iris After LDA")plt.show()
import numpy as np
x_train,x_test,y_train,y_test=load_data()
X=np.vstack((x_train,x_test))#沿着竖直方向将矩阵堆叠起来,把训练与测试的数据放一起来看
Y=np.vstack((y_train.reshape(y_train.size,1),y_test.reshape(y_test.size,1)))#沿着竖直方向将矩阵堆叠起来
lda = discriminant_analysis.LinearDiscriminantAnalysis()
lda.fit(X, Y)
converted_X=np.dot(X,np.transpose(lda.coef_))+lda.intercept_
plot_LDA(converted_X,Y)

结果:

从结果图确实可以看出,Fisher可以实现降维。

线性分类器——Fisher线性判别相关推荐

  1. 【机器学习】Fisher线性判别与线性感知机

    来源 | AI小白入门 作者 | 文杰 编辑 | yuquanle 原文链接 Fisher线性判别与线性感知机 ​ Fisher线性判别和线性感知机都是针对分类任务,尤其是二分类,二者的共同之处在于都 ...

  2. 高效计算基础与线性分类器

    高效计算基础与线性分类器 标签: 深度学习线性分类器 2016-06-29 16:01 131人阅读 评论(0) 收藏 举报 本文章已收录于: 分类: 深度学习笔记(1) 作者同类文章X 版权声明:本 ...

  3. 【火炉炼AI】机器学习008-简单线性分类器解决二分类问题

    [火炉炼AI]机器学习008-简单线性分类器解决二分类问题 (本文所使用的Python库和版本号: Python 3.5, Numpy 1.14, scikit-learn 0.19, matplot ...

  4. 【计算机视觉与深度学习】线性分类器(一)

    目录 从线性分类器开始 线性分类器的定义 线性分类器的决策步骤 线性分类器的矩阵表示 线性分类器的wiT\bm w_i^TwiT​如何理解 线性分类器的决策边界 线性分类器的损失函数 损失函数的定义 ...

  5. 计算机视觉与深度学习第三章:线性分类器

    计算机视觉与深度学习 本文按照北京邮电大学计算机学院鲁鹏老师的计算机视觉与深度学习课程按章节进行整理,需要的同学可借此系统学习该课程详尽知识~ 第三章 线性分类器 计算机视觉与深度学习 本节重点 一. ...

  6. 论文阅读笔记:为什么深度神经网络的训练无论多少次迭代永远有效?可能类内分布已经坍缩为一个点,模型已经崩溃为线性分类器

    论文阅读笔记:Prevalence of neural collapse during the terminalphase of deep learning training,深度学习训练末期普遍的神 ...

  7. cs231n学习记录-理解线性分类器

    目录 前言: 内容简介: 一.NN分类器的缺点 二.什么是线性分类器 三.线性分类器的原理 四.从图像到标签分值的参数化映射 五:解释这个"b" 六.将线性分类器看做模板匹配 七. ...

  8. 基于fisher线性判别法的分类器设计

    0.引言说明 这篇文章实际上是楼主上的模式识别课程的课堂报告,楼主偷懒把东西直接贴出来了.选择fisher判别法的原因主要是想学习一下这个方法,这个方法属于线性判别法,操作起来和lda判别法近乎没啥区 ...

  9. Fisher 线性分类器--转

    原文地址:http://blog.csdn.net/htyang725/article/details/6571550 Fisher 线性分类器由R.A.Fisher在1936年提出,至今都有很大的研 ...

  10. matlab fisher检验,FISHER线性判别MATLAB实现.doc

    FISHER线性判别MATLAB实现 Fisher线性判别上机实验报告 班级: 学号: 姓名: 算法描述 Fisher线性判别分析的基本思想:选择一个投影方向(线性变换,线性组合),将高维问题降低到一 ...

最新文章

  1. 基数排序python实现
  2. 一篇文章搞定GVIM(根据工作经验持续更新)
  3. Sed教程(三):模式缓冲区、模式范围
  4. AssetBundle
  5. (14)FPGA触发器与寄存器区别
  6. Repeater使用:绑定时 结合 前台JS及后台共享方法
  7. python 生成器迭代器
  8. 虹软java接摄像头_java基于虹软sdk实现人脸识别(demo)
  9. 冒泡排序最佳情况的时间复杂度,为什么是O(n)
  10. UZCMS镜像站群旗舰版镜像程序SEO源码
  11. 数据分析越来越火,如何做一个靠谱的职业规划?
  12. linux连接库参数-l,gcc编译时,什么时候需要用-l参数指明连接库?
  13. Python 数组高级索引
  14. as_completed函数用例
  15. wds(无线分布式系统)
  16. 二进制转8421bcd码_码制 || BCD码 || 格雷码 || 奇偶校验码 || 字母数字码 || 数电
  17. python delta_Python 函数
  18. 电子竞技作为一项全新的竞技体育项目,近年来发展迅猛,未来发展趋势
  19. Android平台开发指导(Android Porting Guide)
  20. Excel文件有密码可以取消掉嘛?

热门文章

  1. 计算机控制系统与常规仪表控制系统的主要异同点,计算机控制技术复习资料-20210711112641.doc-原创力文档...
  2. 本人亲测,实用安装Oracle VM VirtualBox教程
  3. Python Selenium IE 上传文件和 处理网页对话框showModalDailog模态对话框
  4. ubuntu20.04离线安装rabbitvcs
  5. 三因子两水平doe_温故而知新 | DOE实验设计学习系列之(三):多因子DOE的魅力 (附视频)...
  6. FFmpeg压缩音频和添加字幕的命令
  7. 阿里代码检查p3c插件使用
  8. HTML5前端开发实战04-儿童摄影
  9. 智慧工厂3D物联网可视化建模管理系统
  10. arm架构安装rxtx_树莓派JAVA开发串口(RXTX)和GPIO(pi4j)