GPT模型

GPT模型:生成式预训练模型(Generative Pre-Training)

总体结构:

无监督的预训练
有监督的下游任务精调

核心结构:中间部分主要由12个Transformer Decoder的block堆叠而成

下面这张图更直观地反映了模型的整体结构:

模型描述

GPT 使用 Transformer的 Decoder 结构,并对 Transformer Decoder 进行了一些改动,原本的 Decoder 包含了两个 Multi-Head Attention 结构,GPT 只保留了 Mask Multi-Head Attention,如下图所示。
(很多资料上说类似于decoder结构,因为采用了decoder的mask机制,不过抛开这一点,其实感觉和encoder会更像,所以实现时有时反而是调encoder实现 莫烦Python GPT实现代码)

对比原有transformer的结构

阶段描述

预训练阶段:



预训练阶段为文本预测,即根据已有的历史词预测当前时刻的词,7-2,7-3,7-4三个式子对应之前的GPT结构图,输出P(x)为输出,每个词被预测到的概率,再利用7-1式,计算最大似然函数,据此构造损失函数,即可以对该语言模型进行优化。

下游任务精调阶段

损失函数

下游任务与上游任务损失的线性组合

计算过程:

  1. 输入
  2. Embedding
  3. 多层transformer的block
  4. 拿到两个输出端结果
  5. 计算损失
  6. 反向传播
  7. 更新参数

一个具体的GPT实例代码:
可以看到GPT模型的forward函数中,首先进行Embedding操作,然后经过12层transformer的block中进行运算,然后分别经过两个线性变换得到最终计算值(一个用于文本预测,一个用于任务分类器),代码与最开始展示的模型结构图保持一致。
参考:莫烦Python GPT实现代码

下面我们着重关注计算步骤2, 3

计算细节:

【Embedding层】:

查表操作
Embedding层就是以one hot为输入、中间层节点为字向量维数的全连接层。而这个全连接层的参数,就是一个“字向量表”。

one hot型的矩阵相乘,就像是相当于查表,于是它直接用查表作为操作,而不写成矩阵再运算,这大大降低了运算量。再次强调,降低了运算量不是因为词向量的出现,而是因为把one hot型的矩阵运算简化为了查表操作。

【GPT中类似transformer的decoder层】:


每个decoder层包含两个子层

  1. sublayer1: mask的多头注意力层
  2. sublayer2: ffn (feed-forward network)前馈网络(多层感知机)

sublayer1:mask的多头注意力层

输入: q, k, v, mask
计算注意力:Linear(矩阵乘法)→Scaled Dot-Product Attention→Concat(多个注意力的结果, reshape )→Linear(矩阵乘法)

残差连接和归一化操作:Dropout操作→残差连接→层归一化操作

计算过程:

下面这段内容介绍了计算注意力的整体过程:

分解说明:

Mask Multi-head Attention

1.矩阵乘法:

将输入的q,k,v进行变换

2.Scaled Dot-Product Attention

主要就是进行attention的计算以及mask的操作


Mask操作:masked_fill_(mask, value)
掩码操作,用value填充tensor中与mask中值为1位置相对应的元素。mask的形状必须与要填充的tensor形状一致。(这里采用-inf填充,从而softmax之后变成0,相当于看不见后面的词)
transformer中的mask操作

mask后可视化矩阵:
直观理解是每个词只能看到它之前的词(因为目的就是要预测未来的词嘛,要是看到了就不用预测了)

3.Concat操作:

综合多个注意力头的结果,实际上是对矩阵做变换:permute,reshape操作,降维。(如下图红框中所示)

4.矩阵乘法:一个Linear层,对注意力结果线性变换

整个mask多头注意力层的代码

注意到:上述代码中后面几行是对注意力结果进行残差连接和归一化操作
下说明这一过程:

残差连接和归一化操作:

5.Dropout层

6.矩阵加法

7.层归一化

批量归一化是不同训练数据之间对单个神经元的归一化,层归一化是单个训练数据对某一层所有神经元之间的归一化。
输入归一化、批量归一化(BN)与层归一化(LN)

代码展示

sublayer2: ffn (feed-forward network)前馈网络

1.线性层(矩阵乘法)

2.relu函数激活

3.线性层(矩阵乘法)

4.Dropout操作

5.层归一化

【线性层】:

多层block的输出结果放到两个线性层中进行变换,比较简单,不做赘述。

补充:注意力层流程图示

参考资料

1.参考论文:Radford et al. 《Improving Language Undersatnding by Generative Pre-Training"》
2.参考书籍:《自然语言处理 基于预训练模型的方法》车万翔,郭江,崔一鸣
3.本文中代码来源:莫烦Python GPT实现代码
4.其它参考链接(博文中已提到部分):
word embedding计算过程剖析
Transformer的矩阵维度分析和Mask详解

GPT模型总结【模型结构及计算过程_详细说明】相关推荐

  1. 无人驾驶运动学模型——线性时变模型预测控制的思路推演过程_百叶书的博客-CSDN博客_线性时变模型预测控制 转

    无人驾驶运动学模型--线性时变模型预测控制的思路推演过程_百叶书的博客-CSDN博客_线性时变模型预测控制

  2. 详解多分类模型的Macro-F1/Precision/Recall计算过程

    引入 关于准确率(accuracy).精度(precision).查全率(recall).F1的计算过程,之前写过一篇文章[1]. 根据文章[1]中的公式,我们可以知道,精度(precision).查 ...

  3. 推理计算过程_转导推理—Transductive Learning

    在统计学习中,转导推理(Transductive Inference)是一种通过观察特定的训练样本,进而预测特定的测试样本的方法.另一方面,归纳推理(Induction Inference)先从训练样 ...

  4. 弹性均质圆环法计算过程_同济大学地下建筑结构复习要点

    同济大学地下建筑结构复习 1 绪论 1.1简述地下建筑结构的概念及形式 地下建筑结构即埋置于地层内部的结构.包括衬砌结构和内部结构两部分.要考虑地下结构与周围岩土体的共同作用.地下建筑结构的形式主要由 ...

  5. 弹性均质圆环法计算过程_第十章盾构隧道衬砌计算方法综述.ppt

    第十章 盾构隧道衬砌计算方法 10.1 国内外的发展动态-常用模型 盾构隧道的设计模型,多用荷载一结构模型.但由于其断面为圆形,地层结构法对均一地层中单孔圆形隧道也取得了精确的解析解,但其他情况仍须借 ...

  6. 图像sobel梯度详细计算过程_数字图像处理(第十章)

    点.线.边缘检测 背景知识.书中主要介绍了图像的一阶导数与二阶导数,这个之前的文章中有过介绍这里在复习一遍.对于函数 ,对于点 在x方向的一阶偏导为: ,二阶偏导为: 之后书中总结了一阶导与二阶导对于 ...

  7. 图像sobel梯度详细计算过程_视频处理之Sobel【附源码】

    边缘检测是检测图像中的一些像素点,它们周围的像素点的灰度发生了急剧的变化,我们认为在这过程中,图像中的物体不同导致了这一变化,因此可以将这些像素点作为一个集合,可以用来标注图像中不同物体的边界.边缘区 ...

  8. 推理计算过程_初中物理电学计算题第六讲:极值问题推理和限制条件

    初中物理电学计算题第六讲:极值问题推理和限制条件 前面已经讲过:初中物理电学计算题第三讲:串联电路电流电阻极值推理实例,本讲是这一问题的进一步深入讨论. 题型分析 极值问题是电学计算题中一类较难的题目 ...

  9. 直线插补计算过程_【计鹏视角】风速数据插补对发电量的影响

    测风数据在插补时通常通过相关函数实现,相关函数一般采用线性方程函数,线性函数根据不同通道的风速相关性散点图来得到. 不同高度层的相关性散点图是成"带"状分布,相关系数越大,&quo ...

  10. 弹性均质圆环法计算过程_蚝油的加工工艺,蚝油总固形物(水分含量)计算公式,检测方法...

    蚝油是用蚝(牡蛎)熬制而成的调味料.蚝油是由豉(牡蛎干)熬制成的汤,经过滤浓缩后即为蚝油.它是一种营养丰富.味道鲜美的调味佐料.蚝油做法程序繁多,最重要的步骤是用水将鲜蚝煮至理想黏度,做出优质的蚝油应 ...

最新文章

  1. 他24岁,4篇Nature在手,也会关心学不懂C语言怎么办
  2. 95 后大学生利用漏洞免费吃肯德基获刑
  3. Activiti数据库
  4. 量子计算机完整的图片,记者带你走近世界首台超越早期经典计算机的光量子计算机(组图)...
  5. cocos2dx游戏开发——微信打飞机学习笔记(五)——BackgroundLayer的搭建
  6. python2.7更新_centos系统python2.7更新到3.5
  7. LINGO--Error Code 1017
  8. PyQt在qrc文件中添加自定义字体并使用
  9. 浙大PAT考试经验/考前必看/日常刷题总结(经验只写了一点点
  10. LSD_SLAM 单目直接法 半稠密slam 加权LM优化 深度值高斯-高斯分布卡尔曼滤波
  11. 《巴菲特法则》书中的精髓:用好巴菲特企业前景投资法则,股票投资稳赚不赔。
  12. RS485通讯四路模拟量隔离采样模块的功能特点及应用
  13. linux ubi 分区,Linux ubi子系统原理分析
  14. RecSys'22|CARCA:交叉注意力感知上下文和属性进行推荐
  15. Jni调用so动态库
  16. Scratch少儿编程案例-植物大战僵尸-趣味角色版
  17. 《部落冲突:皇室战争》——一款不能错过的游戏!
  18. 当前服务器系统内核版本是多少,linux下如何查看系统和内核版本
  19. IE浏览器打不开jupyter notebook网页的解决办法
  20. sql操作access时出现 MSDTC错误,服务器 'SERVER' 上的 MSDTC 不可用。

热门文章

  1. 7、mysql的redo log、bin log日志
  2. push_back()函数的用法
  3. pr中小人国微缩世界,速度快门的变化,动态地图,手写文字效果,打字机输入文字,照片定格效果
  4. 17产品经理需要具备的领导能力
  5. 我也来谈谈《我不是药神》这部电影
  6. 大数据HBase(十五):HBase的Bulk Load批量加载操作
  7. C++程序员必备知识
  8. 原来姹紫嫣红开遍 -- 牡丹亭·游园惊梦
  9. nw.js html5,nw.js 如何使用?
  10. 现有存储系统技术架构