本文主要介绍如何使用Pytorch中的爱因斯坦求和(einsum),掌握einsum的基本用法。

einsum的安装

在安装pytorch的虚拟环境下输入以下命令:

pip install opt_einsum

爱因斯坦求和约定

在数学中,爱因斯坦求和约定是一种标记法,也称为Einstein Summation Convention,在处理关于坐标的方程式时十分有效。简单来说,爱因斯坦求和就是简化掉求和式中的求和符号 ,这样就会使公式更加简洁,如

三条基本规则

einsum实现矩阵乘法的例子如下:

a = torch.randn(2, 3)
b = torch.randn(3, 4)
c = torch.mm(a, b)
d = torch.einsum("ik, kj->ij", [a, b])
print("a:{} \nb:{}".format(a, b))
print("c:{} \nd:{}".format(c, d))# Output:
a:tensor([[ 1.7128,  0.2671, -1.5735],[ 0.6192,  0.0096,  1.3178]])
b:tensor([[ 0.0595, -1.3128,  1.6158,  0.0901],[ 0.9183,  1.2884,  0.6276, -0.3407],[ 1.2795,  1.1721,  0.7161,  1.6859]])
c:tensor([[-1.6661, -3.7489,  1.8083, -2.5894],[ 1.7317,  0.7440,  1.9502,  2.2741]])
d:tensor([[-1.6661, -3.7489,  1.8083, -2.5894],[ 1.7317,  0.7440,  1.9502,  2.2741]])

可以看到,c和d的输出值是一样的,c中比较好理解, torch.mm(mat1, mat2, out=None) 实现的是对矩阵mat1mat2进行相乘。 如果mat1 是一个n×m张量,mat2 是一个 m×p 张量,将会输出一个 n×p 张量out。

那么d中呢?首先来看einsum的API:torch.einsum(equation*operands) → Tensor。

  • 第一个参数为equation,即d中的,它表示了输入张量和输出张量的维度,equation中箭头左边表示输入张量,以逗号来分割每个输入张量,箭头右边则表示输出张量。表示维度的字符只能是26个英文字母,即'a'~'z',这儿用的是i、j、k。
  • 第二个参数为*operands,表示实际输入的张量列表,其数量必须要和equation中的输入张量对应,即箭头左侧有多少个张量,那么你第二个参数的数量就必须有多少个。同时每个张量的子equation的字符个数要与张量的真实维度对应,即本文主要介绍如何使用Pytorch中的爱因斯坦求和(einsum),掌握einsum的基本用法。
equation  中的字符也可以理解为索引,就是输出张量的某个位置的值,是怎么从输入张量中得到的,比如上面矩阵乘法的输出 d 的某个点 d[i, j] 的值是通过 a[i, k] 和 b[i, k] 沿着 k 这个维度做内积得到的。

三条规则:

  • 规则一:equation 箭头左边,在不同输入之间重复出现的索引表示,把输入张量沿着该维度做乘法操作,比如还是以上面矩阵乘法为例, "ik,kj->ij",k 在输入中重复出现,所以就是把 a 和 b 沿着 k 这个维度作相乘操作;
  • 规则二:只出现在 equation 箭头左边的索引,表示中间计算结果需要在这个维度上求和,即求和索引。(求和索引:只出现在箭头左边的索引,表示中间计算结果需要这个维度上求和之后才能得到输出,比如上面的例子就是 k;)
  • 规则三:equation 箭头右边的索引顺序可以是任意的,比如上面的 "ik,kj->ij" 如果写成 "ik,kj->ji",那么就是返回输出结果的转置,用户只需要定义好索引的顺序,转置操作会在 einsum 内部完成。
# 规则三示例x = torch.randn(2, 3)
y = torch.randn(3, 4)m = torch.einsum("ik, kj->ij", x, y)
n = torch.einsum("ik, kj->ji", x, y)
print("a:{} \nb:{}".format(m, n))# Output:
a:tensor([[-1.0836, -0.2650, -1.7384, -0.5368],[ 1.1246, -0.2049,  1.5340,  0.6870]])
b:tensor([[-1.0836,  1.1246],[-0.2650, -0.2049],[-1.7384,  1.5340],[-0.5368,  0.6870]])

特殊规则

  • equation 也可以不写包括箭头在内的右边部分,那么在这种情况下,输出张量的维度会根据默认规则推导。就是把输入中只出现一次的索引取出来,然后按字母表顺序排列,比如上面的矩阵乘法 "ik,kj->ij" 也可以简化为 "ik,kj",根据默认规则,输出就是 "ij" 与原来一样;
  • equation 中支持 "..." 省略号,用于表示用户并不关心的索引,比如只对一个高维张量的最后两维做转置可以这么写:
t = torch.randn(1, 3, 5, 7, 9)
res = torch.einsum('...ij->...ji', t)
print(res.size())# Output:
torch.Size([1, 3, 5, 9, 7])

einsum例子

  • 提取矩阵对角线元素

# 构造一个tensor a
a = torch.arange(9).reshape(3, 3)   # .reshape(3, 3)等价于.view(3, 3)print(a)# 法一:提取矩阵对角线元素
diag1 = torch.einsum('ii->i', a)
print(diag1)# 法二:torch.diagonal(tensor, offset):对tensor取对角线元素,offset为偏移量,0为主对角线,1为主对角线下一个对角线
diag2 = torch.diagonal(a, 0)
print(diag2)# 法三:通过numpy,双重for
out = np.empty((3,), dtype=np.int32)
for i in range(0, 3):sum = 0for inner in range(0, 1):sum += a.numpy()[i, i]out[i] = sum
print(out)
  • 矩阵转置

a = torch.arange(6).view(2, 3)
print("a: ", a)
a_trans1 = torch.einsum('ij->ji', a)# torch.transpose(Tensor,dim0,dim1):transpose()一次只能在两个维度间进行转置
a_trans2 = torch.transpose(a, 0, 1)print("a_trans1:{}\na_trans2:{}".format(a_trans1, a_trans2))
  • permute高维张量转置

# 高维张量转置(两种方法)
b = torch.randn(2, 4, 6, 3, 8)b_trans1 = torch.einsum('...ij->...ji', b)b_trans2 = b.permute(0, 1, 2, 4, 3)
print("shape1:\n{}\nshape2:\n{}".format(b_trans1.shape, b_trans2.size()))# Output:
shape1:
torch.Size([2, 4, 6, 8, 3])
shape2:
torch.Size([2, 4, 6, 8, 3])
  • sum求和

a = torch.arange(6).view(2, 3)# 矩阵所有元素求和
sum1 = torch.einsum('ij->', a)
sum2 = torch.sum(a)
print("a:{}\nsum1:{}, sum2:{}".format(a, sum1, sum2))# Output:
a:tensor([[0, 1, 2],[3, 4, 5]])
sum1:15, sum2:15
  • 按列求和

# 矩阵按列求和
a = torch.arange(6).view(2, 3)sum3 = torch.einsum('ij->j', a)
sum4 = torch.sum(a, dim=0)
print(sum3, sum4)#Output:
tensor([3, 5, 7]) tensor([3, 5, 7])

参考文章:

https://zhuanlan.zhihu.com/p/71639781

一文学会 Pytorch 中的 einsum

Pytorch中的einsum相关推荐

  1. Pytorch中, torch.einsum详解。

    爱因斯坦简记法:是一种由爱因斯坦提出的,对向量.矩阵.张量的求和运算的求和简记法. 在该简记法当中,省略掉的部分是:1)求和符号与2)求和号的下标 省略规则为:默认成对出现的下标(如下例1中的i和例2 ...

  2. PyTorch 中的傅里叶卷积

    欢迎关注 "小白玩转Python",发现更多 "有趣" 注意: 在这个 Github repo 中提供了1D.2D 和3D Fourier 卷积的完整方法.我还 ...

  3. pytorch中调整学习率的lr_scheduler机制

    pytorch中调整学习率的lr_scheduler机制 </h1><div class="clear"></div><div class ...

  4. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  5. PyTorch中的MIT ADE20K数据集的语义分割

    PyTorch中的MIT ADE20K数据集的语义分割 代码地址:https://github.com/CSAILVision/semantic-segmentation-pytorch Semant ...

  6. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  7. 利用 AssemblyAI 在 PyTorch 中建立端到端的语音识别模型

    作者 | Comet 译者 | 天道酬勤,责编 | Carol 出品 | AI 科技大本营(ID:rgznai100) 这篇文章是由AssemblyAI的机器学习研究工程师Michael Nguyen ...

  8. 实践指南 | 用PyTea检测 PyTorch 中的张量形状错误

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨陈萍.泽南 来源丨机器之心 编辑丨极市平台 导读 韩国首尔大学 ...

  9. 实践教程 | 浅谈 PyTorch 中的 tensor 及使用

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者 | xiaopl@知乎(已授权) 来源 | https://z ...

最新文章

  1. 94年出生,6篇SCI,一作发Science,你还不放下手上玩的泥巴
  2. 公司数据部培训讲义:ArcMap数字化培训教程
  3. 万字总结 MySQL核心知识,赠送25连环炮
  4. 微软IE8浏览器个性化设置技巧
  5. vue main.js 导入文件报错Module build failed: Error: No PostCSS Config found in:
  6. mysql char null_关于mysql设置varchar 字段的默认值''和null的区别,以及varchar和char的区别...
  7. Java集合之Properties
  8. WPF实现背景透明磨砂,并通过HandyControl组件实现弹出等待框
  9. zabbix v3.0安装部署【转】
  10. 西铁院云计算机室与应用,关于开展“云桌面应用”技术服务的通知
  11. python编写时钟代码_python Tkinter 编写时钟
  12. matplotlib 颜色板
  13. 书籍分析实例:哈利波特的分词及人物关系
  14. 【Kettle】date类型不能被excel输出
  15. 安卓手机访问 ubuntu 共享的方法
  16. pdf转换器下载使用步骤
  17. c语言矩阵连乘递归算法,动态规划求解矩阵连乘问题
  18. 从定制 Ghost 镜像聊聊优化 Dockerfile
  19. java参数配置jconsole_jconsole 配置详解
  20. markdown的标题设置自动添加序号

热门文章

  1. 乐高计算机游戏泡泡龙教案,疯狂泡泡龙(400关)
  2. SAP access 破解
  3. 怎样编写java程序
  4. 【spring】依赖注入之@Autowired依赖注入
  5. 小学生灯谜计算机,小学生谜语大全
  6. Activity详解2
  7. Vlan间通信原理(HCIA)
  8. 数据可视化笔记7 网络数据可视化
  9. Base64 编码原理及代码实现
  10. 【读书笔记】【程序员的自我修养 -- 链接、装载与库(三)】函数调用与栈(this指针、返回值传递临时对象构建栈、运行库与多线程、_main函数、系统调用与中断向量表、Win32、可变参数、大小端