很难写一个kernel就能同时在transpose的所有场景都最优,归纳transpose的几种常见场景可以针对性优化。这里只列出了transpose对轴变换的几种情况,没有考虑shape大小。因此在这几种场景上还应该考虑转置的轴 shape大小针对性优化。

场景1: batch 2D,perm:021

二维,三维,或者更高维的tensor,交换最内层的两个维度。这些都可以统一为batch 2D。对于2D矩阵的转置相当于batch=1,大于三维的tensor可以把两个最内层维度以外的所有维度合并看成一个维度。

一种大的021矩阵转置变换为小的矩阵转置方法:

比如输入是[4096,4096]的大矩阵进行转置,但是每个线程读取输入同一行的相邻位置,但是写出时写的是同一列的相邻位置,对于写的话,同一列跨的行长度太长对写回cache命中不友好,可以考虑降低跨的行长度来优化性能。

可以考虑把大的转置拆分为几个小的转置:

例如把MN转置为NM,考虑M和N都可以拆分为更小的维度:M拆分为M1M0,N拆分为N1N0

那么问题有MN转置为NM变为M1M0N1N0转置为N1N0M1M0

假设一次只能交换两个轴,那么可以通过如下三个步骤来实现:

M1M0N1N0->N1M1M0N0->N1M1N0M0->N1N0M1M0

其中第一步而第三部每次转置的都是一个tensor(例如下面的0213转置场景),这个一般都能实现的非常高效。

而第二步单元素的转置的维度从MN减小到了M0N0,更有利于缓存的利用。例如[4096,4096]可以拆分为[64,64,64,64]的大小来进行转置。

该方法可视化展示如下:相当于把大矩阵拆分为小矩阵,每个小矩阵独立转置,再把小矩阵看成一个整体转置一下。

要注意这个示意图中展示的是M1M0N1N0->M1N1M0N0->M1N1N0M0->N1M1N0M0,离最终要的还差一步。

场景2:0213

其特点是内部的两个相邻的维度进行交互,不包含最内层的一个或多个维度。跟上面一样,相邻不交换的维度可以合并看成一个整体,最外层的维度不足可以补1。

这个场景perm=[2, 0, 3, 1, 4],看上去同时转置了多个axes,但是由于shape元素1的特殊性,可以squeeze掉, 因此可以转换为[784, 3, 4, 12]到[3, 4, 784, 12]的transpose,可以使用0213的方法来解决。

删除transpose shape为1的算法

perm = [2, 0, 3, 1, 4]
in_shape = [1, 784, 1, 4, 12]rm_axes = []
for idx, elem in enumerate(in_shape):if elem == 1:rm_axes.append(idx)print("rm_axes:", rm_axes)def remove_axis(in_shape, perm, rm_axis):del in_shape[rm_axis]perm_rm_idx = -1for idx, elem in enumerate(perm):if elem == rm_axis:perm_rm_idx = idxif elem > rm_axis:perm[idx] = perm[idx]-1del perm[perm_rm_idx]for rm_axis in reversed(rm_axes):remove_axis(in_shape, perm, rm_axis)print("perm:", perm)
print("in_shape:", in_shape)

场景3:交换两个相邻的axes,但是其中一个axis对应的shape是1

这个场景并不需要transpose,只需要reshape即可。

使用上面的删除transpose shape为1的算法后,这种transpose的perm会变成[0,1,2,3,...] 可以非常简单的判断这个transpose实际上不需要进行任何操作,直接删除即可。

场景4:交换多个axes,但是部分perm是相邻的

这里perm=[1, 2, 0], 看上去交换了3个axes,实际上1x64这两个是一起交换的,可以合并成一个维度,这个问题就变成了上面的场景1。因此解决方案可以是合并一起变换的相邻轴,从而把问题简化。

场景5:其他

当然还有少量场景无法使用上面的方法来解决,例如这里输入shape第一个维度不是1的情况。

transpose算子优化的几种常见场景相关推荐

  1. MySQL 性能优化:8 种常见 SQL 错误用法!

    声明:转载自 MySQL 性能优化:8 种常见 SQL 错误用法! 1.LIMIT 语句 分页查询是最常用的场景之一,但也通常也是最容易出问题的地方.比如对于下面简单的语句,一般 DBA 想到的办法是 ...

  2. MySQL索引失效的几种常见场景

    前言 我们在使用MySQL查询数据的时候,总会遇见没有正确使用到索引的情况. 这里我们列举几种常见的,搜索条件使用了索引列却没有走索引的场景. (以下测试均在MySQL8.0.28中完成,且所有数据均 ...

  3. 数据库优化:8 种常见的SQL错误用法

    作者 | db匠 来源 | http://yq.aliyun.com/articles/72501 前言 MySQL在2016年仍然保持强劲的数据库流行度增长趋势.越来越多的客户将自己的应用建立在My ...

  4. mysql内链查询写法_网站内链优化与几种常见的结构优化方法

    在互联网的海洋沉淀了这么多年,经常会看见很多新人学着学着就放弃了,甚至有人还说做seo还不如去工地上班.真的是这样吗?其实不是这样的,很多行业有进人进来,就有老人离开,这属于自然规律.许多的站长做着做 ...

  5. 8种应用场景!嵌入式BI如何快速提升SaaS数据分析功能

    新一代信息技术的突飞猛进,给我们的工作方式带来了前所未有改变.与时俱进,拥抱数字化,远程办公正被越来越多的企业所热捧,成为当下最受欢迎的一种工作模式.而远程办公仅仅是SaaS应用的冰山一角而已.Saa ...

  6. 千万不要这样写代码!9种常见的OOM场景演示

    <Java虚拟机规范>里规定除了程序计数器外,虚拟机内存的其他几个运行时区域都有发生 OutOfMemoryError 异常的可能,我们本文就来演示一下这些错误的使用场景. 一. Stac ...

  7. es elasticsearch 几种常见查询场景 二次分组 java读取es的查询json文件

    大家好,我是烤鸭: es中几种常见的查询场景,使用java读取es的json文件进行查询. es 中文使用手册. https://www.elastic.co/guide/cn/elasticsear ...

  8. table表头固定4种方法_在常见的3种工资条场景中,教你4种批量打印工资条的方法...

    私信回复关键词[福利]~ 获取丰富办公资源,助你高效办公早下班! 打印工资条估计是财务老师的痛,要把一行行的数据,变成一条条的工资条. 数据很多,表头很复杂. 一个个复制粘贴?那是不可能的! 那怎么办 ...

  9. 几种常见的线程池及使用场景

    为什么要使用线程池? 创建线程和销毁线程的花销是比较大的,这些时间有可能比处理业务的时间还要长.这样频繁的创建线程和销毁线程,再加上业务工作线程,消耗系统资源的时间,可能导致系统资源不足.(我们可以把 ...

最新文章

  1. 一文介绍机器学习中的三种特征选择方法
  2. 哈佛新冠论文用百度写,川普很满意,英国媒体BBC都看不下去:好歹搜索方法要用对呀!...
  3. 千万级通用的分页存储过程
  4. Apache RocketMQ Meetup深圳首秀 引开源爱好者追捧
  5. 查找字符位置_Excel中查找字符第N次出现的位置信息,换个思路其实很简单
  6. 服务器控件开发——组合控件(5)
  7. 从Java到Kotlin(五)
  8. spark streaming 消费 kafka入门采坑解决过程
  9. 负载均衡沙龙活动第二期现场问答汇集
  10. max7219c语言,51单片机+MAX7219数码管显示C程序
  11. 2017年读书计划(一)
  12. 1.2-Nginx编译安装
  13. react-native 报错 RawText must be wrapped in an explicit Text component
  14. python使用ip代理抓取网页
  15. kepware modbus
  16. SPI操作flash MX25L64读写数据
  17. Java程序崩溃原因分析:错误日志分析及解决(Cannot allocate memory)
  18. 小猫钓鱼游戏(c++实现)
  19. Qt编写的SMTP客户端(库)
  20. Dubbo 使用 kryo 序列化

热门文章

  1. 树模型系列之集成树(Random Forest、Adaboost、GBDT)
  2. InsideSherpa虚拟实习-数据分析
  3. linux上10G日志怎么看,10G日志报错Heap size 2119K exceeds notification threshold (2048K)
  4. Java之@Autowired再分析
  5. 百乐嘉利宝在深圳设立新办事处及巧克力学院;药明巨诺在港交所主板上市 | 美通企业日报...
  6. Springboot @Validated参数校验
  7. 华为OD机试用Python实现 -【广播服务器】
  8. LeetCode 659. 分割数组为连续子序列
  9. 【PaddlePaddle】GAN基础
  10. C#和C++ 库的相互引用