API文档
https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html

文档说的过于晦涩,下面以实际例子来研究一下

样例1 axes=0 (二维为叉乘运算)

m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=0)print("============================")
print(m3.shape)
print(m3)

运行结果

实际计算过程是

  1. m1 reshape为(6,1) m2 reshape为(1,6)
  2. m1 dot m2 得到m3(6,6)
  3. 再把m3reshape为两个输入矩阵shape的串联(2,3,3,2)

更简单的理解就是,把任意维度的两个张量m1(m1s1,m1s2,…m1sn),m2(m2s1,m2s2,…m2sn)压“扁”为一个(N,1)和(1,N)矩阵,然后两个矩阵相乘,再把这个结果矩阵(N,N)做reshape,reshape的结果为两个原始张量shape的串联,即为(m1s1,m1s2,…m1sn,m2s1,m2s2,…m2sn)


样例2 axes=1 (二维为点乘运算)

m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
# 等同于np.dot(m1,m2)
m3 = np.tensordot(m1,m2, axes=1)print("============================")
print(m3.shape)
print(m3)

运行结果,它实际等同于np.dot(m1,m2)

与axes=0的区别在于,它要求m1的最后一维与m2的第一位必须“一致”,这类似于二维矩阵的点乘,相乘后的两个维度“消失”


样例3 axes=2 (default)

m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((3,2))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
# 注意这里的shape是一样的
m3 = np.tensordot(m1,m2, axes=2)print("============================")
print(m3.shape)
print(m3)

类似于axes=1,这里要求m1最后二个维度必须和m1的开始二个维度“一致”,其他维度任意。然后类似与做矩阵乘法,就是对axes=1做了+1扩展。
那么如果维度更高怎么控制呢? 这个就是后面采用tuple和list的用法了


样例4 axis为tuple

m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=(0,1))print("============================")
print(m3.shape)
print(m3)


这里的tuple含义是,tuple的第一个值是指定m1的第几个维度,第二个值指定m2的第几个维度。
如这里的(0,1),0表示m1的shape(2,3)中的第0个维度,2;1表示m2的shape(3,2)中的第1个维度,2;
注意,这两个指定的维度值必须相等,否则报错。
另一方面可见,这里tuple只能两个元素.

计算维度大于2的情况

m1.shape=(4,1,3,2)
m2.shape=(2,3,1,2)

举例:

  • 如果使用np.tensordot(m1,m2, axes=(-1,0))使用的就是m1最后一维2,与m2第0维2
    输出的结果shape就是除掉这两位的串联,(4,1,3,3,1,2)
    结果计算与之前的算法类似,首先是将m1转换为 (413,2)的矩阵,m2转换为(2,312)的矩阵,然后两者矩阵相乘,结果再reshape为(4,1,3,3,1,2)

  • 如果使用np.tensordot(m1,m2, axes=(2,1))使用的就是m1第3维3,与m2第1维3
    输出的结果shape就是除掉这两位的串联,(4,1,2,2,1,2)
    这里的计算稍微多一步:
    axes的重新拼接

    对于m1

    去掉拿掉的那一维,其他维度“凑紧”后在最后再补上这一维,见下图示例

    所以重新拼接后的axes为 (0,1,3,2)

    对于m2


重新拼接后的axes为 (1,0,2,3)

拼接完成后,分别做transpose

m1.transpose(0,1,3,2)
m2.transpose(1,0,2,3)

然后m1 reshape为(4*1*2, 3), m2 reshape为(3,2*1*2)在做矩阵相乘。
相乘后的结果做reshape(4,1,2,2,1,2)


样例5 axis为list

m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3,1))
m2=m2.reshape((3,2,1))print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=[(1,0),(0,1)])print("============================")
print(m3.shape)
print(m3)

计算方式同tuple,只不过每个张量维度的选择变成了多个,list中有两个tuple。第一个tuple中指定了m1的维度,第二个tuple中指定了m2的维度

【numpy】tensordot的用法研究相关推荐

  1. (Python)numpy的argmax用法

    (Python)numpy的argmax用法 解释 还是从一维数组出发.看下面的例子. import numpy as np a = np.array([3, 1, 2, 4, 6, 1]) prin ...

  2. numpy.ix_的用法详解

    import numpy as np x=np.arange(32).reshape((8,4)) print (x[np.ix_([1,5,7,2],[0,3,1,2])]) x的矩阵是: [[ 0 ...

  3. NumPy之pad()用法

    Numpy之pad用法 1.函数说明 1.1.语法 1.2.参数解释 2.一.二维数组填充 2.1.一维数组填充 2.2.二维数组填充 3.三维数组填充 3.1.通道在前[CHW] 3.2.通道在后[ ...

  4. python Numpy 的基础用法以及 matplotlib 基础图形绘制

    python Numpy 的基础用法以及 matplotlib 基础图形绘制 1. 环境搭建 1.1 Anaconda ​ anaconda 集成了数据分析,科学计算相关的所有常用安装包,比如Numo ...

  5. python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析

    两个大小都是\(N \times N\)的矩阵相乘,如果使用naive的算法,时间复杂度应该是\(\mathcal{O}(N^3)\),如果使用一些高级的算法,可以使幂指数降到3以下.对于一般情况的矩 ...

  6. numpy.random.choice用法

    python,numpy中np.random.choice()的用法详解及其参考代码 处理数据时经常需要从数组中随机抽取元素,这时候就需要用到np.random.choice().然而choice用法 ...

  7. Python:一篇文章掌握Numpy的基本用法

    前言 Numpy是一个开源的Python科学计算库,它是python科学计算库的基础库,许多其他著名的科学计算库如Pandas,Scikit-learn等都要用到Numpy库的一些功能. 本文主要内容 ...

  8. 关于Python里的super用法研究

    转自:http://blog.csdn.net/johnsonguo/article/details/585193 虽然我现在没看懂,不过先转一个,以后有时间了再看. 一.问题的发现与提出 在Pyth ...

  9. numpy.cov()和numpy.var()的用法

    在PCA中涉及到了方差var和协方差cov,这里简单总结下. 首先:均值,样本方差,样本协方差的公式为 均值:X¯=1N∑Ni=1Xi\bar{X}=\frac{1}{N} \sum_{i=1}^{N ...

最新文章

  1. 先定一个小目标,自己封装个ajax
  2. easyui左侧导航菜单右侧载入百度地图项目框架
  3. Neko and Aki's Prank
  4. mybatis批量插入oracle报表达式,mybatis oracle两种方式批量插入数据
  5. 星辰大海:阿里数据体验技术揭秘!
  6. Atitit.操作注册表 树形数据库 注册表的历史 java版本类库总结
  7. java学习(48):带参带返回
  8. 线程基础知识_线程生命周期_从JVM内存结构看多线程下的共享资源
  9. 迅雷游戏盒子下载|迅雷游戏盒子下载
  10. 3399元起!120Hz瞳孔屏+65W超级闪充,一加 8T今日发布
  11. 做生意失败是一种什么体验?创业中有哪些雷区需要注意?
  12. Linux下MariaDB 安装及root密码设置(修改)
  13. 计算机平均值的快捷键,Excel用快捷键和选项求平均值,且能一次对多行多列批量快速求平均值...
  14. 实现Promise的resolve/reject/then/all/race/finally/catch方法
  15. Unity 3D课程总结
  16. 最近遇到使用Zing.DLL生成条码,但是打印出来不清晰的问题,解决代码记录一下,
  17. HDOJ 4239 - Decoding EDSAC Data 模拟
  18. R语言中的函数1:outer(张量积)
  19. 让div中的p标签文字垂直居中的方法
  20. 海龟如何保留米帝手机号

热门文章

  1. 校正光学系统像差原则
  2. Cell | 分子胶水的兴起
  3. Nat. Commun. | 深度学习探索可编程RNA开关
  4. JAVA连接SQL Server数据库的端口配置操作步骤
  5. 2019 年回顾:生物学年
  6. GEO数据挖掘(1)引出
  7. Nature 子刊:加州大学Banfield组揭示CPR细菌和DPANN古菌多样性及与低温TEM下宿主互作关系...
  8. WebMGA:超快的基因组序列聚类注释在线工具
  9. 回归模型和时间序列模型中的MAPE指标是什么?MAPE指标解读、MAPE越大越好还是越小越好、使用MAPE指标的注意事项
  10. Python使用openCV把原始彩色图像转化为灰度图、使用OpenCV把图像二值化(仅仅包含黑色和白色的简化版本)、基于自适应阈值预处理(adaptive thresholding)方法