如今,深度学习模型处于持续的演进中,它们正变得庞大而复杂。研究者们通常通过组合现有的 TensorFlow 或 PyTorch 操作符来发现新的架构。然而,有时候,我们可能需要通过自定义的操作符来实现更多的优化。随着深度学习模型规模不断增长,为实际生产和可扩展训练设计专门优化的操作符将会变得更加重要。因此,本文作者学习了如何在 CUDA 中为 Transformer 编写一个 PyTorch 自定义层。

选自tunz,作者:Choongwoo Han,机器之心编译,参与:Geek AI、张倩。

性能分析

首先,我们需要对一种深度学习模型很熟悉,这样我们就可以找到其性能瓶颈,并查看在我们进行了优化之后有多大的提升。我们可以使用内置的 PyTorch 分析器,也可以使用通用的 python 分析器。我们将同时考察这两种方法。

torch.autograd.profiler

PyTorch 提供了一个名为「torch.autograd.profiler」的 API。我们可以通过如下方式使用该 API:

with torch.autograd.profiler.profile(use_cuda=True) as prof:# Execute ops here
print(prof)

接着,PyTorch 会自动找到每个操作符并衡量他们的性能。性能分析结果如下:

分析器显示出了每个操作符在 CPU 和 GPU 上花费的时间。分析结果很直观,并且看起来似乎很精确,但是我们很难分辨出每个操作符并将它们与我的源代码匹配起来。例如,上面的输出结果显示出了三个不同的「unsqueeze」操作符,但是我们并不知道它们是在哪里被调用的。因此,我转而使用其它的分析器来寻找性能的瓶颈点

逐行分析器

因为 PyTorch 是基于 python 编写的,所以我们也可以使用通用的 python 分析器。我找来了一个逐行分析器(https://github.com/rkern/line_profiler),它可以逐行分析一个 python 应用程序。在要分析的函数的顶部添加「@profiler」装饰器之后,我们可以在命令行中用「kernprof」替换「python」来运行分析器。此外,在 CUDA 的环境下,我们必须设置一个环境变量「CUDA_LAUNCH_BLOCKING」来同步对 CUDA 调用。

运行一个 epoch 的后分析多头注意力机制前馈函数的结果如上图所示。结果显示了测量每一行所花费的时间,因此我们可以很容易地找到需要优化的目标代码。我们将重点关注第 85、87 和 88 行中的掩码操作。它组合了多个操作符来模拟「掩码处理后的 softmax」操作:为 softmax 的掩码输入填充负无穷数,从而使 softmax 忽略它们。在本文中,我将尝试优化这些操作。请注意,它当前花费了函数执行时间的 19.1%(7.2 + 5.9 + 6.0),而第 86 行花费了 15.2% 的执行时间。让我们使用这个值作为对比基线。

还有另一个适合优化的地方:第 86 行和第 90 行中的矩阵乘法,因为它们的输入或输出都填满了许多 0。本文不会对此进行深入探讨。

掩码处理后的 Softmax

首先,我认为我们可以通过将运算过程封装进一个操作符中来优化掩码处理后的 softmax,因为执行多个操作符本身就会产生开销。每次调用每个独立的操作符时,对 CUDA 核函数的调用会产生开销,而主机和 GPU 之间的数据传输也需要时间。

我们将使用一个名为「MaskedSoftmax」的自定义 CUDA 操作符。我们将其直接简略地定义如下:

x 是一个softmax 函数数的输入张量,m 代表一个掩膜张量,s 是一个用于归一化的标量值。该方程与 softmax 类似,只是掩码处理后值被规定为零,并乘以归一化系数。下图显示了掩码处理后的 Softmax 的一个示例。掩码处理后的位置变为零,并且使用 softmax 计算出其余位置上的值。

第一版

我首先写了一个简单版的 Masked Softmax。它由三个与 softmax 具有相同计算流程的遍历组成:(1)找到一个输入的最大值,(2)计算指数运算的值的和,以及(3)将每个值作为输入计算出指数运算的值,用它们分别除以指数运算的值的和。与 softmax 的不同之处在于,它还会加载掩码值,如果掩码值为 1,则将每个对应位置上的输入值转换为零

template <typename scalar_t>
__global__ void __launch_bounds__(32) masked_softmax_cuda_forward_kernel(const scalar_t* __restrict__ input,const scalar_t* __restrict__ mask,scalar_t* __restrict__ output,unsigned int hidden_size,unsigned int m0, // size of mask dimension 0unsigned int m1, // size of mask dimension 1unsigned int m2, // size of mask dimension 2scalar_t scale) {// This threadIdx.x is a number between 0 and 31 because we only launched 32 threads.const int tid = threadIdx.x;// blockIdx.x, y, z are offsets of 0th, 1st, 2nd dimensions of input tensor.const unsigned int ibase = blockIdx.x * gridDim.y * gridDim.z * hidden_size +blockIdx.y * gridDim.z * hidden_size +blockIdx.z * hidden_size;const unsigned int mbase = blockIdx.x * (m0 > 1 ? m1 * m2 * hidden_size : 0) +blockIdx.y * (m1 > 1 ? m2 * hidden_size : 0) +blockIdx.z * (m2 > 1 ? hidden_size : 0);unsigned shfl_mask = __ballot_sync(0xffffffff, threadIdx.x < hidden_size);// Find a maximum input.scalar_t max_x = -FLT_MAX;for (unsigned int i = tid; i < hidden_size; i+=blockDim.x) {scalar_t m = mask[mbase + i];max_x = fmaxf(max_x, m == 0 ? input[ibase + i] * scale : -FLT_MAX);}// Reduce values in threads to find a global maximum number.for (unsigned int i = 16; i > 0; i >>= 1) {max_x = max(max_x, __shfl_xor_sync(shfl_mask, max_x, i));}// Find a sum of exponential inputs.scalar_t exp_sum = 0;for (unsigned int i = tid; i < hidden_size; i+=blockDim.x) {scalar_t m = mask[mbase + i];exp_sum += m == 0 ? std::exp(input[ibase + i] * scale - max_x) : 0;}// Reduce values in threads to find a global summation of exponential inputs.for (unsigned int i = 16; i > 0; i >>= 1) {exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);}// Calculate outputs and save to global memory.for (unsigned int i = tid; i < hidden_size; i+=blockDim.x) {scalar_t m = mask[mbase + i];output[ibase + i] = m == 0 ? std::exp(input[ibase + i] * scale - max_x) / exp_sum : 0;}
}

CUDA 中有「warp」和「block」的概念。Warp 是一组 32 个线程,而一个 block 则包含多个 warp。每个 block 有一个共享的内存,任何线程都可以访问一个全局内存。每个线程使用不同的线程和 block 的 id 执行相同的核函数代码,因此每个核函数使用全局内存中的 id 查找和读取相关输入,并将每个输出保存到全局内存中。由于计算是分布式的,如果有需要,我们可能需要减少不同 block 或线程中的值。

在这个 softmax 的实现中,我们需要一个约简来获得值的和或最大值。由于访问全局/共享内存是 CUDA 核函数中常见的瓶颈,所以我试图绕开它。为此,我为每个 block 创建了一个 warp,并使用了「shuffle」函数。它使用寄存器进行 warp 内的通信,因此线程可以在不访问共享内存的情况下交换值。

for (unsigned int i = 16; i > 0; i >>= 1) {max_x = max(max_x, __shfl_xor_sync(shfl_mask, max_x, i));
}

通过这个自定义的操作符,掩码处理后的 softmax 占用执行时间的比例降至了 15%。这并不是一个巨大的提升,但无论如何也比之前要快一些了。

现在,内置的 PyTorch 分析器也显示出了这个自定义操作符的性能提升。因此,由于逐行的分析器需要用太长的时间进行性能分析,我将这个第一版的掩码处理后的 softmax 用作进行进一步优化的对比基线。

进一步的优化

正如我所提到的,对于全局内存的访问是一个主要的瓶颈。在一些假设条件下,我们可以最小化内存访问的次数。前面的第一版现在可以从全局内存中读取两种类型的值(掩码和输入)。用于归一化后的点乘注意力机制的掩码通常有如下所示的形式。

从最左或最右开始,它们是连续的,而基本的 transformer 只有从最左开始的三种形式。因此,我们不需要为每个输入加载掩码值。在读取每一行之前,加载一个表示掩码长度的值就足够了。

我们可以使用下面的代码直接将掩码转化为一种新的形式:

mask = mask.size(2) - mask.sum(dim=2, dtype=torch.int32)

接着,我们只需要首先加载掩码长度,将每个循环迭代与掩码长度相同的次数,并将其余输出设置为零。

// Load a mask length.const unsigned int mask_offset = blockIdx.x * (m0 > 1 ? m1 : 0) +blockIdx.z * (m1 > 1 ? 1 : 0);unsigned int mask_size = min(static_cast<unsigned int>(mask[mask_offset]),hidden_size);unsigned shfl_mask = __ballot_sync(0xffffffff, threadIdx.x < mask_size);scalar_t max_x = -FLT_MAX;// Iterate loop as much as the mask length.for (unsigned int i = tid; i < mask_size; i+=blockDim.x) {max_x = fmaxf(max_x, input[ibase + i] * scale);}for (unsigned int i = 16; i > 0; i >>= 1) {max_x = max(max_x, __shfl_xor_sync(shfl_mask, max_x, i));}scalar_t exp_sum = 0;for (unsigned int i = tid; i < mask_size; i+=blockDim.x) {exp_sum += std::exp(input[ibase + i] * scale - max_x);}for (unsigned int i = 16; i > 0; i >>= 1) {exp_sum += __shfl_xor_sync(shfl_mask, exp_sum, i);}// We initialized "output" to zero, so remaining outputs will be zero.for (unsigned int i = tid; i < mask_size; i+=blockDim.x) {output[ibase + i] = std::exp(input[ibase + i] * scale - max_x) / exp_sum;}

这样一来,这个操作就变得快多了。它现在只占用了执行时间的 9%。

掩码处理后的 Softmax(MaskedSoftmax)的执行时间现在比第一版快 2.5 倍。

我还检查了这种优化在多大程度上提高了整个训练的速度。我在 lm1b 数据集上训练了语言模型,并且测量了运行每个(碎片)epoch 的平均时间。第一个 CUDA 的版本比单纯组合 PyTorch 操作符的方法快了约 0.8%,第二个版本比原始版本快了约 1.8%。

结语

我在 CUDA 中编写了一个自定义的操作符并使 Transformer 的训练快了约 2%。我首先希望仅仅在 CUDA 中重写一个操作符来得到巨大的性能提升,但事与愿违。影响性能的因素有很多,但是我不可能找到每一个因素。此外,由于我对 CUDA 并不熟悉,我也遇到了很多 bug。代码越多,bug 越多。这使得我写了很多意想不到的测试代码。这是在提升模型性能和用于写代码的时间之间的一种折中。

编写一个自定义的操作符并没有我想象的那么简单,但是我可以从中学到许多关于 CUDA 如何工作的知识,以及诸如 block、线程、核函数、内存、同步、缓存这样的概念。我希望本文能够对那些想要入门 CUDA 性能优化的人有所帮助。

  • 完整代码:https://github.com/tunz/tcop-pytorch
  • 使用场景:https://github.com/tunz/transformer-pytorch.

原文链接:https://tunz.kr/post/5

java如何给一个链表定义和传值_如何在CUDA中为Transformer编写一个PyTorch自定义层...相关推荐

  1. python中case的用法_如何在Python中使用TestCase实现一个断言功能

    如何在Python中使用TestCase实现一个断言功能?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题. Python TestCase断 ...

  2. matlab中GUI的属性检查器中的XLimMode是什么_如何在Matlab中使用GUI做一个简易音乐播放器? ---- (二)GUIDE...

    咕咕怪由于昨天有重要的事情所以咕了一天的文章 (感觉写得挺基础的,对各个部分有一定了解的童鞋可以直接跳过了解的部分 用Matlab做一个app有几种办法呢? 同样的,帮助文档告诉了我们答案:三种. 英 ...

  3. matlab figure函数_如何在Matlab中使用GUI做一个简易音乐播放器? ---- (六)控件间的数据传递...

    我纠结了两个星期是否要写这一章-最后决定还是要写一章收尾,来解释其中的控件间的数据传递问题. 在前五篇中,如果有童鞋跟上了我的思路或者做完了这样一个gui,会发现还有一个一直避开的遗留问题,就是将歌曲 ...

  4. java编程用户输入两个数字_编写一个Java应用程序,要求从键盘输入两个整数,计算这两个数据...,编写一个Java应用程序,要求从键盘输入一个数,判断该数是不...

    导航:网站首页 > 编写一个Java应用程序,要求从键盘输入两个整数,计算这两个数据...,编写一个Java应用程序,要求从键盘输入一个数,判断该数是不 编写一个Java应用程序,要求从键盘输入 ...

  5. Java黑皮书课后题第9章:*9.4(使用Random类)编写一个程序,创建一个种子为1000的Random对象,然后使用nextInt(100)方法显示0到100之间的前50个随机整数

    Java黑皮书课后题第9章:*9.4(使用Random类)编写一个程序,创建一个种子为1000的Random对象,然后使用nextInt方法显示0到100之间的前50个随机整数 题目 赘述 代码 题目 ...

  6. Java黑皮书课后题第6章:6.35(几何:五边形的面积)五边形的面积可以用如下公式计算。编写一个方法,使用下面的方法头返回五边形面积。编写一个主方法,提示用户输入五边形的边,然后显示它的面积

    6.35(几何:五边形的面积)五边形的面积可以用如下公式计算.编写一个方法,使用下面的方法头返回五边形面积.编写一个主方法,提示用户输入五边形的边,然后显示它的面积 题目 题目描述与运行示例 破题 代 ...

  7. Java黑皮书课后题第6章:*6.14(估算π)π可以使用下面的数列进行计算。编写一个方法,对于给定的i返回m(i),并编写一个测试程序,显示如下表格

    6.14(估算π)π可以使用下面的数列进行计算.编写一个方法,对于给定的i返回m,并编写一个测试程序,显示如下表格 题目 题目描述与运行示例 破题 代码 题目 题目描述与运行示例 6.14(估算π)π ...

  8. Java黑皮书课后题第1章:1.10(以英里计的平均速度)假设一个跑步者45分30秒跑了14千米。编写一个程序显示以每小时多少英里为单位的平均速度值

    题目 题目描述 1.10(以英里计的平均速度)假设一个跑步者45分30秒跑了14千米.编写一个程序显示以每小时多少英里为单位的平均速度值.(注意,1英里约等于1.6千米) 槽点 1.11需要转换的数值 ...

  9. 本关任务:编写一个Point类,有x、y两个属性。编写一个PointDemo类,并提供一个distance(Point p1,Point p2)方法用于计算两点之间的距离,实例化两个具体的Point对

    #java编程基础 以后会时常更新java编程题,分享所遇之难,答疑解惑,共同努力. 本关任务:编写一个Point类,有x.y两个属性.编写一个PointDemo类,并提供一个distance(Poi ...

最新文章

  1. Go 分布式学习利器(11)-- Go语言通过单链表 实现队列
  2. 用bitmap实现中位数的算法
  3. c++ mqtt客户端_MQTT安全性设计详解
  4. pytorch图像和张量的相互转换_[Pytorch]Pytorch的tensor变量类型转换
  5. 用二叉树来理解树状数组
  6. python中list的意思_list在python中是什么意思
  7. LeetCode 1480. 一维数组的动态和(前缀和)
  8. 心斋-------马克奥勒流
  9. 超链接显示网站 A,访问后进入网站 B
  10. 黄永成-thinkphp讲解-个人博客讲解26集
  11. MySQL执行多表联查时,报错ln aggregated query without GROUP BY
  12. Python字符串转义符大全
  13. wifi无线破解记录
  14. Android使用SurfaceView开发《捉小猪》小游戏 (一)
  15. Easy Excel生成压缩包文件,自定义表头样式
  16. 如何清理系统大量的残余文件和系统垃圾文件?(win10)
  17. 算法-狄克斯特拉算法
  18. 规则引擎——IBM ODM(ILog)——基本使用步骤
  19. Git更换关联的远端分支
  20. 一个傻子玩DNF的感人事迹(不看必后悔)

热门文章

  1. matlab里面如何保留小数特定位数(转载)
  2. 当表格列数太多时,怎么实现表格的横向滚动
  3. 服务器搬迁之后的准备工作和应对
  4. 安全公司:苹果iOS10备份功能安全性比iOS9差很多
  5. 数据中心的7个新兴发展趋势
  6. Android多线程断点下载
  7. 在线听音乐要收费,你愿意吗?
  8. ITAA在线试学用户使用说明
  9. C#中常用的经典文件操作方法
  10. Silverlight、JavaFX、Flex技术比较