作者丨PENG Bo@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/476297195

编辑丨极市平台

本文的代码,在 win10 和 linux 均可直接编译运行:

https://github.com/BlinkDL/RWKV-CUDAgithub.com/BlinkDL/RWKV-CUDA

先看需提速的操作,在我的 RWKV 语言模型【 GitHub - BlinkDL/AI-Writer AI 写小说:https://github.com/BlinkDL/AI-Writer 】,类似 depthwise 一维卷积,伪代码:

w.shape = (C, T)
k.shape = (B, C, T)
out.shape = (B, C, T)
out[b][c][t] = eps + sum_u{ w[c][(T-1)-(t-u)] * k[b][c][u] } 这里 u 从 0 到 t

它的意义,是让  对  产生影响,具体的影响程度由  决定,且影响在每个通道  都不同。

用代码写(四重循环):

out = torch.empty((B, C, T), device='cuda')for b in range(B):for c in range(C):for t in range(T):s = epsfor u in range(0, t+1):s += w[c][(T-1)-(t-u)] * k[b][c][u]out[b][c][t] = sreturn out

这个操作,用 pytorch 只需一行,但实际速度不佳,尤其是,反向梯度很慢:

out = eps + F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w.unsqueeze(1), groups=C)

因此,我们可以用 CUDA 手写算子。实际测试,正向和反向速度可以 20x。

而且,这里的代码还有很多优化的空间。还望各位 CUDA 高手指导如何进一步优化,多谢多谢。

如果你从未尝试给 pytorch 添加 CUDA 算子,可以先阅读下面这个教程:

godweiyang:熬了几个通宵,我写了份CUDA新手入门代码

下面我们看看,如何逐步优化 CUDA kernel 的写法。

1. 最简单的 CUDA Kernel 写法

最简单的写法,是直接在每个 thread 求和。这会有大量内存存取,因此效率很低。速度为 45 毫秒。但也比 pytorch 的 94 毫秒更快了。

Grid 和 Block:

dim3 gridDim(1, B * C);
dim3 blockDim(T); // 注意,我们先只在 T 分 thread,因为这样的代码简单,而且效率也够高
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);

Kernel:

template <typename F>
__global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,const F eps, const int B, const int C, const int T)
{const int i = blockIdx.y;const int t = threadIdx.x;F s = eps;const F *__restrict__ const www = w + (i % C) * T + (T - 1) - t;const F *__restrict__ const kk = k + i * T;for (int u = 0; u <= t; u++){s += www[u] * kk[u];}x[i * T + t] = s;
}

2. 运用 shared memory 改善存取效率

优化 CUDA kernel 的第一步,是用 shared memory(就像矩阵乘法做 tiling)。速度提升到 17 毫秒。

template <typename F>
__global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,const F eps, const int B, const int C, const int T)
{const int i = blockIdx.y;const int t = threadIdx.x;__shared__ F ww[1024]; // 这里限制了 T <= 1024 因为我实际只会用到这么多__shared__ F kk[1024];ww[t] = w[(i % C) * T + t];kk[t] = k[i * T + t];__syncthreads();F s = eps;const F *__restrict__ const www = ww + (T - 1) - t;for (int u = 0; u <= t; u++){s += www[u] * kk[u];}x[i * T + t] = s;
}

我们在每个 CUDA thread,预先读取 w 和 k 进入 shared memory 中的 ww 和 kk,然后 __syncthreads() 等待全部读取完毕,然后可使用速度快得多的 ww 和 kk。

3. 将 thread 四合一,并运用 float4 告诉 nvcc 产生 SIMD 代码

优化 CUDA kernel 的第二步,可能是解决 bank conflict,不过,这个话题比较复杂。

我们看另一个简单易懂的步骤:将 thread 四合一,这通常是个好主意。速度提升到 14 毫秒。

Grid 和 Block:

dim3 gridDim(1, B * C);
dim3 blockDim(T >> 2); // 四合一,这里需要保证 T%4 == 0,因为我没有处理除不尽的情况
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);

然后 CUDA 有个 float4 结构,是 4 个 float 合起来。如果用它,更容易让 nvcc 产生 SIMD 代码。

Kernel:

template <typename F>
__global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,const F eps, const int B, const int C, const int T) {const int i = blockIdx.y;const int tt = threadIdx.x;const int t = tt << 2;__shared__ F wk[2048]; // 这里我们将 w 和 k 也合并了,以后会有好处((float4 *)wk)[tt] = ((float4 *)w)[(i % C) * (T >> 2) + tt];((float4 *)wk)[256 + tt] = ((float4 *)k)[i * (T >> 2) + tt];__syncthreads();float4 s = {eps, eps, eps, eps};const F *__restrict__ const ww = wk + T - t - 4;const F *__restrict__ const kk = wk + 1024;for (int u = 0; u <= t; u++) {F x = kk[u];s.x += ww[u + 3] * x;s.y += ww[u + 2] * x;s.z += ww[u + 1] * x;s.w += ww[u + 0] * x;}s.y += ww[t + 3] * kk[t + 1];s.z += ww[t + 2] * kk[t + 1];s.z += ww[t + 3] * kk[t + 2];s.w += ww[t + 1] * kk[t + 1];s.w += ww[t + 2] * kk[t + 2];s.w += ww[t + 3] * kk[t + 3];((float4 *)x)[i * (T >> 2) + tt] = s;
}

可见,四合一还有额外的好处:循环可以重用 k[u],进一步减少了内存读取。

4. 继续将 B 分组整合

@有了琦琦的棍子(//www.zhihu.com/people/581a2fcdf24763fbb9ec2900065986b4)指出,之前我们每个 thread 都只处理一行 T,但是,注意到 w 在 B 向是共享的,所以应该每个 thread 处理多个 w 重复的行。

我实验了代码,的确可以将正向速度提速几倍,速度提升到 3.4 毫秒。而对于反向,只有 grad_K 可利用重复的 w,所以效应弱一些。

dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);

正向可以用 BF = 8,即,每个 thread 处理 8 个 B。反向似乎只适合 thread 处理 2 个 B。

// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)#define F4(A, B) ((float4 *)(A))[(B) >> 2]template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,const F eps, const int B, const int C, const int T) {const int i = blockIdx.y;const int ij = (B * C) / BF;const int t = threadIdx.x << 2;__shared__ F ww[Tmax];__shared__ F kk[Tmax * BF];F4(ww, t) = F4(__w, t + T * (i % C));#pragma unrollfor (int j = 0; j < BF; j++) {F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));}__syncthreads();float4 s[BF];#pragma unrollfor (int j = 0; j < BF; j++) {s[j] = {eps, eps, eps, eps};}const F *__restrict__ const w = ww + T - t - 4;for (int u = 0; u <= t; u++) {#pragma unrollfor (int j = 0; j < BF; j++) {const F x = kk[u + Tmax * j];s[j].x += w[u + 3] * x;s[j].y += w[u + 2] * x;s[j].z += w[u + 1] * x;s[j].w += w[u + 0] * x;}}#pragma unrollfor (int j = 0; j < BF; j++) {const F *__restrict__ const k = kk + Tmax * j;s[j].y += w[t + 3] * k[t + 1];s[j].z += w[t + 2] * k[t + 1];s[j].z += w[t + 3] * k[t + 2];s[j].w += w[t + 1] * k[t + 1];s[j].w += w[t + 2] * k[t + 2];s[j].w += w[t + 3] * k[t + 3];F4(x, t + T * (i + ij * j)) = s[j];}
}

5. 对齐每个 thread 的任务长度

@有了琦琦的棍子同时指出,目前每个 thread 的任务长度不同(因为 t 不同),因此会降低效率(快的 thread 会等慢的 thread)。我预计这个改动可以让速度再提升一倍,稍后加入。

6. 进一步优化

下面怎么进一步优化?还请各位 CUDA 高手指导。可以先看看 B=32,C=768,T=768 的情况,多谢多谢。

本文的代码,在 win10 和 linux 均可直接编译运行:

https://github.com/BlinkDL/RWKV-CUDAgithub.com/BlinkDL/RWKV-CUDA

本文仅做学术分享,如有侵权,请联系删文。

3D视觉精品课程推荐:

1.面向自动驾驶领域的多传感器数据融合技术

2.面向自动驾驶领域的3D点云目标检测全栈学习路线!(单模态+多模态/数据+代码)
3.彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进
4.国内首个面向工业级实战的点云处理课程
5.激光-视觉-IMU-GPS融合SLAM算法梳理和代码讲解
6.彻底搞懂视觉-惯性SLAM:基于VINS-Fusion正式开课啦
7.彻底搞懂基于LOAM框架的3D激光SLAM: 源码剖析到算法优化
8.彻底剖析室内、室外激光SLAM关键算法原理、代码和实战(cartographer+LOAM +LIO-SAM)

9.从零搭建一套结构光3D重建系统[理论+源码+实践]

10.单目深度估计方法:算法梳理与代码实现

11.自动驾驶中的深度学习模型部署实战

12.相机模型与标定(单目+双目+鱼眼)

13.重磅!四旋翼飞行器:算法与实战

14.ROS2从入门到精通:理论与实战

重磅!3DCVer-学术论文写作投稿 交流群已成立

扫码添加小助手微信,可申请加入3D视觉工坊-学术论文写作与投稿 微信交流群,旨在交流顶会、顶刊、SCI、EI等写作与投稿事宜。

同时也可申请加入我们的细分方向交流群,目前主要有3D视觉CV&深度学习SLAM三维重建点云后处理自动驾驶、多传感器融合、CV入门、三维测量、VR/AR、3D人脸识别、医疗影像、缺陷检测、行人重识别、目标跟踪、视觉产品落地、视觉竞赛、车牌识别、硬件选型、学术交流、求职交流、ORB-SLAM系列源码交流、深度估计等微信群。

一定要备注:研究方向+学校/公司+昵称,例如:”3D视觉 + 上海交大 + 静静“。请按照格式备注,可快速被通过且邀请进群。原创投稿也请联系。

▲长按加微信群或投稿

▲长按关注公众号

3D视觉从入门到精通知识星球:针对3D视觉领域的视频课程(三维重建系列、三维点云系列、结构光系列、手眼标定、相机标定、激光/视觉SLAM自动驾驶等)、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答五个方面进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业为一体的铁杆粉丝聚集区,近4000星球成员为创造更好的AI世界共同进步,知识星球入口:

学习3D视觉核心技术,扫描查看介绍,3天内无条件退款

圈里有高质量教程资料、答疑解惑、助你高效解决问题

觉得有用,麻烦给个赞和在看~  

实例:手写 CUDA 算子,让 Pytorch 提速 20 倍相关推荐

  1. oracle最快访问行,平均提速20倍!Oracle 12c In-Memory最佳实践

    来源:三墩IT人 订阅号 作者:唐小丹(浙江移动数据库管理员) 马力行(新炬网络数据库工程师) 一.IM特性简介 Oracle 12.1.0.2 引入了In-Memory Column Store(以 ...

  2. 基于卷积神经网络(cnn)的手写数字识别(PyTorch)

    目录 1.1 卷积神经网络简介 1.2 神经网络 1.2.1 神经元模型 1.2.2 神经网络模型 1.3 卷积神经网络 1.3.1卷积的概念 1.3.2 卷积的计算过程 1.3.3 感受野 1.3. ...

  3. 手写中文数字识别PyTorch实现(全连接卷积神经网络)

    尝试一下手写汉字的数字识别,分别采用全连接神经网络和卷积神经网络 这次准备的数据集有15000张图片,每张图片大小为64*64 训练集10500张图片,测试集4500张图片 全连接神经网络 我们先用上 ...

  4. 4.1 keras基础实例 手写数字识别

    1)手写数据集 手写数据集是深度学习中,最基础应用最广泛的数据集. 手写数据集内置在keras中 import keras from keras import layers import matplo ...

  5. 提速20倍!谷歌AI发布TensorFlow 3D,智能汽车场景亲测好用

    来源丨新智元 编辑丨极市平台 导读 Google AI发布了TensorFlow 3D,将3D深度学习能力引入TensorFlow,加入3D稀疏卷积网络,在Waymo Open数据集上的实验表明,这种 ...

  6. 提速20倍!谷歌AI发布TensorFlow 3D

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 来源丨新智元 编辑丨极市平台 随着自动驾驶汽车与机器人的深入发展,激 ...

  7. Multi-class Classication (多分类问题)实例--手写数字识别

    本实例整理自斯坦福机器学习课程课后练习ex3 本例是对一个手写体的数据集(0-9)进行分类,也就是对原有的数据集进行训练,然后给定一个手写体,识别该手写体是数字几.其分类思想就是之前Andrew Ng ...

  8. caffe(4):mnist实例---手写数字识别

    深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了.由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文: 数据层及参数 ...

  9. Hinton NIPS2017 手写识别实现 TensorFlow \ Pytorch \ Keras

    10月26日,深度学习元老Hinton的NIPS2017 Capsule论文<Dynamic Routing Between Capsules>终于在arxiv上发表.今天相关关于这篇论文 ...

最新文章

  1. 人工智能算法模型必会之——正则化方法综述
  2. POJ - 3476 A Game with Colored Balls---优先队列+链表(用数组模拟)
  3. 基于第三方开源库的OPC服务器开发指南(2)——LightOPC的编译及部署
  4. C++11: chrono
  5. 【LeetCode从零单排】No96 Unique Binary Search Trees
  6. 《数据科学与大数据分析——数据的发现 分析 可视化与表示》一2.3 第2阶段:数据准备...
  7. linux内核体系学习路径_Linux内核分析(一)linux体系简介|内核源码简介|内核配置编译安装...
  8. 39.django的ORM模型
  9. 【Transformer】Transformer中16个注意力头一定要比1个注意力头效果好吗?
  10. LVS部分调度算法的适应场景分析
  11. matlab投资组合权重,Matlab做投资组合最优化
  12. SpringBoot全局异常处理(三十)
  13. 手机设计软件有哪些(合集)
  14. 2019年4月27号,下雨杂谈
  15. Java初始化大乱斗
  16. 如何看待家长培训课?
  17. Numpy 计算男女生各科成绩统计指标
  18. 轻轻松松实现本地和云主机之间的文件上传下载
  19. 为什么类只能单继承,而接口可以多继承?
  20. 【含源码】用python做游戏有多简单好玩

热门文章

  1. 数据结构与算法——AVL树类的C++实现
  2. 使用XML作为配置表,WinForm程序读取配置表来动态显示控件
  3. Python 3 文件和字符编码
  4. 中国首个量子计算机诞生 中科院、阿里巴巴共同研发
  5. ASP.NET Core中的依赖注入(2):依赖注入(DI)
  6. 最小系统必须安装的组件(仅做参考)
  7. 为了拿捏 Redis 数据结构,我画了 40 张图
  8. 迄今为止程序员写过的最大Bug:亏损30亿、致6人死亡,甚至差点毁灭世界
  9. 不容错过的灰度发布系统架构设计
  10. 干货!如何设计实现一个通用的分布式事务框架?