torch.mul() 、 torch.mm() 及torch.matmul()的区别

一、简介

  • torch.mul(a, b) 是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵;
  • torch.mm(a, b) 是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵。
  • torch.bmm() 强制规定维度和大小相同
  • torch.matmul() 没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作

二、具体使用

1、torch.mul(a, b)和torch.mm(a, b)

举例

import torcha = torch.rand(3, 4)
b = torch.rand(3, 4)
c = torch.rand(4, 5)print(torch.mul(a, b).size())  # 返回 1*2 的tensor
print(torch.mm(a, c).size())   # 返回 1*3 的tensor
print(torch.mul(a, c).size())  # 由于a、b维度不同,报错

输出

torch.Size([3, 4])
torch.Size([3, 5])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-aea68cb5481f> in <module>7 print(torch.mul(a, b).size())  # 返回 1*2 的tensor8 print(torch.mm(a, c).size())   # 返回 1*3 的tensor
----> 9 print(torch.mul(a, c).size())  # 由于a、b维度不同,报错
RuntimeError: The size of tensor a (4) must match the size of tensor b (5) at non-singleton dimension 1

2、torch.bmm()

参考:https://pytorch.org/docs/stable/torch.html#torch.bmm

torch.bmm(input, mat2, out=None) → Tensor
torch.bmm()是tensor中的一个相乘操作,类似于矩阵中的A*B。

参数:

input,mat2:两个要进行相乘的tensor结构,两者必须是3D维度的,每个维度中的大小是相同的。
output:输出结果

并且相乘的两个矩阵,要满足一定的维度要求:input(p,m,n) * mat2(p,n,a) ->output(p,m,a)。这个要求,可以类比于矩阵相乘。前一个矩阵的列等于后面矩阵的行才可以相乘。

举例

import torch
x = torch.rand(2,4,5)
y = torch.rand(2,5,7)
print(torch.bmm(x,y).size())

输出

torch.Size([2, 4, 7])

3、torch.matmul()

torch.matmul()也是一种类似于矩阵相乘操作的tensor联乘操作。但是它可以利用python 中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。

参数:

input,other:两个要进行操作的tensor结构

output:结果

一些规则约定:

(1)若两个都是1D(向量)的,则返回两个向量的点积

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系。

输入

import torch
x = torch.rand(5) #1D
x1 = x.view(1,-1)
y = torch.rand(5,3) #2Dprint(x1.size())
print(x.size())
print(y.size())
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
print(torch.matmul(x1,y),'\n',torch.matmul(x,y).size())

输出

torch.Size([1, 5])
torch.Size([5])
torch.Size([5, 3])
tensor([1.5374, 1.3291, 1.8289]) torch.Size([3])
tensor([[1.5374, 1.3291, 1.8289]]) torch.Size([3])

(4)若input是2D,other是1D,则返回两者的点积结果。(个人觉得这块也可以理解成给other添加了维度,然后再去掉此维度,只不过维度是(3, )而不是规则(3)中的( ,4)了,但是可能就是因为内部机制不同,所以官方说的是点积而不是维度的升高和下降)

举例

import torch
x = torch.rand(3) #1D
x1 = x.view(-1,1)
y = torch.rand(5,3) #2Dprint(x1.size())
print(x.size())
print(y.size())
print(torch.matmul(y,x),'\n',torch.matmul(y,x).size())
print(torch.matmul(y,x1),'\n',torch.matmul(y,x1).size())

输出

torch.Size([3, 1])
torch.Size([3])
torch.Size([5, 3])
tensor([0.6472, 0.7025, 0.2358, 0.2873, 0.5696]) torch.Size([5])
tensor([[0.6472],[0.7025],[0.2358],[0.2873],[0.5696]]) torch.Size([5, 1])

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

  • (a)若input是1D,other是大于2D的,则类似于规则(3)
  • (b)若other是1D,input是大于2D的,则类似于规则(4)
  • (c)若input和other都是3D的,则与torch.bmm()函数功能一样
  • (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

言而总之,总而言之:matmul()根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。
参考文献:https://www.jianshu.com/p/e277f7fc67b3

Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别相关推荐

  1. pytorch矩阵乘法mm,bmm

    文章目录 矩阵维度 矩阵乘法 torch.mm torch.bmm torch.matmul 矩阵维度 首先需要确认多维矩阵每个维度的对应含义. a = torch.tensor([[[3.], [1 ...

  2. pytorch矩阵乘法总结

    文章目录 点乘 `torch.mul(a,b)` 二维矩阵乘 `torch.mm(a,b)` 三维矩阵乘 `torch.bmm(`a,b) 高维矩阵乘 `torch.matmul(a,b)` 点乘 t ...

  3. [PyTorch] 矩阵乘法

    参考 『PyTorch』矩阵乘法总结 1. * 两个张量在对应的位置上进行数值相乘. x = torch.randn(2, 2) y = x * x 2. torch.mm() 二维矩阵乘法 x = ...

  4. mm,bmm和matmul的区别

    参考 https://blog.csdn.net/leo_95/article/details/89946318 mm只能是矩阵乘法,即分别为(n×m)和(m×q) bmm是三维张量相乘,即(b×n× ...

  5. torch中乘法整理,*torch.mul()torch.mv()torch.mm()torch.dot()@torch.mutmal()

    目录 *位置乘 torch.mul():数乘 torch.mv():矩阵向量乘法 torch.mm() 矩阵乘法 torch.dot() 点乘积 @操作 torch.matmul() *位置乘 符号* ...

  6. 【Pytorch学习】torch.mm()torch.matmul()和torch.mul()以及torch.spmm()

    目录 1 引言 2 torch.mul(a, b) 3 torch.mm(a, b) 4 torch.matmul() 5 torch.spmm() 参考文献 1 引言   做深度学习过程中免不了使用 ...

  7. pytorch中torch.mul、torch.mm/torch.bmm、torch.matmul的区别

    预备知识:矩阵的各种乘积 三者对比 torch.mul: 两个输入按元素相乘,内积 分两种情况,默认其中一个是tensor 另一个也是tensor,那么两者就要符合broadcasedt的规则 另一个 ...

  8. PyTorch 笔记(05)— Tensor 基本运算(torch.abs、torch.add、torch.clamp、torch.div、torch.mul、torch.pow等)

    1. 函数汇总 Tensor 的基本运算会对 tensor 的每一个元素进行操作,此类操作的输入与输出形状一致,常用操作见下表所示. 对于很多操作,例如 div.mul.pow.fmod.等, PyT ...

  9. pytorch中tensor.mul()和mm()和matmul()

    tensor.mul tensor.mul和tensor * tensor 都是将矩阵的对应位置的元素相乘,因此要求维度相同,点乘 torch.mul(input, other, *, out=Non ...

最新文章

  1. Apache Kafka - Schema Registry
  2. 在图像中绘制基本形状和文字
  3. 连载:阿里巴巴大数据实践—实时技术
  4. JQuery中.css()与.addClass()设置样式的区别
  5. UnityShader26:运动模糊
  6. 关于transmission下载速度提升的小建议
  7. Java实现List集合去重的5种方式
  8. 美式英语口语中连读、略读,音变的技巧
  9. pytorch criterion踩坑小结
  10. 古墓丽影10linux,《古墓丽影:崛起》推出Linux系统版:Ubuntu 17.10可玩
  11. Jetson nano使用anaconda 2021-5-15
  12. OpenCV实现照片自动红眼去除
  13. yarn : 无法加载文件 ...Roaming\npm\yarn.ps1,因为在此系统上禁止运行脚本
  14. 我发现智能无人机课程里面讲了无人机建模这方面的理论知识
  15. Halo2学习笔记——背景资料之Elliptic curves(5)
  16. chrome无法打开无痕模式的解决方案
  17. SpringBoot+Vue实现前后端分离的小而学在线考试系统
  18. MATLAB线性回归实例 平炉炼钢
  19. 上海科技大学计算机浙江分数线,上海科技大学2020录取分数线 上海科技大学录取分数线各省汇总...
  20. jshell(jshell打不了汉字)

热门文章

  1. wangyin 一种新的操作系统设计
  2. 注册表包含了计算机哪些信息,注册表大概主要有哪些内容,分别有什么用?MSDN能找到其内容吗?...
  3. Mysql数据库学习笔记(1.创建数据表)ubuntu18.04
  4. IPC\DVS\DVR\NVR|XVR
  5. hbase shell命令2
  6. 伽利略全球卫星定位导航系统与GPS
  7. C/C++程序员转行人工智能
  8. 私域流量是什么?私域流量是如何运营的?
  9. 将jar包发布到本地maven仓库
  10. 最新 GitHub 上传本地项目代码 (main) (2022 更新)