最近在准备做 HW04,在读 transformer 的源码的时候发现 attention score 的 torch.matmul() 的奇妙设置,故有此篇文章进行分享。

前言碎碎念:

一开始我以为 torch.matmul 所做的工作就是简单的矩阵相乘,即:假设我们有两个矩阵 AB,它们的 size 分别为 (m, n)(n, p),那么 A x B 的 size 为 (m, p)。然后我看了眼官方文档的例子:

tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
>> torch.Size([10, 3, 5])

大大的问号冒了出来 : ),这也能乘?

文章的代码文件:notebook 代码

文章目录

  • 前期工作
  • input_d = other_d = 1(两个 Tensor 皆为 1 维)
  • input_d = other_d = 2 (两个 Tensor 皆为 2 维)
  • input_d = 1, other_d = 2
  • input_d = 2, other_d = 1
  • input_d > 2 or other_d > 2
    • input_d > 2 and other_d = 2
    • input_d > 2 and other_d = 1
    • input_d > 2 and other_d >2 (多维 Tensor)
  • 拓展阅读

下面结合官方文档提供一些例子给大家理解。

torch.matmul(input, other, ***, out=None) → Tensor

两个张量的矩阵乘积,具体行为取决于张量的维度,如下所示。

这里为了描述方便,用 input_dother_d 分别指代 input.dim()other.dim(),使用 torch.randint() 替代 torch.randn() 方便印证。

前期工作

import torch# 固定 torch 的随机数种子,以便重现结果
torch.manual_seed(0)# 打印信息
def print_info(A, B):print(f"A: {A}\nB: {B}")print(f"A 的维度: {A.dim()},\t B 的维度: {B.dim()}")print(f"A 的元素总数: {A.numel()},\t B 的元素总数: {B.numel()}")print(f"torch.matmul(A, B): {torch.matmul(A, B)}")print(f"torch.matmul(A, B).size(): {torch.matmul(A, B).size()}")

input_d = other_d = 1(两个 Tensor 皆为 1 维)

此时就是我们常说的点积(dot product),返回标量。注意,这里是维度为 1,而不是元素总数。

A = torch.randint(0, 5, size=(2,))
B = torch.randint(0, 5, size=(2,))print_info(A, B)
>> A: tensor([4, 4])
>> B: tensor([3, 0])
>> A 的维度: 1,   B 的维度: 1
>> A 的元素总数: 2,     B 的元素总数: 2
>> torch.matmul(A, B) = 12
>> torch.matmul(A, B).size() = torch.Size([])

input_d = other_d = 2 (两个 Tensor 皆为 2 维)

返回矩阵乘积的结果。

A = torch.randint(0, 5, size=(2, 1))
B = torch.randint(0, 5, size=(1, 2))print_info(A, B)
>> A: tensor([[3],
>>         [4]])
>> B: tensor([[2, 3]])
>> A 的维度: 2,   B 的维度: 2
>> A 的元素总数: 2,     B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[ 6,  9],
>>         [ 8, 12]])
>> torch.matmul(A, B).size() = torch.Size([2, 2])

input_d = 1, other_d = 2

按照广播机制(boardcasting)进行处理,即:从 size 的尾部开始一一比对,如果维度不够,则扩展一维,令初始值为 1 再进行计算。计算完之后移除扩展的维度,用下面的例子来说就是扩展成 (1, 2) 后,(1, 2) * (2, 2) => (1, 2) => (2, )

A = torch.randint(0, 5, size=(2, ))
B = torch.randint(0, 5, size=(2, 2))print_info(A, B)
>> A: tensor([2, 3])
>> B: tensor([[1, 1],
>>         [1, 4]])
>> A 的维度: 1,   B 的维度: 2
>> A 的元素总数: 2,     B 的元素总数: 4
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])

input_d = 2, other_d = 1

返回矩阵与向量的乘积。

# 这里使用上一次的矩阵和向量,方便对照
print_info(B, A)
>> A: tensor([[1, 1],
>>         [1, 4]])
>> B: tensor([2, 3])
>> A 的维度: 2,   B 的维度: 1
>> A 的元素总数: 4,     B 的元素总数: 2
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])

input_d > 2 or other_d > 2

以 input_d > 2 为例,维度不匹配就通过广播机制扩展,最后结果上删除掉扩展的维度。

个人理解:对于 dim >= 2 的 tensor 来说最后两维被看作矩阵的行和列,其余(如果存在)被看作 batch。

对于非矩阵(non-matrix)维度也是进行广播处理的,以 A.size() = (j, 1, m, n) 和 B.size() =(k, n, m) 为例,j x 1 和 k 是非矩阵维度,也就是 batch 维度,torch.matmul(A, B).size() = (j, k, m, m)。

input_d > 2 and other_d = 2

矩阵部分:(1, 2) * (2, 1)

A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, 1))print_info(A, B)
>> A: tensor([[[3, 1]],
>>
>>         [[1, 3]]])
>> B: tensor([[4],
>>         [3]])
>> A 的维度: 3,   B 的维度: 2
>> A 的元素总数: 4,     B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[[15]],
>>
>>         [[13]]])

input_d > 2 and other_d = 1

这里可以看成单拎出 A 的最后 2 维与 B 做 input_d = 2 和 other_d = 1 的乘法:(1, 2) * (2, ),具体细节可以回看上面对应的部分。

A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, ))print_info(A, B)
>> A: tensor([[[1, 4]],
>>
>>         [[1, 4]]])
>> B: tensor([4, 1])
>> A 的维度: 3,   B 的维度: 1
>> A 的元素总数: 4,     B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[8],
>>         [8]])
>> torch.matmul(A, B).size() = torch.Size([2, 1])

input_d > 2 and other_d >2 (多维 Tensor)

广播部分:(2, 1, *, *) => (2, 2, *, *)。矩阵部分:(2, 1) * (1, 2)

A = torch.randint(0, 5, size=(2, 1, 2, 1))
B = torch.randint(0, 5, size=(2, 1, 2))print_info(A, B)
>> A: tensor([[[[4],
>>           [4]]],
>>
>>
>>         [[[4],
>>           [0]]]])
>> B: tensor([[[1, 2]],
>>
>>         [[3, 0]]])
>> A 的维度: 4,   B 的维度: 3
>> A 的元素总数: 4,     B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[ 4,  8],
>>           [ 4,  8]],
>>
>>          [[12,  0],
>>           [12,  0]]],
>>
>>
>>         [[[ 4,  8],
>>           [ 0,  0]],
>>
>>          [[12,  0],
>>           [ 0,  0]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 2, 2])

在往下翻之前不妨思考一下 torch.matmul(B, A).size() 等于多少。

print_info(B, A)
>> A: tensor([[[1, 2]],
>>
>>         [[3, 0]]])
>> B: tensor([[[[4],
>>           [4]]],
>>
>>
>>         [[[4],
>>           [0]]]])
>> A 的维度: 3,   B 的维度: 4
>> A 的元素总数: 4,     B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[12]],
>>
>>          [[12]]],
>>
>>
>>         [[[ 4]],
>>
>>          [[12]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 1, 1])

拓展阅读

Broadcasting

torch.matmul() 详解相关推荐

  1. torch.unsqueeze和 torch.squeeze() 详解

    1. torch.unsqueeze 详解 torch.unsqueeze(input, dim, out=None) 作用:扩展维度 返回一个新的张量,对输入的既定位置插入维度 1 注意: 返回张量 ...

  2. pytorch稀疏张量模块torch.sparse详解

      torch.sparse是一个专门处理稀疏张量的模块.通常,张量会按一定的顺序连续地进行存取.但是,对于一个存在很多空值的稀疏张量来说,顺序存储的效率显得较为低下.因此,pytorch推出了稀疏张 ...

  3. torch.roll() 详解

    torch.roll(input, shifts, dims=None) input (Tensor) – the input tensor. shifts (int or tuple of pyth ...

  4. pytorch拼接函数:torch.stack()和torch.cat()--详解及例子

    原文链接: https://blog.csdn.net/xinjieyuan/article/details/105205326 https://blog.csdn.net/xinjieyuan/ar ...

  5. 【PyTorch系例】torch.Tensor详解和常用操作

    学习教材: 动手学深度学习 PYTORCH 版(DEMO) (https://github.com/ShusenTang/Dive-into-DL-PyTorch) PDF 制作by [Marcus ...

  6. Pytorch中, torch.einsum详解。

    爱因斯坦简记法:是一种由爱因斯坦提出的,对向量.矩阵.张量的求和运算的求和简记法. 在该简记法当中,省略掉的部分是:1)求和符号与2)求和号的下标 省略规则为:默认成对出现的下标(如下例1中的i和例2 ...

  7. torch unsqueeze()详解

    Torch官网解释: torch.unsqueeze(input, dim) → Tensor Returns a new tensor with a dimension of size one in ...

  8. 【torch.argmax与torch.max详解】

    Pytorch常用函数 一.torch.max 1.调用方式 2.相关介绍 3.代码实例及图示理解 二.torch.argmax 1.调用方式 2.相关介绍 3.代码实例及图示理解 三.torch.m ...

  9. torch.load() 、torch.load_state_dict() 详解

最新文章

  1. HDU 3966 Aragorn's Story (树链点权剖分,成段修改单点查询)
  2. BZOJ1226 SDOI2009学校食堂(状压dp)
  3. 年轻人,你为什么来阿里做技术?
  4. Android开发之和风天气篇:1、获取天气信息
  5. 全向轮机器人应用平台
  6. javahost:使用虚拟DNS省掉开发环境配置hosts文件
  7. three.js 学习1
  8. 计算机图片文档怎么着,【电脑知识】怎样将图片转换成word文档
  9. IDEA 顶部导航栏(Main Menu)不见了怎么办?
  10. 硬件工程师的真实前途我说出来可能你们不信
  11. ubuntu hashcat 安装
  12. 姜小白的Python日记Day13 jason序列化与开发规范
  13. Windows HANDLE是什么
  14. python 管理 交换机_用python 脚本控制telnet登录交换机
  15. element table 表格实现上移、下移
  16. 用算符优先法对算术表达式求值(六)
  17. 计算机三级网络技术笔记
  18. springboot以FTP方式上传文件到远程服务器
  19. 如何给电脑桌面进行壁纸更换
  20. 蓝绿发布、红黑发布、灰度发布你都分得清吗

热门文章

  1. 打官司证人证言有用吗?
  2. instagram akp_如何备份您的社交媒体帐户-Facebook,Twitter,Google +和Instagram
  3. jieba读取txt文档并进行分词、词频统计,输出词云图
  4. 传入神经和传出神经图片,神经网络图片预处理
  5. 华为又来黑马招聘了!薪资待遇令人眼馋!
  6. vm12制作的centos8虚拟机上搭建k8s一次实践
  7. 微信小程序自定义类似微信联系人组件
  8. Linux中rz -y命令和rz -E命令的区别
  9. java程序员竞赛_广东省Java程序员竞赛
  10. 父类卡子类卡java,Java 问答:终极父类(三)——finalize()和 getClass()