一、函数解释

在torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:

换句话说,就是需要传入5个参数,mat里的每个元素乘以beta,mat1和mat2进行矩阵乘法(左行乘右列)后再乘以alpha,最后将这2个结果加在一起。但是这样说可能没啥概念,接下来博主为大家写上一段代码,大家就明白了~

    def addmm(self, beta=1, mat, alpha=1, mat1, mat2, out=None): # real signature unknown; restored from __doc__"""addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> TensorPerforms a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.The matrix :attr:`mat` is added to the final result.If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a:math:`(m \times p)` tensor, then :attr:`mat` must be:ref:`broadcastable <broadcasting-semantics>` with a :math:`(n \times p)` tensorand :attr:`out` will be a :math:`(n \times p)` tensor.:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between:attr:`mat1` and :attr`mat2` and the added matrix :attr:`mat` respectively... math::out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and:attr:`alpha` must be real numbers, otherwise they should be integers.Args:beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)mat (Tensor): matrix to be addedalpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)mat1 (Tensor): the first matrix to be multipliedmat2 (Tensor): the second matrix to be multipliedout (Tensor, optional): the output tensorExample::>>> M = torch.randn(2, 3)>>> mat1 = torch.randn(2, 3)>>> mat2 = torch.randn(3, 3)>>> torch.addmm(M, mat1, mat2)tensor([[-4.8716,  1.4671, -1.3746],[ 0.7573, -3.9555, -2.8681]])"""pass

二、代码范例

1.先摆出代码,大家可以先复制粘贴运行一下,在之后博主会一一讲解

"""
@author:nickhuang1996
"""
import torchrectangle_height = 3
rectangle_width = 3
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):for j in range(rectangle_width):inputs[i] = i * torch.ones(rectangle_width)
'''
inputs and its transpose
-->inputs   =   tensor([[0., 0., 0.],[1., 1., 1.],[2., 2., 2.]])
-->inputs_t =   tensor([[0., 1., 2.],[0., 1., 2.],[0., 1., 2.]])
'''
print("inputs:\n", inputs)
inputs_t = inputs.t()
print("inputs_t:\n", inputs_t)
'''
inputs_t @ inputs_t    [[0., 1., 2.],       [[0., 1., 2.],          [[0., 3., 6.]=   [0., 1., 2.],   @    [0., 1., 2.],     =     [0., 3., 6.][0., 1., 2.]]        [0., 1., 2.]]           [0., 3., 6.]]
''''''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
print("a:\n", a)
print("b:\n", b)
print("c:\n", c)
print("d:\n", d)print("e:\n", e)
print("f:\n", f)print("g:\n", g)
print("g2:\n", g2)print("h:\n", h)
print("h12:\n", h12)
print("h21:\n", h21)
print("inputs:\n", inputs)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
'''
inputs @ inputs_t       [[0., 0., 0.],       [[0., 1., 2.],          [[0., 0., 0.]=    [1., 1., 1.],   @    [0., 1., 2.],     =     [0., 3., 6.][2., 2., 2.]]        [0., 1., 2.]]           [0., 6., 12.]]
'''
inputs.addmm_(1, -2, inputs, inputs_t)  # In-place
print("inputs:\n", inputs)

2.其中

inputs是一个3×3的矩阵,为

tensor([[0., 0., 0.],[1., 1., 1.],[2., 2., 2.]])

inputs_t也是一个3×3的矩阵,是inputs的转置矩阵,为

tensor([[0., 1., 2.],[0., 1., 2.],[0., 1., 2.]])

inputs_t @ inputs_t

'''
inputs_t @ inputs_t    [[0., 1., 2.],       [[0., 1., 2.],          [[0., 3., 6.]=   [0., 1., 2.],   @    [0., 1., 2.],     =     [0., 3., 6.][0., 1., 2.]]        [0., 1., 2.]]           [0., 3., 6.]]
'''

3.代码中a,b,c和d展示的是完全形式,即标明了位置参数和传入参数。可以看到input这个位置参数可以写在函数的前面,即

torch.addmm(input, mat1, mat2) = inputs.addmm(mat1, mat2)

完成的公式为:

1 × inputs + 1 ×(inputs_t @ inputs_t)

'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
a:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
b:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
c:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
d:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])

4.下面的例子更好了说明了input参数的位置可变性,并且beta和alpha都缺省了:

完成的公式为:

1 × inputs + 1 ×(inputs_t @ inputs_t)

'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
e:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
f:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])

5.加一个参数,实际上是添加了beta这个参数

完成的公式为:

g   = 1 × inputs + 1 ×(inputs_t @ inputs_t)

g2 = 2 × inputs + 1 ×(inputs_t @ inputs_t)

'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
g:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
g2:
tensor([[ 0.,  3.,  6.],[ 2.,  5.,  8.],[ 4.,  7., 10.]])

6.再加一个参数,实际上是添加了alpha这个参数

完成的公式为:

h   = 1 × inputs + 1 ×(inputs_t @ inputs_t)

h12 = 1 × inputs + 2 ×(inputs_t @ inputs_t)

h21 = 2 × inputs + 1 ×(inputs_t @ inputs_t)

'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
h:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
h12:
tensor([[ 0.,  6., 12.],[ 1.,  7., 13.],[ 2.,  8., 14.]])
h21:
tensor([[ 0.,  3.,  6.],[ 2.,  5.,  8.],[ 4.,  7., 10.]])

7.当然,以上的步骤inputs没有变化,还是为

inputs:
tensor([[0., 0., 0.],[1., 1., 1.],[2., 2., 2.]])

*8.addmm_()的操作和addmm()函数功能相同,区别就是addmm_()有inplace的操作,也就是在原对象基础上进行修改,即把改变之后的变量再赋给原来的变量。例如:

inputs的值变成了改变之后的值,不用再去写 某个变量=addmm_() 了,因为inputs就是改变之后的变量!

*inputs@ inputs_t

'''
inputs @ inputs_t       [[0., 0., 0.],       [[0., 1., 2.],          [[0., 0., 0.]=    [1., 1., 1.],   @    [0., 1., 2.],     =     [0., 3., 6.][2., 2., 2.]]        [0., 1., 2.]]           [0., 6., 12.]]
'''

完成的公式为:

inputs   = 1 × inputs - 2 ×(inputs @ inputs_t)

'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
inputs.addmm_(1, -2, inputs, inputs_t)  # In-place
inputs:
tensor([[  0.,   0.,   0.],[  1.,  -5., -11.],[  2., -10., -22.]])

三、代码运行结果

inputs:
tensor([[0., 0., 0.],[1., 1., 1.],[2., 2., 2.]])
inputs_t:
tensor([[0., 1., 2.],[0., 1., 2.],[0., 1., 2.]])
a:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
b:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
c:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
d:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
e:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
f:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
g:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
g2:
tensor([[ 0.,  3.,  6.],[ 2.,  5.,  8.],[ 4.,  7., 10.]])
h:
tensor([[0., 3., 6.],[1., 4., 7.],[2., 5., 8.]])
h12:
tensor([[ 0.,  6., 12.],[ 1.,  7., 13.],[ 2.,  8., 14.]])
h21:
tensor([[ 0.,  3.,  6.],[ 2.,  5.,  8.],[ 4.,  7., 10.]])
inputs:
tensor([[0., 0., 0.],[1., 1., 1.],[2., 2., 2.]])
inputs:
tensor([[  0.,   0.,   0.],[  1.,  -5., -11.],[  2., -10., -22.]])

Pytorch里addmm()和addmm_()的用法详解相关推荐

  1. BigDecimal的用法详解(保留两位小数,四舍五入,数字格式化,科学计数法转数字,数字里的逗号处理)

    一.简介 Java在java.math包中提供的API类BigDecimal,用来对超过16位有效位的数进行精确的运算.双精度浮点型变量double可以处理16位有效数.在实际应用中,需要对更大或者更 ...

  2. python中permute_PyTorch中permute的用法详解

    PyTorch中permute的用法详解 permute(dims) 将tensor的维度换位. 参数:参数是一系列的整数,代表原来张量的维度.比如三维就有0,1,2这些dimension. 例: i ...

  3. oracle中的exists 和 not exists 用法详解

    from:http://blog.sina.com.cn/s/blog_601d1ce30100cyrb.html oracle中的exists 和 not exists 用法详解 (2009-05- ...

  4. MultiByteToWideChar和WideCharToMultiByte用法详解

    //======================================================================== //TITLE: //    MultiByteT ...

  5. linux mount命令参数及用法详解

    linux mount命令参数及用法详解 非原创,主要来自 http://www.360doc.com/content/13/0608/14/12600778_291501907.shtml. htt ...

  6. js数组中foEach和map的用法详解 jq中的$.each和$.map

    数组中foEach和map的用法详解 相同点: 1.都是循环遍历数组(仅仅是数组)中的每一项. 2.forEach() 和 map() 里面每一次执行匿名函数都支持3个参数:数组中的当前项value, ...

  7. SVN switch 用法详解 (ZZ)

    SVN switch 用法详解 (ZZ)  http://www.cnblogs.com/dabaopku/archive/2011/05/21/2052820.html 确实,以前不会用switch ...

  8. 68.connect-flash 用法详解 req,flash()

    转自:http://yunkus.com/connect-flash-usage/ connect-flash 用法详解  前端工具  2016-10-05  2016-10-05  朝夕熊  11 ...

  9. 教程-Delphi中Spcomm使用属性及用法详解

    Delphi中Spcomm使用属性及用法详解 Delphi是一种具有 功能强大.简便易用和代码执行速度快等优点的可视化快速应用开发工具,它在构架企业信息系统方面发挥着越来越重要的作用,许多程序员愿意选 ...

最新文章

  1. 兰州大学萃英学院计算机,兰州大学萃英学院.PDF
  2. python打出由边框包围的_python – 提取边框并将其保存为图像
  3. Linux进阶之路————scp指令介绍与演示
  4. 【编码问题】‘utf-8‘ codec can‘t decode byte 0xce in position 0
  5. Python使用集合实现素数筛选法
  6. ios 简单的计时器游戏 NSUserDefaults NSDate NSTimer
  7. ros中web端通过 按钮加载本地静态 pgm 地图显示在canvas画布中
  8. java并发圣经,差距不止一星半点!Github星标85K的性能优化法则圣经
  9. java项目根目录_获取java项目的根目录
  10. Tensorflow之softmax应用实例
  11. 反向翻译back-translations
  12. Rockchip | Rockchip U-Boot的获取与构建
  13. Python-----函数详解(上篇)(附小项目实战)
  14. Android 系统自动获取来电/短信/提示铃声
  15. 返回到上一个页面并刷新页面
  16. WPViewPDF Delphi 和 .NET 的 PDF 查看组件
  17. 仿小米商城页面和简单效果
  18. 使用BERT fine-tuning 用于推特情感分析
  19. 2022GCVC全球人工智能视觉产业与技术大会在青岛圆满落幕
  20. js防止刷访问量_优化js脚本设计,防止浏览器假死

热门文章

  1. xiuno论坛部署及常见问题处理
  2. 给一个不多于5位的正整数,要求: 1.求出它是几位数; 2.分别输出每一位数字; 3.按逆序输出各位数字;
  3. ip a命令显示的UP与LOWER_UP的区别
  4. C. Alice and the Cake
  5. STM32 —— OLED 屏幕入门
  6. 探索Whisper语音识别
  7. GitHub 狂飙 30K+star 面试现场, 专为程序员面试打造, 现已开源可下载
  8. Adobe Premiere基础特效(卡点和转场)(四)
  9. CAPI 初探及使用小结(1)
  10. 从“穷逼VIP”论注释规范,你见过哪些奇葩的注释?