python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析
两个大小都是\(N \times N\)的矩阵相乘,如果使用naive的算法,时间复杂度应该是\(\mathcal{O}(N^3)\),如果使用一些高级的算法,可以使幂指数降到3以下。对于一般情况的矩阵乘法,特别是张量乘法(numpy中的tensordot函数),时间复杂度又如何呢?
二维矩阵乘法
首先规定一下记号:\(\mathbf{A}_{MN}\),表示一个有两个指标,大小是\(M\times N\)的矩阵\(\mathbf{A}\)。那么\(\mathbf{A}_{MN}\mathbf{B}_{NL}\)的时间复杂度是\(\mathcal{O}(MNL)\)。如果我们把乘法的过程用计算机语言表示出来,这一结论就会非常清晰:
1
2
3
4
5C = np.zeros((M, L))
for m in range(M):
for l in range(L):
for n in range(N):
C[m][l] += A[m][n] * B[n][l]
我们也可以简单地验证一下numpy.dot函数是否满足这样的时间复杂度,首先变化\(M\)。为了节省篇幅,一次将其扩大到四倍:
1
2
3
4
5
6
7
8M = 71
N = 513
L = 4097
for i in range(5):
m1 = np.random.random((M, N))
m2 = np.random.random((N, L))
%timeit m1.dot(m2)
M *= 4
输出是:
1
2
3
4
5100loops, best of 3: 6.82 ms per loop
10loops, best of 3: 22.5 ms per loop
10loops, best of 3: 77.5 ms per loop
1loop, best of 3: 304 ms per loop
1loop, best of 3: 1.38 s per loop
可见基本是线性的(耗时一次扩大到四倍)。然后变化\(N\),代码和上面的一段只变了一个字母,输出是:
1
2
3
4
5100loops, best of 3: 6.79 ms per loop
10loops, best of 3: 22.1 ms per loop
10loops, best of 3: 84.4 ms per loop
1loop, best of 3: 329 ms per loop
1loop, best of 3: 1.31 s per loop
仍然基本是线性的。最后变化\(L\),输出是:
1
2
3
4
5100loops, best of 3: 8.42 ms per loop
10loops, best of 3: 43.5 ms per loop
10loops, best of 3: 115 ms per loop
1loop, best of 3: 408 ms per loop
1loop, best of 3: 1.88 s per loop
耗时是三组实验中最长的。结果汇总起来如下图
不难发现,时间与矩阵维度的关系是线性的且斜率为1,所以\(\mathbf{A}_{MN}\mathbf{B}_{NL}\)的时间复杂度是\(\mathcal{O}(MNL)\)。
高维矩阵(张量)乘法-只对一个轴求和
在numpy中dot,einsum,tensordot等函数都可以做高维矩阵乘法,这里只研究最常见的tensordot。我们从\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\)这样一个例子入手。从理论上分析,\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\)的时间复杂度是\(\mathcal{O}(MNLPQ)\),感兴趣的读者可以自己写写代码分析,或者看一看我之前写的一篇博文。这里简单做一下实验,变化\(M\):
1
2
3
4
5
6
7
8
9
10M = 63
N = 17
L = 255
P = 127
Q = 31
for i in range(5):
m1 = np.random.random((M, N, L))
m2 = np.random.random((L, P, Q))
%timeit np.tensordot(m1, m2, 1)
M *= 4
输出是:
1
2
3
4
510loops, best of 3: 47.6 ms per loop
1loop, best of 3: 166 ms per loop
1loop, best of 3: 700 ms per loop
1loop, best of 3: 2.7 s per loop
1loop, best of 3: 11.5 s per loop
而变化\(L\)输出是:
1
2
3
4
510loops, best of 3: 46.3 ms per loop
10loops, best of 3: 116 ms per loop
1loop, best of 3: 368 ms per loop
1loop, best of 3: 1.52 s per loop
1loop, best of 3: 6 s per loop
如图所示:
类似地,耗时与\(M\)和\(L\)都是线性关系,后者速度貌似比前者略快。
高维矩阵(张量)乘法-对多个轴求和
下面我们再考虑对多个轴求和的情况,这种情况下“数学语言”已经不好给出清晰的描述了。如果想举个例子,也只能啰嗦地说:\(\mathbf{A}_{MNL}\)和\(\mathbf{B}_{NLP}\)之间进行双点积contract掉维数为\(N\)和\(L\)的两个指标。倒是计算机语言还算游刃有余:
1
2
3
4
5
6C = np.zeros((M, P))
for m in range(M):
for p in range(P):
for n in range(N):
for l in range(L):
C[m][p] += A[m][n][l] * B[n][l][p]
也容易据此估计出时间复杂度为\(\mathcal{O}(MNLP)\)。实验一下的话,首先试试\(M\):
1
2
3
4
5
6
7
8
9M = 63
N = 31
L = 255
P = 127
for i in range(5):
m1 = np.random.random((M, N, L))
m2 = np.random.random((N, L, P))
%timeit np.tensordot(m1, m2, 2)
M *= 4
输出为:
1
2
3
4
5100loops, best of 3: 2.41 ms per loop
100loops, best of 3: 5.8 ms per loop
10loops, best of 3: 23.2 ms per loop
10loops, best of 3: 171 ms per loop
1loop, best of 3: 817 ms per loop
然后\(N\)和\(L\)分别为:
1
2
3
4
5100loops, best of 3: 2.43 ms per loop
100loops, best of 3: 8.69 ms per loop
10loops, best of 3: 33.7 ms per loop
10loops, best of 3: 138 ms per loop
1loop, best of 3: 560 ms per loop
和
1
2
3
4
5100loops, best of 3: 2.69 ms per loop
100loops, best of 3: 9.01 ms per loop
10loops, best of 3: 36.2 ms per loop
10loops, best of 3: 140 ms per loop
1loop, best of 3: 563 ms per loop
总结起来如图所示:
结语
总结规律的话,要想知道矩阵、张量乘法的时间复杂度,就把两个矩阵、张量所有没contract掉的维度乘起来,再把contract掉的维度两个取一个乘起来即可。举个例子:\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\),没有contract掉的维度乘起来即\(NMPQ\),contract掉的维度有两个\(L\),只取一个,最后合起来就是\(\mathcal{O}(MNLPQ)\)。
这一规律其实很好理解。np.tensordot在实现时实际上是对普通的np.dot的一个包装,进行了一些前处理和后处理。所谓前处理,基本上就是通过转置和合并(np.reshape)把两个参与运算的高阶张量分别变成矩阵,其中一个指标是原张量所有没contract掉的指标组成的,维度自然就是这些指标的维度的积,而另一个指标是原张量要进行contract的指标组成的,维度也是这些指标的维度的积。而后处理,就是将np.dot之后的结果再通过np.reshape变回原来的形状。np.tensordot的代码位于numpy/core/numeric.py中,核心部分如下图所示(NumPy 1.15):
1
2
3
4at = a.transpose(newaxes_a).reshape(newshape_a)
bt = b.transpose(newaxes_b).reshape(newshape_b)
res = dot(at, bt)
return res.reshape(olda + oldb)
其中a和b是调用者传入的要进行tensordot的矩阵,newaxes_a等参数是根据调用者指定的contract规则确定的用于将a或者b变形为适合进行np.dot的参数。得到变形后的at和bt后直接进行dot,再将中间结果reshape回去就得到了最终的结果。所以张量乘法的时间复杂度与矩阵乘法的时间复杂度其实是一回事。
python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析相关推荐
- python矩阵运算_Python矩阵常见运算操作实例总结
本文实例讲述了Python矩阵常见运算操作.分享给大家供大家参考,具体如下: python的numpy库提供矩阵运算的功能,因此我们在需要矩阵运算的时候,需要导入numpy的包. 一.numpy的导入 ...
- python生成魔方矩阵
python生成魔方矩阵 import numpy as npdef magic(n):row, col = 0, n // 2magic = []for i in range(n):magic.ap ...
- python如何创建不同元素的矩阵_Python numpy学习(2)——矩阵的用法
Python矩阵的基本用法 mat()函数将目标数据的类型转化成矩阵(matrix) 1,mat()函数和array()函数的区别 Numpy函数库中存在两种不同的数据类型(矩阵matrix和数组ar ...
- Python中矩阵库Numpy基本操作
NumPy是一个关于矩阵运算的库,熟悉Matlab的都应该清楚,这个库就是让python能够进行矩阵话的操作,而不用去写循环操作. 下面对numpy中的操作进行总结. numpy包含两种基本的数据类 ...
- python中math函数库矩阵_Python中矩阵库Numpy基本操作详解
NumPy是一个关于矩阵运算的库,熟悉Matlab的都应该清楚,这个库就是让python能够进行矩阵话的操作,而不用去写循环操作. 下面对numpy中的操作进行总结. numpy包含两种基本的数据类型 ...
- Python: 向量、矩阵和多维数组(基于NumPy库)
参考文章: 数值 Python: 向量.矩阵和多维数组 Numpy 中的矩阵向量乘法 对NumPy中dot()函数的理解 np.random.rand()函数 numpy.array函数详解 nump ...
- 怎么在python中输入矩阵_如何使用NumPy在Python中实现矩阵?
矩阵被用作数学工具,在现实世界中有多种用途.在本文中,我们将按照以下顺序讨论Python中关于使用著名的NumPy库的矩阵的所有内容:什么是NumPy以及何时使用它?在NumPy 矩阵被用作数学工具, ...
- python矩阵运算实例_Python矩阵常见运算操作实例总结 python 怎么实现矩阵运算
python 怎么查看一个矩阵的维数你是知道的,等你,我已经栖息了疲惫的憧憬,夜夜抚慰残梦的翅膀. 都是复制党,百度知道回答真的质量太低了,真的很心疼,言归正传 利用numpy分享矩阵维数: impo ...
- python矩阵运算实例_Python矩阵常见运算操作实例总结
本文实例讲述了python矩阵常见运算操作.分享给大家供大家参考,具体如下: python的numpy库提供矩阵运算的功能,因此我们在需要矩阵运算的时候,需要导入numpy的包. 一.numpy的导入 ...
最新文章
- 智能+制造,聪明的公司都走上了智能制造的道路
- ubuntu 安装 codelite
- 无线节能组的充电问题
- mysql题目(二学年)
- 长春成人计算机学校有哪些专业学校,长春成人高考学校有哪些
- Java 14 发布了,再也不怕NullPointerException 了!?
- 只要你能想明白一个道理,你也可以在互联网上赚到属于自己的钱
- OpenCV学习(14) 细化算法(2)
- Xmodem、Ymodem和Zmodem协议是最常用的三种通信协议
- 详解CAN 2.0协议
- CHEMKIN III 学习笔记
- 计算机专硕毕业论文写什么,关于学姐写硕士毕业论文的一些经验,分享给大家...
- 国外常见16款著名的实时网站统计系统
- C语言文件重定向---“系统找不到指定的文件”
- c++基础三 (数组——指针)
- Django学习-app创建与注册
- 桌面计算机图标无响应,win7系统电脑鼠标点击桌面图标没反应怎么办【图文】...
- 拖拽功能之水平拖动图片
- python-简单用户登录注册界面实现
- android ----- goldfish内核编译