torch.matmul() 详解
最近在准备做 HW04,在读 transformer 的源码的时候发现 attention score 的 torch.matmul() 的奇妙设置,故有此篇文章进行分享。
前言碎碎念:
一开始我以为 torch.matmul 所做的工作就是简单的矩阵相乘,即:假设我们有两个矩阵
A
和B
,它们的 size 分别为(m, n)
和(n, p)
,那么 A x B 的 size 为 (m, p)。然后我看了眼官方文档的例子:tensor1 = torch.randn(10, 3, 4) tensor2 = torch.randn(10, 4, 5) torch.matmul(tensor1, tensor2).size()
>> torch.Size([10, 3, 5])
大大的问号冒了出来 : ),这也能乘?
文章的代码文件:notebook 代码
文章目录
- 前期工作
- input_d = other_d = 1(两个 Tensor 皆为 1 维)
- input_d = other_d = 2 (两个 Tensor 皆为 2 维)
- input_d = 1, other_d = 2
- input_d = 2, other_d = 1
- input_d > 2 or other_d > 2
- input_d > 2 and other_d = 2
- input_d > 2 and other_d = 1
- input_d > 2 and other_d >2 (多维 Tensor)
- 拓展阅读
下面结合官方文档提供一些例子给大家理解。
torch
.matmul
(input, other, ***, out=None) → Tensor两个张量的矩阵乘积,具体行为取决于张量的维度,如下所示。
这里为了描述方便,用 input_d
和 other_d
分别指代 input.dim()
和 other.dim()
,使用 torch.randint()
替代 torch.randn()
方便印证。
前期工作
import torch# 固定 torch 的随机数种子,以便重现结果
torch.manual_seed(0)# 打印信息
def print_info(A, B):print(f"A: {A}\nB: {B}")print(f"A 的维度: {A.dim()},\t B 的维度: {B.dim()}")print(f"A 的元素总数: {A.numel()},\t B 的元素总数: {B.numel()}")print(f"torch.matmul(A, B): {torch.matmul(A, B)}")print(f"torch.matmul(A, B).size(): {torch.matmul(A, B).size()}")
input_d = other_d = 1(两个 Tensor 皆为 1 维)
此时就是我们常说的点积(dot product)
,返回标量。注意,这里是维度为 1,而不是元素总数。
A = torch.randint(0, 5, size=(2,))
B = torch.randint(0, 5, size=(2,))print_info(A, B)
>> A: tensor([4, 4])
>> B: tensor([3, 0])
>> A 的维度: 1, B 的维度: 1
>> A 的元素总数: 2, B 的元素总数: 2
>> torch.matmul(A, B) = 12
>> torch.matmul(A, B).size() = torch.Size([])
input_d = other_d = 2 (两个 Tensor 皆为 2 维)
返回矩阵乘积的结果。
A = torch.randint(0, 5, size=(2, 1))
B = torch.randint(0, 5, size=(1, 2))print_info(A, B)
>> A: tensor([[3],
>> [4]])
>> B: tensor([[2, 3]])
>> A 的维度: 2, B 的维度: 2
>> A 的元素总数: 2, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[ 6, 9],
>> [ 8, 12]])
>> torch.matmul(A, B).size() = torch.Size([2, 2])
input_d = 1, other_d = 2
按照广播机制(boardcasting)进行处理,即:从 size 的尾部开始一一比对,如果维度不够,则扩展一维,令初始值为 1 再进行计算。计算完之后移除扩展的维度,用下面的例子来说就是扩展成 (1, 2) 后,(1, 2) * (2, 2) => (1, 2) => (2, )
A = torch.randint(0, 5, size=(2, ))
B = torch.randint(0, 5, size=(2, 2))print_info(A, B)
>> A: tensor([2, 3])
>> B: tensor([[1, 1],
>> [1, 4]])
>> A 的维度: 1, B 的维度: 2
>> A 的元素总数: 2, B 的元素总数: 4
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])
input_d = 2, other_d = 1
返回矩阵与向量的乘积。
# 这里使用上一次的矩阵和向量,方便对照
print_info(B, A)
>> A: tensor([[1, 1],
>> [1, 4]])
>> B: tensor([2, 3])
>> A 的维度: 2, B 的维度: 1
>> A 的元素总数: 4, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([ 5, 14])
>> torch.matmul(A, B).size() = torch.Size([2])
input_d > 2 or other_d > 2
以 input_d > 2 为例,维度不匹配就通过广播机制扩展,最后结果上删除掉扩展的维度。
个人理解:对于 dim >= 2 的 tensor 来说最后两维被看作矩阵的行和列,其余(如果存在)被看作 batch。
对于非矩阵(non-matrix)维度也是进行广播处理的,以 A.size() = (j, 1, m, n) 和 B.size() =(k, n, m) 为例,j x 1 和 k 是非矩阵维度,也就是 batch 维度,torch.matmul(A, B).size() = (j, k, m, m)。
input_d > 2 and other_d = 2
矩阵部分:(1, 2) * (2, 1)
A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, 1))print_info(A, B)
>> A: tensor([[[3, 1]],
>>
>> [[1, 3]]])
>> B: tensor([[4],
>> [3]])
>> A 的维度: 3, B 的维度: 2
>> A 的元素总数: 4, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[[15]],
>>
>> [[13]]])
input_d > 2 and other_d = 1
这里可以看成单拎出 A 的最后 2 维与 B 做 input_d = 2 和 other_d = 1 的乘法:(1, 2) * (2, ),具体细节可以回看上面对应的部分。
A = torch.randint(0, 5, size=(2, 1, 2))
B = torch.randint(0, 5, size=(2, ))print_info(A, B)
>> A: tensor([[[1, 4]],
>>
>> [[1, 4]]])
>> B: tensor([4, 1])
>> A 的维度: 3, B 的维度: 1
>> A 的元素总数: 4, B 的元素总数: 2
>> torch.matmul(A, B) = tensor([[8],
>> [8]])
>> torch.matmul(A, B).size() = torch.Size([2, 1])
input_d > 2 and other_d >2 (多维 Tensor)
广播部分:(2, 1, *, *) => (2, 2, *, *)。矩阵部分:(2, 1) * (1, 2)
A = torch.randint(0, 5, size=(2, 1, 2, 1))
B = torch.randint(0, 5, size=(2, 1, 2))print_info(A, B)
>> A: tensor([[[[4],
>> [4]]],
>>
>>
>> [[[4],
>> [0]]]])
>> B: tensor([[[1, 2]],
>>
>> [[3, 0]]])
>> A 的维度: 4, B 的维度: 3
>> A 的元素总数: 4, B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[ 4, 8],
>> [ 4, 8]],
>>
>> [[12, 0],
>> [12, 0]]],
>>
>>
>> [[[ 4, 8],
>> [ 0, 0]],
>>
>> [[12, 0],
>> [ 0, 0]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 2, 2])
在往下翻之前不妨思考一下 torch.matmul(B, A).size() 等于多少。
print_info(B, A)
>> A: tensor([[[1, 2]],
>>
>> [[3, 0]]])
>> B: tensor([[[[4],
>> [4]]],
>>
>>
>> [[[4],
>> [0]]]])
>> A 的维度: 3, B 的维度: 4
>> A 的元素总数: 4, B 的元素总数: 4
>> torch.matmul(A, B) = tensor([[[[12]],
>>
>> [[12]]],
>>
>>
>> [[[ 4]],
>>
>> [[12]]]])
>> torch.matmul(A, B).size() = torch.Size([2, 2, 1, 1])
拓展阅读
Broadcasting
torch.matmul() 详解相关推荐
- torch.unsqueeze和 torch.squeeze() 详解
1. torch.unsqueeze 详解 torch.unsqueeze(input, dim, out=None) 作用:扩展维度 返回一个新的张量,对输入的既定位置插入维度 1 注意: 返回张量 ...
- pytorch稀疏张量模块torch.sparse详解
torch.sparse是一个专门处理稀疏张量的模块.通常,张量会按一定的顺序连续地进行存取.但是,对于一个存在很多空值的稀疏张量来说,顺序存储的效率显得较为低下.因此,pytorch推出了稀疏张 ...
- torch.roll() 详解
torch.roll(input, shifts, dims=None) input (Tensor) – the input tensor. shifts (int or tuple of pyth ...
- pytorch拼接函数:torch.stack()和torch.cat()--详解及例子
原文链接: https://blog.csdn.net/xinjieyuan/article/details/105205326 https://blog.csdn.net/xinjieyuan/ar ...
- 【PyTorch系例】torch.Tensor详解和常用操作
学习教材: 动手学深度学习 PYTORCH 版(DEMO) (https://github.com/ShusenTang/Dive-into-DL-PyTorch) PDF 制作by [Marcus ...
- Pytorch中, torch.einsum详解。
爱因斯坦简记法:是一种由爱因斯坦提出的,对向量.矩阵.张量的求和运算的求和简记法. 在该简记法当中,省略掉的部分是:1)求和符号与2)求和号的下标 省略规则为:默认成对出现的下标(如下例1中的i和例2 ...
- torch unsqueeze()详解
Torch官网解释: torch.unsqueeze(input, dim) → Tensor Returns a new tensor with a dimension of size one in ...
- 【torch.argmax与torch.max详解】
Pytorch常用函数 一.torch.max 1.调用方式 2.相关介绍 3.代码实例及图示理解 二.torch.argmax 1.调用方式 2.相关介绍 3.代码实例及图示理解 三.torch.m ...
- torch.load() 、torch.load_state_dict() 详解
最新文章
- HDU 3966 Aragorn's Story (树链点权剖分,成段修改单点查询)
- BZOJ1226 SDOI2009学校食堂(状压dp)
- 年轻人,你为什么来阿里做技术?
- Android开发之和风天气篇:1、获取天气信息
- 全向轮机器人应用平台
- javahost:使用虚拟DNS省掉开发环境配置hosts文件
- three.js 学习1
- 计算机图片文档怎么着,【电脑知识】怎样将图片转换成word文档
- IDEA 顶部导航栏(Main Menu)不见了怎么办?
- 硬件工程师的真实前途我说出来可能你们不信
- ubuntu hashcat 安装
- 姜小白的Python日记Day13 jason序列化与开发规范
- Windows HANDLE是什么
- python 管理 交换机_用python 脚本控制telnet登录交换机
- element table 表格实现上移、下移
- 用算符优先法对算术表达式求值(六)
- 计算机三级网络技术笔记
- springboot以FTP方式上传文件到远程服务器
- 如何给电脑桌面进行壁纸更换
- 蓝绿发布、红黑发布、灰度发布你都分得清吗
热门文章
- 打官司证人证言有用吗?
- instagram akp_如何备份您的社交媒体帐户-Facebook,Twitter,Google +和Instagram
- jieba读取txt文档并进行分词、词频统计,输出词云图
- 传入神经和传出神经图片,神经网络图片预处理
- 华为又来黑马招聘了!薪资待遇令人眼馋!
- vm12制作的centos8虚拟机上搭建k8s一次实践
- 微信小程序自定义类似微信联系人组件
- Linux中rz -y命令和rz -E命令的区别
- java程序员竞赛_广东省Java程序员竞赛
- 父类卡子类卡java,Java 问答:终极父类(三)——finalize()和 getClass()