两个大小都是\(N \times N\)的矩阵相乘,如果使用naive的算法,时间复杂度应该是\(\mathcal{O}(N^3)\),如果使用一些高级的算法,可以使幂指数降到3以下。对于一般情况的矩阵乘法,特别是张量乘法(numpy中的tensordot函数),时间复杂度又如何呢?

二维矩阵乘法

首先规定一下记号:\(\mathbf{A}_{MN}\),表示一个有两个指标,大小是\(M\times N\)的矩阵\(\mathbf{A}\)。那么\(\mathbf{A}_{MN}\mathbf{B}_{NL}\)的时间复杂度是\(\mathcal{O}(MNL)\)。如果我们把乘法的过程用计算机语言表示出来,这一结论就会非常清晰:

1

2

3

4

5C = np.zeros((M, L))

for m in range(M):

for l in range(L):

for n in range(N):

C[m][l] += A[m][n] * B[n][l]

我们也可以简单地验证一下numpy.dot函数是否满足这样的时间复杂度,首先变化\(M\)。为了节省篇幅,一次将其扩大到四倍:

1

2

3

4

5

6

7

8M = 71

N = 513

L = 4097

for i in range(5):

m1 = np.random.random((M, N))

m2 = np.random.random((N, L))

%timeit m1.dot(m2)

M *= 4

输出是:

1

2

3

4

5100loops, best of 3: 6.82 ms per loop

10loops, best of 3: 22.5 ms per loop

10loops, best of 3: 77.5 ms per loop

1loop, best of 3: 304 ms per loop

1loop, best of 3: 1.38 s per loop

可见基本是线性的(耗时一次扩大到四倍)。然后变化\(N\),代码和上面的一段只变了一个字母,输出是:

1

2

3

4

5100loops, best of 3: 6.79 ms per loop

10loops, best of 3: 22.1 ms per loop

10loops, best of 3: 84.4 ms per loop

1loop, best of 3: 329 ms per loop

1loop, best of 3: 1.31 s per loop

仍然基本是线性的。最后变化\(L\),输出是:

1

2

3

4

5100loops, best of 3: 8.42 ms per loop

10loops, best of 3: 43.5 ms per loop

10loops, best of 3: 115 ms per loop

1loop, best of 3: 408 ms per loop

1loop, best of 3: 1.88 s per loop

耗时是三组实验中最长的。结果汇总起来如下图

不难发现,时间与矩阵维度的关系是线性的且斜率为1,所以\(\mathbf{A}_{MN}\mathbf{B}_{NL}\)的时间复杂度是\(\mathcal{O}(MNL)\)。

高维矩阵(张量)乘法-只对一个轴求和

在numpy中dot,einsum,tensordot等函数都可以做高维矩阵乘法,这里只研究最常见的tensordot。我们从\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\)这样一个例子入手。从理论上分析,\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\)的时间复杂度是\(\mathcal{O}(MNLPQ)\),感兴趣的读者可以自己写写代码分析,或者看一看我之前写的一篇博文。这里简单做一下实验,变化\(M\):

1

2

3

4

5

6

7

8

9

10M = 63

N = 17

L = 255

P = 127

Q = 31

for i in range(5):

m1 = np.random.random((M, N, L))

m2 = np.random.random((L, P, Q))

%timeit np.tensordot(m1, m2, 1)

M *= 4

输出是:

1

2

3

4

510loops, best of 3: 47.6 ms per loop

1loop, best of 3: 166 ms per loop

1loop, best of 3: 700 ms per loop

1loop, best of 3: 2.7 s per loop

1loop, best of 3: 11.5 s per loop

而变化\(L\)输出是:

1

2

3

4

510loops, best of 3: 46.3 ms per loop

10loops, best of 3: 116 ms per loop

1loop, best of 3: 368 ms per loop

1loop, best of 3: 1.52 s per loop

1loop, best of 3: 6 s per loop

如图所示:

类似地,耗时与\(M\)和\(L\)都是线性关系,后者速度貌似比前者略快。

高维矩阵(张量)乘法-对多个轴求和

下面我们再考虑对多个轴求和的情况,这种情况下“数学语言”已经不好给出清晰的描述了。如果想举个例子,也只能啰嗦地说:\(\mathbf{A}_{MNL}\)和\(\mathbf{B}_{NLP}\)之间进行双点积contract掉维数为\(N\)和\(L\)的两个指标。倒是计算机语言还算游刃有余:

1

2

3

4

5

6C = np.zeros((M, P))

for m in range(M):

for p in range(P):

for n in range(N):

for l in range(L):

C[m][p] += A[m][n][l] * B[n][l][p]

也容易据此估计出时间复杂度为\(\mathcal{O}(MNLP)\)。实验一下的话,首先试试\(M\):

1

2

3

4

5

6

7

8

9M = 63

N = 31

L = 255

P = 127

for i in range(5):

m1 = np.random.random((M, N, L))

m2 = np.random.random((N, L, P))

%timeit np.tensordot(m1, m2, 2)

M *= 4

输出为:

1

2

3

4

5100loops, best of 3: 2.41 ms per loop

100loops, best of 3: 5.8 ms per loop

10loops, best of 3: 23.2 ms per loop

10loops, best of 3: 171 ms per loop

1loop, best of 3: 817 ms per loop

然后\(N\)和\(L\)分别为:

1

2

3

4

5100loops, best of 3: 2.43 ms per loop

100loops, best of 3: 8.69 ms per loop

10loops, best of 3: 33.7 ms per loop

10loops, best of 3: 138 ms per loop

1loop, best of 3: 560 ms per loop

1

2

3

4

5100loops, best of 3: 2.69 ms per loop

100loops, best of 3: 9.01 ms per loop

10loops, best of 3: 36.2 ms per loop

10loops, best of 3: 140 ms per loop

1loop, best of 3: 563 ms per loop

总结起来如图所示:

结语

总结规律的话,要想知道矩阵、张量乘法的时间复杂度,就把两个矩阵、张量所有没contract掉的维度乘起来,再把contract掉的维度两个取一个乘起来即可。举个例子:\(\mathbf{A}_{MNL}\mathbf{B}_{LPQ}\),没有contract掉的维度乘起来即\(NMPQ\),contract掉的维度有两个\(L\),只取一个,最后合起来就是\(\mathcal{O}(MNLPQ)\)。

这一规律其实很好理解。np.tensordot在实现时实际上是对普通的np.dot的一个包装,进行了一些前处理和后处理。所谓前处理,基本上就是通过转置和合并(np.reshape)把两个参与运算的高阶张量分别变成矩阵,其中一个指标是原张量所有没contract掉的指标组成的,维度自然就是这些指标的维度的积,而另一个指标是原张量要进行contract的指标组成的,维度也是这些指标的维度的积。而后处理,就是将np.dot之后的结果再通过np.reshape变回原来的形状。np.tensordot的代码位于numpy/core/numeric.py中,核心部分如下图所示(NumPy 1.15):

1

2

3

4at = a.transpose(newaxes_a).reshape(newshape_a)

bt = b.transpose(newaxes_b).reshape(newshape_b)

res = dot(at, bt)

return res.reshape(olda + oldb)

其中a和b是调用者传入的要进行tensordot的矩阵,newaxes_a等参数是根据调用者指定的contract规则确定的用于将a或者b变形为适合进行np.dot的参数。得到变形后的at和bt后直接进行dot,再将中间结果reshape回去就得到了最终的结果。所以张量乘法的时间复杂度与矩阵乘法的时间复杂度其实是一回事。

python矩阵运算dot_矩阵、张量乘法(numpy.tensordot)的时间复杂度分析相关推荐

  1. python矩阵运算_Python矩阵常见运算操作实例总结

    本文实例讲述了Python矩阵常见运算操作.分享给大家供大家参考,具体如下: python的numpy库提供矩阵运算的功能,因此我们在需要矩阵运算的时候,需要导入numpy的包. 一.numpy的导入 ...

  2. python生成魔方矩阵

    python生成魔方矩阵 import numpy as npdef magic(n):row, col = 0, n // 2magic = []for i in range(n):magic.ap ...

  3. python如何创建不同元素的矩阵_Python numpy学习(2)——矩阵的用法

    Python矩阵的基本用法 mat()函数将目标数据的类型转化成矩阵(matrix) 1,mat()函数和array()函数的区别 Numpy函数库中存在两种不同的数据类型(矩阵matrix和数组ar ...

  4. Python中矩阵库Numpy基本操作

    NumPy是一个关于矩阵运算的库,熟悉Matlab的都应该清楚,这个库就是让python能够进行矩阵话的操作,而不用去写循环操作. 下面对numpy中的操作进行总结.  numpy包含两种基本的数据类 ...

  5. python中math函数库矩阵_Python中矩阵库Numpy基本操作详解

    NumPy是一个关于矩阵运算的库,熟悉Matlab的都应该清楚,这个库就是让python能够进行矩阵话的操作,而不用去写循环操作. 下面对numpy中的操作进行总结. numpy包含两种基本的数据类型 ...

  6. Python: 向量、矩阵和多维数组(基于NumPy库)

    参考文章: 数值 Python: 向量.矩阵和多维数组 Numpy 中的矩阵向量乘法 对NumPy中dot()函数的理解 np.random.rand()函数 numpy.array函数详解 nump ...

  7. 怎么在python中输入矩阵_如何使用NumPy在Python中实现矩阵?

    矩阵被用作数学工具,在现实世界中有多种用途.在本文中,我们将按照以下顺序讨论Python中关于使用著名的NumPy库的矩阵的所有内容:什么是NumPy以及何时使用它?在NumPy 矩阵被用作数学工具, ...

  8. python矩阵运算实例_Python矩阵常见运算操作实例总结 python 怎么实现矩阵运算

    python 怎么查看一个矩阵的维数你是知道的,等你,我已经栖息了疲惫的憧憬,夜夜抚慰残梦的翅膀. 都是复制党,百度知道回答真的质量太低了,真的很心疼,言归正传 利用numpy分享矩阵维数: impo ...

  9. python矩阵运算实例_Python矩阵常见运算操作实例总结

    本文实例讲述了python矩阵常见运算操作.分享给大家供大家参考,具体如下: python的numpy库提供矩阵运算的功能,因此我们在需要矩阵运算的时候,需要导入numpy的包. 一.numpy的导入 ...

最新文章

  1. 智能+制造,聪明的公司都走上了智能制造的道路
  2. ubuntu 安装 codelite
  3. 无线节能组的充电问题
  4. mysql题目(二学年)
  5. 长春成人计算机学校有哪些专业学校,长春成人高考学校有哪些
  6. Java 14 发布了,再也不怕NullPointerException 了!?
  7. 只要你能想明白一个道理,你也可以在互联网上赚到属于自己的钱
  8. OpenCV学习(14) 细化算法(2)
  9. Xmodem、Ymodem和Zmodem协议是最常用的三种通信协议
  10. 详解CAN 2.0协议
  11. CHEMKIN III 学习笔记
  12. 计算机专硕毕业论文写什么,关于学姐写硕士毕业论文的一些经验,分享给大家...
  13. 国外常见16款著名的实时网站统计系统
  14. C语言文件重定向---“系统找不到指定的文件”
  15. c++基础三 (数组——指针)
  16. Django学习-app创建与注册
  17. 桌面计算机图标无响应,win7系统电脑鼠标点击桌面图标没反应怎么办【图文】...
  18. 拖拽功能之水平拖动图片
  19. python-简单用户登录注册界面实现
  20. android ----- goldfish内核编译

热门文章

  1. 双11怎么那么强!之二:浅析淘宝网络通信库tbnet的实现
  2. 疫情当前,宅家学习不无聊,AI视频课程资源盘点
  3. 【python 学习】知识点日记
  4. sonarqube下载地址
  5. maven报错: 错误的类文件:… 类文件具有错误的版本 52.0,应为 54.0
  6. JVM调优:heap dump信息分析
  7. scala reduceLeft和reduceRight执行分析
  8. redis hash数据类型常用命令
  9. CPU乱序执行(指令重排序)
  10. Spring Cloud构建微服务架构:消息驱动的微服务(入门)【Dalston版】