2019-12-30 13:04:12

人工智能顶会 ICLR 2020 将于明年 4 月 26 日于埃塞俄比亚首都亚的斯亚贝巴举行,不久之前,大会官方公布论文接收结果:在最终提交的 2594 篇论文中,有 687 篇被接收,接收率为 26.5%。本文介绍了华为诺亚方舟实验室被 ICLR 2020 接收的一篇满分论文。

论文地址:https://arxiv.org/pdf/1906.04477.pdf

因果研究作为下一个潜在的热点,已经吸引了机器学习/深度学习领域的的广泛关注,例如 Youshua Bengio 和 Fei-Fei Li 近期都有相关的工作。因果研究中一个经典的问题是「因果发现」问题——从被动可观测的数据中发现潜在的因果图结构。

在此论文中,华为诺亚方舟实验室因果研究团队将强化学习应用到打分法的因果发现算法中,通过基于自注意力机制的 encoder-decoder 神经网络模型探索数据之间的关系,结合因果结构的条件,并使用策略梯度的强化学习算法对神经网络参数进行训练,最终得到因果图结构。在学术界常用的一些数据模型中,该方法在中等规模的图上的表现优于其他方法,包括传统的因果发现算法和近期的基于梯度的算法。同时该方法非常灵活,可以和任意的打分函数结合使用。

模型定义和问题

我们假设以下常用的数据生成模型:给定一个有向无环图(DAG),每个节点对应一个随机变量,每个变量的观测值是图中父亲变量的函数加上一个独立的噪声,即

这里噪声 n_i 是联合独立的。如果所有的函数都是线性的且噪声是高斯的,则上述模型为标准的线性高斯模型。当函数为线性但噪声为非高斯函数时,上述模型为线性非高斯加性模型(LiNGAM),在一定的条件下是可以识别出真实的 DAG。

我们目前考虑所有的变量都是一维的实变量;给定一个合适的打分函数则可以直接扩展到多维变量的情形。在固定的函数和噪声分布下,我们的观测数据是根据上述模型在某个未知的 DAG 上独立采样得到。因果发现的目的就是使用这些观测的数据来推断真实的因果 DAG。

背景介绍

打分法是因果发现算法中一类常用的方法:给每个有向图打分(通常基于观测数据计算得到),然后在所有的 DAG 中进行搜索取得最好分数的 DAG:

尽管有很多已经深入研究的打分函数,例如基于线性高斯模型的 BIC/MDL 和 BGe 分数,但上述问题通常是 NP-hard 的,因为 DAG 条件是一个组合问题,并且可能的 DAG 数量的随着图节点的个数增加而超指数增加。为了解决这个问题,大多数已有方法都依赖于局部启发式算法。

例如,贪婪等价搜索(GES)在添加一条边时显式检查 DAG 约束是否满足。GES 在适当的假设和极限数据量的情况下可以找到具全局最优值,但在有限样本的情况下无法得到保证。

最近,也有工作在线性数据模型上对上述的无环条件提出了一个等价的可微分函数,再选择适当的损失函数(例如最小二乘损失),上述问题可以转换为关于带权值的邻接矩阵的连续优化问题。后续的工作也采用 ELBO 和 negative log-likelihood 作为损失函数,并使用神经网络对因果关系进行建模。但是很多已有的得分函数没有显式的表示或者是非常复杂的等价损失函数,这样和上述连续的方法结合会比较困难。

基于强化学习的因果发现算法

我们提出一种基于 RL 的方法来搜索 DAG,整体框架图如下所示。基于随机策略的 RL 可以在给定策略的不确定性信息的情况下自动确定要搜索的位置,同时可以通过奖励信号来及时更新。在合成数据集和真实数据集上的实验表明,基于强化学习的方法大大提高了搜索能力,并且不会影响打分函数的选择。

基于自注意力机制的 Encoder-Decoder 模型

如上图所示,我们采用 Transfomer 中基于自注意机制的 encoder,而 decoder 则是通过建立成对的 encoder 输出之间的关系来生成图的邻接矩阵。为了得到 0-1 的邻接矩阵,我们将每个 decoder 的输出通过 logistic-sigmoid 函数,然后使用 Bernoulli 分布进行采样。

我们也尝试了其他的 decoder,例如 bilinear model 以及 Transformer 中的 decoder。我们实验发现上图中 decoder 的效果最好,可能是因为它的参数量比较少、更容易训练来找到更好的 DAG,而基于自注意力机制的 encoder 已经提供了足够的交互来探索数据之间的因果关系。

Reward

传统的 GES 会在每次添加一条边时显式的检查图是否有环,我们使用打分函数和基于有环性质的惩罚项来设计 reward,并允许生成的图在每次迭代中变化多条边。具体的形式如下:

其中第一项是得分函数,用于衡量给定有向图和观测数据的匹配程度,其他两个正项则衡量某些「DAGness」(给定的有向图距无环的某种度量,例如所有环上的长度之和),lambda_1 和 lamba_2 是惩罚项的权重。通过选择适当的惩罚权重,最大化 reward 等价于之前打分法的问题的形式。但是两个问题等价并不意味着使用 RL 来最大化 reward 就可以直接取得很好的结果:实际中,我们发现较大的惩罚权重可能会妨碍 RL 的探索,得到的因果图的得分通常比较差,而较小的惩罚值将导致有环的图。同时,不同的打分函数可能具有非常不同的范围,而两个惩罚项的值与打分函数是没有关系的。因此,我们将所有的打分函数调整到一定范围,并为惩罚权重设计一种在线更新策略。详细内容可以参见论文的第 5 章。

Actor-Critic 优化参数

我们采用策略梯度和随机优化的方法来优化以下目标:

其中 A 中有向图对应的 0-1 邻接矩阵。我们使用 Actor-Critic 来进行训练,同时还加了熵正则项来鼓励探索。尽管策略梯度方法仅在一定条件下能保证局部收敛,但是通过惩罚项系数的设计,在我们的实验中 RL 算法得到的图都是无环的。

最终输出

由于我们关注的是寻找得分最好的 DAG,而不是 policy,因此我们记录了训练过程中生成的所有的有向图,并选择具有最佳 reward 的图作为输出结果。实际上由于有限的数据,图中会包含一些真图里边不存在的边,因此需要进一步的减枝处理。

我们可以根据损失函数或者打分函数,使用贪婪方法来进行减枝操作。我们删除一个父亲变量并计算相应的结果,如果损失函数或者打分函数效果没有变差或者是在预先设定的范围内,就接受减枝的操作并继续下去。对于线性模型,可以通过和阈值比较的方法来进行减枝。

实验结果
在此工作中,我们使用 BIC 打分函数,并假设附加性的高斯噪声(实际中噪声可能是非高斯的)。考虑两种情况:不同的噪声方差,等价于 negative log-likelihood 加上一个对边的个数的惩罚项作为打分函数;以及相等的噪声方差,将得到最小平方损失加上边的个数的惩罚项。它们分别表示为 RL-BIC 和 RL-BIC2。
我们的方法与传统方法(PC,GES,ICA-LiNGAM 和 CAM)以及最近基于梯度的方法(NOTEARS,DAG-GNN 和 GraN-DAG)在学术界常用的一些数据集上进行了比较。我们使用三个指标评估学到的图结构:错误发现率(FDR),正确率(TPR)和结构汉明距离(SHD)。SHD 是将得到的图转换为真实 DAG 的边添加,删除和反转操作的最少个数。

高斯和非高斯噪声的线性数据模型

我们首先考虑 12 个节点的有向图。图 2 显示了在一个线性高斯数据集上 RL-BIC2 的训练过程。我们采用 NOTEARS 和 DAG-GNN 在同样的数据集上使用的阈值来做减枝。在这个例子中,RL-BIC2 在训练过程中生成 683,784 个不同的图,远低于 12 个节点 DAG 的总数(约 5.22 * 10^26)。经过减枝的 DAG 和真实的图结构完全相同。

图 2:在线性高斯数据集上 RL-BIC2 的学习过程。

表 1 是我们在 LiNGAM 和线性高斯数据模型的实验结果。在该实验中,RL-BIC2 在两个数据模型上恢复了所有真实的因果图,而 RL-BIC 的表现稍差。尽管如此,在相同的 BIC 分数下,RL-BIC 在两个数据集上的表现均远好于 GES。

具有高斯过程的非线性模型

我们考虑一种非线性的数据模型,每个因果关系函数是从高斯过程中采样的一个函数。该问题被证明是可识别的,即可以从联合概率分布中识别出真实的图。我们使用和 GraN-DAG 一样的实验条件:10 个节点,40 条边的 DAG,并考虑 1000 个观测样本。实验结果如下表 3 所示。对于我们的方法,我们将高斯过程回归(GPR)与 RBF 核一起使用来建立因果关系模型。虽然观察到的数据是来自于高斯过程采样得到的函数,但这并不能保证具有相同核的 GPR 可以达到很好的结果。实际上,使用固定的核参数将导致严重的过度拟合,从而导致许多错误的边,这样训练结束最好 reward 对应的有向图通常不是 DAG。为此我们将数据归一化处理,并使用 median heuristics 来选择核参数。我们两种方法的表现都不错,其中 RL-BIC 的结果优于其他所有方法。

真实数据集

我们最后考虑 Sachs 数据集,通过蛋白质和磷脂的表达程度来发现蛋白质信号网络。我们将带有 RBF 内核的 GPR 应用于因果关系建模,对数据做归一化并使用基于 median heuristics 的核参数。我们使用和 CAM 及 Gran-DAG 中同样的减枝方法。实验结果见下表。与其他方法相比,RL-BIC 和 RL-BIC2 均取得了不错的结果。

结语

我们使用强化学习来搜索具有最佳分数的 DAG,其中 actor 是基于自注意力机制的 encoder-decoder 模型,而 reward 结合了预先给定的得分函数和两个惩罚项来得到无环图。在合成和真实数据集上,该方法均取得了很好的结果。在论文里,我们还展示了该方法在 30 节点的图上的效果,但是处理大规模的图(超过 50 个节点)仍然具有挑战性。尽管如此,许多实际的应用(例如 Sachs 数据集)的变量数都相对较少。此外,有可能将大的因果发现问题分解为较小的问题分别处理,基于先验知识或基于约束的方法也可以用来减少搜索空间。

当前的工作有几个未来改进的方向。在目前的实现中,打分函数的计算比训练神经网络会花费更多的时间,一个更有效率的打分函数将会大大提升目前算法的表现。其他 RL 算法也可以用来加速训练,例如 A3C。此外,我们观察到实验中使用的总迭代次数通常超过了需要的次数,我们也会研究如何进行 early stopping。

华为诺亚ICLR 2020满分论文:基于强化学习的因果发现算法相关推荐

  1. CORL: 基于变量序和强化学习的因果发现算法

    深度强化学习实验室 官网:http://www.neurondance.com/ 论坛:http://deeprl.neurondance.com/ 来源:诺亚实验室 华为诺亚方舟实验室.西安交通大学 ...

  2. 华为诺亚方舟郝建业:深度强化学习的三大挑战

    智源导读:近年来,深度强化学习技术在游戏人工智能领域.推荐系统.搜索系统.网络优化.供应链优化.自动驾驶和芯片设计等领域取得了大量成果. 华为诺亚方舟决策与推理实验室郝建业近期在北京智源大会上发表了题 ...

  3. 一种镜像生成式机器翻译模型:MGNMT | ICLR 2020满分论文解读

    MGNMT:镜像生成式NMT (ICLR 2020满分论文) 机构:南京大学,字节跳动 点此获取"论文链接" 一.摘要 常规的神经机器翻译(NMT)需要大量平行语料,这对于很多语种 ...

  4. ICLR 2020 多智能体强化学习论文总结

    ICLR 2020 多智能体强化学习论文总结 如有错误,欢迎指正 所引用内容链接 Multi-Agent RL 1.Multi-agent Reinforcement Learning For Net ...

  5. 基于强化学习的服务链映射算法

    2018年1月   <通信学报>    魏亮,黄韬,张娇,王泽南,刘江,刘韵洁 摘要 提出基于人工智能技术的多智能体服务链资源调度架构,设计一种基于强化学习的服务链映射算法.通过Q-lea ...

  6. [论文]基于强化学习的无模型水下机器人深度控制

    基于强化学习的无模型水下机器人深度控制 摘要 介绍 问题公式 A.水下机器人的坐标框架 B.深度控制问题 马尔科夫模型 A.马尔科夫决策 B.恒定深度控制MDP C.弯曲深度控制MDP D.海底追踪的 ...

  7. 交通计算机专业硕士论文,基于强化学习的交通拥堵控制方法研究-计算机技术专业论文.docx...

    基于强化学习的交通拥堵控制方法研究摘 基于强化学习的交通拥堵控制方法研究 摘要 由于汽车保有量的持续增长,交通拥堵问题已经成为世界各国城市发展中出 现的公共问题.单纯的基础设施建设能够在一定程度上缓解 ...

  8. 【实践】基于强化学习的 Contextual Bandits 算法在推荐场景中的应用

    文章作者:杨梦月.张露露 内容来源:滴滴科技合作 出品平台:DataFunTalk 导读:本文是对滴滴 AI Labs 和中科院大学联合提出的 WWW 2020 Research Track 的 Or ...

  9. 顶会速递 | ICLR 2020录用论文之元学习篇

    抽空为大家整理了人工智能顶会ICLR 2020录用的Meta learning 元学习相关的最新论文,感兴趣的朋友们赶紧Mark读起来吧! [1]. Meta-Q-Learning 链接 | http ...

最新文章

  1. 读书笔记 effective c++ Item 5 了解c++默认生成并调用的函数
  2. [导入]微软CSS,GCR半日游--学了一样东西,什么叫做灰头土脸
  3. HALCON示例程序optical_flow.hdev如何使用optical_flow_mg计算图像序列中的光流以及如何分割光流。
  4. 7招改善你的谷歌chrome浏览器
  5. CentOS7如何安装vsftpd
  6. Swift与OC混编过程中的配置
  7. JDBC中数据库连接池的使用与传统方式的比较
  8. 孔浩老师SpringMVC视频总结
  9. 如何制定目标 (转自我学网)
  10. mcu AD采样值和物理值
  11. html前端页面的字体大小,JQuery 改变页面字体大小的实现代码(实时改变网页字体大小)...
  12. Windows Server 2022 英文版、简体中文版下载 (updated Dec 2021)(2022 年 1 月发布)
  13. PS_1_认识主界面_新建文档(分辨率)_打开保存(序列动画)
  14. sdkman软件开发工具包管理器
  15. jira -workflow之父级任务关注人copy到子任务
  16. RCS(Real-time control systems) 库
  17. updog的一个bug修复 支持多线程 视频播放支持跳转
  18. 【QT学习】QRegExp类正则表达式(一文读懂)
  19. 2020年第四届计算机检测维修与数据恢复国赛模拟比赛
  20. Coding and Paper Letter(六十三)

热门文章

  1. 【126】TensorFlow 使用皮尔逊相关系数找出和标签相关性最大的特征值
  2. VIM 的方向键 h(左)、j(下)、k(上)、l(右)
  3. CV边缘检测索贝尔算子
  4. 智源社区2022新版体验:订阅讲座日历、关注顶尖专家、开启个人频道
  5. 新书 5 折腰斩!畅销技术类图书推荐
  6. 互联网刚刚年满50,发明它的那个人却「后悔」了
  7. 程序员成长路上的团队修炼之道
  8. Master 横扫围棋各路高手,是时候全面研究通用人工智能了!
  9. 在Python3.4中实现opencv3.1.0的安装配置
  10. 独家 | 成功开发者必备的5项软技能