【numpy】tensordot的用法研究
API文档
https://numpy.org/doc/stable/reference/generated/numpy.tensordot.html
文档说的过于晦涩,下面以实际例子来研究一下
样例1 axes=0 (二维为叉乘运算)
m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=0)print("============================")
print(m3.shape)
print(m3)
运行结果
实际计算过程是
- m1 reshape为(6,1) m2 reshape为(1,6)
- m1 dot m2 得到m3(6,6)
- 再把m3reshape为两个输入矩阵shape的串联(2,3,3,2)
更简单的理解就是,把任意维度的两个张量m1(m1s1,m1s2,…m1sn),m2(m2s1,m2s2,…m2sn)压“扁”为一个(N,1)和(1,N)矩阵,然后两个矩阵相乘,再把这个结果矩阵(N,N)做reshape,reshape的结果为两个原始张量shape的串联,即为(m1s1,m1s2,…m1sn,m2s1,m2s2,…m2sn)
样例2 axes=1 (二维为点乘运算)
m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
# 等同于np.dot(m1,m2)
m3 = np.tensordot(m1,m2, axes=1)print("============================")
print(m3.shape)
print(m3)
运行结果,它实际等同于np.dot(m1,m2)
与axes=0的区别在于,它要求m1的最后一维与m2的第一位必须“一致”,这类似于二维矩阵的点乘,相乘后的两个维度“消失”
样例3 axes=2 (default)
m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((3,2))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
# 注意这里的shape是一样的
m3 = np.tensordot(m1,m2, axes=2)print("============================")
print(m3.shape)
print(m3)
类似于axes=1,这里要求m1最后二个维度必须和m1的开始二个维度“一致”,其他维度任意。然后类似与做矩阵乘法,就是对axes=1做了+1扩展。
那么如果维度更高怎么控制呢? 这个就是后面采用tuple和list的用法了
样例4 axis为tuple
m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3))
m2=m2.reshape((3,2))print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=(0,1))print("============================")
print(m3.shape)
print(m3)
这里的tuple含义是,tuple的第一个值是指定m1的第几个维度,第二个值指定m2的第几个维度。
如这里的(0,1),0表示m1的shape(2,3)中的第0个维度,2;1表示m2的shape(3,2)中的第1个维度,2;
注意,这两个指定的维度值必须相等,否则报错。
另一方面可见,这里tuple只能两个元素.
计算维度大于2的情况
如
m1.shape=(4,1,3,2)
m2.shape=(2,3,1,2)
举例:
如果使用
np.tensordot(m1,m2, axes=(-1,0))
使用的就是m1最后一维2,与m2第0维2
输出的结果shape就是除掉这两位的串联,(4,1,3,3,1,2)
结果计算与之前的算法类似,首先是将m1转换为 (413,2)的矩阵,m2转换为(2,312)的矩阵,然后两者矩阵相乘,结果再reshape为(4,1,3,3,1,2)如果使用
np.tensordot(m1,m2, axes=(2,1))
使用的就是m1第3维3,与m2第1维3
输出的结果shape就是除掉这两位的串联,(4,1,2,2,1,2)
这里的计算稍微多一步:
axes的重新拼接对于m1
去掉拿掉的那一维,其他维度“凑紧”后在最后再补上这一维,见下图示例
所以重新拼接后的axes为 (0,1,3,2)对于m2
m1.transpose(0,1,3,2)
m2.transpose(1,0,2,3)
然后m1 reshape为(4*1*2, 3), m2 reshape为(3,2*1*2)在做矩阵相乘。
相乘后的结果做reshape(4,1,2,2,1,2)
样例5 axis为list
m1=np.array([0.1,0.2,0.3,0.4,0.5,0.6])
m2=np.array([-0.1, -0.1,0.1, 0.1,0.2,0.2])
m1=m1.reshape((2,3,1))
m2=m2.reshape((3,2,1))print(m1.shape)
print(m2.shape)
m3 = np.tensordot(m1,m2, axes=[(1,0),(0,1)])print("============================")
print(m3.shape)
print(m3)
计算方式同tuple,只不过每个张量维度的选择变成了多个,list中有两个tuple。第一个tuple中指定了m1的维度,第二个tuple中指定了m2的维度
【numpy】tensordot的用法研究相关推荐
- (Python)numpy的argmax用法
(Python)numpy的argmax用法 解释 还是从一维数组出发.看下面的例子. import numpy as np a = np.array([3, 1, 2, 4, 6, 1]) prin ...
- numpy.ix_的用法详解
import numpy as np x=np.arange(32).reshape((8,4)) print (x[np.ix_([1,5,7,2],[0,3,1,2])]) x的矩阵是: [[ 0 ...
- NumPy之pad()用法
Numpy之pad用法 1.函数说明 1.1.语法 1.2.参数解释 2.一.二维数组填充 2.1.一维数组填充 2.2.二维数组填充 3.三维数组填充 3.1.通道在前[CHW] 3.2.通道在后[ ...
- python Numpy 的基础用法以及 matplotlib 基础图形绘制
python Numpy 的基础用法以及 matplotlib 基础图形绘制 1. 环境搭建 1.1 Anaconda anaconda 集成了数据分析,科学计算相关的所有常用安装包,比如Numo ...
- python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析
两个大小都是\(N \times N\)的矩阵相乘,如果使用naive的算法,时间复杂度应该是\(\mathcal{O}(N^3)\),如果使用一些高级的算法,可以使幂指数降到3以下.对于一般情况的矩 ...
- numpy.random.choice用法
python,numpy中np.random.choice()的用法详解及其参考代码 处理数据时经常需要从数组中随机抽取元素,这时候就需要用到np.random.choice().然而choice用法 ...
- Python:一篇文章掌握Numpy的基本用法
前言 Numpy是一个开源的Python科学计算库,它是python科学计算库的基础库,许多其他著名的科学计算库如Pandas,Scikit-learn等都要用到Numpy库的一些功能. 本文主要内容 ...
- 关于Python里的super用法研究
转自:http://blog.csdn.net/johnsonguo/article/details/585193 虽然我现在没看懂,不过先转一个,以后有时间了再看. 一.问题的发现与提出 在Pyth ...
- numpy.cov()和numpy.var()的用法
在PCA中涉及到了方差var和协方差cov,这里简单总结下. 首先:均值,样本方差,样本协方差的公式为 均值:X¯=1N∑Ni=1Xi\bar{X}=\frac{1}{N} \sum_{i=1}^{N ...
最新文章
- 先定一个小目标,自己封装个ajax
- easyui左侧导航菜单右侧载入百度地图项目框架
- Neko and Aki's Prank
- mybatis批量插入oracle报表达式,mybatis oracle两种方式批量插入数据
- 星辰大海:阿里数据体验技术揭秘!
- Atitit.操作注册表 树形数据库 注册表的历史 java版本类库总结
- java学习(48):带参带返回
- 线程基础知识_线程生命周期_从JVM内存结构看多线程下的共享资源
- 迅雷游戏盒子下载|迅雷游戏盒子下载
- 3399元起!120Hz瞳孔屏+65W超级闪充,一加 8T今日发布
- 做生意失败是一种什么体验?创业中有哪些雷区需要注意?
- Linux下MariaDB 安装及root密码设置(修改)
- 计算机平均值的快捷键,Excel用快捷键和选项求平均值,且能一次对多行多列批量快速求平均值...
- 实现Promise的resolve/reject/then/all/race/finally/catch方法
- Unity 3D课程总结
- 最近遇到使用Zing.DLL生成条码,但是打印出来不清晰的问题,解决代码记录一下,
- HDOJ 4239 - Decoding EDSAC Data 模拟
- R语言中的函数1:outer(张量积)
- 让div中的p标签文字垂直居中的方法
- 海龟如何保留米帝手机号
热门文章
- 校正光学系统像差原则
- Cell | 分子胶水的兴起
- Nat. Commun. | 深度学习探索可编程RNA开关
- JAVA连接SQL Server数据库的端口配置操作步骤
- 2019 年回顾:生物学年
- GEO数据挖掘(1)引出
- Nature 子刊:加州大学Banfield组揭示CPR细菌和DPANN古菌多样性及与低温TEM下宿主互作关系...
- WebMGA:超快的基因组序列聚类注释在线工具
- 回归模型和时间序列模型中的MAPE指标是什么?MAPE指标解读、MAPE越大越好还是越小越好、使用MAPE指标的注意事项
- Python使用openCV把原始彩色图像转化为灰度图、使用OpenCV把图像二值化(仅仅包含黑色和白色的简化版本)、基于自适应阈值预处理(adaptive thresholding)方法