在微调GPT/BERT模型时,会经常遇到“ cuda out of memory”的情况。这是因为transformer是内存密集型的模型,并且内存要求也随序列长度而增加。所以如果能对模型的内存要求进行粗略的估计将有助于估计任务所需的资源。

如果你想直接看结果,可以跳到本文最后。不过在阅读本文前请记住所有神经网络都是通过反向传播的方法进行训练的, 这一点对于我们计算内存的占用十分重要。

 total_memory = memory_modal + memory_activations + memory_gradients

这里的memory_modal是指存储模型所有参数所需的内存。memory_activations是计算并存储在正向传播中的中间变量,在计算梯度时需要使用这些变量。因为模型中梯度的数量通常等于中间变量的数量,所以memory_activations= memory_gradients。因此可以写成:

 total_memory = memory_modal + 2 * memory_activations

所以我们计算总体内存的要求时只需要找到memory_modal和memory_activations就可以了。

估算模型的内存

下面我们以GPT为例。GPT由许多transformer块组成(后面我用n_tr_blocks表示其数量)。每个transformer块都包含以下结构:

 multi_headed_attention --> layer_normalization --> MLP -->layer_normalization

每个multi_headed_attention元素都由键,值和查询组成。其中包括n_head个注意力头和dim个维度。MLP是包含有n_head * dim的尺寸。这些权重都是要占用内存的,那么

 memory_modal = memory of multi_headed_attention + memory of MLP= memory of value  + memory of key + memory of query + memory of MLP= square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim) + square_of(n_head * dim)= 4*square_of(n_head * dim)

因为我们的模型包含了n个单元。所以最后内存就变为:

 memory_modal = 4*n_tr_blocks*square_of(n_head * dim)

上面的估算没有考虑到偏差所需的内存,因为这大部分是静态的,不依赖于批大小、输入序列等。

估算中间变量的内存

多头注意力通常使用softmax,可以写成:

 multi_headed_attention = softmax(query * key * sequence_length) * value

k,q,v的维度是:

 [batch_size, n_head, sequence_length, dim]

multi_headed_attention操作会得出如下形状:

 [batch_size, n_head, sequence_length, sequence_length]

所以最终得内存为:

 memory_softmax  = batch_size * n_head * square_of(sequence_length)

q* k * sequence_length操作乘以value的形状为[batch_size, n_head, sequence_length, dim]。MLP也有相同的维度:

 memory of MLP  = batch_size * n_head * sequence_length * dimmemory of value = batch_size * n_head * sequence_length * dim

我们把上面的整合在一起,单个transformer的中间变量为:

 memory_activations = memory_softmax + memory_value + memory_MLP= batch_size * n_head * square_of(sequence_length)+ batch_size * n_head * sequence_length * dim+ batch_size * n_head * sequence_length * dim= batch_size * n_head * sequence_length * (sequence_length + 2*dim)

再乘以块的数量,模型所有的memory_activations就是:

 n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

整合在一起

我们把上面两个公式进行归纳总结,想看结果的话直接看这里就行了。transformer模型所需的总内存为:

 total_memory = memory_modal + 2 * memory_activations

模型参数的内存:

 4*n_tr_blocks*square_of(n_head * dim)

中间变量内存:

 n_tr_blocks * (batch_size * n_head * sequence_length * (sequence_length + 2*dim))

我们使用下面的符号可以更简洁地写出这些公式。

 R = n_tr_blocks = transformer层堆叠的数量N = n_head = 注意力头数量D = dim = 注意力头的维度B = batch_size = 批大小S = sequence_length =输入序列的长度memory modal = 4 * R * N^2 * D^2memory activations = RBNS(S + 2D)

所以在训练模型时总的内存占用为:

 M = (4 * R * N^2 * D^2) + RBNS(S + 2D)

因为内存的占用和序列长度又很大的关系,如果有一个很长的序列长度S >> D S + 2D <——> S,这时可以将计算变为:

 M = (4 * R * N^2 * D^2) + RBNS(S) = 4*R*N^2*D^2 + RBNS^2

可以看到对于较大的序列,M与输入序列长度的平方成正比,与批大小成线性比例,这也就证明了序列长度和内存占用有很大的关系。

所以最终的内存占用的评估为:

 总内存 = ((4 * R * N^2 * D^2) + RBNS(S + 2D)) * float64(以字节为单位)

https://avoid.overfit.cn/post/6724eec842b740d482f73386b1b8b012

作者:Schartz Rehan

如何估算transformer模型的显存大小相关推荐

  1. 如何估算 transformer 模型的显存大小?

    以下文章来源于微信公众号:DeepHub IMBA 作者:Schartz Rehan   文仅分享,侵删 导读 在微调GPT/BERT模型时,会经常遇到" cuda out of memor ...

  2. 模型的显存和参数量计算

    写在前面:以此记录关于模型显存和参数量的一些理解和计算. 首先是"运算量"和"参数量"两个概念: 参数量:这个比较好理解,例如卷积层中的卷积核c_i*k*k*n ...

  3. win8 查看 linux硬盘大小,如何查看显存大小_win8如何查看显存大小

    2017-01-04 13:57:08 你好哦.同时按下电脑键盘的win键(窗口键)和R键,跳出运行窗口,在运行窗口输入dxdiag,然后点击确定,在弹出的窗口点击上方的显示选项卡!注意调出这个对话窗 ...

  4. tensorflow 显存 训练_tensorflow手动指定GPU以及显存大小

    以前我们组就一块显卡,不存在指定设备的问题.近期刚插了一块新的gtx 1080ti,几人公用两块卡来做训练.测试.预测等等,网上找了个方式可以指定使用的设备,并且限定使用的显存大小,还是很有用的,亲测 ...

  5. linux系统显卡显存容量,Linux下检查显存大小

    Linux下检查显存大小 使用 lspci 检查显存大小 首先使用 lspci 命令列出所有 PCI 设备: [root@localhost ~]# lspci 00:00.0 Host bridge ...

  6. 计算机怎么看显卡内存容量,Win10系统显卡显存大小怎样查看?Win10查看显存大小的两种方法...

    对于十分关心电脑配置的用户而言,查看电脑显卡显存大小是一件非常必要的事情.那么,Win10系统电脑该怎样查看显卡显存大小呢?虽然现在有很多软件都可以直接查看,但是不使用软件查看才是真正的好方法.下面, ...

  7. 怎么看计算机内存和独显,电脑独立显卡或集成显卡的显存大小怎么查看?

    我们有时候想要查看自己电脑显卡的显存大小,当然现在的显卡通常有两种,独立显卡和集成显卡,集成显卡是集成到电脑主板上的,怎么查看电脑显卡显存的大小呢?我们可以用很多方法来进行查看,当然通常的集成显卡本身 ...

  8. 设置GPU及显存大小

    20210128 - 引言 之前搜索过设置GPU和显存大小的方式,但是升级了新的版本的keras以及tensorflow,导致之前的代码失效了,这里记录一下. 本质上,就是版本更换的原因,很多api可 ...

  9. 利用CUDA查看多张显卡可用显存和总显存大小

    利用CUDA查看每张显卡上的可用显存大小和总的显存大小,参考了博文1,博文2,主要使用的函数是cudaMemGetInfo(),cudaGetDeviceCount()和cudaSetDevice() ...

最新文章

  1. MySQL 学习笔记(18)— 索引的分类、创建、查看、删除等
  2. 深入浅出统计学 第一章 数据的可视化
  3. 贪心 - 划分字母区间
  4. linux as86,记linux_centOS安装as86过程
  5. 进行有效编辑的七种习惯
  6. mysql客户端centos离线安装_mysql离线安装部署centos
  7. Python基础之补充1
  8. 如何dos中查看当前MySQL版本信息?
  9. Swift基础语法: 25 - Swift的类和结构体
  10. 【mysql】解决MySQL GPG密钥过期问题
  11. java查找目录文件函数_java 实现 文件操作工具集。包括文件、目录树的拷贝、删除、移动、查找等工具函数...
  12. Unity UI事件管理系统设计
  13. python大文件去重_python3 大文件去重
  14. outlook企业邮箱服务器要多少钱,怎么把企业邮箱配置到outlook中
  15. GitHub项目徽章的添加和设置
  16. 直播答题狂撒币,这些“AI开挂神器”如何在10秒内算出正确答案?
  17. RTX30系列-Ubuntu系统配置与深度学习环境Pytorch配置
  18. iOS开发:图标生成器Prepo 的使用,讲的明明白白
  19. 全网最全的Kali工具大全
  20. del服务器能装win7系统吗,500系列主板能不能装win7?500系列主板装win7教程(支持11代)...

热门文章

  1. 医咖会免费STATA教程学习笔记——多元线性回归
  2. MQTT服务器搭建与试用,桌面工具连接MQTT服务器
  3. SpringCloud微服务-----面试内容
  4. [C语言编程练习][09]编写一个程序,提示用户输入名和姓,并执行以下操作
  5. 又把BLOG捡起来~~
  6. Ubuntu修改swap分区空间大小
  7. 弘辽科技:2020全民参与直播电商模式,新时代的营销战场
  8. 路边烟酒店明明没什么客人,为什么一年还能赚几十万?原因很简单
  9. java interface和impl,被测单元:Impl还是接口?
  10. explode()函数的使用总结