深入浅出线性判别分析(LDA),从理论到代码实现
©作者|善财童子
学校|西北工业大学
研究方向|机器学习/射频微波
在知乎看到一篇讲解线性判别分析(LDA,Linear Discriminant Analysis)的文章,感觉数学概念讲得不是很清楚,而且没有代码实现。所以童子在参考相关文章的基础上在这里做一个学习总结,与大家共勉,欢迎各位批评指正~~
注意:在不加说明的情况下,所有公式的向量均是列向量,这个也会反映到代码中。
本文的基本思路来自以下文章:
https://www.adeveloperdiary.com/data-science/machine-learning/linear-discriminant-analysis-from-theory-to-code/
基本概念和目标
线性判别分析是一种很重要的分类算法,同时也是一种降维方法(这个我还没想懂)。和 PCA 一样,LDA 也是通过投影的方式达到去除数据之间冗余的一种算法。
如下图所示的 2 类数据,为了正确的分类,我们希望这 2 类数据投影之后,同类的数据尽可能的集中(距离近,有重叠),不同类的数据尽可能的分开(距离远,无重叠),左图的投影不好,因为 2 类数据投影后有重叠,而右图投影之后可以很好地进行分类,因为投影之后的 2 类数据之间几乎没有重叠,只是类内重叠得很厉害,而这正是我们想要的结果。
正交投影
因为 LDA 用到了投影,所以这里有必要科普一下投影的知识。以二维平面为例,如图所示
我们要计算向量 在 上的投影 ,很显然 与 成比例关系:,其中 是一个常数。我们使用向量正交的概念来求出这个常数 。在上图中,向量 , 与 垂直,它们的内积为 0,即 ,即
注意:对于两个向量 x 和 y, ,所以有 。
假设 w 是一个单位向量,则 ,这样,对于任意向量 x,其在 w 上的投影 可表示为:
其中, 是一个常数。
对于一个数据集 ,其中 ,i=1,2,3,...m 是 d 维列向量。同样假设 w 是一个单位向量,那么每一个 在 w 的投影是:
上述公式的 是叫做 在 w 上的偏移或者坐标。这一系列的值 表示我们做了一个映射 ,即通过投影,我们将 d 维向量降维到了 1 维。
投影数据的均值
为简化起见,我们先假设有 2 类数据,定义样本 :,其中 。
我们再定义 :
其中 是类别, 是所有类别为 的样本的集合。所有数据 投影到 w 后,求其均值:
其中, 是 数据集的均值,同理 的均值是 ,投影后的均值 。为了使投影之后数据可正确地分类,我们希望这 2 类数据的中心离得越远越好,也就是要使 最大,但是单独这个条件并不能保证能够正确地对每一个数据进行分类,我们还需要考虑每一类数据的方差,大的方差表示 2 类数据之间有重叠,小的方差表示 2 类数据之间没有重叠。
LDA 并没有直接使用方差的计算公式,而是采用如下的定义:
这个有个名称叫 scatter matrix,本文暂时将其翻译成散步矩阵吧。
总结一下,LDA 主要就两点:
(1)最大化各类数据中心的距离,也就是各类数据的均值之间的距离要最大;
(2)各类数据的散步矩阵之和要小,也就是每个类别中的数据尽可能地集中。
将上述两点整合在一起,得到一个优化公式:
这个公式也叫做 Fisher LDA,这样,LDA 的问题就是关于 最优化上述的公式。我们重写上述公式如下:
同理有 :
这样:
这样,LDA 目标优化函数就可以重写为:
对公式(9)关于 求导,并令其导数为 0,可得:
整理得:
公式(11)中 做了替代:, 是一个常量。如果 S 是非奇异矩阵,那么公式(11)左乘 得到:
最终,LDA 问题其实就是求 对应最大特征值,而我们前面要求的投影方向就是最大特征值对应的特征向量,我们将 LDA 问题化成了矩阵的特征值和特征向量的问题了。
上述推导针对二分类问题进行的,对于多分类问题, 矩阵的计算方式不变,而 矩阵需要采用如下的公式计算:
其中:
C 表示类别的个数; 表示第 i 类中样本的个数; 表示第 i 类样本的均值; 表示整个样本的均值。
关于矩阵微分可参考如下文章:
https://zhuanlan.zhihu.com/p/24709748
https://zhuanlan.zhihu.com/p/24863977
这里提醒一下,对 关于 x 求导的结果是 ,如果 A 是对称矩阵,即 ,则 。公式(10)中因为 B 和 S 都是对称矩阵(由它们的定义可以看出是对称矩阵),所以对 关于 w 求导的结果是 2Bw ,即 ,同理 。
代码实现
import numpy as np
from sklearn import datasetsfrom sklearn.datasets import make_blobs
import matplotlib.pyplot as pltclass MyLDA:def __init__(self):passdef fit(self, X, y):# 获取所有的类别labels = np.unique(y)#print(labels)means = []for label in labels:# 计算每一个类别的样本均值means.append(np.mean(X[y == label], axis=0))# 如果是二分类的话if len(labels) == 2:mu = (means[0] - means[1])mu = mu[:,None] # 转成列向量B = mu @ mu.Telse:total_mu = np.mean(X, axis=0)B = np.zeros((X.shape[1], X.shape[1]))for i, m in enumerate(means):n = X[y==i].shape[0]mu_i = m - total_mumu_i = mu_i[:,None] # 转成列向量B += n * np.dot(mu_i, mu_i.T)# 计算S矩阵S_t = []for label, m in enumerate(means):S_i = np.zeros((X.shape[1], X.shape[1]))for row in X[y == label]:t = (row - m)t = t[:,None] # 转成列向量S_i += t @ t.TS_t.append(S_i)S = np.zeros((X.shape[1], X.shape[1]))for s in S_t:S += s# S^-1B进行特征分解S_inv = np.linalg.inv(S)S_inv_B = S_inv @ Beig_vals, eig_vecs = np.linalg.eig(S_inv_B)#从大到小排序ind = eig_vals.argsort()[::-1]eig_vals = eig_vals[ind]eig_vecs = eig_vecs[:, ind]return eig_vecs#构造数据集
def make_data(centers=3, cluster_std=[1.0, 3.0, 2.5], n_samples=150, n_features=2): X, y = make_blobs(n_samples, n_features, centers, cluster_std)return X, yif __name__ == "__main__":X, y = make_data(2, [1.0, 3.0])print(X.shape)lda = MyLDA()eig_vecs = lda.fit(X, y)W = eig_vecs[:, :1]colors = ['red', 'green', 'blue']fig, ax = plt.subplots(figsize=(10, 8))for point, pred in zip(X, y):# 画出原始数据的散点图ax.scatter(point[0], point[1], color=colors[pred], alpha=0.5)# 每个数据点在W上的投影proj = (np.dot(point, W) * W) / np.dot(W.T, W)#画出所有数据的投影ax.scatter(proj[0], proj[1], color=colors[pred], alpha=0.5)plt.show()
4.1 2类2个特征
if __name__ == "__main__":X, y = make_data(2, [1.0, 3.0]) #rint(X.shape)lda = MyLDA()eig_vecs = lda.fit(X, y)W = eig_vecs[:, :1]colors = ['red', 'green', 'blue']fig, ax = plt.subplots(figsize=(10, 8))for point, pred in zip(X, y):# 画出原始数据的散点图ax.scatter(point[0], point[1], color=colors[pred], alpha=0.5)# 每个数据点在W上的投影proj = (np.dot(point, W) * W) / np.dot(W.T, W)#画出所有数据的投影ax.scatter(proj[0], proj[1], color=colors[pred], alpha=0.5)plt.show()
运行结果是:
可见,数据投影后在 1 维上可以很好的分类。
4.2 3类2个特征
if __name__ == "__main__":# 3类X, y = make_data([[2.0, 1.0], [15.0, 5.0], [31.0, 12.0]], [1.0, 3.0, 2.5])print(X.shape)lda = MyLDA()eig_vecs = lda.fit(X, y)W = eig_vecs[:, :1]colors = ['red', 'green', 'blue']fig, ax = plt.subplots(figsize=(10, 8))for point, pred in zip(X, y):# 画出原始数据的散点图ax.scatter(point[0], point[1], color=colors[pred], alpha=0.5)# 每个数据点在W上的投影proj = (np.dot(point, W) * W) / np.dot(W.T, W)#画出所有数据的投影ax.scatter(proj[0], proj[1], color=colors[pred], alpha=0.5)plt.show()
运行结果是:
4.3 3类4个特征
if __name__ == "__main__":#X, y = load_data(cols, load_all=True, head=True)X, y = make_data([[2.0, 1.0], [15.0, 5.0], [31.0, 12.0]], [1.0, 3.0, 2.5], n_features=4)print(X.shape)lda = MyLDA()eig_vecs = lda.fit(X, y)# 取前2个最大特征值对应的特征向量W = eig_vecs[:, :2]# 将数据投影到这两个特征向量上,从而达到降维的目的transformed = X @ Wplt.subplots(figsize=(10, 8))plt.scatter(transformed[:, 0], transformed[:, 1], c=y, cmap=plt.cm.Set1)plt.show()
运行结果如下:
对上述结果使用 sklearn 官方实现的 LDA 进行对比验证:
if __name__ == "__main__":X, y = make_data([[2.0, 1.0], [15.0, 5.0], [31.0, 12.0]], [1.0, 3.0, 2.5], n_features=4)print(X.shape)lda = MyLDA()eig_vecs = lda.fit(X, y)# 取前2个最大特征值对应的特征向量W = eig_vecs[:, :2]# 将数据投影到这两个特征向量上,从而达到降维的目的transformed = X @ Wplt.subplots(figsize=(10, 8))plt.scatter(transformed[:, 0], transformed[:, 1], c=y, cmap=plt.cm.Set1)plt.title('self-implementation')from sklearn.discriminant_analysis import LinearDiscriminantAnalysissk_lda = LinearDiscriminantAnalysis()sk_lda.fit(X, y)transformed = sk_lda.transform(X)plt.subplots(figsize=(10, 8))plt.scatter(transformed[:, 0], transformed[:, 1], c=y, cmap=plt.cm.Set1)plt.title("sklearn's offical implementation")plt.show()
左图是本文实现的 LDA 分类结果,右图是官方实现的 LDA 分类结果,可见,两者的结果是一致的。
总结
LDA 是一个很强大的工具,但它是一个有监督的分类算法,PCA 是一个无监督的算法,这是和 PCA 的一个很重要的区别。
更多阅读
#投 稿 通 道#
让你的论文被更多人看到
如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。
总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。
PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得或技术干货。我们的目的只有一个,让知识真正流动起来。
???? 来稿标准:
• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向)
• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接
• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志
???? 投稿邮箱:
• 投稿邮箱:hr@paperweekly.site
• 所有文章配图,请单独在附件中发送
• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通
????
现在,在「知乎」也能找到我们了
进入知乎首页搜索「PaperWeekly」
点击「关注」订阅我们的专栏吧
关于PaperWeekly
PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。
深入浅出线性判别分析(LDA),从理论到代码实现相关推荐
- 『矩阵论笔记』线性判别分析(LDA)最全解读+python实战二分类代码+补充:矩阵求导可以参考
线性判别分析(LDA)最全解读+python实战二分类代码! 文章目录 一.主要思想! 二.具体处理流程! 三.补充二中的公式的证明! 四.目标函数的求解过程! 4.1.优化问题的转化 4.2.拉格朗 ...
- sklearn实现lda模型_运用sklearn进行线性判别分析(LDA)代码实现
基于sklearn的线性判别分析(LDA)代码实现 一.前言及回顾 本文记录使用sklearn库实现有监督的数据降维技术--线性判别分析(LDA).在上一篇LDA线性判别分析原理及python应用(葡 ...
- 线性判别分析LDA—西瓜书课后题3.5—MATLAB代码
题目:编程实现线性判别分析LDA,给出西瓜数据集 3.0a上的结果 简单说就是找一个分离度最大的投影方向,把数据投射上去. clc clear all [num,txt]=xlsread('D:\机器 ...
- 数据分享|R语言逻辑回归、线性判别分析LDA、GAM、MARS、KNN、QDA、决策树、随机森林、SVM分类葡萄酒交叉验证ROC...
全文链接:http://tecdat.cn/?p=27384 在本文中,数据包含有关葡萄牙"Vinho Verde"葡萄酒的信息(点击文末"阅读原文"获取完整代 ...
- ML之NB:基于news新闻文本数据集利用纯统计法、kNN、朴素贝叶斯(高斯/多元伯努利/多项式)、线性判别分析LDA、感知器等算法实现文本分类预测
ML之NB:基于news新闻文本数据集利用纯统计法.kNN.朴素贝叶斯(高斯/多元伯努利/多项式).线性判别分析LDA.感知器等算法实现文本分类预测 目录 基于news新闻文本数据集利用纯统计法.kN ...
- 【机器学习】机器学习之线性判别分析(LDA)
目录 一.线性判别分析介绍 二.线性判别分析原理 1. 类内散度矩阵(within-class scatter matrix) 2. 类间散度矩阵(between-class scatter matr ...
- 线性分类(二)-- 线性判别分析 LDA
在机器学习领域,LDA是两个常用模型的简称:线性判别分析(Linear Discriminant Analysis) 和隐含狄利克雷分布(Latent Dirichlet Allocation).在自 ...
- 机器学习 周志华 课后习题3.5 线性判别分析LDA
机器学习 周志华 课后习题3.5 线性判别分析LDA 照着书上敲了敲啥都不会,雀食折磨 python代码 # coding=UTF-8 from numpy import * # 我安装numpy的时 ...
- lda 吗 样本中心化 需要_机器学习 —— 基础整理(四):特征提取之线性方法——主成分分析PCA、独立成分分析ICA、线性判别分析LDA...
本文简单整理了以下内容: (一)维数灾难 (二)特征提取--线性方法 1. 主成分分析PCA 2. 独立成分分析ICA 3. 线性判别分析LDA (一)维数灾难(Curse of dimensiona ...
- 07_数据降维,降维算法,主成分分析PCA,NMF,线性判别分析LDA
1.降维介绍 保证数据所具有的代表性特性或分布的情况下,将高维数据转化为低维数据. 聚类和分类都是无监督学习的典型任务,任务之间存在关联,比如某些高维数据的分类可以通过降维处理更好的获得. 降维过程可 ...
最新文章
- NetworkX系列教程(10)-算法之三:关键路径问题
- POJ - 3614 Sunscreen(贪心/二分图最大匹配-多重匹配/网络流-最大流)
- linux中sz和rz的使用,在服务器和本地之间传输数据
- 值得永久收藏的 C# 设计模式套路(三)
- Dubbo(七)之自动加载环境变量
- CentOS新增用户并授予sudo权限
- 不同电脑 命名管道_电脑键盘上的F1到F12,这些键都有哪些用处?用了5年总算明白了...
- KVO - 观察自定义属性值
- 【NATS streaming】NATS streaming 简介与安装
- Socket开发框架之框架设计及分析
- 编写爬虫遇到的问题总结
- C++中的 求模运算 和 求余运算
- Time::HiRes, sleep(), time()
- 如何写期望薪资、离职原因、求职意向?
- Android 仿微信录制短视频(不使用 FFmpeg)
- 一些奇奇怪怪小问题汇总
- rpm包制作之openssh8.7升级
- AssetBundle.Unload(false/true)
- 【人工智能项目】机器学习中文垃圾邮件分类任务
- 八年成就开发梦——IT精英中的活雷锋郭红俊
热门文章
- 树结构遍历节点名字提取,这里提取的是el-tree数据结构,封装成函数
- 20155226 2016-2017-2 《Java程序设计》第一周学习总结
- DOM操作中,遍历动态集合的注意事项。ex: elem.children
- awk多分隔符功能及wc命令案列及企业级应用
- Swif基础语法01
- 如何做一个优秀的销售代表
- 一个不错的SQL储存过程分页,储存过程+Repeater,如果只是浏览数据的话,快就一个字...
- sign python_python机器学习
- linux命令拉取windows的文件,find命令、文件名后缀以及Linux和Windows互传文件(示例代码)...
- oa部署mysql_oa系统部署