torch.Tensor有4种常见的乘法:*, torch.mul, torch.mm, torch.matmul. 本文抛砖引玉,简单叙述一下这4种乘法的区别,具体使用还是要参照官方文档。

点乘

a与b做*乘法,原则是如果a与b的size不同,则以某种方式将a或b进行复制,使得复制后的a和b的size相同,然后再将a和b做element-wise的乘法。

下面以*标量和*一维向量为例展示上述过程。

* 标量

Tensor与标量k做*乘法的结果是Tensor的每个元素乘以k(相当于把k复制成与lhs大小相同,元素全为k的Tensor).

>>> a = torch.ones(3,4)

>>> a

tensor([[1., 1., 1., 1.],

[1., 1., 1., 1.],

[1., 1., 1., 1.]])

>>> a * 2

tensor([[2., 2., 2., 2.],

[2., 2., 2., 2.],

[2., 2., 2., 2.]])

* 一维向量

Tensor与行向量做*乘法的结果是每列乘以行向量对应列的值(相当于把行向量的行复制,成为与lhs维度相同的Tensor). 注意此时要求Tensor的列数与行向量的列数相等。

>>> a = torch.ones(3,4)

>>> a

tensor([[1., 1., 1., 1.],

[1., 1., 1., 1.],

[1., 1., 1., 1.]])

>>> b = torch.Tensor([1,2,3,4])

>>> b

tensor([1., 2., 3., 4.])

>>> a * b

tensor([[1., 2., 3., 4.],

[1., 2., 3., 4.],

[1., 2., 3., 4.]])

Tensor与列向量做*乘法的结果是每行乘以列向量对应行的值(相当于把列向量的列复制,成为与lhs维度相同的Tensor). 注意此时要求Tensor的行数与列向量的行数相等。

>>> a = torch.ones(3,4)

>>> a

tensor([[1., 1., 1., 1.],

[1., 1., 1., 1.],

[1., 1., 1., 1.]])

>>> b = torch.Tensor([1,2,3]).reshape((3,1))

>>> b

tensor([[1.],

[2.],

[3.]])

>>> a * b

tensor([[1., 1., 1., 1.],

[2., 2., 2., 2.],

[3., 3., 3., 3.]])

* 矩阵

经Arsmart在评论区提醒,增补一个矩阵 * 矩阵的例子,感谢Arsmart的热心评论!

如果两个二维矩阵A与B做点积A * B,则要求A与B的维度完全相同,即A的行数=B的行数,A的列数=B的列数

>>> a = torch.tensor([[1, 2], [2, 3]])

>>> a * a

tensor([[1, 4],

[4, 9]])

broadcast

点积是broadcast的。broadcast是torch的一个概念,简单理解就是在一定的规则下允许高维Tensor和低维Tensor之间的运算。broadcast的概念稍显复杂,在此不做展开,可以参考官方文档关于broadcast的介绍. 在torch.matmul里会有关于broadcast的应用的一个简单的例子。

这里举一个点积broadcast的例子。在例子中,a是二维Tensor,b是三维Tensor,但是a的维度与b的后两位相同,那么a和b仍然可以做点积,点积结果是一个和b维度一样的三维Tensor,运算规则是:若c = a * b, 则c[i,*,*] = a * b[i, *, *],即沿着b的第0维做二维Tensor点积,或者可以理解为运算前将a沿着b的第0维也进行了expand操作,即a = a.expand(b.size()); a * b。

>>> a = torch.tensor([[1, 2], [2, 3]])

>>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]])

>>> a * b

tensor([[[ 1, 4],

[ 4, 9]],

[[-1, -4],

[-4, -9]]])

>>> b * a

tensor([[[ 1, 4],

[ 4, 9]],

[[-1, -4],

[-4, -9]]])

其实,上面提到的二维Tensor点积标量、二维Tensor点积行向量,都是发生在高维向量和低维向量之间的,也可以看作是broadcast.

torch.mul

官方文档关于torch.mul的介绍. 用法与*乘法相同,也是element-wise的乘法,也是支持broadcast的。

下面是几个torch.mul的例子.

乘标量

>>> a = torch.ones(3,4)

>>> a

tensor([[1., 1., 1., 1.],

[1., 1., 1., 1.],

[1., 1., 1., 1.]])

>>> a * 2

tensor([[2., 2., 2., 2.],

[2., 2., 2., 2.],

[2., 2., 2., 2.]])

乘行向量

>>> a = torch.ones(3,4)

>>> a

tensor([[1., 1., 1., 1.],

[1., 1., 1., 1.],

[1., 1., 1., 1.]])

>>> b = torch.Tensor([1,2,3,4])

>>> b

tensor([1., 2., 3., 4.])

>>> torch.mul(a, b)

tensor([[1., 2., 3., 4.],

[1., 2., 3., 4.],

[1., 2., 3., 4.]])

乘列向量

>>> a = torch.ones(3,4)

>>> a

tensor([[1., 1., 1., 1.],

[1., 1., 1., 1.],

[1., 1., 1., 1.]])

>>> b = torch.Tensor([1,2,3]).reshape((3,1))

>>> b

tensor([[1.],

[2.],

[3.]])

>>> torch.mul(a, b)

tensor([[1., 1., 1., 1.],

[2., 2., 2., 2.],

[3., 3., 3., 3.]])

乘矩阵

例1:二维矩阵 mul 二维矩阵

>>> a = torch.tensor([[1, 2], [2, 3]])

>>> torch.mul(a,a)

tensor([[1, 4],

[4, 9]])

例2:二维矩阵 mul 三维矩阵(broadcast)

>>> a = torch.tensor([[1, 2], [2, 3]])

>>> b = torch.tensor([[[1,2],[2,3]],[[-1,-2],[-2,-3]]])

>>> torch.mul(a,b)

tensor([[[ 1, 4],

[ 4, 9]],

[[-1, -4],

[-4, -9]]])

torch.mm

官方文档关于torch.mm的介绍. 数学里的矩阵乘法,要求两个Tensor的维度满足矩阵乘法的要求.

例子:

>>> a = torch.ones(3,4)

>>> b = torch.ones(4,2)

>>> torch.mm(a, b)

tensor([[4., 4.],

[4., 4.],

[4., 4.]])

torch.matmul

官方文档关于torch.matmul的介绍. torch.mm的broadcast版本.

例子:

>>> a = torch.ones(3,4)

>>> b = torch.ones(5,4,2)

>>> torch.matmul(a, b)

tensor([[[4., 4.],

[4., 4.],

[4., 4.]],

[[4., 4.],

[4., 4.],

[4., 4.]],

[[4., 4.],

[4., 4.],

[4., 4.]],

[[4., 4.],

[4., 4.],

[4., 4.]],

[[4., 4.],

[4., 4.],

[4., 4.]]])

同样的a和b,使用torch.mm相乘会报错

>>> torch.mm(a, b)

Traceback (most recent call last):

File "", line 1, in

RuntimeError: matrices expected, got 2D, 3D tensors at /pytorch/aten/src/TH/generic/THTensorMath.cpp:2065

到此这篇关于详解torch.Tensor的4种乘法的文章就介绍到这了,更多相关torch.Tensor 乘法内容请搜索菜鸟教程www.piaodoo.com以前的文章或继续浏览下面的相关文章希望大家以后多多支持菜鸟教程www.piaodoo.com!

标签:tensor,Tensor,python,torch,broadcast,ones,乘法

来源: https://www.cnblogs.com/piaodoo/p/13936333.html

matmul torch 详解_python基础教程详解torch.Tensor的4种乘法相关推荐

  1. python自定义函数详解_python基础教程之自定义函数介绍

    函数最重要的目的是方便我们重复使用相同的一段程序. 将一些操作隶属于一个函数,以后你想实现相同的操作的时候,只用调用函数名就可以,而不需要重复敲所有的语句. 函数的定义 首先,我们要定义一个函数, 以 ...

  2. python雷达图详解_Python基础教程 - matplotlib实现雷达图和柱状图

    原标题:Python基础教程 - matplotlib实现雷达图和柱状图 Python基础教程记录 - 使用matplotlib实现雷达图和柱状图. 注:主要是设置add_subplot(133),分 ...

  3. python变量详解_python基础教程-03-变量详解

    变量就像一个小罐子,里面是存放着各种数据类型的数据,并且在程序运行过程中会发生变化.变量名在一个工作空间内是唯一的,通过变量的名字就能找到对应的数据. 变量的赋值 变量的赋值就可以理解为往小罐子里存放 ...

  4. python布尔值的作用_Python基础教程详解布尔变量的作用

    布尔值也叫真值,在Python开发(http://www.maiziedu.com/course/python-px/)中所有的值都被解释为真值,标准的真值为true和false.那么布尔变量在Pyt ...

  5. python 字符串替换_Python基础教程,第四讲,字符串详解

    本节课主要和大家一起学习一下Python中的字符串操作,对字符串的操作在开发工作中的使用频率比较高,所以单独作为一课来讲. 学完此次课程,我能做什么? 学完本次课程后,我们将学会如何创建字符串,以及如 ...

  6. 判断字符串格式_Python基础教程,第四讲,字符串详解

    本节课主要和大家一起学习一下Python中的字符串操作,对字符串的操作在开发工作中的使用频率比较高,所以单独作为一课来讲. 学完此次课程,我能做什么? 学完本次课程后,我们将学会如何创建字符串,以及如 ...

  7. 计算机基础知识及其详解,计算机基础知识详解:计算机入门基础知识

    能力训练网权威发布计算机基础知识详解,更多计算机基础知识详解相关信息请访问少儿综合素质训练网. [导语]以下是大范文网整理的计算机基础知识详解,欢迎阅读! 1.第一台计算机-ENIAC 大家只要知道第 ...

  8. python协程详解_python协程详解

    原博文 2019-10-25 10:07 − # python协程详解 ![python协程详解](https://pic2.zhimg.com/50/v2-9f3e2152b616e89fbad86 ...

  9. python基础教程博客_Python基础教程_Python入门知识

    Python基础教程频道为编程初学者提供入门前的所有基础知识,必须要掌握的一些PYTHON基础语法语句,基本的数据类型. 让大家可以更快速.更容易理解的的方式掌握Python编程所需要的基础知识,灵活 ...

  10. python 包用法_Python 基础教程之包和类的用法

    Python 基础教程之包和类的用法 这篇文章主要介绍了 Python 基础教程之包和类的用法的相关资料, 需要的朋友可以参考下 Python 是一种面向对象.解释型计算机程序设计语言,由 Guido ...

最新文章

  1. 酷!一键构建我自己的PHP框架的开发环境
  2. oracle10.2 管理工具,Oracle 10.2.0.5 EM管理器的BUG
  3. Application Architecture - Table Data Gateway
  4. 【Android 安全】DEX 加密 ( Application 替换 | 分析 Service 组件中调用 getApplication() 获取的 Application 是否替换成功 )
  5. Docker容器日志集中收集(client-server模式)
  6. 云炬Android开发笔记 4单Activity界面架构设计与验证
  7. fcntl函数-文件控制函数
  8. apache 工作模式prefork进程模式和worker线程模式参式详解和推荐设置
  9. MFC实现Windows锁屏
  10. Linux环境下查看IP不显示IPv4地址
  11. zt:tcpdump抓包对性能的影响
  12. Memcached概述
  13. 【EOS】2.4 EOS数据存储
  14. windows10下 mysql5.7.24 免安装版 安装笔记
  15. jvisualvm (Java VisualVM)
  16. 人体姿态识别 tensorflow版本
  17. CVE和NVD的关系
  18. sql 一张表递归_查看我的递归视觉指南(因为一张图片价值1,000字)
  19. android 高德卫星地图数据,白马地图 Bmap for Android v7.3.81 强大高德百度地图应用|张小北...
  20. 视频教程:Java七大外企经典面试套路之基础篇

热门文章

  1. 行政管理专业考计算机研究生分数,行政管理学,考研,历年分数线是多少?
  2. 单纯学python能干啥_如何高效学习Python编程,转行的朋友可以过来看看,单纯的经验分享...
  3. Spring中IOC容器概念
  4. 第一冲刺阶段意见汇总
  5. EasyUI Dialog 对话框
  6. 语音识别技术突飞猛进 终有一天将超过人
  7. Docker 三剑客之 Docker Swarm
  8. 持续集成并不能消除 Bug,而是让它们非常容易发现和改正(转)
  9. MySql中PreparedStatement对象与Statement对象
  10. tika获取压缩文件内容