©作者 | 小舟、陈萍

来源 | 机器之心

来自多伦多大学和斯坦福大学的研究者开发了一种在连续深度贝叶斯神经网络中进行近似推理的实用方法

把神经网络的限制视为无限多个残差层的组合,这种观点提供了一种将其输出隐式定义为常微分方程 ODE 的解的方法。连续深度参数化将模型的规范与其计算分离。虽然范式的复杂性增加了,但这种方法有几个好处:(1)通过指定自适应计算的容错,可以以细粒度的方式用计算成本换取精度;(2)通过及时运行动态 backward 来重建反向传播所需中间状态的激活函数,可以使训练的内存成本显著降低。

另一方面,对神经网络的贝叶斯处理改动了典型的训练 pipeline,不再执行点估计,而是推断参数的分布。虽然这种方法增加了复杂性,但它会自动考虑模型的不确定性——可以通过模型平均来对抗过拟合和改进模型校准,尤其是对于分布外数据。

近日,来自多伦多大学和斯坦福大学的一项研究表明贝叶斯连续深度神经网络的替代构造具有一些额外的好处,开发了一种在连续深度贝叶斯神经网络中进行近似推理的实用方法。该论文的一作是多伦多大学 Vector Institute 的本科学生 Winnie Xu,二作是 NeurIPS 2018 最佳论文的一作陈天琦,他们的导师 David Duvenaud 也是论文作者之一。

论文地址:

https://arxiv.org/pdf/2102.06559.pdf

项目地址:

https://github.com/xwinxu/bayesian-sde

具体来说,该研究考虑了无限深度贝叶斯神经网络每层分别具有未知权重的限制,提出一类称为 SDE-BNN(SDE- Bayesian neural network )的模型。该研究表明,使用 Li 等人(2020)描述的基于可扩展梯度的变分推理方案可以有效地进行近似推理。

在这种方法中,输出层的状态由黑盒自适应随机微分方程(SDE 求解器计算,并训练模型以最大化变分下界。下图将这种神经 SDE 参数化与标准神经 ODE 方法进行了对比。这种方法保持了训练贝叶斯神经 ODE 的自适应计算和恒定内存成本。

无限深度贝叶斯神经网络(BNN)

标准离散深度残差网络可以被定义为以下形式的层的组合:

其中 t 是层索引,表示 t 层隐藏单元激活向量,输入 h_0 = x,表示 t 层的参数,在离散设置中该研究通过设置并将极限设为来构建残差网络的连续深度变体。这样产生一个微分方程,该方程将隐藏单元进化描述为深度 t 的函数。由于标准残差网络每层使用不同的权重进行参数化,因此该研究用 w_t 表示第 t 层的权重。此外该研究还引入一个超网络(hypernetwork) f_w,它将权重的变化指定为深度和当前权重的函数。然后将隐藏单元激活函数的进化和权重组合成一个微分方程:

权重先验过程:该研究使用 Ornstein-Uhlenbeck (OU) 过程作为权重先验,该过程的特点是具有漂移(drift)和弥散(diffusion)的 SDE:

权重近似后验使用另一个具有以下漂移函数的 SDE 隐式地进行参数化:

然后该研究在给定输入下评估了该网络需要边缘化权重和隐藏单元轨迹(trajectory)。这可以通过简单的蒙特卡罗方法来完成,从后验过程中采样权重路径 {w_t},并在给定采样权重和输入的情况下评估网络激活函数 {h_t}。这两个步骤都需要求解一个微分方程,两步可以通过调用增强状态 SDE 的单个 SDE 求解器同时完成:

为了让网络拟合数据,该研究最大化由无限维 ELBO 给出的边缘似然(marginal likelihood)的下限:

采样权重、隐藏激活函数和训练目标都是通过一次调用自适应 SDE 求解器同时计算的。

减小方差的梯度估计

该研究使用 STL(sticking the landing) 估计器来替换 path 空间 KL 中的原始估计器以适应 SDE 设置:

等式 (12) 中的第二项是鞅(martingale),期望值为零。在之前的工作中,研究者仅对第一项进行了蒙特卡罗估计,但该研究发现这种方法不一定会减少梯度的方差,如下图 4 所示。

因为该研究提出的近似后验可以任意表达,研究者推测如果参数化网络 f_w 的表达能力足够强,该方法可在训练结束时实现任意低的梯度方差。

图 4 显示了多个梯度估计器的方差,该研究将 STL 与「完全蒙特卡罗(Full Monte Carlo)」估计进行了比较。图 4 显示,当匹配指数布朗运动时,STL 获得的方差比其他方案低。下表 4 显示了训练性能的改进。

实验

该研究的实验设置如下表所示,该研究在 MNIST 和 CIFAR-10 上进行了 toy 回归、图像分类任务,此外他们还研究了分布外泛化任务:

为了对比求解器与 adjoint 的反向传播,研究者比较了固定和自适应步长的 SDE 求解器,并比较了 Li 等人提出的随机 adjoint 之间的比较, 图 5 显示了这两种方法具有相似的收敛性:

1D 回归

该研究首先验证了 SDE-BNN 在 1D 回归问题上的表现。以弥散过程的样本为条件,来自 1D SDE-BNN 的每个样本都是从输入到输出的双向映射。这意味着从 1D SDE-BNN 采样的每个函数都是单调的。为了能够对非单调函数进行采样,该研究使用初始化为零的 2 个额外维度来增加状态。图 2 显示了模型在合成的非单调 1D 数据集上学习了相当灵活的近似后验。

图像分类

表 1 给出了图像分类实验的结果。SDE-BNN 通常优于基线,由结果可得虽然连续深度神经 ODE (ODEnet) 模型可以在标准残差网络上实现类似的分类性能,但校准(calibration)较差。

图 6a 展示了 SDE-BNN 的性能,图 6b 显示具有相似准确率但比神经 ODE 校准更好的结果。

表 1 用预期校准误差量化了模型的校准。SDE-BNN 似乎比神经 ODE 和平均场 ResNet 基线能更好地校准。

下图 7 显示了损坏测试集上相对于未损坏数据的误差,表明随着扰动严重性级别的增加以及表 1 中总结的总体误差度量,mCE 稳步增加。在 CIFAR10 和 CIFAR10-C 上,SDE-BNN 和 SDE -BNN + STL 模型实现了比基线更低的整体测试误差和更好的校准。

与标准基线(ResNet32 和 MF ResNet32)相比,SDE-BNN 的绝对损坏误差(CE)降低了约 4.4%。域外输入的学习不确定性的有效性表明,尽管没有在多种形式的损坏上进行训练,但 SDE-BNN 对观测扰动也更加稳健。

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

结合随机微分方程,多大Duvenaud团队提出无限深度贝叶斯神经网络相关推荐

  1. 实验一:贝叶斯神经网络及其如何用随机梯度马尔可夫链蒙特卡洛有效训练

    0.实验环境搭建: 源代码获取: 来源一:google 来源二:web 来源三:github 环境: conda create --name python36_google_deep python=3 ...

  2. TPAMI 2021 华为诺亚悉尼大学陶大程团队提出多功能卷积,助力轻量级网络

    关注公众号,发现CV技术之美 0 写在前面 在本文中,作者提出了一种用于构造高效卷积神经网络的多功能滤波器 ,并应用于各种视觉识别任务中.考虑到硬件上运行高效的深度学习模型的需求,研究者们已经开发了许 ...

  3. 基于6种监督学习(逻辑回归+决策树+随机森林+SVM+朴素贝叶斯+神经网络)的毒蘑菇分类

    公众号:尤而小屋 作者:Peter 编辑:Peter 大家好,我是Peter~ 本文是kaggle案例分享的第3篇,赛题的名称是:Mushroom Classification,Safe to eat ...

  4. 谷歌Jeff Dean团队提出利用深度学习对「电子健康记录」数据进行分析,可提高医疗诊断预测的准确性

    原文来源:arXiv 作者:Alvin Rajkomar.Eyal Oren.Kai Chen.Andrew M. Dai.Nissan Hajaj.Peter J. Liu.Xiaobing Liu ...

  5. 录音降噪哪家强?搜狗西工大联合团队DNS挑战赛夺冠

    边策 发自 凹非寺  量子位 报道 | 公众号 QbitAI 近日,全球语音顶级会议Interspeech 2020公布了"深度降噪挑战赛"(Deep Noise Suppress ...

  6. 清华团队通过监督贝叶斯嵌入,对单细胞染色质可及性数据进行细胞类型注释...

    本文约3200字,建议阅读9分钟 本文介绍了清华团队在单细胞技术的最新进展. 单细胞技术的最新进展使得能够在细胞水平上表征表观基因组异质性.鉴于细胞数量呈指数增长,迫切需要用于自动细胞类型注释的计算方 ...

  7. 实战: 对GBDT(lightGBM)分类任务进行贝叶斯优化, 并与随机方法对比

    目录: 一. 数据预处理 1.1 读取&清理&切割数据 1.2 标签的分布 二. 基础模型建立 2.1 LightGBM建模 2.2 默认参数的效果 三. 设置参数空间 3.* 参数空 ...

  8. UTA研究团队提出首个3D点云+GAN新方法,让机器人“眼神”更犀利 | AI日报

    韩国NAVER AI LAB重新标注128万张ImageNet图片:多标签,全面提升模型性能 ImageNet是机器学习社区最流行的图像分类基准数据集,包含超过1400张标注图像.该数据集由斯坦福教授 ...

  9. ICCV 华人团队提出会创作的Paint Transformer,网友反驳:这也要用神经网络?

    来源:新智元 [导读]神经网络相关论文逐渐变成 AI 领域的主流,但那些任务真的需要神经网络这个技术吗?最近ICCV上一篇文章在reddit上分享后引发热议,网友吐槽最多的就是:明明50行代码就能搞定 ...

最新文章

  1. R语言广义线性模型Logistic回归案例代码
  2. Android 程序适应多种多分辨率
  3. LOL手游诺手对线技巧,上分率提高60%,战神玩家推荐玩法
  4. 数据库-优化-pt-query-digest使用简介
  5. Oracle存储过程快速入门
  6. Android控件学习笔记之 ListView
  7. 递推DP URAL 1586 Threeprime Numbers
  8. whoami 显示“我是谁”
  9. android 阿拉伯语下布局,android设计的布局在阿拉伯语下界面错乱的解决方法
  10. Postman测试Soap协议接口
  11. Mac Chrome搜索引擎突然变成了Yahoo?!SearchToolHelper控制了我的搜索引擎
  12. 学习TypeScrip3(接口和对象类型)
  13. 什么是https证书,有什么优势?
  14. “掌上迎新”,这个学校把5400+新生安排的明明白白
  15. compare和compareTo方法的区别
  16. word “域” 插入图片目录
  17. 学习uni-app记录
  18. 力扣第314周赛第三题
  19. 通过J-Flash回读取芯片的固件程序
  20. 今年的双11,真的是太太太太太太香了!万元好礼回馈......

热门文章

  1. linux 核间通讯rpmsg架构分析
  2. oracle11关闭账户验证,Windows下Oracle11g中使用外部操作系统账户验证
  3. jedis set集合 java,使用Jedis操作String、List、Set、Map等常见数据 | zifangsky的个人博客...
  4. a类学科计算机,最全名单来了!上海交大25个学科获评A类学科
  5. linux操作系统原理_Linux内核分析-操作系统是如何工作的(二)
  6. 6个超炫酷的HTML5电子书翻页动画【转】
  7. 外刊晨读 2018 年 年 5 月 月 15 日
  8. opentesty--luasocket 安装
  9. CentOS7 安装NFS SSH免密码登陆
  10. Linux内核--网络协议栈深入分析(二)--sk_buff的操作函数