©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

看了标题,可能读者会有疑惑,大家不都想着将大模型缩小吗?怎么你想着将小模型放大了?其实背景是这样的:通常来说更大的模型加更多的数据确实能起得更好的效果,然而算力有限的情况下,从零预训练一个大的模型时间成本太大了,如果还要调试几次参数,那么可能几个月就过去了。

这时候“穷人思维”就冒出来了(土豪可以无视):能否先训练一个同样层数的小模型,然后放大后继续训练? 这样一来,预训练后的小模型权重经过放大后,就是大模型一个起点很高的初始化权重,那么大模型阶段的训练步数就可以减少了,从而缩短整体的训练时间。

那么,小模型可以无损地放大为一个大模型吗?本文就来从理论上分析这个问题。

含义

有的读者可能想到:这肯定可以呀,大模型的拟合能力肯定大于小模型呀。的确,从拟合能力角度来看,这件事肯定是可以办到的,但这还不是本文关心的“无损放大”的全部。

以 BERT 为例,预训练阶段主要就是一个 MLM 模型,那么“无损放大”的含义就是:是否可以通过某种变换,把一个小模型直接变换成一个大模型,并且输出完全不改变?

这里的变换,指的是对权重做一些确定性的变换,而不用通过梯度下降来继续训练;输出完全不改变,指的是对于同一个输入,小模型和大模型给出的预测结果是完全一致的,也就是说它们表面上看起来不一样,但数学上它们是完全一致的函数,所以称为“无损放大”。由于是无损放大,我们至少可以保证大模型不差于小模型,所以继续预训练理论上有正的收益。至于先小后大这样预训练在效果上能不能比得上一开始就从大训练,这个需要实验来确定,并不是本文关心的问题。

直觉来想,这种放大也不困难,比如通过“重复”、“补零”等操作就可以实现模型权重的自然放大。事实上尝试的方向也是如此,但难点在于我们需要仔细分析模型的每一个模块在被放大之后所产生的后果,以确保最终的结果是无损的。

尝试

下面我们以“将一个 BERT 放大为 2 倍”为例子进行分析尝试,来确定最终的变换形式。这里的“放大”指的是仅仅扩大隐层向量的维度,并不改变模型的层数,也不改变多头注意力机制的头数。

2.1 Embedding

首先,输入层是 Embedding 层,因此先要解决的是 Embedding 层的放大问题。这也是其中最简单的一环,就是直接将每个 token 的向量维度都放大为 2 倍即可,主要就是“重复”、“补零”两种操作:

两种方案都可以作为候选方案,但直觉上来想,补零这种方式引入了太多的零,会导致过度稀疏和同一个值重复次数过多,不利于权重的多样性,因此我们还是选择了重复这种方案。不过,就算只看重复,也不指上述一种方式,比如 也是一种方案,但后面关于 Attention 层的分析表明,后一种方案是不可取的。

除此之外,我们通常还希望变换是正交的,这通常能最大程度上保证模型的稳定性,具体来说,正交变换的最基本性质是不改变向量的模型,所以我们将最终的重复变换调整为:

或者简记成


































,其中





是上取整运算,我们称之为“重复再除以









”。

2.2 LayerNorm

Embedding 的下一层就是 LayerNorm 了,变换前,LayerNorm 的运算为:

而变换后,我们有:

这也就是说,“减均值除以标准差”这一步自动帮我们消去了













这个因子,其结果是放大前结果的直接重复。如果我们将参数向量






也按照公式(2)进行变换,那么结果将是


































,跟 Embedding 层的变换结果一致,而我们就是要尽量使得每一层“净变换”都是同样的一个简单变换:“重复再除以









2.3 FeedForward

按照顺序,接下来本来应该分析 Attention 层才对,不过 FeedForward 层相对简单一点,并且 FeedForward 层的分析结果也对后面理解 Attention 层的变换有所帮助,因此这里先来考虑 FeedForward 层的变换。

FeedForward 层只是两个全连接层的复合,所以我们只需要分析单个全连接层:

这里的









是激活函数。鉴于之前的经验,我们尝试如下变换:

也就是将







按照式(2)进行变换,而对于











则尝试使用形式下述变换:

这里的 D 就是输出维度大小,这里我们假设模型放大 2 倍后,D 也放大 2 倍。不难看出,该变换其实就是对变换矩阵











行列两个方向都分别执行变换(2)。此时:

这说明变换(6)对于线性变换层来说,能够满足我们的理想追求——放大后的结果就是“重复再除以









”。然而,这还不够,因为全连接层还有个激活函数









,现在的问题在于


















未必等于


















,而如果不等,我们就没法让整体的变换等价于“重复再除以









”。

事实上,BERT 用的 GeLU 激活函数就不满足该恒等式;线性激活函数(不加激活函数)显然是满足这个等式的,而满足这个等式一个常见的非线性激活函数便是 ReLU(也包括 LeakyReLU)函数,因此一个直接的解决方式就是 FeedForward 层换用 ReLU 激活函数。事实上,这也已经是预训练模型的一个常见选择了,百度的 Ernie 和 Google 的 T5 模型,它们的 FeedForward 层激活函数都是用 ReLU。

那么,像 BERT 这样的非 ReLU 激活函数的 FeedForward 层就没办法了吗?那也不至于,因为 FeedForward 层是两个全连接层的复合,我们只需要在变换第一个全连接的时候少除以一个









,变换第二个全连接的时候多除以一个









就行了。具体来说,第一个全连接权重变为:

此时就有:

此时结果就是原结果的直接重复,没有除以









,既然如此,后面紧接着的全连接层多除以一个









就行了,即后面的全连接层权重变换为:

这样整个 FeedForward 层的效果就等价于“重复再除以









”了。

2.4 Attention

现在到了最难啃的“硬骨头”——Attention 层的变换。Attention 层首先通过三个线性层将每个输入向量变换为 q,k,v:

根据前面对 FeedForward 层的分析可以得知,如果要想 q,k,v 都达到“重复再除以









”的效果,只需要按照变换(6)进行。但 Attention 层不是单纯的全连接层,变换完之后,我们要检查 Attention 矩阵是否不变,我们来算内积:

其中 d' 是对应的 head_size。这个结果告诉我们,上述变换保持了内积不变,所以应该也保持 Attention 矩阵不变。但是,这里有一个陷阱!如果是 T5 这样的模型,它的内积之后是没有尺度缩放的,所以这样的确完事了;然而像 BERT 这样的模型,它是内积之后除了个












再做 Softmax 的,,而一旦放大模型后,除以












变成了除以













,内积不变也不能保持 Attention 矩阵不变,而应当还需要往 q,k 的权重分别再乘个








,所以最终的变换应该是:

经过这样变换后,Attention 矩阵不变,而


































,所以最终的输出结果也是


































上述内容只是针对 Attention 的单个头进行分析,事实上 Attention 有多个头,多个头的输出结果还要拼接起来再接一个全连接层。当然,由于每个头都是平等的、独立的,因此上述结论基本不变,最后全连接层也只需要按照式(6)进行变换,就可以让 Attention 的变换效果。但是,多头带来的一个效应是,我们在重复的时候,必须局部地进行重复。

具体来说,我们在实现多头的时候,并非是真的做了多个全连接运算,而是做了一个大的全连接运算后再 reshape,这样一来我们可以比较两种不同的重复方式的 reshape 结果:

注意放大前 reshape 结果是 ,所以对比两种不同的重复方式的 reshape 结果,我们发现第二种重复方式 reshape 之后的结果全乱了,不等价于每个头分别重复。因此我们只能选择前一种重复方式。

2.5 输出概率分布

通过以上分析,我们可以使得整个 Encoder 在放大到 2 倍之后,实现“重复再除以









”的效果。最后剩下的就是输出部分,即将 Encoder 的输出向量转化为 token 的概率分布,这里边包含几种情况。

像 GPT、T5 等模型,它们是直接在 Encoder 输出后面乘以了 Embedding 矩阵的转置来做作为概率分布的 logits(当然有可能还有个偏置),由于 Embedding 矩阵本身就包含了“重复再除以









”的操作,而 Encoder 的输出也是“重复再除以









”,两者结合刚好抵消,所以从概率分布角度看,输出是完全不变的。

不过 BERT 多了一层全连接,也就是说它先接了一个 GeLU 激活的全连接层,然后才乘以 Embedding 矩阵的转置并加上偏置项作为 logitis。在“FeedForward”那一节我们已经讨论了,非 ReLU 激活的全连接层无法实现“重复再除以









”的效果,而只能通过变换(9)来实现单纯的“重复”效果,这时候乘以 Embedding 矩阵的转置的话,得到的是原来的 logits 乘以









的效果,输出会有所改变。

当然,由于只是乘以了一个常数倍,所以分布虽然改变了,但是每个 token 的概率大小顺序并没有改变,这也就意味着,如果只看 MLM 的准确率,那么是完全没有改变的,所以问题应该不大。

当然,如果是 ReLU 激活,那么按照式(6)来变换,那么可以实现完全不改变了。此外,如果是像 mT5 那样,最后转为 logits 的变换矩阵跟 Embedding 层不共享,那么可以同时调整最后的变换矩阵,也能实现输出的完全不变。

2.6 RoPE位置编码

前面的分析都只适用于每个神经元都是不相关的情形,也就是说向量的任意两个分量












是没啥关联的。但如果我们在模型中用了“旋转式位置编码(RoPE)”,那么这个假设就不成立了,因为 RoPE 是以每两个分量为一组进行变换的,即














为一组、














为一组,依此类推。

如果还是按照之前式(2)进行重复变换,那么变换之后就变成了














为一组、














为一组、...,跟原来的分组不一致,所以会带来很大的偏差。这种情况下,重复的时候也应当按照两个为一组来进行:

当然,由于默认的 RoPE 是没有可训练权重的,它是按照固定的方式进行渐变的,所以哪怕按照该方式进行重复,那不能完全保证结果一致。也就是说,如果使用了 RoPE,那么基本上不能实现无损放大。不过实际测试结果表明,按照该方式进行重复放大后,对应的 RoFormer 虽然性能有所损失,但不多,可以很快通过继续训练恢复。

结论

现在我们可以确认,对于 BERT 来说,如果非线性激活函数用 ReLU,那么 BERT 是可以直接无损放大的,如果非线性激活函数不是 ReLU,那么可以实现 MLM 准确率无损的放大(事实上经过更精细的调整,也可以实现完全无损放大,但每个层的变换有点不统一了,不够优雅);对于 GPT、T5 等模型来说,不管激活函数用啥(包括 mT5 用的 GLU 激活,也可以定制适当),其实都可以实现无损放大。

其中,将 BERT 权重进行放大为 2 倍的变换汇总如下:

如果是其他略有不同的模型,那么就模仿前面的思想进行类似的分析即可。如果是 RoPE,那么将重复的方案改为式(15)就好;如果是扩大k倍,那么将表格中的多数 2 换为 k 就好。简单来说,如果 Attention 没有尺度缩放(除以












),以及 FeedForward 的激活函数是 ReLU(或者 LeakyReLU),那么放大 k 倍的变换就最简单的,将权重的每一维都执行“重复 k 次并除以









”就好了。

小结

本文从数学上分析了直接放大 Transformer 模型的可能性,得到了若干可用的变换,在部分情况下可以无损放大 Transformer 模型,另外一些情况则可以将损失降到很小(比如保持 MLM 的准确率完全不变)。而研究 Transformer 模型的无损放大操作,可以为我们实现渐进式地训练大模型提供参考思路。

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

???? 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

???? 投稿通道:

• 投稿邮箱:hr@paperweekly.site

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

△长按添加PaperWeekly小编

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

我们可以无损放大一个Transformer模型吗?相关推荐

  1. 构建Transformer模型 | 在wikiText-2数据集上训练一个语言模型

    0 Introduction 自然语言处理通用解决方案 需要熟悉word2Vec, 了解词向量如何建模 重点在于Transformer网络架构,BERT训练方法,实际应用 开源项目,都是现成的,套用进 ...

  2. Transformer模型总结

    Transformer改进了RNN最被人诟病的训练慢的缺点,利用self-attention机制实现快速并行. 它是由编码组件.解码组件和它们之间的连接组成. 编码组件部分由一堆编码器(6个 enco ...

  3. Paper:Transformer模型起源—2017年的Google机器翻译团队—《Transformer:Attention Is All You Need》翻译并解读

    Paper:Transformer模型起源-2017年的Google机器翻译团队-<Transformer:Attention Is All You Need>翻译并解读 目录 <T ...

  4. 基于Tensorflow实现一个Transformer翻译器

    Transformer是谷歌在2017年的一篇论文"Attention is all you need"提出的一个seq2seq的模型架构,其创造性的提出了自注意力的思想,可以很好 ...

  5. 一个既能做CV任务,也能做NLP任务的Transformer模型!谷歌UCLA提出统一的基础模型...

    关注公众号,发现CV技术之美 本文分享论文『Towards a Unified Foundation Model: Jointly Pre-Training Transformers on Unpair ...

  6. ​怎么把图片无损放大?分享一个图片无损放大小妙招

    怎么把图片无损放大呢?无损放大功能是指在不损失图像质量的情况下将图片放大,这意味着即使将图片放大到很大的尺寸,也不会出现模糊或失真的情况.在我们制作海报.PPT演示.社交媒体分享等方面,如果我们使用的 ...

  7. 图片如何无损放大?分享一个简单好用的工具

    图片如何无损放大?很多朋友在日常生活或者工作过程中,都需要处理一些图片,但有时候我们找到的图片大小可能无法满足我们的需要,这时候就需要放大处理,但常规的方法放大后图片变得模糊不堪,根本无法使用,怎么才 ...

  8. 每日一个小技巧:1招教你怎么将照片无损放大

    照片是一种记录.分享和保存记忆的重要方式.它可以记录特殊的时刻和经历,如毕业典礼.婚礼.旅游等,为我们锁住美好回忆.不知道大家有没有经历过,在手机或者电脑上打开一张拍摄的照片,却发现它的尺寸太小了,手 ...

  9. 视频分辨率无损放大软件 Topaz Video Enhance AI 2.3.0

    视频分辨率无损放大软件 Topaz Video Enhance AI 2.3.0 Topaz Video Enhance AI是一款非常好用的视频分辨率放大软件,用户可以通过这款软件将视频的分辨率进行 ...

最新文章

  1. python从零实习深度学习_月薪45K的深度程序员教你从零在Python中开发深度学习
  2. Linux记录-普通用户下执行sudo xxx 找不到命令解决方案
  3. 苹果新的编程语言 Swift 语言进阶(一)--综述
  4. 2020年,.NET Core起飞在即,最强日志分析ELK还不会?
  5. python实战扫码下载_实例:用 Python 做一个扫码工具
  6. linux安装dev命令,Linux安装与基础命令
  7. 联想台式计算机驱动程序,联想台式机网卡驱动,详细教您联想台式机网卡驱动...
  8. linux系统dc模拟器,wine(linux模拟器)
  9. win7_64位下部署Apache+Mysql5.7.19+Php7+Snipe-IT
  10. 求助华为HG8321R光猫这样还有救吗
  11. 【滤波】概率、高斯和贝叶斯
  12. Spring Cloud:熔断器Hystrix
  13. 暗影精灵4适合计算机专业,暗影精灵4什么时候出?今日发布,为专业电竞而生...
  14. Spring Cloud Task 主要是干什么的啊?跟 Quartz 和 Spring Task 有啥关系?
  15. 个人电脑安全防范措施
  16. 逐鹿量子计算,“先导杯”向世界难题发起冲击!
  17. docsify部署静态文件服务器,云开发 Docsify 文档部署
  18. python如何控制伺服驱动_在控制伺服电机的驱动中,控制器和驱动器各有什么功能和作用?...
  19. 怎么判断间隙过渡过盈配合_[判断题] 配合有间隙配合、过盈配合和过渡配合三种...
  20. 【编译原理】复习总结

热门文章

  1. java 获取三天前时间_java 获取前几天时间
  2. android 异步线程的使用
  3. android http pos 请求和gson解析处理head头信息
  4. 转自知乎大神----JS 的 new 到底是干什么的?
  5. Java并发编程--CountDownLatch
  6. G少爷上证技术分析 8月31日
  7. linux内核合并dtb文件,c – 如何修改内核DTB文件
  8. 机器人聊天软件c#_使用python3.7配置开发钉钉群自定义机器人(2020年新版攻略)
  9. 爬虫模拟登陆手机验证码_网络爬虫干货总结,这次比较全面!
  10. spirngmvc如何实现直接输入网页重定向到登录_Python 模拟新浪微博登录