PyTorch中的两个张量的乘法可以分为两种:

  1. 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过torch.mul函数(或者∗*∗运算符)实现

  2. 两个张量矩阵相乘(Matrix product),在PyTorch中可以通过torch.matmul函数实现

本文主要介绍两个张量的矩阵相乘。

语法为:

torch.matmul(input, other, out = None)

函数对input和other两个张量进行矩阵相乘。为了方便后续的讲解,将input记为a,将other记为b。

点积在数学中,又称数量积,是指接受在实数R上的两个1D张量并返回一个实数值0D张量的二元运算。
若1D张量a=[1,2],1D张量b=[3,4],则:
a⋅\cdot⋅b=1×\times× 3 + 2×\times× 4 = 11

  1. 若a为1D张量,b为1D张量,则返回两个张量的点积,则返回两个张量的点积(此时的torch.matmul不支持out参数)

举例如下:

import torch
a = torch.tensor([1, 2])
b = torch.tensor([3, 4])
result = torch.matmul(a, b)
print(result)

结果为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor(11)
  1. 若a为2D张量,b为2D张量,则返回两个张量的矩阵乘积。

矩阵相乘最重要的方法是一般矩阵乘积,它只有在第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同时才有意义。
若2D张量a=[[1,2],[3,4]],2D张量b=[[5,6,7],[8,9,10]],则:
a×\times× b=[[21,24,27],[47,54,61]],2D张量a的形状为(2,2),而2D张量b的形状(2,3)。矩阵乘积的运算规则:

举例为:

import torch
a = torch.tensor([[1, 2],[3,4]])
b = torch.tensor([[5,6,7],[8,9,10]])
result = torch.matmul(a, b)
print(result)

结果展示为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor([[21, 24, 27],[47, 54, 61]])
  1. 若a为1D张量,b为2D张量,torch.matmul函数:

首先,在1D张量a的前面插入一个长度为1的新维度变成2D张量;

然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;

最后,将矩阵乘积结果中长度为1的维度(前面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果。

import torch
a = torch.tensor([1, 2])
b = torch.tensor([[5, 6, 7],[8, 9, 10]])
result = torch.matmul(a, b)
print(result, result.shape)

结果为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor([21, 24, 27]) torch.Size([3])

简单来说,先将1D张量a扩展成2D张量,满足矩阵乘积的条件下,将两个2D张量进行矩阵乘积的运算。

此时得到的形状是(1,3)的2D张量,最后将前面插入长度为1的新维度删除即为最终torch.matmul(a, b)函数返回的结果。

  1. 若a为2D张量,b为1D张量,torch.matmul函数:

首先,在1D张量b的后面插入一个长度为1的新维度变成2D张量;

然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;

最后,将矩阵乘积结果中长度为1的维度(后面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果;

import torch
b = torch.tensor([1, 2, 3])
a = torch.tensor([[5, 6, 7],[8, 9, 10]])
result = torch.matmul(a, b)
print(result, result.shape)

结果展示为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"
tensor([38, 56]) torch.Size([2])

其中:

38 = 15+26+3*7

56 = 18+29+3*10

PyTorch中的matmul函数详解相关推荐

  1. PyTorch入门笔记-matmul函数详解

    PyTorch入门笔记-matmul函数详解 本文转载自:PyTorch入门笔记-matmul函数详解 - 腾讯云开发者社区-腾讯云 (tencent.com) 41409)]

  2. PyTorch中torch.norm函数详解

    torch.norm() 是 PyTorch 中的一个函数,用于计算输入张量沿指定维度的范数.具体而言,当给定一个输入张量 x 和一个整数 p 时,torch.norm(x, p) 将返回输入张量 x ...

  3. PyTorch中的topk函数详解

    听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index. 用法 torch.topk(input, k, dim=None, largest=True, sor ...

  4. timm 视觉库中的 create_model 函数详解

    timm 视觉库中的 create_model 函数详解 最近一年 Vision Transformer 及其相关改进的工作层出不穷,在他们开源的代码中,大部分都用到了这样一个库:timm.各位炼丹师 ...

  5. 【Pytorch】torch.argmax 函数详解

    文章目录 一.一个参数时的 torch.argmax 函数 1. 介绍 2. 实例 二.多个参数时的 torch.argmax 函数 1. 介绍 2. 实例 实例1:二维矩阵 实例2:三维矩阵 实例3 ...

  6. python getattr_Python中的getattr()函数详解:

    标签:Python中的getattr()函数详解: getattr(object, name[, default]) -> value Get a named attribute from an ...

  7. python input函数详解_对Python3中的input函数详解

    下面介绍python3中的input函数及其在python2及pyhton3中的不同. python3中的ininput函数,首先利用help(input)函数查看函数信息: 以上信息说明input函 ...

  8. Python中的bbox_overlaps()函数详解

    Python中的bbox_overlaps()函数详解 想要编写自己的目标检测算法,就需要掌握bounding box(边界框)之间的关系.在这之中,bbox_overlaps()函数是一个非常实用的 ...

  9. java的匿名函数_JAVA语言中的匿名函数详解

    本文主要向大家介绍了JAVA语言中的匿名函数详解,通过具体的内容向大家展示,希望对大家学习JAVA语言有所帮助. 一.使用匿名内部类 匿名内部类由于没有名字,所以它的创建方式有点儿奇怪.创建格式如下: ...

最新文章

  1. java postdelayed_你真的懂Handler.postDelayed()的原理吗?
  2. MySQL笔记4:desc命令的两个用法
  3. Android NDK 内存泄露检测
  4. 生日祝福(HTML+CSS+JavaScript+jQuery)
  5. 【STM32】定时器程序
  6. HP proliant服务器从usb启动
  7. Qt工作笔记-让界面飞一会(让界面旋转出来)
  8. java中json对象去重复_如何忽略Java中JSON对象的多个属性?
  9. 【Linux】Linux中文本编辑器和系统管理命令
  10. 用R进行文本挖掘与分析:分词、画词云
  11. python request返回的响应_Python爬虫库requests获取响应内容、响应状态码、响应头...
  12. 使用Tenorshare iCareFone for mac如何对iPhone进行系统修复?
  13. 拓端tecdat|R语言计算资本资产定价模型(CAPM)中的Beta值和可视化
  14. 玩转BIOS与注册表
  15. Win10下windows mobile device center设备中心连接不上无法启动
  16. 打字母案例完整版(C#)
  17. 马云坦然不懂计算机,来自马云的绝望:三角函数让我彻底失去学数学的信心
  18. 经度从0-360更改为-180到180
  19. Mybatis源码解析《二》
  20. Several ports (8005, 8080) required by Tomcat v9.0 Server at localhost are already in use

热门文章

  1. 为什么在Notepad++里面打开图片是乱码?底层原理是什么?
  2. c语言老鼠走迷宫原理,C语言算法(3) 老鼠走迷宫
  3. 一名非计算机专业大学牲想说的一些话
  4. 电脑键盘各个按键作用讲解
  5. 前端get请求接收后端传来的二进制文件流blob实现下载功能,解决下载文件打不开问题
  6. 树莓派安装Ubuntu Mate解决无法连接WiFi问题,并部署Ros系统
  7. 单板计算机(SBC)市场现状研究分析与发展前景预测报告
  8. Roson讲Qt#14 设置滚动条样式
  9. php 门户网站,PHP的CMS系统整理
  10. 网站css内图片下载脚本