Transformer:

  • 模型结构:

    • Encoder:

      • Position Embedding:引入位置信息,DNN结构,默认没有考虑位置信息
      • Multi-head Self-attention:模型中计算量最大的部分,Head(头)的数量是8,任意两两字符之间,计算相关性。
      • LayerNorm & Residual:层归一化,和残差连接
      • Feedforward Nenual Network:Self-attention位置混合,FFN通道混合,类似通道分离卷积(空间混合)和1x1卷积(通道混合),第1层2028,第2层512
    • Decoder:Teacher Forcing
      • Casual(因果) Multi-head Self-attention,下三角矩阵
      • Memory-base Multi-head Cross-attention,Decoder是Query,Encoder是Key和Value,Key是转置
  • 使用类型:
    • Encoder only:BERT、分类任务、非流式任务
    • Decoder only:GPT(Generative Pre-trained Transformer) 系列、语言建模、自回归生成任务、流式任务
    • Encoder-Decoder:机器翻译、语音识别
  • 特定:
    • 无先验假设,例如局部关联性、有序建模
    • 核心在于自注意力机制,平方复杂度
    • 数据量的要求与先验假设的程度成反比

先验假设(归纳偏置)数据(样本)量,成反比:

  • 归纳偏置(Inductive Bias):在学习算法中,当学习器去预测其未遇到过的输入结果时,所做的一些假设的集合。
  • Transformer计算量以序列长度的平方成正比。
  • 基于先验假设,优化模型,例如降低计算量,要注入先验假设。
  • Transformer长时建模性长,并行计算,对比与RNN或LSTM

Transformer的Loss函数

交叉熵:torch.nn.CrossEntropyLoss

PyTorch中,CrossEntropy的输入,期望Class放在第2维,Batch放在第1维,可以是类别索引(Class indices),也可以是类别的概率(Probabilities for each class)。

reduction默认是mean,例如6个单词的平均交叉熵。reduction是none,默认交叉熵:先做softmax,再做-ln(prob)

参考:CLIP算法的Loss详解 和 交叉熵CrossEntropy实现

# 定义softmax函数
def softmax(x):return np.exp(x) / np.sum(np.exp(x))# 利用numpy计算
def cross_entropy_np(x, y):x_softmax = [softmax(x[i]) for i in range(len(x))]x_log = [np.log(x_softmax[i][y[i]]) for i in range(len(y))]loss = - np.sum(x_log) / len(y)return loss# 测试逻辑
x = [[1.9269, 1.4873, 0.9007, -2.1055]]
y = [[2]]
v1 = cross_entropy_np(x, y)
print(f"v1:{v1}")x = torch.unsqueeze(torch.Tensor(x), dim=0)
x = x.transpose(1, 2)  # CrossEntropy输入期望: Class放在第2维,Batch放在第1维y = torch.Tensor(y)
y = y.to(torch.long)  # label的类型为longv2 = F.cross_entropy(x, y, reduction="none")
print(f"v2:{v2}")

随机种子:torch.manual_seed(42),每个rand之前,都需要添加

构建序列建模的Mask,如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fimport random
import numpy as np# batch_size=2, seqlen=3, vocab_size=4
torch.manual_seed(42)
logits = torch.randn(2, 3, 4)
logits = logits.transpose(1, 2)
print(f'[Info] logits.shape:{logits.shape}')
print(f'[Info] logits: \n{logits}')# logits_softmax = F.softmax(logits, dim=1)
# print(f'[Info] logits_softmax: \n{logits_softmax}')# batch_size=2, vocab_size=4
torch.manual_seed(42)
label = torch.randint(0, 4, (2, 3))
print(f'[Info] label.shape:{label.shape}')
print(f'[Info] label: \n{label}')
# loss: torch.nn.CrossEntropyLoss -> F.cross_entropy
# (2x4x3) + (2x3) = (2x3)
val = F.cross_entropy(logits, label, reduction="none")
print(f"[Info] val.shape:{val.shape}")
print(f"[Info] val: \n{val}")# 在loss中, 增加mask, 与ignore_index参数功能类似,默认值是-100
tgt_len = torch.Tensor([2,3]).to(torch.int32)
mask = [F.pad(torch.ones(L), (0, max(tgt_len)-L)) for L in tgt_len]
mask = torch.stack(mask)
print(f"[Info] mask: \n{mask}")
val = F.cross_entropy(logits, label, reduction="none") * mask
print(f"[Info] val.shape:{val.shape}")
print(f"[Info] val: \n{val}")# 与ignore_index参数功能类似,默认值是-100
label[0, 2] = -100
val = F.cross_entropy(logits, label, reduction="none")
print(f"[Info] val.shape:{val.shape}")
print(f"[Info] val: \n{val}")

PyTorch笔记 - Attention Is All You Need (4)相关推荐

  1. PyTorch笔记 - Attention Is All You Need (1)

    CNN: 权重共享:平移不变形.可并行计算 滑动窗口:局部关联性建模.依赖多层堆积来进行长程建模 对相对位置敏感,对绝对位置不敏感 RNN:依次有序递归建模 对顺序敏感 串行计算耗时 长程建模能力弱 ...

  2. PYTORCH笔记 actor-critic (A2C)

    理论知识见:强化学习笔记:Actor-critic_UQI-LIUWJ的博客-CSDN博客 由于actor-critic是policy gradient和DQN的结合,所以同时很多部分和policy ...

  3. pytorch笔记:policy gradient

    本文参考了 策略梯度PG( Policy Gradient) 的pytorch代码实现示例 cart-pole游戏_李莹斌XJTU的博客-CSDN博客_策略梯度pytorch 在其基础上添加了注释和自 ...

  4. pytorch 笔记:手动实现AR (auto regressive)

    1 导入库& 数据说明 import numpy as np import torch import matplotlib.pyplot as plt from tensorboardX im ...

  5. pytorch 笔记:tensorboardX

    1 SummaryWriter 1.1 创建 首先,需要创建一个 SummaryWriter 的示例: from tensorboardX import SummaryWriter#以下是三种不同的初 ...

  6. pytorch 笔记:DataLoader 扩展:构造图片DataLoader

    数据来源:OneDrive for Business 涉及内容:pytorch笔记:Dataloader_UQI-LIUWJ的博客-CSDN博客 torchvision 笔记:ToTensor()_U ...

  7. pytorch 笔记:torchsummary

    作用:打印神经网络的结构 以pytorch笔记:搭建简易CNN_UQI-LIUWJ的博客-CSDN博客 中搭建的CNN为例 import torch from torchsummary import ...

  8. (d2l-ai/d2l-zh)《动手学深度学习》pytorch 笔记(2)前言(介绍各种机器学习问题)以及数据操作预备知识Ⅰ

    开源项目地址:d2l-ai/d2l-zh 教材官网:https://zh.d2l.ai/ 书介绍:https://zh-v2.d2l.ai/ 笔记基于2021年7月26日发布的版本,书及代码下载地址在 ...

  9. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

最新文章

  1. java学习笔记-java中运算符号的优先顺序
  2. c mysql封装 jdbc_利用Java针对MySql封装的jdbc框架类 JdbcUtils 完整实现
  3. use vue 多个_vue.use 插件系统详解
  4. 【codevs2287】火车站,第一个A掉的钻石题(迟来的解题报告)
  5. 我认知的javascript之作用域和闭包
  6. vTestStudio:变体Variant初理解
  7. ideal上初写mapreduce程序出现的报错信息解决
  8. sit是什么环境_DEV SIT UAT PET SIM PRD PROD常见环境英文缩写含义
  9. oracle 中dummy,layout设计中dummy的作用详解(上图。好贴好贴,讲的很仔细)
  10. C#编写一个简单串口通讯上位机
  11. VR火得不行 那么它商业化的突破点到底在哪里?
  12. 【ROSE】1. Rational Rose简介
  13. Wpf大屏软件开发过程中遇到的若干问题
  14. 分析Perm()函数功能、代码、时间复杂度
  15. 两台计算机都使用远程桌面,远程桌面设置及使用
  16. 一、多媒体技术的基础本章小结
  17. 《卓有成效的管理者》读书分享
  18. sas统计分析学习笔记(六)
  19. mavlink协议详解_MAVLink通讯协议在STM32上移植,并自定义协议
  20. 补足每天的饮食营养,还得数神奇的小麦胚芽

热门文章

  1. SQL Sever数据库简介
  2. linux tail -f命令
  3. 实名认证接口都有哪些?
  4. MotoSimEG-VRC软件:机器人虚拟仿真动画3DPDF文件输出方法
  5. 人到中年,奋斗了十几年结果却是负债累累,还要继续创业吗?
  6. 开源一个基于微信小程序的蓝牙室内定位软件(附下载链接)
  7. win10系统如何变更默认字体 - 教程篇
  8. IntelliJ IDEA 没有Tomcat 也没有Application Servers的解决办法
  9. 关于视频中的速率问题,海思视频速率(高速信号与高频信号区分与解释)
  10. codeforces 351A A. Jeff and Rounding(★)