文章目录

  • 1、简介
  • 2、torch.mm
  • 3、torch.bmm
  • 4、torch.matmul
  • 5、masked_fill

1、简介

这几天正在看NLP中的注意力机制,代码中涉及到了一些关于张量矩阵乘法和填充一些代码,这里积累一下。主要参考了pytorch2.0的官方文档。
①torch.mm(input,mat2,*,out=None)
②torch.bmm(input,mat2,*,out=None)
③torch.matmul(input, other, *, out=None)
④Tensor.masked_fill

2、torch.mm

torch.mm语法为:

torch.mm(input, mat2, *, out=None) → Tensor

就是矩阵的乘法。如果输入input是(n,m),mat2是(m, p),则输出为(n, p)。
示例:

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
-->tensor([[ 0.4851,  0.5037, -0.3633],[-0.0760, -3.6705,  2.4784]])

3、torch.bmm

torch.bmm语法为:

torch.bmm(input, mat2, *, out=None) → Tensor
  • 功能:对存储在input和mat2矩阵中的批数量的矩阵进行乘积。
  • 要求:input矩阵和mat2必须是三维的张量,且第一个维度即batch维度必须一样。
  • 举例:如果input是一个(b, n , m)的张量,mat2是一个(b, m, p)张量,则输出形状为(b, n, p)

示例:

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()
-->torch.Size([10, 3, 5])

解读:实际上刻画的就是一组矩阵与另一组张量矩阵的乘积,至于一组有多少个矩阵,由input和mat2的第一个输入维度决定,上述代码第一个维度为10,就代表着10个形状为(3, 4)的矩阵与10个形状为(4, 5)的矩阵分别对应相乘,得到10个形状为(3, 5)的矩阵。

4、torch.matmul

torch.matmul语法为:

torch.matmul(input, other, *, out=None) → Tensor

该函数刻画的是两个张量的乘积,且计算过程与张量的维度密切相关。

如果张量是一维的,输出结果是点乘,是一个标量。

a = torch.tensor([1,2,4])
b = torch.tensor([2,5,6])
print(torch.matmul(a, b))
print(a.shape)
--> tensor(36)
-->torch.Size([3])

注意:张量a.shape显示的是torch.Size([3]),只有一个维度,3是指这个维度中有3个数。
如果两个张量都是二维的,执行的是矩阵的乘法。

a = torch.tensor([[1,2,4], [6,2,1]])
b = torch.tensor([[2,5],[1,2],[6,8]
])
print(a.shape)
print(b.shape)
print(torch.matmul(a, b))
-->torch.Size([2, 3])
-->torch.Size([3, 2])
-->tensor([[28, 41],[20, 42]])

由上述示例可知,如果两个张量均为2维,那么其运算和torch.mm是一样的。
如果第一个参数input是1维的,第二个参数是二维的,那么在计算时,在第一个参数前增加一个维度1,计算完毕之后再把这个维度去掉。

a = torch.tensor([1,2,4])
b = torch.tensor([[2,5],[1,2],[6,8]
])print(a.shape)
print(b.shape)
print(torch.matmul(a, b))
-->torch.Size([3])
-->torch.Size([3, 2])
-->tensor([28, 41])

如上所示,a只有一个维度,在进行计算时,变成了(1, 3),则变成了(1, 3)乘以(3, 2),变成(1, 2),最后在去掉1这个维度。
如果第一个参数是2维的,第二个参数是1维的,则返回矩阵-向量乘积。

a = torch.tensor([1,2])
b = torch.tensor([[2,5],[1,2],[6,8]
])print(b.shape)
print(a.shape)
print(torch.matmul(b, a))
-->torch.Size([3, 2])
-->torch.Size([2])
-->tensor([12,  5, 22])

矩阵乘以张量,就是矩阵中的每一行都与这个张量相乘,最终得到一个一维的,大小为3的结果。
⑤多个维度

  • 如果两个参数至少都是1维的,且有一个参数的维度N>2,则返回的是一个批矩阵的乘积(即把多出的那个维度看作batch即可,让每个batch后的矩阵与后面的张量相乘即可)。
  • 如果第一个参数是1维的,则在它的维度前加上1,以便批量矩阵相乘并在之后删除。如果第二个参数是1维的,则将1追加到其维度,用于批处理矩阵倍数,然后删除。
  • 举例:如果input形状是(j,1,n,n),other的张量形状是(k,n,n),那么输出张量的形状将会是(j,k,n,n)。
  • 如果input形状是(j,1,n,m),other的张量形状是(k,m,p),那么输出张量的形状将会是(j,k,n,p)。
tensor1 = torch.randn(10, 3, 4, 5)
tensor2 = torch.randn(5, 4)
torch.matmul(tensor1, tensor2).size()
-->torch.Size([10, 3, 4, 4])
tensor1 = torch.randn(10, 3, 4, 5)
tensor2 = torch.randn(1, 5, 4)
torch.matmul(tensor1, tensor2).size()
-->torch.Size([10, 3, 4, 4])
tensor1 = torch.randn(10, 3, 4, 5)
tensor2 = torch.randn(1, 1, 5, 4)
torch.matmul(tensor1, tensor2).size()
-->torch.Size([10, 3, 4, 4])

仔细比较上述三个代码块,其最终的结果是一样的。可以简单记为如果两个维度不一致的话,多出的维度就看作是batch维,相当于在低维度前面增加一个维度。

5、masked_fill

语法为:

Tensor.masked_fill_(mask, value)

参数:

  • mask(BoolTensor):布尔掩码
  • value(float):用于填充的值。

mask是一个pytorch张量,元素是布尔值,value是要填充的值,填充规则是mask中取值为True的位置对应与需要填充的张量中的位置用value填充。

a = torch.tensor([[0, 8],[ 6, 8],[ 7,  1]
])mask = torch.tensor([[ True, False],[False, False],[False,  True]
])
b = a.masked_fill(mask, -1e9)
print(b)
-->tensor([[-1000000000,           8],[          6,           8],[          7, -1000000000]])

Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill相关推荐

  1. pytorch教程之nn.Module类详解——使用Module类来自定义网络层

    前言:前面介绍了如何自定义一个模型--通过继承nn.Module类来实现,在__init__构造函数中申明各个层的定义,在forward中实现层之间的连接关系,实际上就是前向传播的过程. 事实上,在p ...

  2. pytorch教程之nn.Module类详解——使用Module类来自定义模型

    pytorch教程之nn.Module类详解--使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思前言:pytorch中对于一般的序列模型,直接使用torch.nn.Se ...

  3. 【Pytorch】对比matual,mm和bmm函数

    pytorch中提供了 matmul.mm和bmm等矩阵的乘法运算功能,但其具体计算细节和场景截然不同,应予以注意和区别. 1. torch.mm 该函数即为矩阵的乘法,torch.mm(tensor ...

  4. PyTorch教程之DCGAN

    原文连接:DCGAN TUTORIAL 简介 本教程通过例程来介绍 DCGANs .我们使用名人照片来训练 GAN 网络使其能够生成新的名人. 这里使用的大部分代码都来自pytorch/example ...

  5. 【Pytorch学习】torch.mm()torch.matmul()和torch.mul()以及torch.spmm()

    目录 1 引言 2 torch.mul(a, b) 3 torch.mm(a, b) 4 torch.matmul() 5 torch.spmm() 参考文献 1 引言   做深度学习过程中免不了使用 ...

  6. torch中乘法整理,*torch.mul()torch.mv()torch.mm()torch.dot()@torch.mutmal()

    目录 *位置乘 torch.mul():数乘 torch.mv():矩阵向量乘法 torch.mm() 矩阵乘法 torch.dot() 点乘积 @操作 torch.matmul() *位置乘 符号* ...

  7. pytorch matmul和mm和bmm区别

    pytorch中matmul和mm和bmm区别 matmul mm bmm 结论 先看下官网上对这三个函数的介绍. matmul mm bmm 顾名思义, 就是两个batch矩阵乘法. 结论 从官方文 ...

  8. 【pytorch】torch.mm,torch.bmm以及torch.matmul的使用

    torch.mm torch.mm是两个矩阵相乘,即两个二维的张量相乘 如下面的例子 mat1 = torch.randn(2,3) print("mat1=", mat1)mat ...

  9. pytorch中torch.mul、torch.mm/torch.bmm、torch.matmul的区别

    预备知识:矩阵的各种乘积 三者对比 torch.mul: 两个输入按元素相乘,内积 分两种情况,默认其中一个是tensor 另一个也是tensor,那么两者就要符合broadcasedt的规则 另一个 ...

最新文章

  1. 20个案例详解 Pandas 当中的数据统计分析与排序
  2. if,elif,else的关系 input print int的用法
  3. 关于序列化的 10 几个问题,你顶得住不?
  4. 做正确的事,正确的做事
  5. 用“夜间模式”模式(javascript书签)浏览网页
  6. Java黑皮书课后题第8章:***8.35(最大块)给定一个元素为0或者1的方阵,编写程序,找到一个元素都为1的最大的子方阵。程序提示用户输入矩阵的行数。然后显示最大的子方阵的第一个元素、行数
  7. 启动与停止mysql服务的命令
  8. oracle备份磁盘头,ASM 磁盘头信息备份
  9. MFC源码解读(一)最原始一个MFC程序,手写不用向导
  10. 中国系泊系统行业市场供需与战略研究报告
  11. 完美解决SAMSUNG Mobile USB CDC Composite Device安装失败 三星手机USB驱动失败。
  12. AES攻击方法 :差分密码分析 boomerang attack飞去来器攻击
  13. TortoiseGit状态图标不能正常显示的解决办法
  14. 天猫淘宝越来越难做了,为什么不考虑下跨境电商?
  15. MyBatis 常见面试题有哪些?
  16. 最全量子计算硬件概述(建议收藏)
  17. MIFI与随身wifi、wifi共享软件,玩坏wifi的几种方法
  18. Mysql数据库日常使用备注
  19. UBT8:ubuntu安装Java1.8
  20. 【每日新闻】Gartner:2017年CRM跃升为规模最大、增速最快的软件市场 | 中国科学家发现神奇半导体材料...

热门文章

  1. RFSoC应用笔记 - RF数据转换器 -09- RFSoC关键配置之RF-DAC内部解析(三)
  2. 腾讯 美团 字节 抖音 面经
  3. [置顶] 代码审查工具FxCop建议采用的规则总结
  4. 网页显示不全的原因css,css 页面显示不全怎么办
  5. 基于STM32单片机智能花盆控制系统设计(毕业设计资料)
  6. pycharm安装及添加桌面图标
  7. 2017服务器cpu性能排行,2017年手机处理器排名_CPU排行榜名单
  8. 蔡氏混沌电路matlab程序,蔡氏混沌电路简介——Chuaapos;s-Circut.pptx-全文可读
  9. java8 Stream 使用案例
  10. storyboard(故事版)新手教程 图文详解 2.为无约束的故事版添加约束