本节为吴恩达教授机器学习课程第四部分,生成学习算法(1),包括:多元正态分布,高斯判别分析模型GDA以及及GDA与逻辑回归的关系,并附上高斯判别分析的python实现代码

  之前关于该部分的学习算法尝试对给定x条件下y的分布建模,即对p(y∣x;θ)p(y|x;\theta)p(y∣x;θ)建模。比如逻辑回归对p(y∣x;θ)p(y|x;\theta)p(y∣x;θ)建模,且hθ(x)=g(θTx)h_{\theta}(x)=g(\theta^Tx)hθ​(x)=g(θTx),其中g是sigmoid函数。
  考虑一个二分类问题,判断一个动物是象(y=1)还是狗(y=0)。给定训练集,逻辑回归或者基础的感知机算法都是尝试找到一个决策边界来分离象和狗。当输入测试样本时,算法会看该样本位于决策边界的哪边,以此进行预测。
  本节介绍的生成学习算法不同于此,此算法首先建立一个大象的模型,再建立一个狗的模型。当输入测试样本时,我们将该样本分别与两个模型进行比对,更像那个模型就输出哪一类。
  前者如逻辑回归等算法,尝试对p(y∣x)p(y|x)p(y∣x)建模,称为判别学习算法。后者对p(x∣y)p(x|y)p(x∣y)和p(y)p(y)p(y)建模,称为生成学习算法,即p(x∣y=0)p(x|y=0)p(x∣y=0)对狗的特征进行建模,p(x∣y=1)p(x|y=1)p(x∣y=1)对象的特征进行建模。
  完成对p(x∣y)p(x|y)p(x∣y)和类别先验p(y)p(y)p(y)建模后,算法通过贝叶斯法则得到给定x条件下y的后验分布:

  其中分母p(x)=p(x∣y=1)p(y=1)+p(x∣y=0)p(y=0)p(x)=p(x|y=1)p(y=1)+p(x|y=0)p(y=0)p(x)=p(x∣y=1)p(y=1)+p(x∣y=0)p(y=0),但实际上我们并不需要计算这个分母因为:

1. 高斯判别分析GDA

  第一个生成学习算法即GDA,这个模型中,我们假设p(x∣y)p(x|y)p(x∣y)服从多元正态分布,在学习GDA之前,先学习一下多元正态分布的基本知识。

1.1 多元正态分布

  nnn维多元正态分布又称多元高斯分布,参数维均值向量μ∈Rn\mu \in R^nμ∈Rn和协方差矩阵Σ∈Rn×n\Sigma \in R^{n \times n}Σ∈Rn×n,并且Σ\SigmaΣ时非负对称半正定矩阵,多元高斯分布的概率密度可以写作:

  其中∣Σ∣|\Sigma|∣Σ∣表示矩阵的行列式。

正定矩阵和半正定矩阵

  对于服从多元正态分布的随机变量来说,有:

  又因为Cov(Z)可以如下计算

  即:

  所以有:

  下面给出高斯分布的一些图像以更好理解:

  上图中均值为 2x1维零向量,协方差矩阵从左到右依次为I、0.6I、2I(2×2)I、0.6I、2I(2 \times 2)I、0.6I、2I(2×2),第一个也称为标准分布。

  上图中均值也都为零向量,协方差矩阵从左到右依次为:

  对应地,其在底面上的投影为:

  变化协方差矩阵如下:

  得到的高斯分布底面投影如下所示:

  而下图表示的是协方差矩阵为单位矩阵时,变化均值得到的高斯分布图像:

  其中均值分别为:

1.2 高斯判别分析模型

  给定一个分类问题,输入的特征xxx是连续型随机变量,则可以使用GDA模型,使用多元正态分布对p(x∣y)p(x|y)p(x∣y)进行建模,模型为:

  分布可以写为:

  模型的参数为φ、Σ、μ0、μ1\varphi、\Sigma、\mu_0、\mu_1φ、Σ、μ0​、μ1​,数据的对数似然函数可以写为:

  这个对数似然函数也叫做joint likelihood(逻辑回归里的对数似然函数叫做conditional likelihood),通过最大化这个对数似然函数,我们可以得到参数的极大似然估计:

  其中μ0\mu_0μ0​的分子即对所有的负样本的特征x求均值。
  算法要做的事情其实如下图所示:

  图中画出了所有的训练样本和两个高斯分布的等值线(表示两个不同的类别),两个分布由于协方差矩阵相同所以等值线形状也相同,但是他们的均值不同的所以位置不同,图中的直线是一个决策边界,上面的点有p(y=1∣x)=0.5p(y=1|x)=0.5p(y=1∣x)=0.5,直线两边分别表示不同的类别。

1.3 GDA和逻辑回归的比较

  GDA和逻辑回归之间确实存在这某种有趣的关系,如果我们将p(y=1∣x;φ,μ0,μ1,Σ)p(y=1|x;\varphi,\mu_0,\mu_1,\Sigma)p(y=1∣x;φ,μ0​,μ1​,Σ)视为xxx的函数,那么有:

  其中θ\thetaθ是φ,μ0,μ1,Σ\varphi,\mu_0,\mu_1,\Sigmaφ,μ0​,μ1​,Σ的函数,这和逻辑回归的形式是相同的。
  然而GDA和逻辑回归在给定相同数据集时会训练得到不同的决策边界,哪一个更好呢?
  总的来说,GDA做了更强的建模假设,当假设正确或者接近正确时模型利用数据的效率更高(需要较少的训练数据就能的到很好的效果);逻辑回归建模假设相对较弱,因此对于偏差更加鲁棒,而且当数据并不服从高斯分布时,逻辑回归由于GDA。基于这个原因,逻辑回归在实践中比GDA更加常用。附上GDA的代码。

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
import matplotlib as mpl
import matplotlib.pyplot as pltclass GDA:def __init__(self,train_data,train_label):self.Train_Data = train_dataself.Train_Label = train_labelself.postive_num = 0  self.negetive_num = 0     postive_data = []negetive_data = []for (data,label) in zip(self.Train_Data,self.Train_Label):if label == 1:          # 正样本self.postive_num += 1postive_data.append(list(data))else:                   # 负样本self.negetive_num += 1negetive_data.append(list(data))row,col = np.shape(train_data)# 计算正负样本的高斯分布的均值向量postive_data = np.array(postive_data)negetive_data = np.array(negetive_data)postive_data_sum = np.sum(postive_data, 0)negetive_data_sum = np.sum(negetive_data, 0)self.mu_positive = postive_data_sum*1.0/self.postive_num                # 正样本的高斯分布的均值向量self.mu_negetive = negetive_data_sum*1.0/self.negetive_num              # 负样本的高斯分布的均值向量# 计算高斯分布的协方差矩阵positive_deta = postive_data-self.mu_positivenegetive_deta = negetive_data-self.mu_negetiveself.sigma = []for deta in positive_deta:deta = deta.reshape(1,col)ans = deta.T.dot(deta)self.sigma.append(ans)for deta in negetive_deta:deta = deta.reshape(1,col)ans = deta.T.dot(deta)self.sigma.append(ans)self.sigma = np.array(self.sigma)#print(np.shape(self.sigma))self.sigma = np.sum(self.sigma,0)self.sigma = self.sigma/rowself.mu_positive = self.mu_positive.reshape(1,col)self.mu_negetive = self.mu_negetive.reshape(1,col)def Gaussian(self, x, mean, cov):dim = np.shape(cov)[0]# cov的行列式为零时的措施covdet = np.linalg.det(cov + np.eye(dim) * 0.001)covinv = np.linalg.inv(cov + np.eye(dim) * 0.001)xdiff = (x - mean).reshape((1, dim))# 概率密度prob = 1.0 / (np.power(np.power(2 * np.pi, dim) * np.abs(covdet), 0.5)) * \np.exp(-0.5 * xdiff.dot(covinv).dot(xdiff.T))[0][0]return probdef predict(self,test_data):predict_label = []for data in test_data:positive_pro = self.Gaussian(data,self.mu_positive,self.sigma)negetive_pro = self.Gaussian(data,self.mu_negetive,self.sigma)if positive_pro >= negetive_pro:predict_label.append(1)else:predict_label.append(0)return predict_labeldef main():# 导入乳腺癌数据,scikit-learn自带的breast_cancer = load_breast_cancer()data = np.array(breast_cancer.data)label = np.array(breast_cancer.target)data = MinMaxScaler().fit_transform(data)## 解决画图是的中文乱码问题# mpl.rcParams['font.sans-serif'] = [u'simHei']#mpl.rcParams['axes.unicode_minus'] = False# 分割训练集与测试集train_data,test_data,train_label,test_label = train_test_split(data,label,test_size=3/7)# 数据可视化plt.scatter(test_data[:,0],test_data[:,1],c = test_label)plt.title("乳腺癌数据集显示")plt.show()# GDA结果gda = GDA(train_data,train_label)test_predict = gda.predict(test_data)print("GDA的正确率为:",accuracy_score(test_label,test_predict))# 数据可视化plt.scatter(test_data[:,0],test_data[:,1],c = test_predict)plt.title("GDA分类结果显示")plt.show()## Logistic回归结果# lr = LogisticRegression()# lr.fit(train_data,train_label)# test_predict = lr.predict(test_data)# print("Logistic回归的正确率为:",accuracy_score(test_label,test_predict))## 数据可视化# plt.scatter(test_data[:,0],test_data[:,1],c = test_predict)# plt.title("Logistic回归分类结果显示")# plt.show()if __name__ == '__main__':main()

欢迎扫描二维码关注微信公众号 深度学习与数学   [每天获取免费的大数据、AI等相关的学习资源、经典和最新的深度学习相关的论文研读,算法和其他互联网技能的学习,概率论、线性代数等高等数学知识的回顾]

吴恩达教授机器学习课程笔记【四】- 生成学习算法(1)高斯判别分析模型相关推荐

  1. 吴恩达《机器学习》笔记汇总

    根据学习进度,将课程分为15部分进行笔记,具体内容如下: 吴恩达机器学习(一)-- 简介 吴恩达机器学习(二)-- 线性回归 吴恩达机器学习(三)-- Logisitic回归 吴恩达机器学习(四)-- ...

  2. 吴恩达《机器学习》笔记(一)【线性回归梯度下降法】

    通过在网易云课堂学习吴恩达先生的<机器学习>课程,为了巩固自己的学习且方便读者们共同交流学习,特此做此学习笔记,希望与大家共勉. 吴恩达<机器学习>课程链接:https://s ...

  3. 【机器学习 吴恩达】CS229课程笔记notes3翻译-Part V支持向量机

    CS229 课程笔记 吴恩达 Part V 支持向量机 这部分展现了支持向量机(SVM)学习算法.SVM是最好的监督学习算法之一.为了讲述SVM,我们需要首先谈论边界和用大的间隔分离数据.接下来,我们 ...

  4. 【AI】吴恩达斯坦福机器学习中文笔记汇总

    1.吴恩达机器学习和深度学习课程的字幕翻译以及笔记整理参见: 以黄海广博士为首的一群机器学习爱好者发起的公益性质项目(http://www.ai-start.com). 2.黄海广博士公益项目介绍 h ...

  5. 吴恩达《机器学习》第四章:多元线性回归

    目录 四.多元线性回归 4.1 特征缩放 4.2 学习率α 4.4 特征和多项式 4.4 正规方程 四.多元线性回归 多特征下的假设形式: 4.1 特征缩放 特征缩放:Feature Scaling, ...

  6. Chapter5:Octave教程:AndrewNg吴恩达《机器学习》笔记

    文章目录 Chapter 5 : Octave 教程 5.1 基本操作 5.2 移动数据 5.3 计算数据 5.4 绘图数据 5.5 控制语句:for,while,if语句 5.6 向量化 5.7 工 ...

  7. Chapter1:监督学习、无监督学习:AndrewNg吴恩达《机器学习》笔记

    文章目录 Chapter 1 Introduction 1.1 Welcome 1.2 Definition 1.2.1 定义1: --from **Arthur Samuel** 1.2.2 定义2 ...

  8. 课程笔记|吴恩达Coursera机器学习 Week1 笔记-机器学习基础

    1 1. Introduction 1.1 Supervised Learning 已知输入x以及其对应的标签y,求解 f:x→y 回归 regression:输出的结果y是一个连续的变量 y=ℝ 分 ...

  9. 【机器学习 吴恩达】2022课程笔记(持续更新)

    一.机器学习 1.1 机器学习定义 计算机程序从经验E中学习,解决某一任务T,进行某一性能P,通过P测定在T上的表现因经验E而提高 eg:跳棋程序 E: 程序自身下的上万盘棋局 T: 下跳棋 P: 与 ...

  10. 吴恩达AI机器学习-01神经网络与深度学习week2下-神经网络基础 python中的广播

    ‼️博客为作者学习回顾知识点所用,并非商用,如有侵权,请联系作者删除‼️ 目录 2.15Python中的广播 python广播中的规则 2.16Python numpy 向量的注释 排除bug的技巧 ...

最新文章

  1. ant 使用常见问题
  2. 这里先发布一个,自己写得unityUI的适配的方案(插播)
  3. C# 时间+三位随机数
  4. HD 1525 Euclid's Game
  5. 从零点五开始用Unity做半个2D战棋小游戏(九)
  6. struts基础配置
  7. 疯狂动物消消乐html5游戏在线玩,疯狂动物消消乐免费
  8. 归档-软考部分科目的考察内容
  9. MySQL事务(脏读、不可重复读、幻读)
  10. 关于Timestamp的valueOf()方法
  11. Jenkins添加注册用户默认权限/Add a default authenticated user role
  12. 41. Understand implicit interfaces and compile-time polymorphism
  13. python监听键盘事件pyhook用法_python 监听键盘事件pyHook
  14. Android-adb获取当前前台进程
  15. 致敬柳传志三网合一的佳沃品牌之路
  16. vs2010环境下wincap的配置
  17. decimals数据格式化
  18. 港科夜闻|全国政协副主席梁振英先生率香港媒体高管团到访香港科大(广州)...
  19. 多功能跑步机外观及结构设计
  20. AVB源码学习(一):AVB2.0工作原理及编译配置

热门文章

  1. Elasticsearch 7.x Nested 嵌套类型查询 | ES 干货
  2. 《软件工程》团队第一阶段Sprint检查表
  3. IDEA2018全局搜索中搜索jar包/lib
  4. nginx root与alias区别
  5. 图解clientWidth,offsetWidth,scrollWidth,scrollTop
  6. Java Beanutils 配置
  7. POJ1546(进制转换)
  8. SVN服务器端安装过程出现“Custom action InstallWMISchemaExecute failed:无法启动服务,原因可能是已被禁用或与其相关联的设备没有启动。”
  9. vue中this.$set的用法
  10. JavaScript的7个位运算符