西瓜书第三章:LDA(及详细Fisher实现),QDA的python实现[仅代码实现]

为了进行此实验,本人特地制作了一个训练集和一个测试集,这些测试集的参数如下:

  • 红点

    • N(1,05)
    • N(1,05)
  • 绿点
    • N(2,0.5)
    • N(2,1)

测试集具体形状如图所示:(不是训练集)
当然是用matlab生成的

如果我们采用LDA算法:

即使用python中的sklearn包LinearDiscriminantAnalysis
算法如下

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix,precision_score,accuracy_score,recall_score,f1_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import matplotlib.pyplot as plt
data=pd.read_csv("train.csv")
y_train=data['type'].tolist()
x_train=np.mat([data['x'].tolist(),data['y'].tolist()]).T
data2=pd.read_csv("new.csv")
y_test=data2['type'].tolist()
x_test=np.mat([data2['x'].tolist(),data2['y'].tolist()]).T
clf=LinearDiscriminantAnalysis()
clf.fit(x_train,y_train)
y_pred=clf.predict(x_test)
con=confusion_matrix(y_test, y_pred)
print(con)
print(accuracy_score(y_test, y_pred),precision_score(y_test, y_pred),recall_score(y_test,y_pred),f1_score(y_test, y_pred))
for i in range(0,len(y_test)):if y_pred[i]!=y_test[i]:if y_test[i]==1:plt.scatter(x_test[i,0],x_test[i,1], marker = '+', color = 'green', s = 40)#wrong1if y_test[i]==0:plt.scatter(x_test[i,0],x_test[i,1],marker = 'x', color = 'cyan', s = 40)#wrong0else:if y_test[i]==1:plt.scatter(x_test[i,0],x_test[i,1], marker = 'o', color = 'blue', s = 40)#right1if y_test[i]==0:plt.scatter(x_test[i,0],x_test[i,1], marker = 'v', color = 'magenta', s = 40)#right0
plt.legend(loc = 'best')
plt.show()

通过这个算法我们可以得到我们的最终判定结果

      其中,蓝色点是判断正确的正例,粉色三角是判断正确负例,天蓝色叉是判断错误的负例,绿色加号是判断错误的正例,可以知道,错误是难免的,我们只能找到一些尽可能正确的学习机

如果使用QDA算法

即借用QuadraticDiscriminantAnalysis包

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix,precision_score,accuracy_score,recall_score,f1_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis
import matplotlib.pyplot as plt
data=pd.read_csv("train.csv")
y_train=data['type'].tolist()
x_train=np.mat([data['x'].tolist(),data['y'].tolist()]).T
data2=pd.read_csv("new.csv")
y_test=data2['type'].tolist()
x_test=np.mat([data2['x'].tolist(),data2['y'].tolist()]).T
clf=QuadraticDiscriminantAnalysis()
clf.fit(x_train,y_train)
y_pred=clf.predict(x_test)
con=confusion_matrix(y_test, y_pred)
print(con)
print(accuracy_score(y_test, y_pred),precision_score(y_test, y_pred),recall_score(y_test,y_pred),f1_score(y_test, y_pred))
for i in range(0,len(y_test)):if y_pred[i]!=y_test[i]:if y_test[i]==1:plt.scatter(x_test[i,0],x_test[i,1], marker = '+', color = 'green', s = 40)#wrong1if y_test[i]==0:plt.scatter(x_test[i,0],x_test[i,1],marker = 'x', color = 'cyan', s = 40)#wrong0else:if y_test[i]==1:plt.scatter(x_test[i,0],x_test[i,1], marker = 'o', color = 'blue', s = 40)#right1if y_test[i]==0:plt.scatter(x_test[i,0],x_test[i,1], marker = 'v', color = 'magenta', s = 40)#right0
plt.legend(loc = 'best')
plt.show()

不得不说,这个包还是蛮好用的

我们会发现最终结果稍微好了一点

此时负例预测的更加正确了

最终我们不调用LinearDiscriminantAnalysis

进行所谓的Fisher判别

虽然借用了sklearn的一些功能,不过那无关紧要

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix,precision_score,accuracy_score,recall_score,f1_score
import matplotlib.pyplot as plt
def meanvector(x,y,num_of_type):nrow,ncol=x.shapenp.set_printoptions(precision=4)result=np.zeros((num_of_type,ncol))for i in range(0,ncol):tmp=x[:,i]for j in range(0,num_of_type):result[j,i]=np.mean(tmp[y==j],axis=1)result=result.Treturn resultdef sw(x,y,num_of_type):nrow,ncol=x.shapemean_vec=meanvector(x,y,num_of_type)part2=mean_vec[:,0]-mean_vec[:,1]sw=np.zeros((ncol,ncol))for i in range(0,ncol):for j in range(0,ncol):for k in range(0,nrow):type_this=y[k,0];sw[i,j]=sw[i,j]+(x[k,i]-mean_vec[type_this,i])*(x[k,j]-mean_vec[type_this,j])result=np.linalg.inv(sw)result=np.dot(result,part2)return result,mean_vecdata=pd.read_csv("train.csv")
y_train=np.mat(data['type'].tolist()).T
x_train=np.mat([data['x'].tolist(),data['y'].tolist()]).T
data2=pd.read_csv("new.csv")
y_test=np.mat(data2['type'].tolist()).T
x_test=np.mat([data2['x'].tolist(),data2['y'].tolist()]).Tw,mean_vec=sw(x_train,y_train,2)
w=w.T
mean0=w.dot(mean_vec[:,0])
mean1=w.dot(mean_vec[:,1])
mid=mean0/2+mean1/2
y_pred=np.zeros((len(y_test),1))
for i in range(0,len(y_test)):now=x_test[i,:].Tthis_mean=w.dot(now)if this_mean>mid:y_pred[i,0]=0else:y_pred[i,0]=1
con=confusion_matrix(y_test, y_pred)
print(con)
print(accuracy_score(y_test, y_pred),precision_score(y_test, y_pred),recall_score(y_test,y_pred),f1_score(y_test, y_pred))
for i in range(0,len(y_test)):if y_pred[i]!=y_test[i]:if y_test[i]==1:plt.scatter(x_test[i,0],x_test[i,1], marker = '+', color = 'green', s = 40)#wrong1if y_test[i]==0:plt.scatter(x_test[i,0],x_test[i,1],marker = 'x', color = 'cyan', s = 40)#wrong0else:if y_test[i]==1:plt.scatter(x_test[i,0],x_test[i,1], marker = 'o', color = 'blue', s = 40)#right1if y_test[i]==0:plt.scatter(x_test[i,0],x_test[i,1], marker = 'v', color = 'magenta', s = 40)#right0
plt.legend(loc = 'best')
plt.show()

我们最终得出了和LDA一样的答案,说明周志强在西瓜书中讲到的Fisher判别法和sklearn中的LDA算法是几乎相同的,因为我不是非常熟悉包内的内容,只能说,他们的效果是一样的。


可见QDA对于高斯分布数据的判断还是更优一些的,问题在于这种方法并不是很适合降维,所以主流的还是LDA占据一些优势,相关的资料也更多。但是他们的思想还是很近似的。

西瓜书第三章:LDA(及详细Fisher实现),QDA的python实现[仅代码实现]相关推荐

  1. 西瓜书第三章阅读笔记

    西瓜书第三章阅读笔记 第三章 线性模型 1.机器学习三要素 2.基本形式 3.线性回归 3.1 模型 3.2 策略 3.3 求解算法 4.对数几率回归 4.1 模型 4.2 策略 4.3 求解算法 5 ...

  2. 周志华-机器学习西瓜书-第三章习题3.3 编程实现对率回归

    本文为周志华机器学习西瓜书第三章课后习题3.3答案,编程实现对率回归,数据集为书本第89页的数据 使用tensorflow实现过程 # coding=utf-8 import tensorflow a ...

  3. 小白学机器学习西瓜书-第三章对数几率回归

    小白学机器学习西瓜书-第三章对数几率回归 3.3 对数几率回归 3.3.1 对数几率函数 3.3.1 估计参数 上一部分我们介绍了线性回归,包括简单的二元回归和多元回归,这两个主要解决的是拟合预测的问 ...

  4. 【吃瓜教程】周志华机器学习西瓜书第三章答案

    线性模型结构梳理 3.1 试析在什么情形下式3.2中不必考虑偏置项b 答案一: 偏置项b在数值上代表了自变量取0时,因变量的取值: 1.当讨论变量x对结果y的影响,不用考虑b: 2.可以用变量归一化( ...

  5. 周志华西瓜书第三章学习笔记

    第三章学习笔记 文章目录 第三章学习笔记 1.知识脉络 2.我的笔记 参考 1.知识脉络 2.我的笔记 这一章公式推导实在太多了,需要补充的推导过程也有很多,就不写电子档了.扩展公式推导和LDA部分补 ...

  6. 《机器学习》西瓜书第三章回归(南瓜书辅助)

    第三章 3.1一元线性回归 假如说现在有一个正态分布,正态分布由mu和sigama决定,极大似然估计就是用来确定正态分布的这两个参数的 3.2多元线性回归 对线性回归方程进行化简 将 b=wd+1∗1 ...

  7. 正则表达式引擎的构建——基于编译原理DFA(龙书第三章)——1 概述

    说明:本系列文章介绍的算法均来自编译原理(龙书)一书,如果读者对代码没有兴趣,只想了解算法思路,完全可以阅读龙书相关章节内容,比我讲得清晰透彻. 序: 啃编译原理半年以来,任然徘徊在前4章,其间反反复 ...

  8. 【编译原理】龙书第三章作业答案

    [编译原理]龙书第三章作业答案 练习3.1.1:根据3.1.2节中的讨论,将下面的C++程序划分成正确的词素序列.哪些词素应该有相关联的词法值?应该具有什么值? 答案: 左列为词素,右列为值,划分如下 ...

  9. 西瓜书第四章阅读笔记

    西瓜书第四章阅读笔记 1.基本概念 1.1 基本算法 1.2 信息熵 1.3 信息增益 2.ID3决策树 3.C4.5决策树 4.CART决策树 5.剪枝操作 6.连续与缺失值处理 7.多变量决策树 ...

最新文章

  1. swift 多线程GCD和延时调用
  2. 闭门沙龙招募:吃吃喝喝聊CG | 真格×量子位
  3. 今天才发现ff不支持navigate。
  4. 为 Visual Studio 安装数据库工具
  5. textaligncenter仍然不居中_戊唑醇和己唑醇都是杀菌剂,有啥不同?真正懂的人不多...
  6. java在图片下方写文字_Java画图给图片底部添加文字标题
  7. [thinkphp] 是如何输出一个页面的
  8. python运行结果图_[宜配屋]听图阁
  9. 用vector实现二维向量
  10. 回归、插值、逼近、拟合的区别
  11. Audiority TS-1 Transient Shaper for Mac(TS-1瞬态整形器)
  12. mysql sslcipher_解决mysql数据库创建用户报错Field 'ssl_cipher' doesn't have a default value...
  13. java aapt linux_Android:linux下aapt使用 | 学步园
  14. qt 两界面类操作另外一个界面的的ui控件;以及会出现的the class containing “ui::XXX”cound not be found...Please verify the .
  15. 总结2019,立2020flag
  16. 樊登读书赋能读后感_樊登读书会读后感01012019
  17. 微信小程序上传文件详解
  18. 蓝桥杯——单片机赛道
  19. 刚进职场的程序员,和工作了2、3年的程序员到底有什么不一样?
  20. 柯尼卡美能达306i提示更换感光鼓定影单元 清零方法

热门文章

  1. 你这一生还能陪妈妈几天?来看看
  2. Python笔记 | 数据筛选
  3. [Android Studio]微型技术报告-手机平台应用开发
  4. 认计算机电源,电脑硬件认识之什么是电脑的电源[图文]
  5. Markov Chains
  6. Vue:从单页面到工程化项目
  7. createjs初学-关于Ticker
  8. qemu图形界面linux,QEMU 简单几步搭建一个虚拟的ARM开发板
  9. 一文搞懂0.1UF和10UF电容并联使用技巧
  10. 直流电机控制 pwm 和 pid 算法