本文主要译介自Graphcore在2017年1月的这篇博客: Why is so much memory needed for deep neural networks。介绍了深度学习中内存的开销,以及降低内存需求的几种解决方案。

为便于阅读,本文修改了原文分段,并添加更详细的计算说明。

深度学习的内存消耗在哪里?

回顾:简单例子

考虑一个单层线性网络,附带一个激活函数:
h=w1x+w2h=w_1x+w_2h=w1​x+w2​

y=f(h)y=f(h)y=f(h)

代价函数:E=∣∣y−y‾∣∣2E=||y-\overline{y}||^2E=∣∣y−y​∣∣2

在训练时,每一个迭代要记录以下数据:

  • 当前模型参数w1,w2w_1,w_2w1​,w2​
  • 前向运算各层响应:x,h,yx, h, yx,h,y

这样,可以在后向运算中用梯度下降更新参数:
Δw1=η⋅∂E∂w1=η⋅2(y−y‾)⋅f′(h)⋅x\Delta w_1=\eta\cdot \frac{\partial E}{\partial w_1}=\eta \cdot 2(y-\overline{y})\cdot f'(h) \cdot xΔw1​=η⋅∂w1​∂E​=η⋅2(y−y​)⋅f′(h)⋅x

Δw2=η⋅∂E∂w1=η⋅2(y−y‾)⋅f′(h)\Delta w_2=\eta\cdot \frac{\partial E}{\partial w_1}=\eta \cdot 2(y-\overline{y})\cdot f'(h)Δw2​=η⋅∂w1​∂E​=η⋅2(y−y​)⋅f′(h)

内存消耗的三方面

输入数据

很小,不做考量。

256256的彩色图像:25625631 byte= 192KB

模型参数

较大,和模型复杂度有关。

入门级的MNIST识别网络有6.6 million参数,使用32-bit浮点精度,占内存:6.6M * 32 bit = 25MB

50层的ResNet有26 million参数,占内存:26M * 32 bit = 99MB

当然,你可以设计精简的网络来处理很复杂的问题。

各层响应

较大,同样和模型复杂度有关。

50层的ResNet有16 million响应,占内存:16M*32bit = 64MB

响应和模型参数的数量并没有直接关系。卷积层可以有很大尺寸的响应,但只有很少的参数;激活层甚至可以没有参数。

– 这样看起来也不大啊?几百兆而已。
– 往下看。

batch的影响

为了有效利用GPU的SIMD机制,要把数据以mini-batch的形式输入网络。
如果要用32 bit的浮点数填满常见的1024 bit通路,需要32个样本同时计算。

在使用mini-batch时,模型参数依然只保存一份,但各层响应需要按mini-batch大小翻倍。

50层的ResNet,mini-batch=32,各层相应占内存:64MB*32 = 2GB

卷积计算的影响

设H×WH\times WH×W的输入图像为XXX,K×KK\times KK×K的卷积核为RRR,符合我们直觉的卷积是这样计算的。

对每一个输出位置,计算小块对位乘法结果之和。
Y(h,w)=∑Xk,ks(h,w)⊙RY(h,w) = \sum{X^s_{k,k}(h,w) \odot R}Y(h,w)=∑Xk,ks​(h,w)⊙R

h=1:H,w=1:Wh=1:H, w=1:Wh=1:H,w=1:W
其中,Xk,ks(h,w)X^s_{k,k}(h,w)Xk,ks​(h,w)表示输入图像中,以h,wh,wh,w为中心,尺寸为K×KK\times KK×K的子图像。

但是,这种零碎运算很慢

在深度学习库中,一般会采用lowering的方式,把卷积计算转换成矩阵乘法

首先,把输入图像分别平移不同距离,得到K2K^2K2个H×WH\times WH×W的位移图像,串接成H×W×K2H\times W \times K^2H×W×K2的矩阵X‾\overline{X}X。
之后,把K×KK\times KK×K的卷积核按照同样顺序拉伸成K2×1K^2\times 1K2×1的矩阵R‾\overline{R}R
卷积结果通过一次矩阵乘法获得:
Y=X‾⋅R‾Y=\overline{X}\cdot \overline{R}Y=X⋅R

输入输出为多通道时,方法类似,详情参见这篇博客。

在计算此类卷积时,前层响应XXX需要扩大K2K^2K2倍。

50层的ResNet,考虑lowering效应时,各层响应占内存7.5GB

使用低精度不能降内存

为了有效利用SIMD,如果精度降低一倍,batch大小要扩大一倍。不能降低内存消耗。

降内存的有效方法

in-place运算

不开辟新内存,直接重写原有响应。
很多激活函数都可以这样操作。
复杂一些,通过分析整个网络图,可以找出只需要用一次的响应,它可以和后续响应共享内存。例如MxNet的memory sharing机制。

综合运用这种方法,MIT在2016年的这篇论文能够把内存降低两到三倍。

计算换存储

找出那些容易计算的响应结果(例如激活函数层的输出)不与存储,在需要使用的时候临时计算。

使用这种方法,MxNet的这个例子能够把50层的ResNet网络占用的内存减小四倍。

类似地,DeepMind在2016年的这篇论文用RNN处理长度为1000的序列,内存占用降低20倍,计算量增加30%。

百度语音在2016年的这篇论文同样针对RNN,内存占用降低16倍,可以训练100层网络。

当然,还有Graphcore自家的IPU,也通过存储和计算的平衡来节约资源。

Graphcore本身是一家机器学习芯片初创公司,行文中难免夹带私货,请明辨。

【深度学习】为什么深度学习需要大内存?相关推荐

  1. 【深度学习】【物联网】深度解读:深度学习在IoT大数据和流分析中的应用

    作者|Natalie 编辑|Emily AI 前线导读:在物联网时代,大量的感知器每天都在收集并产生着涉及各个领域的数据.由于商业和生活质量提升方面的诉求,应用物联网(IoT)技术对大数据流进行分析是 ...

  2. 【深度学习】深度解读:深度学习在IoT大数据和流分析中的应用

    来源:网络大数据(ID:raincent_com) 摘要:这篇论文对于使用深度学习来改进IoT领域的数据分析和学习方法进行了详细的综述. 在物联网时代,大量的感知器每天都在收集并产生着涉及各个领域的数 ...

  3. Dataset:数据集集合(综合性)——机器学习、深度学习算法中常用数据集大集合(建议收藏,持续更新)

    Dataset:数据集集合(综合性)--机器学习.深度学习算法中常用数据集大集合(建议收藏,持续更新) 目录 常规数据集 各大方向分类数据集汇总 具体数据集分类 相关文章 DL:关于深度学习常用数据集 ...

  4. 【资源放送】机器学习/深度学习最全公开视频大放送!

    文章首发于微信公众号<有三AI> [资源放送]机器学习/深度学习最全公开视频大放送! 该篇小记一下机器学习与深度学习的一些好的基础视频资源. 如果你是刚入门的小白,建议细细阅读一下下面将要 ...

  5. 自学机器学习、深度学习、人工智能学习资源推大聚合

    想要解决如何自学机器学习.深度学习和人工智能这一问题,首先要了解三个概念以及它们之间的关系. 人工智能:人工智能英文缩写为AI,它是研究.开发用于模拟.延伸和扩展人的智能的理论.方法.技术及应用系统的 ...

  6. 深度学习数据集中数据差异大_使用差异隐私来利用大数据并保留隐私

    深度学习数据集中数据差异大 The modern world runs on "big data," the massive data sets used by governmen ...

  7. 过拟合和欠拟合_现代深度学习解决方案中的两大挑战:拟合和欠拟合

    全文共2306字,预计学习时长5分钟 对机器学习模型而言,最糟糕的两种情况无非是构建无用的知识体系,或是从训练数据集中一无所获.在机器学习理论中,这两种现象分别被称为过拟合和欠拟合,是现代深度学习解决 ...

  8. dpg learning 和q_深度学习和强化学习之间的差别有多大?

    我是做深度强化学习的(Deep Reinforcement Learning)的,这个问题有趣.我对@张馨宇他在此问题下的简洁回答非常认同:"可以用深度学习这个工具来做强化学习这个任务,也可 ...

  9. 2017深度学习最新报告及8大主流深度学习框架超详细对比(内含PPT)

    2017深度学习最新报告(PPT) ​ 深度学习领军人物 Yoshua Bengio 主导的蒙特利尔大学深度学习暑期学校目前"深度学习"部分的报告已经全部结束. 本年度作报告的学术 ...

  10. 大疆M100无人机 妙算Manifold 深度学习视觉伺服系统 学习历程(一)妙算Manifold环境配置

    实验室有一架 DJI M100 无人机和若干台 DJI Manifold ,由于与我的研究方向有相关性,因此打算将其利用起来做一些深度学习视觉伺服的开发工作,本系列文章将一些我在学习和研究过程中经历的 ...

最新文章

  1. linux复制压缩文件,Linux如何复制,打包,压缩文件
  2. 服务器硬盘冷迁移后网卡无法启动问题
  3. 【Python】手把手教你用Python做一个图像融合demo,小白可上手!
  4. OS / 总线锁和缓存一致性
  5. html如何设置滑轮效果,HTML中鼠标滚轮事件onmousewheel处理
  6. qq发文件大小上限_微信推出新功能!网友:终于不用转QQ了
  7. eve星战前夜登录提示服务器维护中,EVE星战前夜进不去怎么办 游戏进不去问题解决方法...
  8. 【前端】关于事件的代码片段
  9. 转载:PLSQL中显示Cursor、隐示Cursor、动态Ref Cursor区别
  10. C++_Operator Overloading(运算符重载 | 计算有理数的加减乘除)
  11. 3. Longest Substring Without Repeating Characters
  12. vue学习笔记-6-属性绑定
  13. 因果信号的傅里叶变换_常用信号的傅里叶变换对
  14. crmeb多商户二开crmeb类库二开文档services服务类【5】
  15. ARM Cortex-M0系统简介
  16. springboot酒店客房预定管理系统
  17. flex布局实现骰子六面的示例
  18. The authenticity of host ‘172.16.132.189 (172.16.132.189)‘ can‘t be established.
  19. 如何使用开源合成器Natron入门
  20. ECShop 后台订单列表美化

热门文章

  1. AWS免费账号取消步骤
  2. Python爬虫实践-网易云音乐
  3. mysql所选路径已经存在_mysql安装常见问题解决办法
  4. 高数_第5章常微分方程_二阶微分方程
  5. DStream转换操作
  6. 2021-2027全球与中国USB智能电源板市场现状及未来发展趋势
  7. CVE-2021-3560-POLKIT本地提权漏洞复现
  8. 微信翻译出 Bug 上热搜,程序员又背锅?!
  9. 野火霸道者开发板移植LVGL代码
  10. 《算法图解》读书笔记