Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill
文章目录
- 1、简介
- 2、torch.mm
- 3、torch.bmm
- 4、torch.matmul
- 5、masked_fill
1、简介
这几天正在看NLP中的注意力机制,代码中涉及到了一些关于张量矩阵乘法和填充一些代码,这里积累一下。主要参考了pytorch2.0的官方文档。
①torch.mm(input,mat2,*,out=None)
②torch.bmm(input,mat2,*,out=None)
③torch.matmul(input, other, *, out=None)
④Tensor.masked_fill
2、torch.mm
torch.mm语法为:
torch.mm(input, mat2, *, out=None) → Tensor
就是矩阵的乘法。如果输入input是(n,m),mat2是(m, p),则输出为(n, p)。
示例:
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)
-->tensor([[ 0.4851, 0.5037, -0.3633],[-0.0760, -3.6705, 2.4784]])
3、torch.bmm
torch.bmm语法为:
torch.bmm(input, mat2, *, out=None) → Tensor
- 功能:对存储在input和mat2矩阵中的批数量的矩阵进行乘积。
- 要求:input矩阵和mat2必须是三维的张量,且第一个维度即batch维度必须一样。
- 举例:如果input是一个(b, n , m)的张量,mat2是一个(b, m, p)张量,则输出形状为(b, n, p)
示例:
input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()
-->torch.Size([10, 3, 5])
解读:实际上刻画的就是一组矩阵与另一组张量矩阵的乘积,至于一组有多少个矩阵,由input和mat2的第一个输入维度决定,上述代码第一个维度为10,就代表着10个形状为(3, 4)的矩阵与10个形状为(4, 5)的矩阵分别对应相乘,得到10个形状为(3, 5)的矩阵。
4、torch.matmul
torch.matmul语法为:
torch.matmul(input, other, *, out=None) → Tensor
该函数刻画的是两个张量的乘积,且计算过程与张量的维度密切相关。
① 如果张量是一维的,输出结果是点乘,是一个标量。
a = torch.tensor([1,2,4])
b = torch.tensor([2,5,6])
print(torch.matmul(a, b))
print(a.shape)
--> tensor(36)
-->torch.Size([3])
注意:张量a.shape显示的是torch.Size([3]),只有一个维度,3是指这个维度中有3个数。
② 如果两个张量都是二维的,执行的是矩阵的乘法。
a = torch.tensor([[1,2,4], [6,2,1]])
b = torch.tensor([[2,5],[1,2],[6,8]
])
print(a.shape)
print(b.shape)
print(torch.matmul(a, b))
-->torch.Size([2, 3])
-->torch.Size([3, 2])
-->tensor([[28, 41],[20, 42]])
由上述示例可知,如果两个张量均为2维,那么其运算和torch.mm是一样的。
③如果第一个参数input是1维的,第二个参数是二维的,那么在计算时,在第一个参数前增加一个维度1,计算完毕之后再把这个维度去掉。
a = torch.tensor([1,2,4])
b = torch.tensor([[2,5],[1,2],[6,8]
])print(a.shape)
print(b.shape)
print(torch.matmul(a, b))
-->torch.Size([3])
-->torch.Size([3, 2])
-->tensor([28, 41])
如上所示,a只有一个维度,在进行计算时,变成了(1, 3),则变成了(1, 3)乘以(3, 2),变成(1, 2),最后在去掉1这个维度。
④如果第一个参数是2维的,第二个参数是1维的,则返回矩阵-向量乘积。
a = torch.tensor([1,2])
b = torch.tensor([[2,5],[1,2],[6,8]
])print(b.shape)
print(a.shape)
print(torch.matmul(b, a))
-->torch.Size([3, 2])
-->torch.Size([2])
-->tensor([12, 5, 22])
矩阵乘以张量,就是矩阵中的每一行都与这个张量相乘,最终得到一个一维的,大小为3的结果。
⑤多个维度
- 如果两个参数至少都是1维的,且有一个参数的维度N>2,则返回的是一个批矩阵的乘积(即把多出的那个维度看作batch即可,让每个batch后的矩阵与后面的张量相乘即可)。
- 如果第一个参数是1维的,则在它的维度前加上1,以便批量矩阵相乘并在之后删除。如果第二个参数是1维的,则将1追加到其维度,用于批处理矩阵倍数,然后删除。
- 举例:如果input形状是(j,1,n,n),other的张量形状是(k,n,n),那么输出张量的形状将会是(j,k,n,n)。
- 如果input形状是(j,1,n,m),other的张量形状是(k,m,p),那么输出张量的形状将会是(j,k,n,p)。
tensor1 = torch.randn(10, 3, 4, 5)
tensor2 = torch.randn(5, 4)
torch.matmul(tensor1, tensor2).size()
-->torch.Size([10, 3, 4, 4])
tensor1 = torch.randn(10, 3, 4, 5)
tensor2 = torch.randn(1, 5, 4)
torch.matmul(tensor1, tensor2).size()
-->torch.Size([10, 3, 4, 4])
tensor1 = torch.randn(10, 3, 4, 5)
tensor2 = torch.randn(1, 1, 5, 4)
torch.matmul(tensor1, tensor2).size()
-->torch.Size([10, 3, 4, 4])
仔细比较上述三个代码块,其最终的结果是一样的。可以简单记为如果两个维度不一致的话,多出的维度就看作是batch维,相当于在低维度前面增加一个维度。
5、masked_fill
语法为:
Tensor.masked_fill_(mask, value)
参数:
- mask(BoolTensor):布尔掩码
- value(float):用于填充的值。
mask是一个pytorch张量,元素是布尔值,value是要填充的值,填充规则是mask中取值为True的位置对应与需要填充的张量中的位置用value填充。
a = torch.tensor([[0, 8],[ 6, 8],[ 7, 1]
])mask = torch.tensor([[ True, False],[False, False],[False, True]
])
b = a.masked_fill(mask, -1e9)
print(b)
-->tensor([[-1000000000, 8],[ 6, 8],[ 7, -1000000000]])
Pytorch教程之torch.mm、torch.bmm、torch.matmul、masked_fill相关推荐
- pytorch教程之nn.Module类详解——使用Module类来自定义网络层
前言:前面介绍了如何自定义一个模型--通过继承nn.Module类来实现,在__init__构造函数中申明各个层的定义,在forward中实现层之间的连接关系,实际上就是前向传播的过程. 事实上,在p ...
- pytorch教程之nn.Module类详解——使用Module类来自定义模型
pytorch教程之nn.Module类详解--使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思前言:pytorch中对于一般的序列模型,直接使用torch.nn.Se ...
- 【Pytorch】对比matual,mm和bmm函数
pytorch中提供了 matmul.mm和bmm等矩阵的乘法运算功能,但其具体计算细节和场景截然不同,应予以注意和区别. 1. torch.mm 该函数即为矩阵的乘法,torch.mm(tensor ...
- PyTorch教程之DCGAN
原文连接:DCGAN TUTORIAL 简介 本教程通过例程来介绍 DCGANs .我们使用名人照片来训练 GAN 网络使其能够生成新的名人. 这里使用的大部分代码都来自pytorch/example ...
- 【Pytorch学习】torch.mm()torch.matmul()和torch.mul()以及torch.spmm()
目录 1 引言 2 torch.mul(a, b) 3 torch.mm(a, b) 4 torch.matmul() 5 torch.spmm() 参考文献 1 引言 做深度学习过程中免不了使用 ...
- torch中乘法整理,*torch.mul()torch.mv()torch.mm()torch.dot()@torch.mutmal()
目录 *位置乘 torch.mul():数乘 torch.mv():矩阵向量乘法 torch.mm() 矩阵乘法 torch.dot() 点乘积 @操作 torch.matmul() *位置乘 符号* ...
- pytorch matmul和mm和bmm区别
pytorch中matmul和mm和bmm区别 matmul mm bmm 结论 先看下官网上对这三个函数的介绍. matmul mm bmm 顾名思义, 就是两个batch矩阵乘法. 结论 从官方文 ...
- 【pytorch】torch.mm,torch.bmm以及torch.matmul的使用
torch.mm torch.mm是两个矩阵相乘,即两个二维的张量相乘 如下面的例子 mat1 = torch.randn(2,3) print("mat1=", mat1)mat ...
- pytorch中torch.mul、torch.mm/torch.bmm、torch.matmul的区别
预备知识:矩阵的各种乘积 三者对比 torch.mul: 两个输入按元素相乘,内积 分两种情况,默认其中一个是tensor 另一个也是tensor,那么两者就要符合broadcasedt的规则 另一个 ...
最新文章
- 20个案例详解 Pandas 当中的数据统计分析与排序
- if,elif,else的关系 input print int的用法
- 关于序列化的 10 几个问题,你顶得住不?
- 做正确的事,正确的做事
- 用“夜间模式”模式(javascript书签)浏览网页
- Java黑皮书课后题第8章:***8.35(最大块)给定一个元素为0或者1的方阵,编写程序,找到一个元素都为1的最大的子方阵。程序提示用户输入矩阵的行数。然后显示最大的子方阵的第一个元素、行数
- 启动与停止mysql服务的命令
- oracle备份磁盘头,ASM 磁盘头信息备份
- MFC源码解读(一)最原始一个MFC程序,手写不用向导
- 中国系泊系统行业市场供需与战略研究报告
- 完美解决SAMSUNG Mobile USB CDC Composite Device安装失败 三星手机USB驱动失败。
- AES攻击方法 :差分密码分析 boomerang attack飞去来器攻击
- TortoiseGit状态图标不能正常显示的解决办法
- 天猫淘宝越来越难做了,为什么不考虑下跨境电商?
- MyBatis 常见面试题有哪些?
- 最全量子计算硬件概述(建议收藏)
- MIFI与随身wifi、wifi共享软件,玩坏wifi的几种方法
- Mysql数据库日常使用备注
- UBT8:ubuntu安装Java1.8
- 【每日新闻】Gartner:2017年CRM跃升为规模最大、增速最快的软件市场 | 中国科学家发现神奇半导体材料...
热门文章
- RFSoC应用笔记 - RF数据转换器 -09- RFSoC关键配置之RF-DAC内部解析(三)
- 腾讯 美团 字节 抖音 面经
- [置顶] 代码审查工具FxCop建议采用的规则总结
- 网页显示不全的原因css,css 页面显示不全怎么办
- 基于STM32单片机智能花盆控制系统设计(毕业设计资料)
- pycharm安装及添加桌面图标
- 2017服务器cpu性能排行,2017年手机处理器排名_CPU排行榜名单
- 蔡氏混沌电路matlab程序,蔡氏混沌电路简介——Chuaapos;s-Circut.pptx-全文可读
- java8 Stream 使用案例
- storyboard(故事版)新手教程 图文详解 2.为无约束的故事版添加约束