1. 简介

贝叶斯神经网络不同于一般的神经网络,其权重参数是随机变量,而非确定的值。如下图所示:

也就是说,和传统的神经网络用交叉熵,mse等损失函数去拟合标签值相反,贝叶斯神经网络拟合后验分布。

这样做的好处,就是降低过拟合。

2. BNN模型

BNN 不同于 DNN,可以对预测分布进行学习,不仅可以给出预测值,而且可以给出预测的不确定性。这对于很多问题来说非常关键,比如:机器学习中著名的 Exploration & Exploitation (EE)的问题,在强化学习问题中,agent 是需要利用现有知识来做决策还是尝试一些未知的东西;实验设计问题中,用贝叶斯优化来调超参数,选择下一个点是根据当前模型的最优值还是利用探索一些不确定性较高的空间。比如:异常样本检测,对抗样本检测等任务,由于 BNN 具有不确定性量化能力,所以具有非常强的鲁棒性

概率建模:

在这里,选择似然分布的共轭分布,这样后验可以分析计算。 比如,beta分布的先验和伯努利分布的似然,会得到服从beta分布的后验。

由于共轭分布,需要对先验分布进行约束。因此,我们尝试使用采用和变分推断来近似后验分布。


神经网络: 使用全连接网络来拟合数据,相当于使用多个全连接网络。 但是神经网络容易过拟合,泛化性差;并且对预测的结果无法给出置信度。

BNN: 把概率建模和神经网络结合起来,并能够给出预测结果的置信度。

先验用来描述关键参数,并作为神经网络的输入。神经网络的输出用来描述特定的概率分布的似然。通过采样或者变分推断来计算后验分布。 同时,和神经网络不同,权重 W 不再是一个确定的值,而是一个概率分布。


BNN建模如下:

假设 NN 的网络参数为

是参数的先验分布,给定观测数据
,这里
是输入数据,
是标签数据。BNN 希望给出以下的分布:

也就是我们预测值为:

由于,

是随机变量,因此,我们的预测值也是个随机变量。

其中:

这里

是后验分布,
是似然函数,
是边缘似然。

从公式(1)中可以看出,用 BNN 对数据进行概率建模并预测的核心在于做高效近似后验推断,而 变分推断 VI 或者采样是一个非常合适的方法。

如果采样的话: 我们通过采样后验分布

来评估
, 每个样本计算
, 其中 f 是我们的神经网络。

正是我们的输出是一个分布,而不是一个值,我们可以估计我们预测的不确定度。

3. 基于变分推断的BNN训练

如果直接采样后验概率

来评估
的话,存在后验分布多维的问题,而变分推断的思想是使用简单分布去近似后验分布。

表示

, 每个权重
从正态分布
中采样。

希望

相近,并使用 KL 散度来度量这两个分布的距离。 也就是优化:

进一步推导:

公式中,

表示给定正态分布的参数后,权重参数的分布;
表示给定网络参数后,观测数据的似然;
表示权重的先验,这部分可以作为模型的正则化。

并且使用

来表示变分下界ELBO, 也就是公式(4)等价于最大化ELBO:

其中,

我们需要对公式(4)中的期望进行求导,但是,这里,我们使用对权重进行重参数的技巧:

其中,

.

于是,用

代 替
后有:

也就是说,我们可以通过 多个不同的

,求取
的平均值,来近似 KL 散度对
的求导。

此外,除了对

进行重采样之外,为了保证
参数取值范围包含这个实轴,对
进行重采样,可以令,

然后,

,这里的
已经和原来定义的
不一样了。

4. BNN实践

算法:

  1. 中采样,获得
  2. 分别计算
    . 其中,计算
    实际计算
    ,
    . 也就可以得到
  3. 重复更新参数
    .

Pytorch实现:

import 

这里是重复计算100次的平均值和100次平均值的97.5%大和2.5%小的区域线图(即置信度95%)。


参考:

  1. 变分推断;
  2. Weight Uncertainty in Neural Networks Tutorial;
  3. Bayesian Neural Networks;

优化概率神经网络_贝叶斯神经网络BNN(推导+代码实现)相关推荐

  1. pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接

    pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...

  2. 贝叶斯优化python包_贝叶斯优化

    万壑松风知客来,摇扇抚琴待留声 1. 文起 本篇文章记录通过 Python 调用第三方库,从而调用使用了贝叶斯优化原理的 Hyperopt 方法来进行超参数的优化选择.具体贝叶斯优化原理与相关介绍将在 ...

  3. 贝叶斯优化python包_贝叶斯全局优化(LightGBM调参)

    这里结合Kaggle比赛的一个数据集,记录一下使用贝叶斯全局优化和高斯过程来寻找最佳参数的方法步骤. 1.安装贝叶斯全局优化库 从pip安装最新版本 pip install bayesian-opti ...

  4. 基于变化点 copula 优化算法中的贝叶斯研究(Matlab代码实现)

  5. 贝叶斯深度神经网络_深度学习为何胜过贝叶斯神经网络

    贝叶斯深度神经网络 Recently I came across an interesting Paper named, "Deep Ensembles: A Loss Landscape ...

  6. ​贝叶斯神经网络最新综述

    ©PaperWeekly 原创 · 作者|尹娟 学校|北京理工大学博士生 研究方向|随机过程.复杂网络 论文标题:Bayesian Neural Networks: An Introduction a ...

  7. 贝叶斯神经网络计算核裂变碎片产额

    作者丨庞龙刚 单位丨华中师范大学 研究方向丨高能核物理.人工智能 今天介绍一篇北京大学物理系使用贝叶斯神经网络计算核裂变碎片产额的文章.这篇文章发表在 PRL 上,业内同行都很感兴趣.这里对我们大同行 ...

  8. 优化概率神经网络_Bayesian Neural Networks:贝叶斯神经网络

    贝叶斯神经网络,简单来说可以理解为通过为神经网络的权重引入不确定性进行正则化(regularization),也相当于集成(ensemble)某权重分布上的无穷多组神经网络进行预测. 本文主要基于 C ...

  9. 贝叶斯神经网络BNN

    反向传播网络在优化完毕后,其权重是一个固定的值,而贝叶斯神经网络把权重看成是服从均值为 μ ,方差为 δ 的高斯分布,每个权重服从不同的高斯分布,反向传播网络优化的是权重,贝叶斯神经网络优化的是权重的 ...

  10. 07. 贝叶斯神经网络

    算法思路 普通的神经网络的权值是确定的,而贝叶斯神经网络的权值是不确定的,他服从于一个概率分布,这便是贝叶斯神经网络和普通神经网络的差别. 可以简单认为,贝叶斯神经网络是无穷个神经网络的融合,不过给每 ...

最新文章

  1. 统计学习方法笔记(七)-线性支持向量机原理及python实现
  2. 最短路径-Dijkstra算法与Floyd算法
  3. mysql按照datetime精确查询_MySQL datetime字段查询按小时:分钟排序
  4. PAT甲级题目翻译+答案 AcWing(哈希表)
  5. response.setHeader()的用法
  6. lintcode 中等题:A + B Problem A + B 问题
  7. ad19原理图标注_AD19中原理图的模板如何进行编辑?
  8. 16速 java_不停歇的 Java 即将发布 JDK 16,新特性速览!
  9. git-bug分支-git-stash-工作代码与bug解决同时处理时解决模拟
  10. alisql mysql_AliSQL 5.6.32 vs MySQL 5.7.15抢鲜测试
  11. MT7621完美支持32M SPI Flash(W25Q256) 修复 soft reset fail
  12. iconfont添加新图标_IconFont图标引用的方法步骤(代码)
  13. LOL英雄联盟首页以及攻略页面制作
  14. Office2010安装出错1935
  15. java倒计时器_Java并发系列5--倒计时器CountDownLatch
  16. Gebru被辞退的背后真相:指出BERT的4大危害,威胁谷歌商业利益
  17. Flask教程(十九)SocketIO
  18. Map接口以及那些实现类
  19. 稳定智能的在线考试系统
  20. 如何下载并安装Firebug插件

热门文章

  1. 第八章第五题(代数:两个矩阵相加)(Algebra: adding two matrices)
  2. c语言临时内存变量释放,C语言中的内存分配与释放
  3. LGP970刷机心得
  4. windows2003 php 加速,window_Win 2003 加速****,微软的Windown Server 2003尽管它是 - phpStudy...
  5. MATLAB | MATLAB配色不够用 全网最全的colormap补充包来啦
  6. qq飞车手游服务器找不到了,QQ飞车手游服务器拉取失败是怎么回事
  7. 关于基本勾股数规律的探讨总结与例题!
  8. 用php求勾股数,勾股数(示例代码)
  9. php通过imap获取腾讯企业邮箱信息后的解码处理
  10. 【Android】关于WIFI局域网的手机摄像头当视频监控用实现方案详解