作者 | Marat Dukhan from Google Research

译者 | 凯隐

责编 | Jane

出品 | AI科技大本营(ID: rgznai100)

【导读】本文介绍的内容主要聚焦Google 的一项最新工作:改变基于 GEMM 实现的 CNN底层算法提出的新方法。通用矩阵乘法(General Matrix Multiply, GEMM)是广泛用于线性代数、机器学习、统计学等各个领域的常见底层算法,其实现了基本的矩阵与矩阵相乘的功能,因此算法效率直接决定了所有上层模型性能,目前主流的卷积算法都是基于GEMM来实现的。来自谷歌的Peter Vajda在ECV2019中提出了一种全新的间接卷积算法,用于改进GEMM在实现卷积操作时存在的一些缺点,进而提升计算效率。

通用矩阵乘法

GEMM是基础线性代数子程序库(Basic Linear Algebra Subprograms, BLAS)中的一个函数。BLAS提供了实现矩阵和向量基本运算的函数,最早于1979年由C.L.LAWSON提出。BLAS的发展大致可以分为三个阶段(levels)的历程,这和函数定义,出版顺序,以及算法中多项式的阶数以及复杂性有关,第一阶段只包含与向量(vector)有关的运算,第二阶段添加了向量与矩阵进行运算的操作,第三阶段添加了矩阵与矩阵之间的运算,前两个阶段的BLAS都是用于向量处理器的,而第三阶段适用于矩阵处理器,所以BLAS的发展和硬件的发展密不可分。GEMM属于第三阶段的算法,正式公布于1990年,其迭代更新形式为:

其中A和B可以进行转置或hermitian共轭转置,而A、B和C都可以被忽略(be strided),因此实际上这个公式就表示了任意矩阵之间所有可能的加法和乘法组合,例如最基本的A*B,可以将α置1,C置为全0矩阵即可,这也是其通用性的表现。

由于矩阵乘法相对于向量-向量乘法以及向量-矩阵乘法,有更低的时间复杂度,效率更高,因此其广泛用于许多科学任务中,与之相关的GEMM算法成为了目前BLAS设计者的主要优化对象。例如可以将A和B分解为分块矩阵,使得GEMM可以递归实现。有关GEMM的详细信息可以参见[1][2][3]。如何对GEMM进行优化,是BLAS相关工作的研究热点。

基于 GEMM 的卷积算法及其缺点

卷积神经网络(CNN)在CV问题中的表现很出色,有多种在算法层面对齐进行实现的方法:直接卷积算法,采用7层循环,快速卷积算法,利用傅里叶变换来进行卷积,以及基于GEMM的卷积算法。

通过将卷积操作用矩阵乘法来代替,进而使用GEMM算法来间接进行卷积操作,这使得卷积操作可以在任何包含GEMM的平台上进行,并且受益于矩阵乘法的高效性,任何针对GEMM的改进和研究都能有助于卷积运算效率的提升,从而提高模型的运算速度,因此目前大部分主流的神经网络框架,例如Tensorflow、Pytorch和Caffe都使用基于GEMM的方法来在底层代码中实现卷积。

具体的,基于GEMM的卷积方法需要借助于 im2col或im2row buffer来内存转换,使得数据格式满足GEMM算法的输入要求,从而将卷积操作转化为GEMM操作,然而这个转换过程是一个计算开销和内存开销都比较大的过程,特别是在输入channel数较小时,这个过程会在整个卷积过程中占有很大的比例。简言之,就是在卷积过程中,每个pixel都会被多次重复的转换,这是不必要的计算开销。因此有许多工作都在对这一过程进行改进,本文工作提出了一种改进算法——间接卷积算法(Indirect Convolution algorithm),主要有以下两个优点:

1、去掉了im2row的转换过程,这使得算法性能有了巨大的提升(up to 62%)。

2、用了一个更小的indirection buffer来代替原来的im2row buffer。不同于im2row buffer的大小随着输入channel数线性增加,indirection buffer没有这个特性,因此indirection buffer的内存占用特性非常有利于输入channel数较多时的卷积操作。

间接卷积算法

原始的GEMM通过如下计算来不断迭代进行矩阵运算操作并输出矩阵:

其中A是输入张量,B是一个常量滤波器,C是输出矩阵,在传统的im2col+GEMM算法中,通常α=1而β=0,原始GEMM操作示意图如下:

图1 原始GEMM操作

其中 im2col buffer 代表矩阵A,filter tensor 代表矩阵B,A和B的乘积就是输出copy表示将输入的张量展开为一个二维矩阵,也就是im2col buffer。可以看到buffer的每一行则是由固定个数(步长)的pixel展开成一维的向量组成的,这些pixel都在原始tensor中的一个patch内,在经过和filter tensor相乘后,由于矩阵行列相乘得到一个元素,因此这几个pixel的信息都被整合成了一个值,也就是对他们进行了卷积操作。最后在输出矩阵C中,行数rows代表输出的像素点个数,columns代表输出的channel数。可以看到buffer的columns是和输入channel数有关的。

为了降低buffer带来的开销,作者提出了一种间接矩阵乘法的思想,不把输入的tensor直接展开并存储在buffer中,而只是在buffer中存放每个pixel在input tensor的坐标,也就是从存数据变成了存地址(类似于指针pointer思想),这样不管channel数有多少,存的地址信息始终只有二维,极大的降低了buffer的计算和存储开销,如下图:

图2 indirect convolution

当然,由于buffer中存的是地址信息,因此不能直接和filter做矩阵乘法,所以就只能通过在buffer的行间进行循环,根据该行的pointer找到对应的输入数据,再将输入数据与kernel相乘,并与之前循环的结果拼接起来,从而间接的实现矩阵乘法,因此叫做indirection buffer。

对于不同的卷积步长,只需要将不同步长对应的卷积patch位置确定即可。而对于padding策略,将指向填充位置的pointer对应的输入pixel的向量值全部设置为0。

间接卷积算法的缺点

间接卷积算法作为GEMM-BASED CNN算法的一种改进,能极大的提升计算效率,但是存在以下几个限制:

1. 这个算法是为NHWC layout设计的,也就是说应用范围比较窄,不能和目前的主流方法相比。

2. 算法适用于前向传播中的卷积操作,而在反向传播中作用不大,不及基于col2im和row2im的算法。

3. 具有和GEMM相同的缺点,在深度小卷积核的卷积操作中效率并不好。

实验测试结果

Efficient Deep Learning for Computer Vision主要聚焦于如何将深度学习部署到移动设备上,因此本文的工作主要在移动设备和移动芯片上进行测试,结果如下:

可以看到一旦步长增加,那么Indirect convolution带来的性能提升就会明显下降,这是因为步长越大,在原始的GEMM算法中重复计算的量就会减小,因此原始GEMM的性能本身就会提升,而indirect convolution并不受益于步长增加。

延伸介绍:Efficient Deep Learning for Computer Vision Workshop

目前CV方向主流的研究都着重于如何提升算法和模型性能,并不是太注重模型速度,运算时间,内存消耗等与运算资源有关的性能指标,这不利于将模型部署在类似于移动设备等计算资源有限的平台上。CVPR的这个workshop主要关注评估模型的计算开销和存储开销有关的指标,以及如何将其应用到移动设备上,相关团队隶属于谷歌研究院,详见[4]。

参考资料:

[1] https://spatial-lang.org/gemm

[2] https://en.wikipedia.org/wiki/Vector_processor

[3] https://petewarden.com/2015/04/20/why-gemm-is-at-the-heart-of-deep-learning/

[4] https://sites.google.com/view/ecv2019/home

原文链接:

https://arxiv.org/abs/1907.02129

(*本文为 AI科技大本营编译文章,转载请联系1092722531

精彩推荐

“只讲技术,拒绝空谈!”2019 AI开发者大会将于9月6日-7日在北京举行,这一届AI开发者大会有哪些亮点?一线公司的大牛们都在关注什么?AI行业的风向是什么?2019 AI开发者大会,倾听大牛分享,聚焦技术实践,和万千开发者共成长。

目前,大会盲订票限量发售中~扫码购票,领先一步!

推荐阅读

  • 码农们的「血与泪」:新零售「全渠道中台」的前世今身

  • 腾讯拥抱开源:首次公布开源路线图,技术研发向共享、复用和开源迈进

  • 混合云发展之路:前景广阔,巨头混战

  • 干货 | Python后台开发的高并发场景优化解决方案

  • 5G 浪潮来袭!程序员在风口中有何机遇?

  • 这次又坑多少人? 深度解析 Dash 钱包"关键"漏洞!

  • 壕!两万多名腾讯员工获 51 万港元股票奖励

你点的每个“在看”,我都认真当成了喜欢

基于GEMM实现的CNN底层算法被改?Google提出全新间接卷积算法相关推荐

  1. em算法 实例 正态分布_Petuum提出序列生成学习算法通用框架

    近日,来自人工智能创业公司 Petuum 的研究人员发表论文,提出序列生成学习算法的通用框架--广义的熵正则化策略优化框架(Generalized Entropy-Regularized Policy ...

  2. ICLR 2022 | 绝艺学会打麻将,腾讯AI Lab提出全新策略优化算法ACH

    感谢阅读腾讯AI Lab微信号第144篇文章.本文介绍「绝艺」在二人麻将游戏环境取得的进展,相关算法及benchmark已开源,论文被机器学习国际顶会 ICLR 2022 接收. 「绝艺」是腾讯AI ...

  3. 【Nature重磅】OpenAI科学家提出全新强化学习算法,推动AI向智能体进化

    深度强化学习实验室 官网:http://www.neurondance.com/ 论坛:http://deeprl.neurondance.com/ 编辑:DeepRL 近年来,人工智能(AI)在强化 ...

  4. 【清华伯克利】提出全新算法RPG,通过奖励随机化发现多智能体游戏中多样性策略行为。

    深度强化学习实验室 官网:http://www.neurondance.com/ 论坛:http://deeprl.neurondance.com/ 作者:本文转载自机器之心 编辑.排版:DeepRL ...

  5. CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别

    CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别 目录 基于tensorflow框架采用CNN(改进的AlexNet, ...

  6. CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用

    CV之IC之AlexNet:基于tensorflow框架采用CNN卷积神经网络算法(改进的AlexNet,训练/评估/推理)实现猫狗分类识别案例应用 目录 基于tensorflow框架采用CNN(改进 ...

  7. 大话卷积神经网络CNN,小白也能看懂的深度学习算法教程,全程干货建议收藏!...

    来源 | 程序员管小亮 本文创作的主要目的,是对时下最火最流行的深度学习算法的基础知识做一个简介,作者看过许多教程,感觉对小白不是特别友好,尤其是在踩过好多坑之后,于是便有了写这篇文章的想法. 由于文 ...

  8. CNN可视化最新研究方法进展(附结构、算法)

    译者 | reason_W 责编 | 明 明 出品 | AI科技大本营(公众号ID:rgznai100) [AI科技大本营导读]深度学习一直被看做是一个难以解释的"黑匣子".一方面 ...

  9. 基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络

    基于pytorch使用实现CNN 如何使用pytorch构建CNN卷积神经网络 所用工具 文件结构: 数据: 代码: 结果: 改进思路 拓展 本文是一个基于pytorch使用CNN在生物信息学上进行位 ...

最新文章

  1. api可以主动采集用户数据吗_数据埋点采集的那些事儿
  2. RMAN 备份与恢复 实例
  3. VS2010中使用JSONCPP方法
  4. 金山“云”上音乐节 —— 一文带你看懂如何支持一场线上演出
  5. java客户端程序用什么自动化测试_五大Java自动化测试框架
  6. 【AI视野·今日NLP 自然语言处理论文速览 第十五期】Fri, 25 Jun 2021
  7. linetv_linetv台湾版官方下载|line tv安卓版下载_v1.0.10_9ht安卓下载
  8. 分屏 投影显示 PPT
  9. ubuntu20关闭自动更新
  10. ecshop的dwt模板文件
  11. 解码方法( dfs | dp )
  12. 想看懂资管行业?不清楚有哪些资管产品怎么行!
  13. 多核与多个CPU啥区别
  14. pycharm在ubuntu中不能输入中文的问题
  15. RAID中有一块硬盘离线的情况下应该对其采取强制上线操作么?
  16. rcu锁原理以及rcu example学习
  17. java判断一天是星期几_java判断今天星期几
  18. Scrapy中的crawlspider爬虫
  19. 公式推导出创意,阿里妈妈“AI智能文案”通过图灵测试!
  20. 运营商线路细分_国际电信运营商电信市场细分的比较分析

热门文章

  1. Linux Shell 脚本限制ssh最大用户登录数
  2. STL学习系列九:Map和multimap容器
  3. 开源硬件:极客们的伟大理想
  4. 2009-徘徊-开场白
  5. Java中集合类型线程安全性
  6. Mysql将SQL查询结果以字符串形式返回
  7. SpringCloud 面试题,最新SpringCloud 面试题,2020 SpringCloud 面试题
  8. 使用c3p0对mysql进行增删改查_c3p0连接池连接数据库 并增删改查
  9. python 抛出异常raise
  10. CF 1093 E. Intersection of Permutations