点击上方“视学算法”,选择加"星标"或“置顶

重磅干货,第一时间送达

作者 | Michael Yuan@知乎(已授权)
来源 | https://zhuanlan.zhihu.com/p/31575074

编辑丨极市平台

导读

在梳理CNN经典模型的过程中,作者理解到其实经典模型演进中的很多创新点都与改善模型计算复杂度紧密相关,因此今天就让我们对卷积神经网络的复杂度分析简单总结一下。

在梳理CNN经典模型的过程中,我理解到其实经典模型演进中的很多创新点都与改善模型计算复杂度紧密相关,因此今天就让我们对卷积神经网络的复杂度分析简单总结一下下。

本文主要关注的是针对模型本身的复杂度分析(其实并不是很复杂啦~)。如果想要进一步评估模型在计算平台上的理论计算性能,则需要了解 Roofline Model 的相关理论,欢迎阅读本文的进阶版: Roofline Model与深度学习模型的性能分析。( 链接:https://zhuanlan.zhihu.com/p/34204282 )

“复杂度分析”其实没有那么复杂啦~

时间复杂度

即模型的运算次数,可用衡量,也就是浮点运算次数(FLoating-point OPerations)。

1.1 单个卷积层的时间复杂度

  • 每个卷积核输出特征图  的边长

  • 每个卷积核 的边长

  • 每个卷积核的通道数,也即输入通道数,也即上一层的输出通道数。

  • 本卷积层具有的卷积核个数,也即输出通道数。

  • 可见,每个卷积层的时间复杂度由输出特征图面积 、卷积核面、输入 和输出通道数完全决定。

  • 其中,输出特征图尺寸本身又由输入矩阵尺寸 、卷积核尺寸 、这四个参数所决定,表示如下:

  • 注1:为了简化表达式中的变量个数,这里统一假设输入和卷积核的形状都是正方形。

  • 注2:严格来讲每层应该还包含 1 个 参数,这里为了简洁就省略了。

1.2 卷积神经网络整体的时间复杂度

  • 神经网络所具有的卷积层数,也即网络的深度。

  • 神经网络第个卷积层

  • 神经网络第个卷积层的输出通道数,也即该层的卷积核个数。

  • 对于第个卷积层而言,其输入通道数就是第个卷积层的输出通道数。

  • 可见,CNN整体的时间复杂度并不神秘,只是所有卷积层的时间复杂度累加而已。

  • 简而言之,层内连乘,层间累加。

示例:用 Numpy 手动简单实现二维卷积

假设 Stride = 1, Padding = 0, img 和 kernel 都是 np.ndarray.

def conv2d(img, kernel):height, width, in_channels = img.shapekernel_height, kernel_width, in_channels, out_channels = kernel.shapeout_height = height - kernel_height + 1out_width = width - kernel_width + 1feature_maps = np.zeros(shape=(out_height, out_width, out_channels))for oc in range(out_channels):              # Iterate out_channels (# of kernels)for h in range(out_height):             # Iterate out_heightfor w in range(out_width):          # Iterate out_widthfor ic in range(in_channels):   # Iterate in_channelspatch = img[h: h + kernel_height, w: w + kernel_width, ic]feature_maps[h, w, oc] += np.sum(patch * kernel[:, :, ic, oc])return feature_maps

空间复杂度

空间复杂度(访存量),严格来讲包括两部分:总参数量 + 各层输出特征图。

  • 参数量:模型所有带参数的层的权重参数总量(即模型体积,下式第一个求和表达式)

  • 特征图:模型在实时运行过程中每层所计算出的输出特征图大小(下式第二个求和表达式)

  • 总参数量只与卷积核的尺寸 、通道数、层数相关,而与输入数据的大小无关。

  • 输出特征图的空间占用比较容易,就是其空间尺寸和通道数的连乘。

  • 注:实际上有些层(例如 ReLU)其实是可以通过原位运算完成的,此时就不用统计输出特征图这一项了。

复杂度对模型的影响

  • 时间复杂度决定了模型的训练/预测时间。如果复杂度过高,则会导致模型训练和预测耗费大量时间,既无法快速的验证想法和改善模型,也无法做到快速的预测。

  • 空间复杂度决定了模型的参数数量。由于维度诅咒的限制,模型的参数越多,训练模型所需的数据量就越大,而现实生活中的数据集通常不会太大,这会导致模型的训练更容易过拟合。

  • 当我们需要裁剪模型时,由于卷积核的空间尺寸通常已经很小(3x3),而网络的深度又与模型的表征能力紧密相关,不宜过多削减,因此模型裁剪通常最先下手的地方就是通道数。

Inception 系列模型是如何优化复杂度的

通过五个小例子说明模型的演进过程中是如何优化复杂度的。

4.1 中的卷积降维同时优化时间复杂度和空间复杂度

(图像被压缩的惨不忍睹...)
  • InceptionV1 借鉴了 Network in Network 的思想,在一个 Inception Module 中构造了四个并行的不同尺寸的卷积/池化模块(上图左),有效的提升了网络的宽度。但是这么做也造成了网络的时间和空间复杂度的激增。对策就是添加 1 x 1 卷积(上图右红色模块)将输入通道数先降到一个较低的值,再进行真正的卷积。

  • 以 InceptionV1 论文中的 (3b) 模块为例(可以点击上图看超级精美的大图),输入尺寸为, 卷积核  个,卷积核  个,卷积核  个,卷积核一律采用 Same Padding 确保输出不改变尺寸。

  • 卷积分支上加入卷积前后的时间复杂度对比如下式:

  • 同理,在卷积分支上加入卷积前后的时间复杂度对比如下式:

  • 可见,使用卷积降维可以降低时间复杂度3倍以上。该层完整的运算量可以在论文中查到,为 300 M,即

  • 另一方面,我们同样可以简单分析一下这一层参数量在使用 1 x 1 卷积前后的变化。可以看到,由于 1 x 1 卷积的添加,3 x 3 和 5 x 5 卷积核的参数量得以降低 4 倍,因此本层的参数量从 1000 K 降低到 300 K 左右。

4.2 中使用  代替

  • 全连接层可以视为一种特殊的卷积层,其卷积核尺寸  与输入矩阵尺寸 一模一样。每个卷积核的输出特征图是一个标量点,即 。复杂度分析如下:

  • 可见,与真正的卷积层不同,全连接层的空间复杂度与输入数据的尺寸密切相关。因此如果输入图像尺寸越大,模型的体积也就会越大,这显然是不可接受的。例如早期的VGG系列模型,其 90% 的参数都耗费在全连接层上。

  • InceptionV1 中使用的全局平均池化 GAP 改善了这个问题。由于每个卷积核输出的特征图在经过全局平均池化后都会直接精炼成一个标量点,因此全连接层的复杂度不再与输入图像尺寸有关,运算量和参数数量都得以大规模削减。复杂度分析如下:

4.3 中使用两个卷积级联替代卷积分支

感受野不变
  • 根据上面提到的二维卷积输入输出尺寸关系公式,可知:对于同一个输入尺寸,单个卷积的输出与两个卷积级联输出的尺寸完全一样,即感受野相同。

  • 同样根据上面提到的复杂度分析公式,可知:这种替换能够非常有效的降低时间和空间复杂度。我们可以把辛辛苦苦省出来的这些复杂度用来提升模型的深度和宽度,使得我们的模型能够在复杂度不变的前提下,具有更大的容量,爽爽的。

  • 同样以 InceptionV1 里的 (3b) 模块为例,替换前后的卷积分支复杂度如下:

4.4 中使用  与卷积级联替代 卷积

  • InceptionV3 中提出了卷积的 Factorization,在确保感受野不变的前提下进一步简化。

  • 复杂度的改善同理可得,不再赘述。

4.5 中使用 

  • 我们之前讨论的都是标准卷积运算,每个卷积核都对输入的所有通道进行卷积。

  • Xception 模型挑战了这个思维定势,它让每个卷积核只负责输入的某一个通道,这就是所谓的 Depth-wise Separable Convolution。

  • 从输入通道的视角看,标准卷积中每个输入通道都会被所有卷积核蹂躏一遍,而 Xception 中每个输入通道只会被对应的一个卷积核扫描,降低了模型的冗余度。

  • 标准卷积与可分离卷积的时间复杂度对比:可以看到本质上是把连乘转化成为相加。

总结

通过上面的推导和经典模型的案例分析,我们可以清楚的看到其实很多创新点都是围绕模型复杂度的优化展开的,其基本逻辑就是乘变加。模型的优化换来了更少的运算次数和更少的参数数量,一方面促使我们能够构建更轻更快的模型(例如MobileNet),一方面促使我们能够构建更深更宽的网络(例如Xception),提升模型的容量,打败各种大怪兽,欧耶~

参考论文

  • https://arxiv.org/abs/1412.1710

  • https://arxiv.org/abs/1409.4842

  • https://arxiv.org/abs/1502.03167

  • https://arxiv.org/abs/1512.00567

  • https://arxiv.org/abs/1610.02357

注:本文主要关注的是针对模型本身的复杂度分析。如果想要进一步评估模型在计算平台上的理论计算性能,则需要了解 Roofline Model 的相关理论,欢迎阅读本文的进阶版: Roofline Model与深度学习模型的性能分析。(文章链接:https://zhuanlan.zhihu.com/p/34204282)

如果觉得有用,就请分享到朋友圈吧!

点个在看 paper不断!

卷积神经网络的复杂度分析相关推荐

  1. 【 卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10)】

    卷积神经网络CNN 数学原理分析与源码详解 深度学习 Pytorch笔记 B站刘二大人(9/10) 本章主要进行卷积神经网络的相关数学原理和pytorch的对应模块进行推导分析 代码也是通过demo实 ...

  2. 用随机场理论和卷积神经网络(CNN)分析边坡可靠度

    近期,深度学习变得越来越火热,用深度学习跨领域研究一些其他领域的问题可以开拓研究思路.比如说用深度学习中的人工神经网络(ANNs)和卷积神经网络(CNN)去分析边坡可靠度的问题.边坡可靠度的研究是在假 ...

  3. 卷积神经网络(CNN)经典模型分析(一)

    CNN经典模型分析

  4. [深度学习-实战篇]情感分析之卷积神经网络-TextCNN,包含代码

    0. 前言 在"卷积神经网络"中我们探究了如何使用二维卷积神经网络来处理二维图像数据.在之前的语言模型和文本分类任务中,我们将文本数据看作是只有一个维度的时间序列,并很自然地使用循 ...

  5. 4. 卷积神经网络CNN

    文章目录 4. 卷积神经网络CNN 4.1 概念 4.1.1 概念 4.1.2 用途 4.2 结构介绍 4.2.1 结构简介 4.2.2 卷积层 1) 基本概念 2) 前期准备 3) 参数共享 4) ...

  6. 【华为云技术分享】序列特征的处理方法之二:基于卷积神经网络方法

    [摘要] 本文介绍了针对序列特征采用的处理方法之二:基于卷积神经网络方法,并分析了为何卷积神经网络擅长对于局部特征的提取. 前言 上一篇文章介绍了基本的基于注意力机制方法对序列特征的处理,这篇主要介绍 ...

  7. 解读:基于图卷积特征的卷积神经网络的股票趋势预测(文末赠书)

    写在前面 下面这篇文章的内容主要是来自2020年发表于Information Science 的一篇文章<A novel graph convolutional feature based co ...

  8. 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】

    深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...

  9. 【卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10)】

    卷积神经网络CNN 实战案例 GoogleNet 实现手写数字识别 源码详解 深度学习 Pytorch笔记 B站刘二大人 (9.5/10) 在上一章已经完成了卷积神经网络的结构分析,并通过各个模块理解 ...

最新文章

  1. 数组之间的计算matlab,MATLAB软件数组的运算
  2. fft谱分析的误差有哪些原因造成的?如何减小分析误差。_回归分析 | 闯荡数据江湖的武功秘籍...
  3. java redirect 跨域_碰到了跨域问题, Redirect is not allowed for a preflight request
  4. error LNK2005: _DllMain@12 already defined的解决办法
  5. 华谊兄弟出现什么问题_什么是语言训练?这就要从语言问题的出现说起了
  6. 跳过 centos部署 webpy的各种坑
  7. JAVA获取别人发过来的json字符串(Post方式)
  8. html模拟右键系统菜单,HTML中自定义右键菜单功能
  9. 一个到顶部自动加载更多的ListView
  10. UDP协议和socketserver以及文件上传
  11. MySQL二进制日志文件格式
  12. 2020宁波银行终面一分钟抽词演讲
  13. H3C华三路由器nat避免生成null 0路由并解决nat需求
  14. 人工智能数学基础8:两个重要极限及夹逼定理
  15. Ubuntu 更改环境变量 PATH
  16. 容错对于游戏体验的重要性
  17. 深度学习课程大纲_MIT深度学习基础-2019视频课程分享
  18. maya2018放大字体及窗口
  19. steam上c语言的游戏,【图片】在steam吧你甚至可以讨论c语言_steam吧_百度贴吧
  20. 金仓数据库 KingbaseES PL/SQL 过程语言参考手册(16. A PL/SQL源文本加密)

热门文章

  1. appium IOS真机测试
  2. 通过 cygwin64 自己编译对应的 Tera Term cyglaunch.exe
  3. 上传图片并生成缩略图
  4. 数据结构与算法:14 Leetcode同步练习(五)
  5. 技术图文:Python描述符 (descriptor) 详解
  6. 【怎样写代码】确保对象的唯一性 -- 单例模式(五):一种更好的单例实现方法(静态内部类)
  7. 【POJ】2503 Babelfish(字典树,map,指针)
  8. 芯片刀片服务器,使用“刀片服务器”其实不难
  9. 又被 AI 抢饭碗?2457 亿参数规模,全球最大中文人工智能巨量模型 “源1.0”正式开源...
  10. 第三届“达观杯”文本智能信息抽取挑战赛丰厚奖金,群英集结,等你来战!...