文章目录

  • 一.算法描述
    • (1)W的确定
    • (2)阈值的确定
    • (3)Fisher线性判别的决策规则
  • 二.数据描述
    • 1.iris数据
    • 2.sonar数据
  • 三.鸢尾花数据集例子
  • 四.python代码推导
    • 1.数据生成
    • 2.fisher算法实现
    • 3.判定类别
    • 4.绘图
  • 五.理解和心得
  • 六.参考链接

一.算法描述

Fisher线性判别分析的基本思想:选择一个投影方向(线性变换,线性组合),将高维问题降低到一维问题来解决,同时变换后的一维数据满足每一类内部的样本尽可能聚集在一起,不同类的样本相隔尽可能地远。
Fisher线性判别分析,就是通过给定的训练数据,确定投影方向W和阈值w0, 即确定线性判别函数,然后根据这个线性判别函数,对测试数据进行测试,得到测试数据的类别。
线性判别函数的一般形式可表示成 :

其中

Fisher选择投影方向W的原则,即使原样本向量在该方向上的投影能兼顾类间分布尽可能分开,类内样本投影尽可能密集的要求。 如下为具体步骤:

(1)W的确定

各类样本均值向量mi

样本类内离散度矩阵 和总类内离散度矩阵

样本类间离散度矩阵

在投影后的一维空间中,各类样本均值

样本类内离散度和总类内离散度

样本类间离散度

Fisher准则函数为

(2)阈值的确定

是个常数,称为阈值权,对于两类问题的线性分类器可以采用下属决策规则:
令 则:

如果g(x)>0,则决策x属于w1 ;如果g(x)<0,则x属于w2 ;如果g(x)=0,则可将x任意分到某一类,或拒绝。

(3)Fisher线性判别的决策规则

Fisher准则函数满足两个性质:
1.投影后,各类样本内部尽可能密集,即总类内离散度越小越好。
2.投影后,各类样本尽可能离得远,即样本类间离散度越大越好。
根据这个性质确定准则函数,根据使准则函数取得最大值,可求出

这就是Fisher判别准则下的最优投影方向。
最后得到决策规则



对于某一个未知类别的样本向量x,如果y=WT·x>y0,则x∈w1;否则x∈w2。

二.数据描述

1.iris数据

IRIS数据集以鸢尾花的特征作为数据来源,数据集包含150个数据集,有4维,分为3 类,每类50个数据,每个数据包含4个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

2.sonar数据

Sonar数据集包含208个数据集,有60维,分为2类,第一类为98个数据,第二类为110个数据,每个数据包含60个属性,是在数据挖掘、数据分类中非常常用的测试集、训练集。

“群内离散度”要求的是距离越远越好;而“群间离散度”的距离越近越好

三.鸢尾花数据集例子

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
path=r'Iris.csv'
df = pd.read_csv(path, header=0)
Iris1=df.values[0:50,0:4]
Iris2=df.values[50:100,0:4]
Iris3=df.values[100:150,0:4]
m1=np.mean(Iris1,axis=0)
m2=np.mean(Iris2,axis=0)
m3=np.mean(Iris3,axis=0)
s1=np.zeros((4,4))
s2=np.zeros((4,4))
s3=np.zeros((4,4))
for i in range(0,30,1):a=Iris1[i,:]-m1a=np.array([a])b=a.Ts1=s1+np.dot(b,a)
for i in range(0,30,1):c=Iris2[i,:]-m2c=np.array([c])d=c.Ts2=s2+np.dot(d,c) #s2=s2+np.dot((Iris2[i,:]-m2).T,(Iris2[i,:]-m2))
for i in range(0,30,1):a=Iris3[i,:]-m3a=np.array([a])b=a.Ts3=s3+np.dot(b,a)
sw12=s1+s2
sw13=s1+s3
sw23=s2+s3
#投影方向
a=np.array([m1-m2])
sw12=np.array(sw12,dtype='float')
sw13=np.array(sw13,dtype='float')
sw23=np.array(sw23,dtype='float')
#判别函数以及T
#需要先将m1-m2转化成矩阵才能进行求其转置矩阵
a=m1-m2
a=np.array([a])
a=a.T
b=m1-m3
b=np.array([b])
b=b.T
c=m2-m3
c=np.array([c])
c=c.T
w12=(np.dot(np.linalg.inv(sw12),a)).T
w13=(np.dot(np.linalg.inv(sw13),b)).T
w23=(np.dot(np.linalg.inv(sw23),c)).T
#print(m1+m2) #1x4维度  invsw12 4x4维度  m1-m2 4x1维度
T12=-0.5*(np.dot(np.dot((m1+m2),np.linalg.inv(sw12)),a))
T13=-0.5*(np.dot(np.dot((m1+m3),np.linalg.inv(sw13)),b))
T23=-0.5*(np.dot(np.dot((m2+m3),np.linalg.inv(sw23)),c))
kind1=0
kind2=0
kind3=0
newiris1=[]
newiris2=[]
newiris3=[]
for i in range(30,49):x=Iris1[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)kind1=kind1+1elif g12<0 and g23>0:newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)
#print(newiris1)
for i in range(30,49):x=Iris2[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:newiris2.extend(x)kind2=kind2+1elif g13<0 and g23<0 :newiris3.extend(x)
for i in range(30,50):x=Iris3[i,:]x=np.array([x])g12=np.dot(w12,x.T)+T12g13=np.dot(w13,x.T)+T13g23=np.dot(w23,x.T)+T23if g12>0 and g13>0:newiris1.extend(x)elif g12<0 and g23>0:     newiris2.extend(x)elif g13<0 and g23<0 :newiris3.extend(x)kind3=kind3+1
correct=(kind1+kind2+kind3)/60
print("样本类内离散度矩阵S1:",s1,'\n')
print("样本类内离散度矩阵S2:",s2,'\n')
print("样本类内离散度矩阵S3:",s3,'\n')
print('-----------------------------------------------------------------------------------------------')
print("总体类内离散度矩阵Sw12:",sw12,'\n')
print("总体类内离散度矩阵Sw13:",sw13,'\n')
print("总体类内离散度矩阵Sw23:",sw23,'\n')
print('-----------------------------------------------------------------------------------------------')
print('判断出来的综合正确率:',correct*100,'%')

四.python代码推导

1.数据生成

scikit-learn的接口来生成数据:

from sklearn.datasets import make_multilabel_classification
import numpy as npx, y = make_multilabel_classification(n_samples=20, n_features=2,n_labels=1, n_classes=1,random_state=2)  # 设置随机数种子,保证每次产生相同的数据。# 根据类别分个类
index1 = np.array([index for (index, value) in enumerate(y) if value == 0])  # 获取类别1的indexs
index2 = np.array([index for (index, value) in enumerate(y) if value == 1])  # 获取类别2的indexsc_1 = x[index1]   # 类别1的所有数据(x1, x2) in X_1
c_2 = x[index2]  # 类别2的所有数据(x1, x2) in X_2

2.fisher算法实现

def cal_cov_and_avg(samples):u1 = np.mean(samples, axis=0)cov_m = np.zeros((samples.shape[1], samples.shape[1]))for s in samples:t = s - u1cov_m += t * t.reshape(2, 1)return cov_m, u1
def fisher(c_1, c_2):cov_1, u1 = cal_cov_and_avg(c_1)cov_2, u2 = cal_cov_and_avg(c_2)s_w = cov_1 + cov_2u, s, v = np.linalg.svd(s_w)  # 奇异值分解s_w_inv = np.dot(np.dot(v.T, np.linalg.inv(np.diag(s))), u.T)return np.dot(s_w_inv, u1 - u2)

3.判定类别

def judge(sample, w, c_1, c_2):u1 = np.mean(c_1, axis=0)u2 = np.mean(c_2, axis=0)center_1 = np.dot(w.T, u1)center_2 = np.dot(w.T, u2)pos = np.dot(w.T, sample)return abs(pos - center_1) < abs(pos - center_2)
w = fisher(c_1, c_2)  # 调用函数,得到参数w
out = judge(c_1[1], w, c_1, c_2)   # 判断所属的类别
print(out)

4.绘图

在jupyter下面绘制需要添加以下代码:

%matplotlib inline

这样才能在当前的jupyter下显示出图片

import matplotlib.pyplot as plt
%matplotlib inline
plt.scatter(c_1[:, 0], c_1[:, 1], c='#99CC99')
plt.scatter(c_2[:, 0], c_2[:, 1], c='#FFCC00')
line_x = np.arange(min(np.min(c_1[:, 0]), np.min(c_2[:, 0])),max(np.max(c_1[:, 0]), np.max(c_2[:, 0])),step=1)line_y = - (w[0] * line_x) / w[1]
plt.plot(line_x, line_y)
plt.show()

运行结果:

五.理解和心得

Fisher判别法是判别分析的方法之一。Fisher判别法是一种投影方法,把高维空间的点向低维空间投影。在原来的坐标系下,可能很难把样品分开,而投影后可能区别明显。一般说,可以先投影到一维空间(直线)上,如果效果不理想,在投影到另一条直线上(从而构成二维空间),依此类推。每个投影可以建立一个判别函数。

六.参考链接

没有显示图片问题解决方案
fisher判别分析原理+python实现
波波
个人博客链接

python代码完成Fisher判别相关推荐

  1. Fisher判别的推导概念和过程+python代码实现(三分类)

    python代码完成Fisher判别的推导 一.Fisher算法的主要思想 二.Fisher数学算法步骤 ①计算各类样本均值向量 m i m_i mi​, m i m_i mi​是各个类的均值, N ...

  2. 机器学习--python代码实现基于Fisher的线性判别(鸢尾花数据集的分类)

    一.线性分类–判断该函数属于哪一类 先上例题,然后我会通过两种方法来判断该函数属于哪一类 1.图解法 定义 对于多类问题:模式有 ω1 ,ω2 , - , ωm 个类别,可分三种情况: 第一种情况:每 ...

  3. fisher判别_经典模式识别:Fisher线性判别

    本文将介绍Fisher线性判别的原理和具体实践,阅读时间约8分钟,关注公众号可在后台领取数据集资源哦^-^ Fisher线性判别 1.背景介绍 生活中我们往往会遇到具有高维特性的数据,如个人信息,天气 ...

  4. 【分类算法】Logistic算法原理、标准评分卡开发流程、python代码案例

    [博客地址]:https://blog.csdn.net/sunyaowu315 [博客大纲地址]:https://blog.csdn.net/sunyaowu315/article/details/ ...

  5. knn算法python代码_K-最近邻分类算法(KNN)及python实现

    一.引入 问题:确定绿色圆是属于红色三角形.还是蓝色正方形? KNN的思想: 从上图中我们可以看到,图中的数据集是良好的数据,即都打好了label,一类是蓝色的正方形,一类是红色的三角形,那个绿色的圆 ...

  6. 基于深度学习的人脸识别与管理系统(UI界面增强版,Python代码)

    摘要:人脸检测与识别是机器视觉领域最热门的研究方向之一,本文详细介绍博主自主设计的一款基于深度学习的人脸识别与管理系统.博文给出人脸识别实现原理的同时,给出Python的人脸识别实现代码以及PyQt设 ...

  7. 李航统计学习方法----感知机章节学习笔记以及python代码

    目录 1 感知机模型 2 感知机学习策略 2.1 数据集的线性可分性 2.2 感知机学习策略 3 感知机学习算法 3.1 感知机学习算法的原始形式 3.2 感知机算法的对偶形式 4 感知机算法pyth ...

  8. 数据挖掘学习——SOM网络聚类算法+python代码实现

    目录 1.SOM简述 2.SOM训练过程 (1)初始化 (2)采样(抽取样本点) (3)竞争 (4)合作和适应(更新权重值) (5)重复 3.python 代码实现 (1)初始化 (2)计算样本点和权 ...

  9. 2022年数模国赛C题(岭回归、区间预测、矩阵热力图、Fisher判别分类模型)——总结心得(附最后一次数模经历,Matlab\SPSS\Lingo的理解综合)

    文章目录 一.国赛 二.国赛代码展示 1.1 问题一 1.2 问题二 1.3 问题三 1.4 问题四 三.对于软件的理解 3.1 Matlab 3.1.1 表格的读取 3.1.2 元胞数组的相关函数的 ...

最新文章

  1. C++读写ini配置文件GetPrivateProfileString()WritePrivateProfileString()
  2. Selenium高亮页面对象
  3. 数据中心架构有哪些组件?
  4. centos6.5 安装docker方法
  5. python mysql l链式查询_使用python flask sqlacalchemy orm在PostgreSQL中联接查询
  6. 板邓:wordpress自定义登录页面实现用户登录
  7. 难以置信的美丽,世界的数学结构
  8. css自动换行加前置_StudyNode -- CSS
  9. js 调用webservice接口
  10. php空间搭建tcshare,新秀网 - 宝塔面板搭建天翼云盘目录列表TCShare
  11. NOIP1998车站
  12. TIOBE 12 月编程语言排行榜:Python 夺回前三,Go 跌出前十
  13. 关于调用ArcGIS中GP工具.Erase、SymDiff
  14. vue 获取安卓原生方法_H5-vue与原生Android、ios交互获取相册图片
  15. 位图转矢量图工具,快和模糊图片说白白
  16. Stereo Matching文献笔记之(九):经典算法Semi-Global Matching(SGM)之神奇的HMI代价计算~
  17. Linux系统之磁盘管理
  18. 数据分析师需要考试或考证吗?
  19. 【读书笔记->统计学】04-02 利用概率理论预测和决策-条件概率、概率树、全概率公式、贝叶斯定理、相关与独立概念简介
  20. 计算机视觉领域摄像头布置,几种深度摄像头简介 | 增强视觉 | 计算机视觉 增强现实...

热门文章

  1. 实现对mysql增删改查_Java语言实现对MySql数据库中数据的增删改查操作的代码
  2. 定制婚礼小程序开发功能
  3. python有什么颜色_Python中常见颜色记录
  4. sencha touch总结
  5. android获取uid,Android获得UID的办法
  6. openstack虚拟机的热迁移和疏散
  7. FSR薄膜压力传感器使用教程
  8. 纽约大学理工学院:MULTIMEDIA SIGNAL COMPRESSION: SPEECH AND
  9. html5 模仿语音聊天气泡,HTML5实现对话气泡动画方法
  10. 图形学-着色(Blinn-Phong模型)