像 PyTorch 或 TensorFlow 这样通用的自动微分框架是非常有用和高效的,而且在大多数情况下,几乎不需要再写一些更专门化的东西。然而本文作者构建了一个自动微分库,以高效地计算小批量数据上的训练。此外,作者还详细描述了在构建自动微分库中的过程与思考,是理解自动微分理念的优秀博文。

我最近开始写自己的 autodiff 程序包。这篇博客文章记录了我一路以来学到的东西,并把它当成 Julia Evans 的「穷人版」博客文章。

因为有许多博客文章都解释了自动微分的机制,比我解释的要好得多,所以这里我跳过了解释。此外,在构建神经网络结构方面还有其他一些有趣的文章,因此,虽然我的库遵循非常相似的模式(静态计算图和依赖类型),但我并没有过多地关注类型系统。

最后,如果你想直接跳转到代码部分,最终的结果在以下 GitHub 地址,同时还有一个基于神经网络的 FizzBuzz 解决方案。

自动微分代码:https://github.com/maciejkula/wyrm

FizzBuzz:https://github.com/maciejkula/fizzbuzz

动机

关于为什么我想要有自己的 autodiff/backprop 框架,而不是直接用 PyTorch 或者 TensorFlow 的原因有以下几点。

1.PyTorch 和 TF 在拟合每个 x 小批量所需计算量很少的模型时非常慢。因为在计算机视觉问题中,每个小批量处理的计算量非常大,以至于框架开销几乎不成问题。但这在矩阵分解式的模型中却不可忽略,这种模型在推荐系统中是有用的。且即使在 GPU 上,拟合这些模型也很慢。

2. 我希望能够用我的 autodiff 库像 Python 包那样以最小的依赖关系来编写和构造模型。这个库能够生成一个相当小的、且独立的二进制文件,这是相对于繁琐的 TF 和 PyTorch 依赖的优势。

3. 这是一个有趣的学习经验,并且让我更详细地了解神经网络库的内部工作机制。

受到对推荐模型(也可能是 NLP) 有效的轻量级解决方案需求的启发,我编写了一系列的设计约束条件(design constraints)。

1. 我希望框架能够自然地支持稀疏梯度:即绝大多数梯度都为零的情况。这在 NLP 和使用大型嵌入层的推荐模型中非常常见。在任何给定的小批量中,只有很小一部分嵌入层被使用,其余记录的梯度均为零。在执行梯度更新时能够跳过零对于快速创建这些模型非常重要。

2. 我希望除实际计算之外,框架有最小的开销。因为我主要想要拟合小的、稀疏的模型,所以开销是关键。在 PyTorch 中,此类模型的运行时间以 Python 中的循环为主要开销。为了避免这种情况,我的库必须在它的拟合循环中放弃 Python,并且需要完全用编译语言编写以充分利用编译器优化的性质。

3. 模型图必须逐个定义,就像 Chainer 或者 PyTorch 一样。这种方法的可用性和可调试性对我来说是非常有价值的,以至于我甚至不想回到 TensorFlow 的处理方式。同时,我很高兴图形一旦被定义就是静态的。这有助于保持较小的开销:我可以分配一次中间计算缓冲区并继续使用它们,而不是写一个复杂的缓冲池系统(或者,更糟糕的是,在每次传递的时候不断地分配和释放内存)。

4. 我希望性能可以与可用 CPU 内核的数量大致呈线性关系。这意味着在整个图形的层次上进行并行化,而不是对单独的操作。每个计算线程将有它自己的计算图副本,但在更新时写入共享参数缓冲区。这实际上是 Hogwild! 方法,这个方法中多个计算线程同时更新共享参数缓冲区而没有任何锁定。只要梯度相对稀疏,就可以在模型质量下降很少的情况下进行近线性的缩放。

这里还列出了一些我现在不想添加或不太关心的事情:

1.GPU 支持。我主要想要拟合小型模型(或者至少有很多参数但每个小批量的计算很少的模型)。

2.CNNs,或者,实际上具有两个维度以上的张量。

考虑到需求(和非需求)列表,我们就能自然地得出一些设计决策。

1. 整个事情将用一种编译语言(compiled language)编写,这种编译语言能够生成没有运行时间的本地共享对象,模型也将用相同的语言来定义。

2. 这个语言将会是 Rust,这是一门令人惊叹的语言,而且非常适合这种任务。因此下面的许多东西都有 Rust 的味道。然而,我所描述的设计权衡在 C++、其他静态类型和 AOT 编译的编程语言中是相同的。

3. 我将会使用反向模式自动微分。这样,我可以很容易地通过多输入的任意(静态)计算图进行反向传播。

在编写库时,我经常想到 API,我希望能够将这个微分库公开并获得社区的帮助。在这种情况下,我想写如下内容:

  1. let slope = Parameter::new(1.0);

  2. let intercept = Parameter::new(0.0);

  3. let x = Input::new(3.0);

  4. let y = Input::new(2.0 * 3.0 + 1.0);

  5. let loss = (y — (slope * x + intercept)).square();

  6. loss.backward();

并让它工作。

准备工作完成之后,我们可以进入有趣的部分:弄清楚如何实现计算图。

表示计算图

我们选择什么样的数据结构来表示计算图?我了解有以下两种方案:

1. 基于向量:所有计算节点都被连续地存储在一个向量中,并使用索引来寻址它们的父节点。例如,在创建输入节点时,对象 InputNode 被压入向量,且索引为 0。如果随后将该节点平方,SquareNode 将被压入索引为 1 的分量,并知道它的父节点是索引 0。在正向传播过程中,SquareNode 将使用该索引来获取其输入的值。

2. 基于图形。节点被放置在内存中的任意位置,并用指向其父节点的索引来维护计算图的结构。(向量表示可以看作是图模型的线性化。)

基于矢量的方法有很多优点。

1. 所有的节点都在同一个地方。他们连续地储存在内存中,可能会减少内存的寻址问题。

2. 他们的所有权很容易解释。这使得克隆计算图图非常简单:只需克隆节点向量即可。这一点很重要,因为我依靠于为我的并行处理方法提供多个图的副本。

3. 节点按拓扑顺序排列。我们可以通过简单地沿着向量向前迭代来正确地执行前向传播,且没有重复的工作。

但是它也有缺点。

我们在节点向量中存储了什么类型的对象是不清楚的。所有的节点类型都不一样(不同的大小),但向量都是同质的类型。Rust 为这种问题提供了两种解决方案,但是都不是特别令人满意。

第一种是枚举(sum 类型;ADTs; tagged unions)。我们定义一个 Node 类型作为所有可能的节点类型的集合,并将其储存在节点向量中。这样,所有的节点就具有相同的类型了。但我们仍然需要将 Node 的方法从封装的 Node 类型分配到所包含的内部节点。这可以通过模式匹配(联合类型标签上的 switch 语句)完成;有 Rust 对模式识别和宏的支持,编写必要的代码是轻而易举的。

但是,这种做法会增加运行时间成本。每次我们使用一个节点,我们需要经过一个 switch 语句来解决内部类型问题。原则上,优化编译器会将这种代码编译成跳转表(jump tables)。实际上,在我的实验中为分配代码生成的程序集仅仅是对所有可能性的线性扫描,强加了与框架支持的具体节点类型数量呈线性关系的分配成本。更糟的是,编译器不愿意内联 switch 本身和被调用的函数。前者是因为它增加了分支预测的失误,后者增加了函数调用的开销。(最近的分值预测攻击加剧了这个问题: compiler mitigations 可能会导致像这样的间接指令更加昂贵。)

对节点向量使用 sum 类型的最后一个缺点是它会导致一个封闭的系统(类似于 Scala『s 的 封闭特性):库的下游用户不能添加新的节点类型。

另一种方法是用 Rust 的运行时多态机制(polymorphism mechanism): trait objects。trait objects 是对目标具体类型进行抽象的一种方法:我们将他们隐藏在指向数据的指针和他们方法表的后面,而不是将结构存储在内联中。调用方法时,我们跳转到 vtable,找到函数并执行。通过使用 trait ojbects,我们将这些 fat pointers 放到节点向量中而不是节点自身里面。

然而,这种解决方案恰恰引入了我们开始时想要避免的那种间接性。此外,它完全否认了编译器在内联方面做的努力:被调用的函数直到运行时才知道。

那么基于图的设计呢?在这里,每个节点都在内存中被放置在自己的位置,并且可以通过索引指向其祖先。因为每个节点可以重复使用任意次数,我用 Rust 中的 Rc<T>相当于 C++中的 shared_ptr。

这种方法的一个直接缺点是模糊了图的所有权结构,使克隆和序列化/反序列化变得困难:因为节点可以被重复利用,单纯的克隆/反序列化将导致创建相同节点的多个副本。

第二个缺点是缺少一个容易获得的拓扑排序:前向和后向传递都递归地完成,而且必须小心地避免重复计算共享子图的值。

使用图形表达的优点是在编译时已知任何节点的父节点类型。每一个节点在其父节点类型上是(递归地)通用的:添加两个 InputNodes 将会产生一个 AddNode<InputNode, InputNode>。将其添加到另一个输入节点会产生 AddNode<AddNode<InputNode, InputNode>,InputNode>等等。除了一个在类型系统中表现更好的设计,这给了我分配和内联的静态方法。

结果

使用一些非正式的基准,基于图的方法比基于向量的方法快大约 30%。最后的结果可以在我很普通的双核笔记本上,20 毫秒内在 Movielens 100K 数据集上完整地运行一个 BPR 学习-排序分解模型。此外,它的性能会随着处理器内核的增加而线性增长。

除了底层的图形结构之后,这里还利用了很多优化。

1. 我用 Rust 的 SIMD 内在函数进行了很多操作,如向量点积和标量加法。

2. 对于大多数操作,我假定 C 为连续矩阵并直接在底层数据上迭代,而不是用 ndarrays 迭代方法。事实证明,这样做要快得多,大概是因为它允许 LLVM 自动对向量实现向量化。

3. 事实证明,LLVM 足够智能,能够自动向量化大部分不涉及缩减步骤(主要是赋值)的数值循环。与(2)结合起来看,这种方法使得很多的数值循环以最小的优化努力获得更高的效率。

仍然有很多方法可以使计算速度更快。

1. 此时,代码在正向传递中不会缓存任何子图的结果:如果一个节点在正向传递中被用了两次,所有依赖它的计算将会执行两次。这可以通过一个简单的拓扑排序算法很容易的解决,一旦评估了它们的值,就将该节点标记为已评估。

2. 类似地,在后向传递中梯度被直接传递给参数节点。如果一个节点被多次使用,这意味着在逐步向下传递梯度时做了不必要的工作。累积所有的梯度并且只递归一次将节省这项工作。

3. 对输入有一些不必要的复制,在可能的情况下更好的利用索引应该会产生一些小的性能收益。

下一步是什么

我写了(并继续维护)很多的开源 Python ML 包。这些模型是在 Cython 中手工编写的,尽管它们表现的很好,但是扩展它们是困难的。部分是因为 Cython 的局限性,另一部分的原因在于手动派生更新规则所需的努力。

我希望这个库(或它的一些变体)可以使这个任务变得简单一些,并且可以让我更轻松地实现复杂模型以将它们作为独立的 Python 包发布出去。

附录

结果表明,当图形表达应用到递归神经网络时有一些问题:在递归的每一步,结果类型的复杂度增加,导致了相当奇怪的类型:

  1. Variable<nodes::LogNode<nodes::SoftmaxNode<nodes::DotNode<layers::recurrent::LSTMCellHidden<layers::recurrent::LSTMCellState<layers::recurrent::LSTMCellSt

  2. ate<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::Par

  3. ameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, layers::recu

  4. rrent::LSTMCellHidden<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHid

  5. den<layers::recurrent::LSTMCellState<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<nodes::InputNode, nodes::Input

  6. Node, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, layers::recurrent::LSTMCellHidden<layers::recurrent::LSTMCellState<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nod

  7. es::ParameterNode>>, layers::recurrent::LSTMCellHidden<nodes::InputNode, nodes::InputNode, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>>, nodes::IndexNode<nodes::ParameterNode>

  8. >, nodes::IndexNode<nodes::ParameterNode>>, nodes::ParameterNode>>>>

不用说,在经过一些迭代步骤后,编译器放弃了。这可以通过实现一个融合的 LSTM 单元来解决,而不是将其从更简单的操作中组装起来,或者选择通过 trait objects 选择性擦除。目前为止,我已经使用了第二种方案:通过将每个 LSTM 单元的输出值装入 trait object 来将其具体类型删除。

原文链接:https://medium.com/@maciejkula/building-an-autodifferentiation-library-9ccf32c7a658

干货丨从概念到实践,我们该如何构建自动微分库相关推荐

  1. 干货丨从概念到案例:初学者须知的十大机器学习算法

    本文先为初学者介绍了必知的十大机器学习(ML)算法,并且我们通过一些图解和实例生动地解释这些基本机器学习的概念.我们希望本文能为理解机器学习(ML)基本算法提供简单易读的入门概念. 机器学习模型 在& ...

  2. 干货丨先搞懂这八大基础概念,再谈机器学习入门

    翻译 | AI科技大本营 参与 | 林椿眄 准备好开始AI了吗?可能你已经开始了在机器学习领域的实践学习,但是依然想要扩展你的知识并进一步了解那些你听过却没有时间了解的话题. 这些机器学习的专业术语能 ...

  3. 干货丨渗透测试常用方法总结,大神之笔!

    干货丨渗透测试常用方法总结,大神之笔! 一.渗透流程 信息收集 漏洞验证/漏洞攻击 提权,权限维持 日志清理 信息收集 一般先运行端口扫描和漏洞扫描获取可以利用的漏洞.多利用搜索引擎 端口扫描 有授权 ...

  4. 干货丨【看图识算法】这是你见过最简单的 “算法说明书”

    文章来源:新智元 [导读]像阅读宜家的安装说明书一样学习算法,是怎样的体验?不伦瑞克工业大学的三名研究者制作了这份"算法说明书",简明传神地解释了一些基本算法,一起来看图说话. Q ...

  5. 5u以太网用交换机连接电脑_干货丨如何用自己的电脑直接连接NUS打印机

    期末考试又来啦,有好多同学在后台问小助手怎么把复习资料打印下来. 有同学会抢答了:这题我会!拿着U盘去图书馆就可以打印了呀! 没错,用U盘打印是大家最常用的方式.但小小的U盘不仅容易弄丢,还容易发生各 ...

  6. 干货丨如何准确找到剪辑点?后期剪辑进阶必看

    干货丨如何准确找到剪辑点?后期剪辑进阶必看 初级剪辑师在剪辑时可能会聚焦在已有影视素材的整理,但优秀的剪辑师关注的点应该更加巧妙细致,要真正做好剪辑,知道"什么时候该剪"很重要. ...

  7. AFC中央计算机系统图,干货丨城市地铁AFC系统由哪些部分组成的?如何运作?

    原标题:干货丨城市地铁AFC系统由哪些部分组成的?如何运作? AFC系统的全称是Automatic Fare Collection System(城市轨道交通自动售检票系统),是基于计算机.通信.网络 ...

  8. ad设置塞孔_干货丨PCB线路板过孔堵上,到底是什么学问?

    原标题:干货丨PCB线路板过孔堵上,到底是什么学问? 1.BGA位在阻焊为什么要塞孔?接收标准是什么? 答:首先阻焊塞孔是为了保护过孔的使用寿命,因为BGA位所需塞的孔一般孔径都比较小,在0.2--0 ...

  9. 电气simulink常用模块_干货丨16种常用模块电路分析,工程师的必备~

    电路图一大张,看似复杂,但也都是由一小块一小块的功能模块组成的.因此要根据大的功能先划分成块,再在块里面看是通过什么电路形式实现的,有些起辅助作用,有些起主要作用.下面小编给大家整理了16种常用的模块 ...

最新文章

  1. INS-20802 PRVF-9802 PRVF-5184 PRVF-5186 After Successful Upgradeto 11gR2 Grid Infrastructure
  2. SDNU 1464.最大最小公倍数(思维)
  3. 处理上百万条的数据库如何提高处理查询速度
  4. java arcgis server_ArcGIS Server Java 开发实战---自定义command
  5. 基于I2C协议的EEPROM驱动控制
  6. lvs负载均衡—高可用集群(keepalived)
  7. ajax跨浏览器初始化,使用Ajax的jQuery localStorage的跨浏览器
  8. AtCoder Beginner Contest 185
  9. php $GLOBALS 超全局变量的理解
  10. 以.a(a为后缀)的文件类型是啥鸭?
  11. TLS协议、PKI、CA
  12. 网易云音乐推荐中的用户行为序列深度建模
  13. FL Studio20.9水果软件高级中文版电音编曲
  14. win10易升_win10性能模式是什么?怎么开启?
  15. 移动端后台管理系统框架
  16. 金融安全视角农民投资理财的实证研究——以X县为例
  17. 深度学习在图像识别上的应用
  18. 序列标注 | (4) Hierarchically-Refined Label Attention Network for Sequence Labeling
  19. FasterR-CNN,R-FCN,SSD,FPN,RetinaNet,YOLOv3速度和准确性比较
  20. 猪八戒网站上的骗子为什么这么多

热门文章

  1. GPT-3等三篇论文获NeurIPS2020最佳论文奖 | AI日报
  2. 2021全球产品经理大会蓄势待发!
  3. 用numpy做图像处理
  4. 奇异值的物理意义是什么?强大的矩阵奇异值分解(SVD)及其应用
  5. openSUSE中文输入的安装和设置
  6. 清华大学大数据研究中心给您拜年啦!
  7. 五大自动化测试的 Python 框架
  8. 独家 | 规范性分析的实用介绍(附R语言案例研究演示代码)
  9. DeepMind论文:深度压缩感知,新框架提升GAN性能(附链接)
  10. Regularized Evolution for Image Classifier Architecture Search--阅读笔记