一、线性判别-LDA

        线性分类:指存在一个线性方程可以把待分类数据分开,或者说用一个超平面能将正负样本区分开,表达式为y=wx,这里先说一下超平面,对于二维的情况,可以理解为一条直线,如一次函数。它的分类算法是基于一个线性的预测函数,决策的边界是平的,比如直线和平面。一般的方法有感知器,最小二乘法。

  非线性分类:指不存在一个线性分类方程把数据分开,它的分类界面没有限制,可以是一个曲面,或者是多个超平面的组合。

  LDA在模式识别领域(比如人脸识别,舰艇识别等图形图像识别领域)中有非常广泛的应用,因此我们有必要了解一下它的算法原理。不过在学习LDA之前,我们有必要将其与自然语言处理领域中的LDA区分开,在自然语言处理领域,LDA是隐含狄利克雷分布(Latent DIrichlet Allocation,简称LDA),它是一种处理文档的主题模型,我们本文讨论的是线性判别分析,因此后面所说的LDA均为线性判别分析。

二.线性分类算法-svm向量机

线性判别式分析(Linear Discriminant Analysis, LDA),也叫做Fisher线性判别(Fisher Linear Discriminant ,FLD),是模式识别的经典算法,它是在1996年由Belhumeur引入模式识别和人工智能领域的。性鉴别分析的基本思想是将高维的模式样本投影到最佳鉴别矢量空间,以达到抽取分类信息和压缩特征空间维数的效果,投影后保证模式样本在新的子空间有最大的类间距离和最小的类内距离,即模式在该空间中有最佳的可分离性。因此,它是一种有效的特征抽取方法。使用这种方法能够使投影后模式样本的类间散布矩阵最大,并且同时类内散布矩阵最小。就是说,它能够保证投影后模式样本在新的空间中有最小的类内距离和最大的类间距离,即模式在该空间中有最佳的可分离性。

三、实现

1.LDA算法-鸢尾花

代码

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets._samples_generator import make_classificationclass LDA():def Train(self, X, y):"""X为训练数据集,y为训练label"""X1 = np.array([X[i] for i in range(len(X)) if y[i] == 0])X2 = np.array([X[i] for i in range(len(X)) if y[i] == 1])# 求中心点mju1 = np.mean(X1, axis=0)  # mju1是ndrray类型mju2 = np.mean(X2, axis=0)# dot(a, b, out=None) 计算矩阵乘法cov1 = np.dot((X1 - mju1).T, (X1 - mju1))cov2 = np.dot((X2 - mju2).T, (X2 - mju2))Sw = cov1 + cov2# 计算ww = np.dot(np.mat(Sw).I, (mju1 - mju2).reshape((len(mju1), 1)))# 记录训练结果self.mju1 = mju1  # 第1类的分类中心self.cov1 = cov1self.mju2 = mju2  # 第2类的分类中心self.cov2 = cov2self.Sw = Sw  # 类内散度矩阵self.w = w  # 判别权重矩阵def Test(self, X, y):"""X为测试数据集,y为测试label"""# 分类结果y_new = np.dot((X), self.w)# 计算fisher线性判别式nums = len(y)c1 = np.dot((self.mju1 - self.mju2).reshape(1, (len(self.mju1))), np.mat(self.Sw).I)c2 = np.dot(c1, (self.mju1 + self.mju2).reshape((len(self.mju1), 1)))c = 1/2 * c2  # 2个分类的中心h = y_new - c# 判别y_hat = []for i in range(nums):if h[i] >= 0:y_hat.append(0)else:y_hat.append(1)# 计算分类精度count = 0for i in range(nums):if y_hat[i] == y[i]:count += 1precise = count / nums# 显示信息print("测试样本数量:", nums)print("预测正确样本的数量:", count)print("测试准确度:", precise)return precise
if '__main__' == __name__:# 产生分类数据n_samples = 500X, y = make_classification(n_samples=n_samples, n_features=2, n_redundant=0, n_classes=2,n_informative=1, n_clusters_per_class=1, class_sep=0.5, random_state=10)# LDA线性判别分析(二分类)lda = LDA()# 60% 用作训练,40%用作测试Xtrain = X[:299, :]Ytrain = y[:299]Xtest = X[300:, :]Ytest = y[300:]lda.Train(Xtrain, Ytrain)precise = lda.Test(Xtest, Ytest)# 原始数据plt.scatter(X[:, 0], X[:, 1], marker='o', c=y)plt.xlabel("x1")plt.ylabel("x2")plt.title("Test precise:" + str(precise))plt.show()

 2. 处理月亮数据集

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
class LDA():def Train(self, X, y):"""X为训练数据集,y为训练label"""X1 = np.array([X[i] for i in range(len(X)) if y[i] == 0])X2 = np.array([X[i] for i in range(len(X)) if y[i] == 1])# 求中心点mju1 = np.mean(X1, axis=0)  # mju1是ndrray类型mju2 = np.mean(X2, axis=0)# dot(a, b, out=None) 计算矩阵乘法cov1 = np.dot((X1 - mju1).T, (X1 - mju1))cov2 = np.dot((X2 - mju2).T, (X2 - mju2))Sw = cov1 + cov2# 计算ww = np.dot(np.mat(Sw).I, (mju1 - mju2).reshape((len(mju1), 1)))# 记录训练结果self.mju1 = mju1  # 第1类的分类中心self.cov1 = cov1self.mju2 = mju2  # 第1类的分类中心self.cov2 = cov2self.Sw = Sw  # 类内散度矩阵self.w = w  # 判别权重矩阵def Test(self, X, y):"""X为测试数据集,y为测试label"""# 分类结果y_new = np.dot((X), self.w)# 计算fisher线性判别式nums = len(y)c1 = np.dot((self.mju1 - self.mju2).reshape(1, (len(self.mju1))), np.mat(self.Sw).I)c2 = np.dot(c1, (self.mju1 + self.mju2).reshape((len(self.mju1), 1)))c = 1/2 * c2  # 2个分类的中心h = y_new - c# 判别y_hat = []for i in range(nums):if h[i] >= 0:y_hat.append(0)else:y_hat.append(1)# 计算分类精度count = 0for i in range(nums):if y_hat[i] == y[i]:count += 1precise = count / (nums+0.000001)# 显示信息print("测试样本数量:", nums)print("预测正确样本的数量:", count)print("测试准确度:", precise)return precise
if '__main__' == __name__:# 产生分类数据X, y = make_moons(n_samples=100, noise=0.15, random_state=42)# LDA线性判别分析(二分类)lda = LDA()# 60% 用作训练,40%用作测试Xtrain = X[:60, :]Ytrain = y[:60]Xtest = X[40:, :]Ytest = y[40:]lda.Train(Xtrain, Ytrain)precise = lda.Test(Xtest, Ytest)# 原始数据plt.scatter(X[:, 0], X[:, 1], marker='o', c=y)plt.xlabel("x1")plt.ylabel("x2")plt.title("Test precise:" + str(precise))plt.show()

3.对月亮数据集进行SVM分类

import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
import numpy as np
import matplotlib as mpl
from sklearn.datasets import make_moons
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
# 为了显示中文
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False#rc配置或rc参数,通过rc参数可以修改默认的属性,包括窗体大小、每英寸的点数、线条宽度、颜色、样式、坐标轴、坐标和网络属性、文本、字体等。
X, y = make_moons(n_samples=100, noise=0.15, random_state=42)#生成月亮数据集
def plot_dataset(X, y, axes):#绘制图形plt.plot(X[:, 0][y==0], X[:, 1][y==0], "bs")plt.plot(X[:, 0][y==1], X[:, 1][y==1], "g^")plt.axis(axes)plt.grid(True, which='both')plt.xlabel(r"$x_1$", fontsize=20)plt.ylabel(r"$x_2$", fontsize=20, rotation=0)plt.title("月亮数据",fontsize=20)
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()

4.月亮数据集的二分法:

import matplotlib.pyplot as plt
from sklearn.pipeline import Pipeline
import numpy as np
import matplotlib as mpl
from sklearn.datasets import make_moons
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
# 为了显示中文
mpl.rcParams['font.sans-serif'] = [u'SimHei']
mpl.rcParams['axes.unicode_minus'] = False
X, y = make_moons(n_samples=100, noise=0.15, random_state=42)
def plot_dataset(X, y, axes):plt.plot(X[:, 0][y==0], X[:, 1][y==0], "bs")plt.plot(X[:, 0][y==1], X[:, 1][y==1], "g^")plt.axis(axes)plt.grid(True, which='both')plt.xlabel(r"$x_1$", fontsize=20)plt.ylabel(r"$x_2$", fontsize=20, rotation=0)plt.title("月亮数据",fontsize=20)
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
polynomial_svm_clf = Pipeline([# 将源数据 映射到 3阶多项式("poly_features", PolynomialFeatures(degree=3)),# 标准化("scaler", StandardScaler()),# SVC线性分类器("svm_clf", LinearSVC(C=10, loss="hinge", random_state=42))])
polynomial_svm_clf.fit(X, y)
def plot_predictions(clf, axes):# 打表x0s = np.linspace(axes[0], axes[1], 100)x1s = np.linspace(axes[2], axes[3], 100)x0, x1 = np.meshgrid(x0s, x1s)X = np.c_[x0.ravel(), x1.ravel()]y_pred = clf.predict(X).reshape(x0.shape)y_decision = clf.decision_function(X).reshape(x0.shape)
#     print(y_pred)
#     print(y_decision)  plt.contourf(x0, x1, y_pred, cmap=plt.cm.brg, alpha=0.2)plt.contourf(x0, x1, y_decision, cmap=plt.cm.brg, alpha=0.1)
plot_predictions(polynomial_svm_clf, [-1.5, 2.5, -1, 1.5])
plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])
plt.show()

5. 改变C并运行(此处设置C的值分别为0.001,1,100,10000)

from sklearn.svm import SVC
gamma1, gamma2 = 0.1, 5
C1, C2 = 0.001, 100
hyperparams = (gamma1, C1), (gamma1, C2)
svm_clfs = []
for gamma, C in hyperparams:rbf_kernel_svm_clf = Pipeline([("scaler", StandardScaler()),("svm_clf", SVC(kernel="rbf", gamma=gamma, C=C))])rbf_kernel_svm_clf.fit(X, y)svm_clfs.append(rbf_kernel_svm_clf)
plt.figure(figsize=(11, 7))
for i, svm_clf in enumerate(svm_clfs):plt.subplot(221 + i)plot_predictions(svm_clf, [-1.5, 2.5, -1, 1.5])plot_dataset(X, y, [-1.5, 2.5, -1, 1.5])gamma, C = hyperparams[i]plt.title(r"$\gamma = {}, C = {}$".format(gamma, C), fontsize=16)
plt.tight_layout()
plt.show()

参考:线性判别准则和线性分类算法_m0_61811389的博客-CSDN博客

线性判别准则和线性分类算法相关推荐

  1. 线性判别准则与线性分类编程实践

    一.线性判别分析 (一)简介 线性判别分析(linear discriminant analysis,LDA)是对费舍尔的线性鉴别方法的归纳,这种方法使用统计学,模式识别和机器学习方法,试图找到两类物 ...

  2. 机器学习算法(九): 基于线性判别LDA模型的分类(基于LDA手写数字分类实践)

    机器学习算法(九): 基于线性判别模型的分类 1.前言:LDA算法简介和应用 1.1.算法简介 线性判别模型(LDA)在模式识别领域(比如人脸识别等图形图像识别领域)中有非常广泛的应用.LDA是一种监 ...

  3. 基于fisher线性判别法的分类器设计

    0.引言说明 这篇文章实际上是楼主上的模式识别课程的课堂报告,楼主偷懒把东西直接贴出来了.选择fisher判别法的原因主要是想学习一下这个方法,这个方法属于线性判别法,操作起来和lda判别法近乎没啥区 ...

  4. Python 分类问题研究-Fisher线性判别

    [实验目的] 1.掌握常见机器学习分类模型思想.算法,包括Fisher线性判别.KNN.朴素贝叶斯.Logistic回归.决策树等: 2.掌握Python编程实现分类问题,模型评价指标.计时功能.保存 ...

  5. 机器学习(六)分类模型--线性判别法、距离判别法、贝叶斯分类器

    机器学习(六)分类模型--线性判别法.距离判别法.贝叶斯分类器 首先我们了解常见的分类模型和算法有哪些 线性判别法 简单来说就是用一些规定来寻找某一条直线,用直线划分学习集,然后根据待测点在直线的哪一 ...

  6. 经典分类:线性判别分析模型!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:小雨姑娘,康涅狄格大学,Datawhale成员 这几天看了看SVM ...

  7. Fisher线性判别算法原理及实现 MATLAB

    Fisher线性判别算法原理及实现 MATLAB 一.Fisher判别器原理 二.代码实现 clc; close all; clear; %% 生成数据 rng(2020); %指定一个种子 mu1 ...

  8. 【数学与算法】支持向量机、线性判别 详细数学原理

    线性判别各种数学原理 推荐查看这篇博客:[线性分类器](一)线性判别 支持向量机各种数学原理 推荐查看这篇博客:[线性分类器](四)万字长文解释拉格朗日乘子与支持向量机

  9. R语言分类算法之线性判别分析(Linear Discriminant Analysis)

    1.线性判别原理解析 基本思想是"投影",即高纬度空间的点向低纬度空间投影,从而简化问题的处理.在原坐标系下,空间中的点可能很难被分开,如图8-1,当类别Ⅰ和类别Ⅱ中的样本点都投影 ...

最新文章

  1. stream流对象的理解及使用
  2. Android高效加载大图、多图解决方案,有效避免程序OOM
  3. C#线程系列讲座(3):线程池和文件下载服务器
  4. java的annotation_Java Annotation认知(包括框架图、详细介绍、示例说明)
  5. 把mac地址转换为标准mac地址
  6. java web 颜色灰色_网站动态变灰解决方案(java web项目网站)
  7. 【渝粤教育】国家开放大学2018年春季 0463-22T英语语音 参考试题
  8. SQL COALESCE函数和NULL
  9. Stanford CoreNLP 3.6.0 中文指代消解模块调用失败的解决方案
  10. L1-032. Left-pad-PAT团体程序设计天梯赛GPLT
  11. tensorflow之add_to_collection
  12. 13岁我们在做什么,现在20岁我又在做什么
  13. 淘宝API item_history_price - 获取商品历史价格信息
  14. 基于 Java 的 I Don’t Wanna Be The Bugger 冒险游戏【100010211】
  15. 腾讯在汉投资10亿 建设华中地区最大研发中心
  16. 斧乃木余接win10主题分享
  17. android基础知识13:AndroidManifest.xml文件解析【转载】
  18. windows的命令行(CMD)
  19. 怎样检查一张 SIMATIC 存储卡(SMC)有非一致性或者是格式错误?如何修复?
  20. 小程序动画animation向左移动效果

热门文章

  1. 【华为机试真题详解 Python实现】最差产品奖【2023 Q1 | 100分】
  2. Oauth2方式实现单点登录
  3. 英语日期中介词in on at for to的区别
  4. and门 simuilink_电力电子电路仿真-MATLAB和PSpice应用.PPT
  5. 全球与中国轮胎弦和轮胎面料市场运营动态及未来发展规划报告2022-2028年
  6. 大咖 | 舍恩伯格:相比“新石油”,大数据更应是削弱资本的“润滑脂”
  7. MIT离散数学二元关系笔记
  8. 使用mybatis plus自定义拦截器,实现数据权限
  9. TestPatten测试
  10. Jquery之ShowLoading遮罩组件