直接参考官方文档,x1和x2是两个输入,A是参数矩阵,如下表达式

但仔细看实现发现这个表达式并不是简单连乘的关系。假设x1(shape是b,n)和x2(shape是b,m)是二维,那么A是个三维tensor(shape是a,n,m)。具体实现时,A先拆成a个(n,m)形状的tensor,x1分别与之矩阵乘后再点乘(公式里的两次乘法),得到了a个(b,m)形状的tensor,然后在axis=1的维度上求和,最终输出成(b,a)的shape,用爱因斯坦表示法就是bn,anm,bm->ba,下面上一下code:

import torch
import torch.nn as nn
import numpy as npl = torch.ones(2,5)
A = torch.ones(3,5,4)
r = torch.ones(2,4)
print(torch.einsum('bn,anm,bm->ba', l, A, r))
print(torch.nn.functional.bilinear(l,r,A))x = torch.ones(2,5)
w = torch.ones(3,5)
print(torch.einsum('ij,kj->ik', x,w))
print(torch.nn.functional.linear(x,w))print('learn nn.Bilinear')
m = nn.Bilinear(5, 4, 3)output = m(l, r)
print(output.size())
arr_output = output.data.cpu().numpy()weight = m.weight.data.cpu().numpy()
bias = m.bias.data.cpu().numpy()
x1 = l.data.cpu().numpy()
x2 = r.data.cpu().numpy()
print(x1.shape, weight.shape, x2.shape, bias.shape)
y = np.zeros((x1.shape[0], weight.shape[0]))
for k in range(weight.shape[0]):buff = np.dot(x1, weight[k])buff = buff * x2buff = np.sum(buff, axis=1)y[:, k] = buff
y += bias
dif = y - arr_output
print(np.mean(np.abs(dif.flatten())))

pytorch中bilinear的理解相关推荐

  1. pytorch中repeat()函数理解

    pytorch中repeat()函数理解 最近在学习过程中遇到了repeat()函数的使用,这里记录一下自己对这个函数的理解. 情况1:repeat参数个数与tensor维数一致时 a = torch ...

  2. pytorch 中 contiguous() 函数理解

    pytorch 中 contiguous() 函数理解 文章目录 pytorch 中 contiguous() 函数理解 引言 使用 contiguous() 后记 文章抄自 Pytorch中cont ...

  3. Pytorch中的contiguous理解

    最近遇到这个函数,但查的中文博客里的解释貌似不是很到位,这里翻译一下stackoverflow上的回答并加上自己的理解. 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新 ...

  4. Pytorch中contiguous()函数理解

    引言 在pytorch中,只有很少几个操作是不改变tensor的内容本身,而只是重新定义下标与元素的对应关系的.换句话说,这种操作不进行数据拷贝和数据的改变,变的是元数据. 会改变元数据的操作是: n ...

  5. Pytorch中dim的理解

    dim的定义 dim 表示维度 x = torch.randn(2, 3, 3)print(x) print(x.size()) print(x.dim()) 输出: tensor([[[-1.694 ...

  6. pytorch中unsqueeze()函数理解

    unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度. 在第一个维度(中括号)的每个元素加中括号 0表示在张量最外层加一个中括号变成第一维. 直接看例子: import torch i ...

  7. pytorch中数组维度的理解

    pytorch中数组维度理解与numpy中类似,pytorch中维度用dim表示,numpy中用axis表示 这里主要想说下维度的变化. dim = x ,表示在第x为上进行操作,那个维度会发生变化. ...

  8. pytorch中的nn.Bilinear

    参考:pytorch中的nn.Bilinear的计算原理详解 代码实现 使用numpy实现Bilinear(来自参考资料): print('learn nn.Bilinear') m = nn.Bil ...

  9. pytorch中网络loss传播和参数更新理解

    相比于2018年,在ICLR2019提交论文中,提及不同框架的论文数量发生了极大变化,网友发现,提及tensorflow的论文数量从2018年的228篇略微提升到了266篇,keras从42提升到56 ...

最新文章

  1. NEO共识节点推荐搭建步骤
  2. 在ComboBox控件中使用嵌入字体。
  3. 授以渔 - Autodesk Forge 学习简谈 - 引言
  4. QT-- MainWindow外的cpp文件调用ui
  5. AWS RDS强制升级的应对之道——版本升级的最佳实践
  6. php 反射 调用私有方法,PHP通过反射方法调用执行类中的私有方法
  7. windows下多个静态库合并的方法
  8. 中国云市场生变:华为云 Q2 份额超 AWS,IaaS+PaaS 迎来整体增长
  9. java 创建水果_简单的java水果商店后台
  10. 家校协同小程序实战教程
  11. PAT 1055 集体照
  12. Python策略模式实例
  13. jsp 中${ } 是什么意思?
  14. 面试题:看数字找规律
  15. 【Spring】SpringIOC容器启动过程源码分析 以及 循环依赖问题
  16. sp_addlinkedserver oracle,SP_addlinkedserver 小结 (oracle,sql server,access,excel)
  17. 用 JavaScript 比较两个日期
  18. PIC单片机的AD数据传输和上位机C#串口界面实时显示
  19. 简单理解在线性函数的估计中bias(偏差)与variance(方差)的影响
  20. python 拆分(几G)的tsv文件为较小的csv文件

热门文章

  1. 红极一时的VB,输给时代,新型开发工具,或成未来
  2. 腾讯云服务器Intel Xeon Cascade Lake 8255C处理器CPU性能评测
  3. android 复制控件,Android长按复制文本功能
  4. 11款开放中文分词引擎大比拼 1
  5. 三星鸿蒙手机,绝版麒麟芯的手机、鸿蒙 OS 的手表,华为 Mate 系列全家桶曝光汇总...
  6. python中计数器函数_Python中使用多个函数的字计数器
  7. 什么是2147483647 ?
  8. 原生js实现淘宝衣服相册悬浮切换效果
  9. Python小游戏——孔明棋
  10. 树莓派3B+ 移动硬盘启动系统,从SD卡直接复制系统