选自medium

作者:Jesus Rodriguez

机器之心编译

编辑:Panda

Facebook 提出了一种可高效训练包含数十亿节点和数万亿边的图模型的框架 BigGraph 并开源了其 PyTorch 实现。本文将解读它的创新之处,解析它能从大规模图网络高效提取知识的原因。

图(graph)是机器学习应用中最基本的数据结构之一。具体来说,图嵌入方法是一种无监督学习方法,可使用本地图结构来学习节点的表征。社交媒体预测、物联网模式检测或药物序列建模等主流场景中的训练数据可以很自然地表征为图结构。其中每一种场景都可以轻松得到具有数十亿相连节点的图。图结构非常丰富且具有与生俱来的导向能力,因此非常适合机器学习模型。尽管如此,图结构却非常复杂,难以进行大规模扩展应用。也因此,现代深度学习框架对大规模图数据结构的支持仍非常有限。

Facebook 推出过一个框架 PyTorch BigGraph:https://github.com/facebookresearch/PyTorch-BigGraph,它能更快更轻松地为 PyTorch 模型中的超大图结构生成图嵌入。

某种程度上讲,图结构可视为有标注训练数据集的一种替代,因为节点之间的连接可用于推理特定的关系。这种方法遵照无监督图嵌入方法的模式,它可以学习图中每个节点的向量表征,其具体做法是优化节点对的嵌入,使得之间有边相连的节点对的嵌入比无边相连的节点对的嵌入更近。这类似于在文本上训练的 word2vec 的词嵌入的工作方式。

当应用于大型图结构时,大多数图嵌入方法的结果都相当局限。举个例子,如果一个模型有 20 亿个节点,每个节点有 100 个嵌入参数(用浮点数表示),则光是存储这些参数就需要 800 GB 内存,因此很多标准方法都超过了典型商用服务器的内存容量。这是深度学习模型面临的一大挑战,也是 Facebook 开发 BigGraph 框架的原因。

PyTorch BigGraph

PyTorch BigGraph(PBG)的目标是扩展图嵌入模型,使其有能力处理包含数十亿节点和数万亿边的图。PBG 为什么有能力做到这一点?因为它使用了四大基本构建模块:

  1. 图分区,这让模型不必完全载入到内存中。

  2. 在每台机器上的多线程计算

  3. 在多台机器上的分布式执行(可选),所有操作都在图上不相连的部分进行

  4. 分批负采样,当每条边 100 个负例时,可实现每台机器每秒处理超过 100 万条边。

通过将图结构分区为随机划分的 P 个分区,使得可将两个分区放入内存中,PBG 解决了传统图嵌入方法的一些短板。举个例子,如果一条边的起点在分区 p1,终点在分区 p2,则它会被放入 bucket (p1, p2)。然后,在同一模型中,根据源节点和目标节点将这些图节点划分到 P2 bucket。完成节点和边的分区之后,可以每次在一个 bucket 内执行训练。bucket (p1, p2) 的训练仅需要将分区 p1 和 p2 的嵌入存储到内存中。PBG 结构能保证 bucket 至少有一个之前已训练的嵌入分区。

PBG 的另一大创新是训练机制的并行化和分布式。PBG 使用 PyTorch 自带的并行化机制实现了一种分布式训练模型,这用到了前面描述的模块分区结构。在这个模型中,各个机器会协调在不相交的 bucket 上进行训练。这会用到一个锁服务器(lock server),其负责将 bucket 分派给工作器(worker),从而尽可能地减少不同机器之间的通信。每台机器都可以使用不同的 bucket 并行地训练模型。

在上图中,机器 2 中的 Trainer 模块向机器 1 上的锁服务器请求了一个 bucket,这会锁定该 bucket 的分区。然后该 trainer 会保存它不再使用的所有分区并从共享分区服务器载入它需要的新分区,此时它可以将自己的旧分区释放回锁服务器。然后边会从一个共享文件系统载入,并在没有线程内同步的情况下在多个线程上进行训练。在一个单独的线程中,仅有少量共享参数会与一个共享参数服务器持续同步。模型检查点偶尔会从 trainer 写入到共享文件系统中。这个模型允许使用至多 P/2 台机器时,让一组 P 个 bucket 并行化。

PBG 一项不那么直接的创新是使用了分批负采样技术。传统的图嵌入模型会沿真正例边将随机的「错误」边构建成负训练样本。这能显著提升训练速度,因为仅有一小部分权重必须使用每个新样本进行更新。但是,负例样本最终会为图的处理引入性能开销,并最终会通过随机的源或目标节点「损害」真正的边。PBG 引入了一种方法,即复用单批 N 个随机节点以得到 N 个训练边的受损负例样本。相比于其它嵌入方法,这项技术让我们能以很低的计算成本在每条边对应的许多负例上进行训练。

要增加在大型图上的内存效率和计算资源,PBG 利用了单批 Bn 个采样的源或目标节点来构建多个负例。在典型的设置中,PBG 会从训练集取一批 B=1000 个正例,然后将其分为 50 条边一个的块。来自每个块的目标(与源等效)嵌入会与从尾部实体类型均匀采样的 50 个嵌入相连。50 个正例与 200 个采样节点的外积等于 9900 个负例。

分批负采样方法可直接影响模型的训练速度。如果没有分批,训练的速度就与负例的数量成反比。分批训练可改善方程,得到稳定的训练速度。

Facebook 使用 LiveJournal、Twitter 数据和 YouTube 用户互动数据等不同的数据集评估了 PBG。此外,PBG 还使用 Freebase 知识图谱进行了基准测试,该知识图谱包含超过 1.2 亿个节点和 27 亿条边。另外还使用 Freebase 的一个小子集 FB15k 进行了测试,FB15k 包含 15000 个节点和 600000 条边,常被用作多关系嵌入方法的基准。FB15k 实验表明 PBG 的表现与当前最佳的图嵌入模型相近。但是,当在完整的 Freebase 数据集上评估时,PBG 的内存消耗得到了 88% 的改善。

PBG 是首个可扩展的、能训练和处理包含数十亿节点和数万亿边的图数据的方法。PBG 的首个实现已经开源,未来应该还会有更有意思的贡献。

原文链接:https://medium.com/dataseries/facebooks-pygraph-is-an-open-source-framework-for-capturing-knowledge-in-large-graphs-b52c0fb902e8

© THE END

转载请联系 机器之心 公众号获得授权

投稿或寻求报道:content@jiqizhixin.com

可高效训练超大规模图模型,PyTorch BigGraph是如何做到的?相关推荐

  1. 训练超大规模图模型,PyTorchBigGraph如何做到?

    Facebook 提出了一种可高效训练包含数十亿节点和数万亿边的图模型的框架 BigGraph 并开源了其 PyTorch 实现.本文将解读它的创新之处,解析它能从大规模图网络高效提取知识的原因. 图 ...

  2. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  3. AI 图片截取、ffmpeg使用及安装, anaconda环境,图片标注(labelme),模型训练(yolov5),CUDA+Pytorch安装及版本相关问题

    AI 图片截取(ffmpeg), anaconda环境,图片标注(labelme),模型训练(yolov5),CUDA+Pytorch安装及版本相关问题 一.截取有效图片 录制RTSP视频脚本 #!/ ...

  4. Pytorch训练Bilinear CNN模型笔记

    Pytorch训练Bilinear CNN模型笔记 注:一个项目需要用到机器学习,而本人又是一个python小白,根据老师的推荐,然后在网上查找了一些资料,终于实现了目的. 参考文献: Caltech ...

  5. pytorch 驱动不兼容_解决Pytorch 加载训练好的模型 遇到的error问题

    这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- py ...

  6. Pytorch——保存训练好的模型参数

    文章目录 1.前言 2.torch.save(保存模型) 3.torch.load整个网络 4.torch.load网络参数(只提取参数) 5.调用三个函数 1.前言 训练好了一个模型, 我们当然想要 ...

  7. R语言使用caret包的knnreg函数拟合KNN回归模型:使用predict函数和训练好的模型进行预测推理、使用plot函数可视化线图对比预测值和实际值曲线

    R语言使用caret包的knnreg函数拟合KNN回归模型:使用predict函数和训练好的模型进行预测推理.使用plot函数可视化线图对比预测值和实际值曲线 目录

  8. 【Pytorch】加载torchvision中预训练好的模型并修改默认下载路径(使用models.__dict__[model_name]()读取)

    说明 使用torchvision.model加载预训练好的模型时,发现默认下载路径在系统盘下面的用户目录下(这个你执行的时候就会发现),即C:\用户名\.cache\torch\.checkpoint ...

  9. Nebula 在 Akulaku 智能风控的实践:图模型的训练与部署

    本文整理自 Akulaku 反欺诈团队在 nMeetup·深圳场的演讲,B站视频见:https://www.bilibili.com/video/BV1nQ4y1B7Qd 这次主要来介绍下 Nebul ...

最新文章

  1. 复制订阅服务器和 AlwaysOn 可用性组 (SQL Server)
  2. android studio 打开github开源代码
  3. 系统进程间的同步机制
  4. cgi备份还原和ghost有什么区别_装系统教程!如何用ghost安装系统(下)!小白也能变装机大神!...
  5. boost::statechart模块实现状态转换测试
  6. primefaces_PrimeFaces:在动态生成的对话框中打开外部页面
  7. Font Awesome 中文网
  8. 机器学习中数据清洗预处理入门完整指南
  9. Cannot find module 'less-bundle-promise'
  10. “分类垃圾桶”成交同比涨超七成 塑料概念股集体波动
  11. tar解压时遇到tar: Skipping to next header
  12. markdown显示箭头方法
  13. maven scope范围
  14. 光伏机器人最前线_送水、送药、送餐!哈市这些地方率先用上AI配送机器人(视频)...
  15. win10系统dnf安装不上服务器,win10系统玩不了DNF的解决方法
  16. Java初学者作业——定义客户类(Customer),客户类的属性包括:姓名、年龄、电话、余额、账号和密码;方法包括:付款。
  17. python 卡方分布函数_推断统计分析(二):python验证三大抽样分布
  18. Sun jdk、Openjdk、Icedtea jdk关系
  19. 基于Python实现对房价的预测
  20. python 视图对象_Python之路【第二十八篇】:django视图层、模块层

热门文章

  1. Wiki为什么会流行
  2. 数据结构与算法--线性表(顺序表)
  3. 鼠标按键获取感兴趣区域 2
  4. 【Codeforces】1065B Vasya and Isolated Vertices (无向图的)
  5. Meta 开发 AI 语音助手,用于创建虚拟世界和实时翻译
  6. 11 款可替代 top 命令的工具!
  7. 滴滴联合比亚迪:首款定制网约车D1发布
  8. WAIC汇聚全球顶级科学家,畅谈人工智能的未来挑战与突破
  9. 利用MTCNN和FaceNet实现人脸检测和人脸识别 | CSDN博文精选
  10. 买不到回家的票,都是“抢票加速包”惹的祸?