Mixture Density Networks

最近看论文经常会看到在模型中引入不确定性(Uncertainty)。尤其是MDN(Mixture Density Networks)在World Model这篇文章多次提到。之前只是了解了个大概。翻了翻原版论文和一些相关资料进行了整理。

1. 直观理解:

混合密度网络通常作为神经网络的最后处理部分。将某种分布(通常是高斯分布)按照一定的权重进行叠加,从而拟合最终的分布。

如果选择高斯分布的MDN,那么它和GMM(高斯混合模型 Gaussian Mixture Model)有着相同的效果。但是他们有着很明显的区别:

  • MDN的均值方差每个模型的权重是通过神经网络产生的,利用最大似然估计作为Loss函数进行反向传播从而确定网络的权重(也就是确定一个较好的高斯分布参数)

  • GMM的均值方差每个模型的权重是通过估计出来的,通常使用EM算法来通过不断迭代确定。

    GMM的详解以及为什么要用EM而不是极大似然估计来优化参数,请见这个博客

总之,MDN的思想与GMM一样,将模型混合的思想与神经网络相结合。在回归问题上通常都有很好的表现。例如,论文中提到的一个翻转的x,t翻转的例子:

  1. 如果x是训练数据,t是我们的label:

    普通的神经网络,使用sum-of-squares error作为loss可以得到一个较好的拟合效果。

  2. 同样的数据,将x和t的数据翻转(原来x的数据作为标签,原来t的数据作为训练集, tmp = x, x = t, t = tmp):

    使用sum-of-squares error作为loss似乎并没有捕捉到我们的走势。

  3. MDN效果如何呢

    先上效果图(来自原版论文)。下图绘制的是可能性最大的点(分布的均值)。可见基本上可以捕捉到这个趋势。

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oPgn4RpM-1605340386543)(Untitled.assets/image-20201114140657278.png)]

    在输出的分布内进行采样获取预测,图片来自:

2. 算法细节

2.1. 结构

参数化表示:

CCC :要混合的分布个数。是用户需要制定的参数。例如我们需要混合5个高斯分布作为最终结果,那么C = 5;

α\alphaα :每个分布的权重参数。网络输出的参数

DDD: 某一种被混合的分布, 如果是高斯分布,那么KaTeX parse error: Undefined control sequence: \cal at position 1: \̲c̲a̲l̲ ̲D 就应该用 NNN表示。

λ\lambdaλ:分布的一些参数,高斯分布则包括μ\muμ和σ\sigmaσ。网络输出的参数

需要注意的是:混合的分布可以是任意的。

以高斯分布为例,网络结构如下:

  • α\alphaα (alpha)的和应该等于1,即∑cCαc=1\sum^{C}_{c} \alpha_c = 1∑cC​αc​=1。 所以我们可以在使用softmax激活函数来解决。
  • σ\sigmaσ(sigma)>0。 可以保证这个的方法有很多,在Mixture Density Networks中使用指数激活:σ=exp(z)\sigma = exp(z)σ=exp(z)。指数可能会引起数值不稳定,出现无穷大。可以使用变种的ELU [3],即σ=ELU(σ)+1\sigma = ELU(\sigma)+1σ=ELU(σ)+1
  • μ\muμ 的范围是否要确定区间,可以根据实际问题。例如价格预测,不可能出现负的,就可以选择相关的激活函数来固定区间大于0.

2.2 Loss设计:

损失函数使用的极大似然估计。极大似然估计认为我们采样出来的都是那些出现概率最大的数。所以我们希望我们需要最大化的似然函数为(这里使用了平均值,即每个分布的似然函数大小):

极大似然估计公式:L(θ)=L(x1,x2...xn;θ)=∏i=1np(xi;θ)L(\theta) = L(x_1,x_2...x_n ; \theta) = \prod_{i = 1 } ^n p(x_i; \theta)L(θ)=L(x1​,x2​...xn​;θ)=∏i=1n​p(xi​;θ)。用多个分布混合,则p(xi;θ)=∑kKakpk(xi;θ)p(x_i;\theta) = \sum_k ^K a_k p_k(x_i ; \theta)p(xi​;θ)=∑kK​ak​pk​(xi​;θ)。 下式中 xix_ixi​为yn∣xny_n|x_nyn​∣xn​

L(θ)=1N∏nN∑kKakpk(yn∣xn)ln(L(θ))=1N∑nNlog⁡{∑kKαkpk(yn∣xn)}L(\theta) = \frac{1}{N} \prod_n ^N \sum_k ^K a_k p_k(y_n|x_n) \\ ln(L(\theta)) =\frac{1}{N} \sum_n ^N \log \{ \sum_k ^K \alpha_k p_k(y_n|x_n)\} L(θ)=N1​n∏N​k∑K​ak​pk​(yn​∣xn​)ln(L(θ))=N1​n∑N​log{k∑K​αk​pk​(yn​∣xn​)}

N 样本总数

K 分布的数量

aka_kak​ 是当前分布的权重

pkp_kpk​ 是当前分布的概率

$ \sum_k ^K a_k p_k(y_n|x_n)$ 就是xnx_nxn​样本出现的概率。对应似然函数中的p(xi;θ)p(x_i; \theta)p(xi​;θ)。 是k个分布按照权重α\alphaα累加的结果。

优化器一般都是梯度下降,用来最小化目标函数,所以我们要在上式加一个负号,作为优化函数,这样就是梯度上升最大化上式。
Loss(θ)=−ln(L(θ))Loss(\theta) = -ln(L(\theta)) Loss(θ)=−ln(L(θ))
如果是N个高斯分布,那么我们的损失函数:
Loss(θ)=−1N∑1Nlog⁡{∑kαkN(yn∣μk,σk2)}Loss(\theta) = -\frac{1}{N} \sum_1 ^N \log \{\sum_k \alpha_k N(y_n|\mu_k,\sigma^2_k)\} Loss(θ)=−N1​1∑N​log{k∑​αk​N(yn​∣μk​,σk2​)}

N(y∣μ,σ2)=12πσ2e−(x−μ)22σ2N(y|\mu,\sigma^2) = \frac{1}{\sqrt{2 \pi \sigma^2}} e^{\frac{-(x-\mu)^2}{2\sigma^2}} N(y∣μ,σ2)=2πσ2​1​e2σ2−(x−μ)2​

3. 总结

MDN实现简单,而且可以直接模块化的连接到神经网络的后端。他的结果可以得到一个概率范围,相对有deterministic类只输出一个结果,往往有更好的健壮性。[3][4]中有相关代码实现。

4. reference:

[1]. Christopher M. Bishop, Mixture Density Networks (1994)

[2]. Blog-详解EM算法与混合高斯模型(Gaussian mixture model, GMM)

[3]. Blog-A Hitchhiker’s Guide to Mixture Density Networks

[4]. Blog-Mixture Density Networks

论文阅读23 - Mixture Density Networks(MDN)混合密度网络理论分析相关推荐

  1. Graph Mixture Density Networks 图混合密度网络

    是一类新的机器学习模型,可以适应条件为任意拓扑图的多模态输出分布.通过结合混合模型和图表示学习的思想,我们解决了一类更广泛的依赖结构化数据的具有挑战性的条件密度估计问题.我们在一个利用随机图进行随机流 ...

  2. [论文阅读] (23)恶意代码作者溯源(去匿名化)经典论文阅读:二进制和源代码对比

    <娜璋带你读论文>系列主要是督促自己阅读优秀论文及听取学术讲座,并分享给大家,希望您喜欢.由于作者的英文水平和学术能力不高,需要不断提升,所以还请大家批评指正,非常欢迎大家给我留言评论,学 ...

  3. Multiple Object Tracking with Mixture Density Networks for Trajectory Estimation 详细解读

    文章目录 简介(abstract) 介绍(introduction) 相关工作(Related Work) 轨迹预测(Trajectory Estimation) Mixture Density Ne ...

  4. 论文阅读2018-Deep Convolutional Neural Networks for breast cancer screening 重点:利用迁移学习三个网络常规化进行分类

    论文阅读2018-Deep Convolutional Neural Networks for breast cancer screening 摘要:我们探讨了迁移学习的重要性,并通过实验确定了在训练 ...

  5. 论文阅读:Regularizing Deep Networks with Semantic Data Augmentation

    论文阅读:Regularizing Deep Networks with Semantic Data Augmentation 动机 特征空间的语义变换 Implicit semantic data ...

  6. 【论文阅读】深度强化学习的攻防与安全性分析综述

    文章目录 一.论文信息 二.论文结构 三.论文内容 摘要 1 深度强化学习方法 2 深度强化学习的攻击方法 2.1 基于观测的攻击 4 深度强化学习的安全性分析 5 应用平台与安全性评估指标 5.1 ...

  7. 使用Pytorch简单实现混合密度网络(Mixture Density Network, MDN)

    本文主要参考自: https://github.com/sksq96/pytorch-mdn/blob/master/mdn.ipynb https://blog.otoro.net/2015/11/ ...

  8. 论文阅读-VERY DEEP CONVOLUTIONAL NETWORKS FOR LARGE-SCALE IMAGE RECOGNITION

    作者: Karen Simonyan et al. 日期: 2015 类型: conference article 来源: ICLR 评价: veyr deep networks 论文链接: http ...

  9. 论文阅读——Quantizing deep convolutional networks for efficient inference: A whitepaper

    Quantizing deep convolutional networks for efficient inference: A whitepaper Abstract 本文针对如何对卷积神经网络的 ...

最新文章

  1. 按树型显示BOM的结构
  2. gitpythonapi_GitPython 使用基础
  3. FewRel 2.0数据集:以近知远,以一知万,少次学习新挑战
  4. SpringBoot基础篇Bean之条件注入之注解使用
  5. Ubuntu将在明年推出平板及手机系统
  6. Docker-compose 安装Minio 最新版本
  7. 头条终面:写个消息中间件
  8. impdb导入oracle,impdp导入.dmp到oracle
  9. powerdesigner辅助导入导出excel文件
  10. hive体系架构以及各个组件的作用
  11. 【优化分类】基于matlab遗传算法优化支持向量机分类(多输入多分类)【含Matlab源码 QF003期】
  12. 大学计算机课程复习--汇编语言
  13. 官方解决方案:WPS for Mac 云字体删除的问题,Mac版WPS已下载云字体无法删除的问题
  14. Verilog学习 | 数字下变频与脉冲压缩的综合仿真与硬件实现
  15. 2021-01-16
  16. pd对焦速度_PDAF对焦技术原理解析及生产应用
  17. ElasticSearch for GIS应用
  18. h3c 抓包么 能通过debug_H3C debugging 使用技巧
  19. P2916 [USACO08NOV]Cheering up the Cow G 题解
  20. 语音/视频转文字的工具选择它-不仅仅是好用还免费

热门文章

  1. 网络安全笔记--文件上传1(文件上传基础、常见后端验证、黑名单、白名单、后端绕过方式)
  2. Excel基础教程(1)
  3. 数据结构也不是那么没意思之中序二叉树+二叉树转伪双向循环链表
  4. 谈下我曾经做过的一个心理健康管理系统
  5. android qq截屏快捷键是什么,手机截屏的快捷键是什么,超过3种截图的快捷键操作方法!...
  6. python怎么实现模糊找色_Python下尝试实现图片的高斯模糊化
  7. YOLOv5的参数IOU与PR曲线,F1 score
  8. 优秀课程案例:使用Scratch制作扫雷插旗排雷完整版
  9. BLE(11)—— 细说 Initiating
  10. 网易云信第三方接口调用超详细Demo