GDA Python代码

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2018/8/812:56
# @Author  : DaiPuWei
# E-Mail   : 771830171@qq.com
# @Site    : 湖北省荆州市公安县自强中学
# @File    : GDA.py
# @Software: PyCharmimport numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
import matplotlib as mpl
import matplotlib.pyplot as pltclass GDA:def __init__(self,train_data,train_label):"""这是GDA算法构造函数:param train_data: 训练数据:param train_label: 训练数据标签"""self.Train_Data = train_dataself.Train_Label = train_labelself.postive_num = 0                                                    # 正样本个数self.negetive_num = 0                                                   # 负样本个数postive_data = []                                                       # 正样本数组negetive_data = []                                                      # 负样本数组for (data,label) in zip(self.Train_Data,self.Train_Label):if label == 1:          # 正样本self.postive_num += 1postive_data.append(list(data))else:                   # 负样本self.negetive_num += 1negetive_data.append(list(data))# 计算正负样本的二项分布的概率row,col = np.shape(train_data)self.postive = self.postive_num*1.0/row                                 # 正样本的二项分布概率self.negetive = 1-self.postive                                          # 负样本的二项分布概率# 计算正负样本的高斯分布的均值向量postive_data = np.array(postive_data)negetive_data = np.array(negetive_data)postive_data_sum = np.sum(postive_data, 0)negetive_data_sum = np.sum(negetive_data, 0)self.mu_positive = postive_data_sum*1.0/self.postive_num                # 正样本的高斯分布的均值向量self.mu_negetive = negetive_data_sum*1.0/self.negetive_num              # 负样本的高斯分布的均值向量# 计算高斯分布的协方差矩阵positive_deta = postive_data-self.mu_positivenegetive_deta = negetive_data-self.mu_negetiveself.sigma = []for deta in positive_deta:deta = deta.reshape(1,col)ans = deta.T.dot(deta)self.sigma.append(ans)for deta in negetive_deta:deta = deta.reshape(1,col)ans = deta.T.dot(deta)self.sigma.append(ans)self.sigma = np.array(self.sigma)#print(np.shape(self.sigma))self.sigma = np.sum(self.sigma,0)self.sigma = self.sigma/rowself.mu_positive = self.mu_positive.reshape(1,col)self.mu_negetive = self.mu_negetive.reshape(1,col)def Gaussian(self, x, mean, cov):"""这是自定义的高斯分布概率密度函数:param x: 输入数据:param mean: 均值向量:param cov: 协方差矩阵:return: x的概率"""dim = np.shape(cov)[0]# cov的行列式为零时的措施covdet = np.linalg.det(cov + np.eye(dim) * 0.001)covinv = np.linalg.inv(cov + np.eye(dim) * 0.001)xdiff = (x - mean).reshape((1, dim))# 概率密度prob = 1.0 / (np.power(np.power(2 * np.pi, dim) * np.abs(covdet), 0.5)) * \np.exp(-0.5 * xdiff.dot(covinv).dot(xdiff.T))[0][0]return probdef predict(self,test_data):predict_label = []for data in test_data:positive_pro = self.Gaussian(data,self.mu_positive,self.sigma)negetive_pro = self.Gaussian(data,self.mu_negetive,self.sigma)if positive_pro >= negetive_pro:predict_label.append(1)else:predict_label.append(0)return predict_labeldef run_main():"""这是主函数"""# 导入乳腺癌数据breast_cancer = load_breast_cancer()data = np.array(breast_cancer.data)label = np.array(breast_cancer.target)data = MinMaxScaler().fit_transform(data)# 解决画图是的中文乱码问题mpl.rcParams['font.sans-serif'] = [u'simHei']mpl.rcParams['axes.unicode_minus'] = False# 分割训练集与测试集train_data,test_data,train_label,test_label = train_test_split(data,label,test_size=1/4)# 数据可视化plt.scatter(test_data[:,0],test_data[:,1],c = test_label)plt.title("乳腺癌数据集显示")plt.show()# GDA结果gda = GDA(train_data,train_label)test_predict = gda.predict(test_data)print("GDA的正确率为:",accuracy_score(test_label,test_predict))# 数据可视化plt.scatter(test_data[:,0],test_data[:,1],c = test_predict)plt.title("GDA分类结果显示")plt.show()# Logistic回归结果lr = LogisticRegression()lr.fit(train_data,train_label)test_predict = lr.predict(test_data)print("Logistic回归的正确率为:",accuracy_score(test_label,test_predict))# 数据可视化plt.scatter(test_data[:,0],test_data[:,1],c = test_predict)plt.title("Logistic回归分类结果显示")plt.show()if __name__ == '__main__':run_main()

结果分析

以下是上述程序的结果截图: 
 
明显看到,GDA的效果略微差于Logistic回归,这也证实了GDA的模型假设更强,当数据不是特别服从高斯分布时,效果略差于LR。LR更具备鲁棒性,实用性更强

高斯判别分析(GDA)Python代码相关推荐

  1. 高斯判别分析(GDA)和朴素贝叶斯(NB)

    生成模型和判别模型 监督学习一般学习的是一个决策函数y=f(x)y=f(x)y=f(x)或者是条件概率分布p(y∣x)p(y|x)p(y∣x). 判别模型直接用数据学习这个函数或分布,例如Linear ...

  2. 生成学习算法.高斯判别分析(GDA).GDA与Logistic模型

    http://blog.csdn.net/v1_vivian/article/details/52190572 <Andrew Ng 机器学习笔记>这一系列文章文章是我再观看Andrew ...

  3. 经典机器学习算法:高斯判别分析GDA

    高斯判别分析介绍 高斯判别分析 GDA GDA模型 模型求解 具体计算 高斯判别分析 GDA GDA:Guassian Discrimant Analysis 高斯判别分析属于两分类.软分类.概率生成 ...

  4. Gaussian Discriminative Analysis 高斯判别分析 GDA

    Gaussian Discriminative Analysis 高斯判别分析 GDA Multidimensional Gaussian Model z ∼ N ( μ ⃗ , Σ ) z \sim ...

  5. 高斯判别分析(GDA)——含python代码

    基本数学知识 多元正态分布   多元正态分布是正态分布在多维变量下的扩展,它的参数是一个均值向量(mean(mean(mean vector)μvector)μvector)\mu和协方差矩阵(cov ...

  6. 高斯判别分析GDA的简单python实现

    参考文章:https://blog.csdn.net/qq_30091945/article/details/81508055 作为机器学习的小白,最近将GDA给简单实现了,有很多不足的地方,欢迎大家 ...

  7. 线性分类(四)-- 高斯判别分析 GDA

    高斯判别分析(Gaussian Discriminant analysis,GDA),与之前的线性回归和Logistic回归从方法上讲有很大的不同,GDA是一种生成学习算法(Generative Le ...

  8. 高斯判别分析- GDA原理简介

    GDA是生成学习方法的一个典型代表. 判别学习方法是直接对P(y | x)进行建模,也就是说生成学习方法学到的是P(y | x)这样一个条件概率:另外一种判别学习方法是直接输出hθ(x)hθ(x)h_ ...

  9. 【备忘】高斯判别分析(GDA)参数手推记录

    高斯判别模型是通过最大化贝叶斯模型中的最大后验概率为目标进行训练模型,是一个非常典型的生成模型,假设服从高斯分布,服从伯努利分布,通过训练数据集来确定正态分布与伯努利分布中的各项参数完善模型.对于新的 ...

  10. 【机器学习】线性分类——高斯判别分析GDA(理论+图解+公式推导)

最新文章

  1. java integer int 比较_java Integer和int之间的比较问题是什么?
  2. JavaScript函数式编程学习
  3. 【干货】2014年iOS推广四大秘籍
  4. 1.3(java学习笔记)构造方法及重载
  5. wxWidgets:wxFont概览
  6. chromebook刷机_如何从Chromebook上的APK侧面加载Android应用
  7. android 8.0 用户体验优化--day02
  8. 我可以隐藏HTML5号码输入的旋转框吗?
  9. CPP-week fourteen
  10. wordpress上传文件报错的解决方法(413 Request Entity Too Large、超过upload_max_filesize文件中定义的php.ini值)
  11. 防抖 节流_坚持造轮子第二天 防抖与节流
  12. Thinkpad X230 黑苹果macOS 10.14 和10. 15驱动AR9285网卡
  13. 如何将MBR分区转换成GPT分区
  14. deepin更新启动项_Deepin修复启动项菜单---grub2启动修复
  15. 怎么样在家拍出好看的证件照?标准证件照拍摄技巧分享
  16. java实现业务模块的热插拔_如何来实现SpringBoot应用的JPA数据持久化和热插拔
  17. 简单的云平台基础环境的构建(一)
  18. jQuery的css()如何修改背景图片
  19. 习题整理(简单01背包 可用查并集2022/4/24)
  20. 正在为首次使用计算机做准备黑屏,Windows 10首次启动时意外重启计算机或遇到错误的解决方案...

热门文章

  1. 10秒钟脱口而出两位数的平方
  2. 国外程序员也都是996么?
  3. 我的NVIDIA开发者之旅-Jetson Nano 2gb教你怎么训练模型(完整的模型训练套路)
  4. 读懂微信:从1.0到7.0版本,一个主流IM社交工具的进化史...
  5. 浙大版《C语言程序设计实验与习题指导(第4版)》题目集 实验2-3-2 计算摄氏温度
  6. 利用tshark对网络数据包做进一步的分析
  7. 《众妙之门——Web用户体验设计与可用性测试》一2.2 从数字上看:行为应答...
  8. 使用服务网格提升应用和网络安全
  9. spring实战笔记_第4章
  10. pip install mysqlclient安装