实例:手写 CUDA 算子,让 Pytorch 提速 20 倍
作者丨PENG Bo@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/476297195
编辑丨极市平台
本文的代码,在 win10 和 linux 均可直接编译运行:
https://github.com/BlinkDL/RWKV-CUDAgithub.com/BlinkDL/RWKV-CUDA
先看需提速的操作,在我的 RWKV 语言模型【 GitHub - BlinkDL/AI-Writer AI 写小说:https://github.com/BlinkDL/AI-Writer 】,类似 depthwise 一维卷积,伪代码:
w.shape = (C, T)
k.shape = (B, C, T)
out.shape = (B, C, T)
out[b][c][t] = eps + sum_u{ w[c][(T-1)-(t-u)] * k[b][c][u] } 这里 u 从 0 到 t
它的意义,是让 对 产生影响,具体的影响程度由 决定,且影响在每个通道 都不同。
用代码写(四重循环):
out = torch.empty((B, C, T), device='cuda')for b in range(B):for c in range(C):for t in range(T):s = epsfor u in range(0, t+1):s += w[c][(T-1)-(t-u)] * k[b][c][u]out[b][c][t] = sreturn out
这个操作,用 pytorch 只需一行,但实际速度不佳,尤其是,反向梯度很慢:
out = eps + F.conv1d(nn.ZeroPad2d((T-1, 0, 0, 0))(k), w.unsqueeze(1), groups=C)
因此,我们可以用 CUDA 手写算子。实际测试,正向和反向速度可以 20x。
而且,这里的代码还有很多优化的空间。还望各位 CUDA 高手指导如何进一步优化,多谢多谢。
如果你从未尝试给 pytorch 添加 CUDA 算子,可以先阅读下面这个教程:
godweiyang:熬了几个通宵,我写了份CUDA新手入门代码
下面我们看看,如何逐步优化 CUDA kernel 的写法。
1. 最简单的 CUDA Kernel 写法
最简单的写法,是直接在每个 thread 求和。这会有大量内存存取,因此效率很低。速度为 45 毫秒。但也比 pytorch 的 94 毫秒更快了。
Grid 和 Block:
dim3 gridDim(1, B * C);
dim3 blockDim(T); // 注意,我们先只在 T 分 thread,因为这样的代码简单,而且效率也够高
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
Kernel:
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,const F eps, const int B, const int C, const int T)
{const int i = blockIdx.y;const int t = threadIdx.x;F s = eps;const F *__restrict__ const www = w + (i % C) * T + (T - 1) - t;const F *__restrict__ const kk = k + i * T;for (int u = 0; u <= t; u++){s += www[u] * kk[u];}x[i * T + t] = s;
}
2. 运用 shared memory 改善存取效率
优化 CUDA kernel 的第一步,是用 shared memory(就像矩阵乘法做 tiling)。速度提升到 17 毫秒。
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,const F eps, const int B, const int C, const int T)
{const int i = blockIdx.y;const int t = threadIdx.x;__shared__ F ww[1024]; // 这里限制了 T <= 1024 因为我实际只会用到这么多__shared__ F kk[1024];ww[t] = w[(i % C) * T + t];kk[t] = k[i * T + t];__syncthreads();F s = eps;const F *__restrict__ const www = ww + (T - 1) - t;for (int u = 0; u <= t; u++){s += www[u] * kk[u];}x[i * T + t] = s;
}
我们在每个 CUDA thread,预先读取 w 和 k 进入 shared memory 中的 ww 和 kk,然后 __syncthreads() 等待全部读取完毕,然后可使用速度快得多的 ww 和 kk。
3. 将 thread 四合一,并运用 float4 告诉 nvcc 产生 SIMD 代码
优化 CUDA kernel 的第二步,可能是解决 bank conflict,不过,这个话题比较复杂。
我们看另一个简单易懂的步骤:将 thread 四合一,这通常是个好主意。速度提升到 14 毫秒。
Grid 和 Block:
dim3 gridDim(1, B * C);
dim3 blockDim(T >> 2); // 四合一,这里需要保证 T%4 == 0,因为我没有处理除不尽的情况
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
然后 CUDA 有个 float4 结构,是 4 个 float 合起来。如果用它,更容易让 nvcc 产生 SIMD 代码。
Kernel:
template <typename F>
__global__ void kernel_forward(const F *__restrict__ const w, const F *__restrict__ const k, F *__restrict__ const x,const F eps, const int B, const int C, const int T) {const int i = blockIdx.y;const int tt = threadIdx.x;const int t = tt << 2;__shared__ F wk[2048]; // 这里我们将 w 和 k 也合并了,以后会有好处((float4 *)wk)[tt] = ((float4 *)w)[(i % C) * (T >> 2) + tt];((float4 *)wk)[256 + tt] = ((float4 *)k)[i * (T >> 2) + tt];__syncthreads();float4 s = {eps, eps, eps, eps};const F *__restrict__ const ww = wk + T - t - 4;const F *__restrict__ const kk = wk + 1024;for (int u = 0; u <= t; u++) {F x = kk[u];s.x += ww[u + 3] * x;s.y += ww[u + 2] * x;s.z += ww[u + 1] * x;s.w += ww[u + 0] * x;}s.y += ww[t + 3] * kk[t + 1];s.z += ww[t + 2] * kk[t + 1];s.z += ww[t + 3] * kk[t + 2];s.w += ww[t + 1] * kk[t + 1];s.w += ww[t + 2] * kk[t + 2];s.w += ww[t + 3] * kk[t + 3];((float4 *)x)[i * (T >> 2) + tt] = s;
}
可见,四合一还有额外的好处:循环可以重用 k[u],进一步减少了内存读取。
4. 继续将 B 分组整合
@有了琦琦的棍子(//www.zhihu.com/people/581a2fcdf24763fbb9ec2900065986b4)指出,之前我们每个 thread 都只处理一行 T,但是,注意到 w 在 B 向是共享的,所以应该每个 thread 处理多个 w 重复的行。
我实验了代码,的确可以将正向速度提速几倍,速度提升到 3.4 毫秒。而对于反向,只有 grad_K 可利用重复的 w,所以效应弱一些。
dim3 gridDim(1, B * C / BF);
dim3 blockDim(T >> 2);
kernel_forward<<<gridDim, blockDim>>>(w, k, x, eps, B, C, T);
正向可以用 BF = 8,即,每个 thread 处理 8 个 B。反向似乎只适合 thread 处理 2 个 B。
// require T <= Tmax, T % 4 == 0, B % BF == 0, B % BB === 0 (Tmax and BF and BB are passed by compiler)#define F4(A, B) ((float4 *)(A))[(B) >> 2]template <typename F>
__global__ void kernel_forward(const F *__restrict__ const __w, const F *__restrict__ const __k, F *__restrict__ const x,const F eps, const int B, const int C, const int T) {const int i = blockIdx.y;const int ij = (B * C) / BF;const int t = threadIdx.x << 2;__shared__ F ww[Tmax];__shared__ F kk[Tmax * BF];F4(ww, t) = F4(__w, t + T * (i % C));#pragma unrollfor (int j = 0; j < BF; j++) {F4(kk, t + Tmax * j) = F4(__k, t + T * (i + ij * j));}__syncthreads();float4 s[BF];#pragma unrollfor (int j = 0; j < BF; j++) {s[j] = {eps, eps, eps, eps};}const F *__restrict__ const w = ww + T - t - 4;for (int u = 0; u <= t; u++) {#pragma unrollfor (int j = 0; j < BF; j++) {const F x = kk[u + Tmax * j];s[j].x += w[u + 3] * x;s[j].y += w[u + 2] * x;s[j].z += w[u + 1] * x;s[j].w += w[u + 0] * x;}}#pragma unrollfor (int j = 0; j < BF; j++) {const F *__restrict__ const k = kk + Tmax * j;s[j].y += w[t + 3] * k[t + 1];s[j].z += w[t + 2] * k[t + 1];s[j].z += w[t + 3] * k[t + 2];s[j].w += w[t + 1] * k[t + 1];s[j].w += w[t + 2] * k[t + 2];s[j].w += w[t + 3] * k[t + 3];F4(x, t + T * (i + ij * j)) = s[j];}
}
5. 对齐每个 thread 的任务长度
@有了琦琦的棍子同时指出,目前每个 thread 的任务长度不同(因为 t 不同),因此会降低效率(快的 thread 会等慢的 thread)。我预计这个改动可以让速度再提升一倍,稍后加入。
6. 进一步优化
下面怎么进一步优化?还请各位 CUDA 高手指导。可以先看看 B=32,C=768,T=768 的情况,多谢多谢。
本文的代码,在 win10 和 linux 均可直接编译运行:
https://github.com/BlinkDL/RWKV-CUDAgithub.com/BlinkDL/RWKV-CUDA
本文仅做学术分享,如有侵权,请联系删文。
3D视觉精品课程推荐:
1.面向自动驾驶领域的多传感器数据融合技术
2.面向自动驾驶领域的3D点云目标检测全栈学习路线!(单模态+多模态/数据+代码)
3.彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进
4.国内首个面向工业级实战的点云处理课程
5.激光-视觉-IMU-GPS融合SLAM算法梳理和代码讲解
6.彻底搞懂视觉-惯性SLAM:基于VINS-Fusion正式开课啦
7.彻底搞懂基于LOAM框架的3D激光SLAM: 源码剖析到算法优化
8.彻底剖析室内、室外激光SLAM关键算法原理、代码和实战(cartographer+LOAM +LIO-SAM)
9.从零搭建一套结构光3D重建系统[理论+源码+实践]
10.单目深度估计方法:算法梳理与代码实现
11.自动驾驶中的深度学习模型部署实战
12.相机模型与标定(单目+双目+鱼眼)
13.重磅!四旋翼飞行器:算法与实战
14.ROS2从入门到精通:理论与实战
重磅!3DCVer-学术论文写作投稿 交流群已成立
扫码添加小助手微信,可申请加入3D视觉工坊-学术论文写作与投稿 微信交流群,旨在交流顶会、顶刊、SCI、EI等写作与投稿事宜。
同时也可申请加入我们的细分方向交流群,目前主要有3D视觉、CV&深度学习、SLAM、三维重建、点云后处理、自动驾驶、多传感器融合、CV入门、三维测量、VR/AR、3D人脸识别、医疗影像、缺陷检测、行人重识别、目标跟踪、视觉产品落地、视觉竞赛、车牌识别、硬件选型、学术交流、求职交流、ORB-SLAM系列源码交流、深度估计等微信群。
一定要备注:研究方向+学校/公司+昵称,例如:”3D视觉 + 上海交大 + 静静“。请按照格式备注,可快速被通过且邀请进群。原创投稿也请联系。
▲长按加微信群或投稿
▲长按关注公众号
3D视觉从入门到精通知识星球:针对3D视觉领域的视频课程(三维重建系列、三维点云系列、结构光系列、手眼标定、相机标定、激光/视觉SLAM、自动驾驶等)、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答五个方面进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业为一体的铁杆粉丝聚集区,近4000星球成员为创造更好的AI世界共同进步,知识星球入口:
学习3D视觉核心技术,扫描查看介绍,3天内无条件退款
圈里有高质量教程资料、答疑解惑、助你高效解决问题
觉得有用,麻烦给个赞和在看~
实例:手写 CUDA 算子,让 Pytorch 提速 20 倍相关推荐
- oracle最快访问行,平均提速20倍!Oracle 12c In-Memory最佳实践
来源:三墩IT人 订阅号 作者:唐小丹(浙江移动数据库管理员) 马力行(新炬网络数据库工程师) 一.IM特性简介 Oracle 12.1.0.2 引入了In-Memory Column Store(以 ...
- 基于卷积神经网络(cnn)的手写数字识别(PyTorch)
目录 1.1 卷积神经网络简介 1.2 神经网络 1.2.1 神经元模型 1.2.2 神经网络模型 1.3 卷积神经网络 1.3.1卷积的概念 1.3.2 卷积的计算过程 1.3.3 感受野 1.3. ...
- 手写中文数字识别PyTorch实现(全连接卷积神经网络)
尝试一下手写汉字的数字识别,分别采用全连接神经网络和卷积神经网络 这次准备的数据集有15000张图片,每张图片大小为64*64 训练集10500张图片,测试集4500张图片 全连接神经网络 我们先用上 ...
- 4.1 keras基础实例 手写数字识别
1)手写数据集 手写数据集是深度学习中,最基础应用最广泛的数据集. 手写数据集内置在keras中 import keras from keras import layers import matplo ...
- 提速20倍!谷歌AI发布TensorFlow 3D,智能汽车场景亲测好用
来源丨新智元 编辑丨极市平台 导读 Google AI发布了TensorFlow 3D,将3D深度学习能力引入TensorFlow,加入3D稀疏卷积网络,在Waymo Open数据集上的实验表明,这种 ...
- 提速20倍!谷歌AI发布TensorFlow 3D
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 来源丨新智元 编辑丨极市平台 随着自动驾驶汽车与机器人的深入发展,激 ...
- Multi-class Classication (多分类问题)实例--手写数字识别
本实例整理自斯坦福机器学习课程课后练习ex3 本例是对一个手写体的数据集(0-9)进行分类,也就是对原有的数据集进行训练,然后给定一个手写体,识别该手写体是数字几.其分类思想就是之前Andrew Ng ...
- caffe(4):mnist实例---手写数字识别
深度学习的第一个实例一般都是mnist,只要这个例子完全弄懂了,其它的就是举一反三的事了.由于篇幅原因,本文不具体介绍配置文件里面每个参数的具体函义,如果想弄明白的,请参看我以前的博文: 数据层及参数 ...
- Hinton NIPS2017 手写识别实现 TensorFlow \ Pytorch \ Keras
10月26日,深度学习元老Hinton的NIPS2017 Capsule论文<Dynamic Routing Between Capsules>终于在arxiv上发表.今天相关关于这篇论文 ...
最新文章
- 人工智能算法模型必会之——正则化方法综述
- POJ - 3476 A Game with Colored Balls---优先队列+链表(用数组模拟)
- 基于第三方开源库的OPC服务器开发指南(2)——LightOPC的编译及部署
- C++11: chrono
- 【LeetCode从零单排】No96	Unique Binary Search Trees
- 《数据科学与大数据分析——数据的发现 分析 可视化与表示》一2.3 第2阶段:数据准备...
- linux内核体系学习路径_Linux内核分析(一)linux体系简介|内核源码简介|内核配置编译安装...
- 39.django的ORM模型
- 【Transformer】Transformer中16个注意力头一定要比1个注意力头效果好吗?
- LVS部分调度算法的适应场景分析
- matlab投资组合权重,Matlab做投资组合最优化
- SpringBoot全局异常处理(三十)
- 手机设计软件有哪些(合集)
- 2019年4月27号,下雨杂谈
- Java初始化大乱斗
- 如何看待家长培训课?
- Numpy 计算男女生各科成绩统计指标
- 轻轻松松实现本地和云主机之间的文件上传下载
- 为什么类只能单继承,而接口可以多继承?
- 【含源码】用python做游戏有多简单好玩