作者 | 永远在你身后

转载自知乎

【导读】einsum 全称 Einstein summation convention(爱因斯坦求和约定),又称为爱因斯坦标记法,是爱因斯坦 1916 年提出的一种标记约定,本文主要介绍了einsum 的应用。

简单的说,应用 einsum 就是省去求和式中的求和符号,例如下面的公式:

以 einsum 的写法就是:

后者将  符号给省去了,显得更加简洁;再比如:

上面两个栗子换成 einsum 的写法就变成:

在实现一些算法时,数学表达式已经求出来了,需要将之转换为代码实现,简单的一些还好,有时碰到例如矩阵转置、矩阵乘法、求迹、张量乘法、数组求和等等,若是以分别以 transopse、sum、trace、tensordot 等函数实现的话,不但复杂,还容易出错。

现在,这些问题你统统可以一个函数搞定,没错,就是 einsum,einsum 函数就是根据上面的标记法实现的一种函数,可以根据给定的表达式进行运算,可以替代但不限于以下函数:

矩阵求迹:trace求矩阵对角线:diag张量(沿轴)求和:sum张量转置:transopose矩阵乘法:dot张量乘法:tensordot向量内积:inner外积:outer

该函数在 numpy、tensorflow、pytorch 上都有实现,用法基本一样,定义如下:

equation 是字符串的表达式,operands 是操作数,是一个元组参数,并不是只能有两个,所以只要是能够通过 einsum 标记法表示的乘法求和公式,都可以用一个 einsum 解决,下面以 numpy 举几个栗子:

# 沿轴计算张量元素之和:
c = a.sum(axis=0)

上面的以 sum 函数的实现代码,设 为三维张量,上面代码用公式来表达的话就是:

换成 einsum 标记法:

然后根据此式使用 einsum 函数实现等价功能:

c = np.einsum('ijk->jk', a)
# 作用与 c = a.sum(axis=0) 一样

更进一步的,如果  不止是三维,可以将下标  换成省略号,以表示剩下的所有维度:

这种写法 pytorch 与 tensorflow 同样支持,如果不是很理解的话,可以查看其对应的公式:

# 矩阵乘法
c = np.dot(a, b)

矩阵乘法的公式为:

然后是 einsum 对应的实现:

最后再举一个张量乘法栗子:

# 张量乘法
c = np.tensordot(a, b, ([0, 1], [0, 1]))

如果  是三维的,对应的公式为:

对应的 einsum 实现:

下面以 numpy 做一下测试,对比 einsum 与各种函数的速度,这里使用 python 内建的 timeit 模块进行时间测试,先测试(四维)两张量相乘然后求所有元素之和,对应的公式为:

然后是测试代码:

from timeit import Timer
import numpy as np  # 定义两个全局变量
a = np.random.rand(64, 128, 128, 64)
b = np.random.rand(64, 128, 128, 64)   # 定义使用einsum与sum的函数
def einsum():   temp = np.einsum('ijkl,ijkl->', a, b) def npsum():    temp = (a * b).sum()   # 打印运行时间
print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))
print("npsum cost:", Timer("npsum()", "from __main__ import npsum").timeit(20))

上面 Timer 是 timeit 模块内的一个类

Timer(stmt, setup).timeit(number)  # stmt: 要测试的语句  # setup: 传入stmt的运行环境,比如stmt中要导入的模块等。 # 可以写一行语句,也可以写多行语句,写多行语句时要用分号;隔开语句 # number: 执行次数

将两个函数各执行 20 遍,最后的结果为,单位为秒:

einsum cost: 1.5560735
npsum cost: 8.0874927

可以看到,einsum 比 sum 快了几乎一个量级,接下来测试单个张量求和:

将上面的代码改一下:

def einsum():   temp = np.einsum('ijkl->', a) def npsum():    temp = a.sum()

相应的运行时间为:

einsum cost: 3.2716003
npsum cost: 6.7865246

还是 einsum 更快,所以哪怕是单个张量求和,numpy 上也可以用 einsum 替代,同样,求均值(mean)、方差(var)、标准差(std)也是一样。

接下来测试 einsum 与 dot 函数,首先列一下矩阵乘法的公式以以及 einsum表达式:

然后是测试代码:

a = np.random.rand(2024, 2024)
b = np.random.rand(2024, 2024) # einsum与dot比较
def einsum():   res = np.einsum('ik,kj->ij', a, b)    def dot():  res = np.dot(a, b) print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(20))
print("dot cost:", Timer("dot()", "from __main__ import dot").timeit(20)) # einsum cost: 80.2403851
# dot cost: 2.0842243

这就很尴尬了,比 dot 慢了 40 倍(并且差距随着矩阵规模的平方增加),这还怎么打天下?不过在 numpy 的实现里,einsum 是可以进行优化的,去掉不必要的中间结果,减少不必要的转置、变形等等,可以提升很大的性能,将 einsum 的实现改一下:

def einsum():  res = np.einsum('ik,kj->ij', a, b, optimize=True)

加了一个参数 optimize=True,官方文档上该参数是可选参数,接受4个值:

optimize 默认为 False,如果设为 True,这默认选择‘greedy(贪心)’方式,再看看速度:

einsum cost: 2.0330937
dot cost: 1.9866218

可以看到,通过优化,虽然还是稍慢一些,但是 einsum 的速度与 dot 达到了一个量级;不过 numpy 官方手册上有个 einsum_path,说是可以进一步提升速度,但是我在自己电脑上(i7-9750H)测试效果并不稳定,这里简单的介绍一下该函数的用法为:

path = np.einsum_path('ik,kj->ij', a, b)[0]
np.einsum('ik,kj->ij', a, b, optimize=path)

einsum_path 返回一个 einsum 可使用的优化路径列表,一般使用第一个优化路径;另外,optimize 及 einsum_path 函数只有 numpy 实现了, tensorflow 和 pytorch 上至少现在没有。

最后,再测试 einsum 与另一个常用的函数 tensordot,首先定义两个四维张量的及 tensordot 函数:

a = np.random.rand(128, 128, 64, 64)
b = np.random.rand(128, 128, 64, 64)   def tensordot():    res = np.tensordot(a, b, ([0, 1], [0, 1]))

该实现对应的公式为:

所以 einsum 函数的实现为:

def einsum():    res = np.einsum('ijkl,ijmn->klmn', a, b, optimize=True)

tensordot 也是链接到 BLAS 实现的函数,所以不加 optimize 肯定比不了,最后结果为:

print("einsum cost:", Timer("einsum()", "from __main__ import einsum").timeit(1))
print("tensordot cost:", Timer("tensordot()", "from __main__ import tensordot").timeit(1))    # einsum cost: 4.2361331
# tensordot cost: 4.2580409

测试了 10 多次,基本上速度一样,einsum 表现好一点的;不过说是一个函数打天下,肯定是做不到的,还有一些数组的分割、合并、指数、对数等功能没法实现,需要使用别的函数,其他的基本都可以用 einsum 来实现,简单而又高效。

经过进一步测试发现,优化反而出现速度降低的情况,例如:

def einsum():    temp = einsum('...->', a, optimize=True) def test(): temp = a.sum()

上面两中对数组求和的方法,当a是一维向量时,或者 a 是多维但是规模很小是,优化的 einsum 反而更慢,但是去掉 optimize 参数后表现比内置的 sum函数稍好,我认为优化是有一个固定的成本。

还有一个坑需要注意的是,有些情况的省略号不加 optimize 会报错,就拿上面的栗子而言:

np.einsum('...->', a, optimize=True)   # 正常运行
np.einsum('...->', a)   # 报错

很无奈,试了很多次,不加 optimize 就是会报错,但是并不是所有的省略号写法都需要加 optimize ,例如:

使用省略号实现上面两个公式并不需要加 optimize ,能够正常运行

np.einsum('i...->...', a)   # 正常
np.einsum('...,...->...', a, b)   # 正常

但是如果碰到下面的公式:

上式表示将 a 除第一个维度之外,剩下的维度全部累加,这种实现就必须要加 optimize。

再举一个栗子:

c = (a * b).sum()
# 如果不知道a, b的维数,使用einsum实现上面的功能也必须要加optimize
c = einsum('...,...->', a, b, optimize=True)

总结一下,在计算量很小时,优化因为有一定的成本,所以速度会慢一些;但是,既然计算量小,慢一点又怎样呢,而且使用优化之后,可以更加肆意的使用省略号写表达式,变量的维数也不用考虑了,所以建议无脑使用优化。

原文链接:

https://zhuanlan.zhihu.com/p/71639781

(*本文为AI科技大本营转载文章,转载请联系作者)

福利时刻

距离大会参与通道关闭还有 1 天,扫描下方二维码或点击阅读原文,马上参与!(学生票特享 598 元,团购票每人立减优惠,倒计时 1 天!)

推荐阅读

  • 从垃圾分类到千行百业,如何打响AI“落地战”?

  • 2亿日活,日均千万级视频上传,快手推荐系统如何应对技术挑战

  • 在图数据上做机器学习,应该从哪个点切入?

  • Docker容器化部署Python应用

  • AI 假冒老板骗取 24.3 万美元

  • 编程吸金榜:你排第几?网友神回应了!

  • 吴子宁:手握 280 多项专利的斯坦福技术先锋 | 人物志

  • 阿里云 CDN 业务基于边缘容器的云原生转型实践

你点的每个“在看”,我都认真当成了喜欢

einsum,一个函数走天下相关推荐

  1. 如何一个模型走天下?集成训练多数据集,打造通用目标检测模型方法详解

    在目标检测的实际应用中,常常会出现需要泛化的目标检测系统的情况.如城市安防中,需要目标检测系统能够检测足够多类别的目标,才能达到更好的安防效果. 但目前常用的目标检测数据集中包含的类别数量有限,使用单 ...

  2. 【5min+】 一个令牌走天下!.Net Core中的ChangeToken

    系列介绍 [五分钟的dotnet]是一个利用您的碎片化时间来学习和丰富.net知识的博文系列.它所包含了.net体系中可能会涉及到的方方面面,比如C#的小细节,AspnetCore,微服务中的.net ...

  3. 小米手环模拟门禁卡读卡失败_一个手环走天下?可以!

    目标:将门禁卡.考勤卡.会员卡.停车卡.电梯卡等等各种卡模拟进手机里,模拟后可用手机代替刷卡,无需root,不用电脑 背景介绍: 1. 前言 目前,IC卡已被广泛应用于身份识别.金融消费.安全认证等领 ...

  4. 一个函数打天下,einsum

    作者丨永远在你身后@知乎 来源丨https://zhuanlan.zhihu.com/p/71639781 编辑丨极市平台 einsum全称Einstein summation convention( ...

  5. 请编写一个函数,计算n*m的棋盘格子(n为横向的格子数,m为竖向的格子数)沿着各自边缘线从左上角走到右下角,总共有多少种走法,要求不能走回头路,即:只能往右和向下走,不能往左和往上走。

    请编写一个函数,计算n*m的棋盘格子(n为横向的格子数,m为竖向的格子数)沿着各自边缘线从左上角走到右下角,总共有多少种走法,要求不能走回头路,即:只能往右和向下走,不能往左和往上走. 递归实现: # ...

  6. if-else走天下,让CPU分支预测技术浮出水面

    关键字 圈复杂度 CPU分支预测机制 指令 吞吐量 IPS-每秒指令 GIPS-每秒十亿指令 延迟-皮秒 分支预测 if-else走天下 圈复杂度 void sort(int *A) { int i ...

  7. 进入编译器后,一个函数经历了什么?

    我是一个函数 我是一个函数,名叫str_upper,我可以把输入的字符串从小写变成大写.不信你看,我长这样: char* str_upper(char* str, int len) {char upp ...

  8. Long类型传到前端失去精度(2):Long类型不是实体类的某一个字段,Long类型是一个函数的返回值

    Long类型传到前端失去精度(2):Long类型不是实体类的某一个字段,Long类型是一个函数的返回值 又是转换Mybatis-Plus的一天,又遇到了之前熟悉的问题:Long类型传到前端失去精度.可 ...

  9. python 如何判断一个函数执行完成_Python核心编程的四大神兽迭代器、生成器 、闭包以及装饰器...

    本文将主要分为4大部分,分别介绍Python核心编程中的迭代器.生成器 .闭包以及装饰器. 生成器 生成器是生成一个值的特殊函数,它具有这样的特点:第一次执行该函数时,先从头按顺序执行,在碰到yiel ...

最新文章

  1. 【WebAPI No.5】Core WebAPI中的自定义格式化
  2. cdoj 1070 秋实大哥打游戏 带权并查集
  3. JSTL-EL表达式
  4. 一般图带权多重匹配(欧拉图+最小费用流)
  5. IOS-网络(监听网络状态)
  6. 【华为大咖分享】5.交付在云端-全云DevOps研发实践(后附PPT下载地址)
  7. oracle中PLSQL存储过程中如何使用逗号分隔的集合(逗号分隔字符串转换为一个集合)...
  8. VMware workstation 磁盘扩容
  9. 封装一个时间百分比多个数比较
  10. 动易sitefactory 3.0 模板标签系统
  11. 兆骑科创平台创新创业赛事路演,投融资服务
  12. 利用STM32F103精确控制步进电机
  13. javascript基础常识问答(五)
  14. 在浏览器输入URL,按下回车之后的流程?
  15. shell脚本:删除文本中的字母、找单词、筛选,匹配,删除,替换
  16. MUR6060PT-ASEMI快恢复二极管MUR6060PT
  17. 1985年全国计算机编程大赛,关于举办2021年“中国高校计算机大赛-团体程序设计天梯赛”校内选拔赛的通知...
  18. Docker  入门
  19. 2023 源支付码支付系统源码v3.0 二开修复版 全本地化
  20. 《抽样技术》第1章 绪论

热门文章

  1. java/android 设计模式学习笔记(1)--- 单例模式
  2. css中的垂直居中方法
  3. 谁登录了你的linux
  4. 杨学海:跨境电商新通道-进口保税直邮模式解析
  5. HTML Inspector – 帮助你编写高质量的 HTML 代码
  6. django 中文乱码或不识别
  7. 使用HttpClient实现跨服务图片下载
  8. linux差分备份,完全和差分备份的自动化模型
  9. 万能头文件#include<bits/stdc++.h>更新GCC10.2.0版本
  10. 数据库管理工具dbeaver