这是博主在五一期间对Transformer几篇相关论文阅读的小笔记和总结
也借鉴参考了很多大佬的优秀文章,链接贴在文章下方,推荐大家前去阅读

该文章只是简单叙述几个Transformer模型的基本框架,对其详细信息(如实验情况等)请阅读论文或点击下方对应文章链接前往阅读

阅读论文

A Survey on Visual Transformer[1]

Transformers in Vision: A Survey[2]

An Image Is Worth 16X16 Words: Transformers for Image Recognition at Scale[3]

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows[4]

Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions[5]

Transformer in Transformer[6]

ConvBERT: Improving BERT with Span-based Dynamic Convolution[7]

Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet[8]

AutoTrans: Automating Transformer Design via Reinforced Architecture Search[9]

综述

[1],[2]讲述了目前Visual Transformer的爆火和Transformer在CV方面的各种应用,

Transformer应用

虽然Transformer爆火,但目前CV领域还是由CNN主导

但人们依然很看好Transformer在CV领域的应用于前景,希望将Transformer迁移到CV领域中,

主要分为两大类

  • 将self-attention机制与常见的CNN架构结合
  • 用self-attention机制完全替代CNN,如[3],[4],[5]

[3]提出VIT模型,应用于CV领域的Image Classification,在大规模数据上训练的得到的模型达到不错的效果

[4]提出Swin Transformer模型,基于VIT模型的改进,应用于Image classification,Object detection,Semantic segmentation

[5]提出Pyramid Vision Transformer模型,基于VIT模型的改进,应用于Object detection,Semantic segmentation

[6]提出Transformer in Transformer模型,应用于Image benchmark and downstream tasks

[7]提出ConvBERT模型,基于BERT的预训练语言模型

[8]提出Tokens-to-Token ViT模型,基于VIT模型改进

[9]讨论了Transformer Model 的自动化设计,对于layer-norm的位置,个数,attention head 的个数,使用哪个激活函数的问题

VIT

VIT是早期提出用于CV中Image classification的结构,虽然存在着不足,但对最近Swin TransformerPyramid Vision Transformer等模型提出提供了宝贵经验。

Transformer在NLP中处理的是序列化的数据,而CV中处理的是三维的图像数据(H,W,C)

所以,VIT提出了Patch划分的方法,将三维图像数据转化为序列化数据

基础模型

Model

  • 将图像转化为序列化数据

    • 将图像划分为一个个Patch,对这一系列Patch进行reshape,从而得到序列化的Patch数据,即Flattened Patch

    对于一张H×W×CH \times W \times CH×W×C的图片,采用P×PP \times PP×P的Patch来划分,得到N=HP×WPN=\frac{H}{P} \times \frac{W}{P}N=PH×PW个patch(P×P×CP\times P\times CP×P×C),reshape后得到维度P2CP^2CP2C的Flattened Patch(向量),对N个Flattened Patch进行concat之后得到一个N×(P2C)N\times (P^2C)N×(P2C)的二维矩阵,相对于NLP中输入Transformer的Word Embedding

    • Linear Projection:对维度P2CP^2CP2C的Flattened Patch进行维度转化,得到固定长度DDD的特征向量

      • 防止模型结构受Patch大小的影响

    综上,原本H×W×CH \times W \times CH×W×C的图片转化为了N个DN个DND维向量(或一个N×DN \times DN×D的二维矩阵)

  • Position embedding

    • 0-9的紫色框表示各个位置的position embedding,粉色框为经过Linear Projection后的Flattened Patch,二者相加
  • Learnable embedding

    • 分类标志位[class],即图中0编号,它经过Encoder后的结果作为整个图像的表示
    • 假如随意指定1-9作为整个图像的表示,则会导致整体表示偏向这个Embedding信息,即图像表示偏重于反映某个Patch
  • Transformer Encoder

    • 第一层输入,xxx表示一个Patch,EEE表示Linear Projection转化矩阵
      z0=[xclass;xp1E;...xpNE]+EposE∈R(P2⋅C)×D,Epos∈R(N+1)×Dz_0=[x_{class};x^1_pE;...x_p^NE]+E_{pos}\\ E\in \Bbb R^{(P^2·C)\times D},E_{pos}\in \Bbb R^{(N+1)\times D} z0=[xclass;xp1E;...xpNE]+EposER(P2C)×D,EposR(N+1)×D

    • 对于第lll

    zl′=MSA(LN(zl−1))+zl−1zl=MLP(LN(zl′))+zl′l=1...Lz^{'}_l=MSA(LN(z_{l-1}))+z_{l-1}\\ z_l=MLP(LN(z^{'}_l))+z^{'}_l\\ l = 1...L zl=MSA(LN(zl1))+zl1zl=MLP(LN(zl))+zll=1...L

    • 每个块后使用Residual connection,每个块前使用Layernorm

不足

在中等大小的数据集(如 ImageNet)上训练得到的VIT模型准确率低于CNN模型

Swin

Swin Transformer基于VIT进行了改进,可以作为通用backone,应用于Image classification,Object detection,Semantic segmentation

  • 引入层次化结构

    • 解决CV领域scale变化范围大的问题
  • 将self-attention计算限制在local window中
    • 解决CV使用Transformer计算复杂度为图像size二次方的问题
  • 支持cross-window connection

基础模型

创新

  • 引入stage,随着网络深度增加,合并patch
  • 使用基于Shifted Windows的W-MSA替换标准的MSA
  • 提出Shifted Window划分

Stage

引入4个stage,每个stage的操作类似于上述的VIT模型

区别

Swin随着网络深度的加深,patch的数量在不断减少,且每个patch的感知范围在不断扩大,而VIT模型的patch保持不变

  • Patch Partition:如VIT模型将图像转化为序列化数据,将原始图像划分为一个个Patch,但彼此不重叠

    • 根据设置的Patch大小(4×44 \times 44×4),对输入的Image(H×W×3H\times W\times 3H×W×3)划分为H4×W4\frac{H}{4} \times \frac{W}{4}4H×4W个Patch
    • 将原始图像的像素值作为其 Feature
    • 每个Patch的维度为4×4×3=484 \times 4 \times 3 = 484×4×3=48
  • Stage1

    • Linear Embedding:将 Feature 映射到任意维度,记作CCC
    • Transformer Block:进行Self-Attention运算,不改变token数量
  • Stage2

    • Patch Merging:对2×22 \times 22×2的相邻Patch进行合并,通过linear layer将Feature映射到4C4C4C维度

      • 合并过后,Patch数量降为14\frac{1}{4}41,MLP作用后,维度为2C2C2C
  • Stage3,Stage4

    • 类似于Stage2,最终得到一个层次化的结构

W-MSA

每个window包含M×MM \times MM×M个patch,整个图像包含h×wh \times wh×w个patch,总共h4×w4\frac{h}{4} \times \frac{w}{4}4h×4w个Patch

复杂度计算
Ω(MSA)=4hwC2+2(hw)2CΩ(W−MSA)=4hwC2+2M2hwC\Omega(MSA) = 4hwC^2 + 2(hw)^2C\\ \Omega(W-MSA)=4hwC^2 + 2M^2hwC Ω(MSA)=4hwC2+2(hw)2CΩ(WMSA)=4hwC2+2M2hwC
MMM固定大小时,Ω(W−MSA)\Omega(W-MSA)Ω(WMSA)线性增长

Shifted Window 划分

在连续的Swin Transformer Block之间交替进行两种划分配置,进而保持 non-overlapped windows 的高效计算同时引入 cross-window connection

  • Layer1:将8×88\times 88×8的Feature Map 划分为2×22 \times 22×2个window,每个window大小为4×4,即M=44 \times4,即M=44×4,M=4
  • Layer2:将Layer1的一系列window移动(⌊M2⌋,⌊M2⌋)(\lfloor\frac{M}{2} \rfloor,\lfloor\frac{M}{2} \rfloor)(2M,2M)个像素

存在问题

  • 可能会生成过多window,范围[⌈hM⌉×⌈wM⌉,(⌈hM⌉+1)×(⌈wM⌉+1)][\lceil\frac{h}{M} \rceil\times\lceil\frac{w}{M} \rceil,(\lceil\frac{h}{M} \rceil+1)\times(\lceil\frac{w}{M} \rceil+1)][Mh×Mw,(Mh+1)×(Mw+1)]
  • 生成window大小不一致,部分window较小

解决方案

Cyclic Shift

经过循环填补,Layer2的window划分类似于Layer1,此时的batch window由不相邻的sub-window构成,采用masking mechanism来限制seft attention在sub-window中的计算

不足

图像分类上比ViT、DeiT等Transformer类型的网络效果更好,但是比不过CNN类型的EfficientNet,猜测Swin Transformer还是更加适用于更加复杂、尺度变化更多的任务。

Swin Transformer没有提供类似反卷积的上采样的算法

每一个window的Q,K,V都是独立的,即没有CNN的权值共享特性

PVT

Pyramid Vision Transformer同样基于VIT模型进行改进,是一种使用Transformer模型的无卷积骨干网络,主要用于除图像分类外的密集预测任务。

  • 引入金字塔结构

    • Feature Map的分辨率随着网络加深,逐渐减小
    • Feature Map 的Channel数随着网络加深,逐渐增大
    • 解决了VIT模型中,输入高分辨率图像产生高内存占用甚至显存溢出的问题
  • PVT继承了CNN和Transformer的优势,通过简单的替换CNN骨干使其成为不同视觉任务的统一骨干结构
  • 在object detection,semantic and instance segmentation 任务上取得优越性

基础模型

创新

  • 引入stage,随着网络深度增加,Feature Map的分辨率逐渐减小Channel数逐渐增大
  • 使用**spatial-reduction attention(SRA)**替换标准的MSA

Stage

  • Stage1

    • patch embedding:对输入的H×W×3H\times W\times 3H×W×3 图像(Feature Map)进行token化,设定其patch的大小p×pp\times pp×p,将RGB图像划分为HWP2\frac{HW}{P^2}P2HW个patch

      • 对这一系列的patch进行展开,然后传入linear projection中进行Patch Embedding
      • 将结果reshape后得到大小HWP2×C\frac{HW}{P^2}\times CP2HW×C的embedded patch
    • 将embedded patch 和其对应的position embedding 一起传入Transformer Encoder中
      • SRA处理:使用函数Reshape(x,Ri)Reshape(x,R_i)Reshape(x,Ri)将传入的embedding patch进行变形,由R(HW)×C\Bbb R^{(HW)\times C}R(HW)×C转换为RHWRi2×(Ri2C)\Bbb R^{\frac{HW}{R_i^2}\times(R_i^2C)}RRi2HW×(Ri2C),然后乘以一个Ws∈R(Ri2C)×CW^s\in \Bbb R^{(R_i^2C)\times C}WsR(Ri2C)×C矩阵,进而将原本的空间规模降为1Ri2\frac{1}{R_i^2}Ri21,即RHWRi2×C\Bbb R^{\frac{HW}{R_i^2}\times C}RRi2HW×C,这是SRA的核心
    • 将Transformer Encoder的结果进行reshape得到HP×WP×C\frac{H}{P}\times\frac{W}{P}\times CPH×PW×C Feature Map,即F1F_1F1
  • Stage2,Stage3,Stage4
    • 由上一个Stage的输出作为输入,重复流程

SRA

SRA的核心是减少K与V的空间规模,也就是<K,V>对的数量,对比于MHA,K与V的空间规模降为其1Ri2\frac{1}{R_i^2}Ri21

处理过程
SRA(Q,K,V)=Concat(head0,...,headNi)Woheadj=Attention(QWjQ,SR(K)WjV,SR(V)WjV)SRA(Q,K,V)=Concat(head_0,...,head_{N_{i}})W^o\\ head_j=Attention(QW_j^Q,SR(K)W_j^V,SR(V)W_j^V) SRA(Q,K,V)=Concat(head0,...,headNi)Woheadj=Attention(QWjQ,SR(K)WjV,SR(V)WjV)
SR(·)为spatial-reduction operation,WsW^sWs是一个linear projection,它把input sequence的维度降到CiC_iCi
SR(x)=Norm(Reshape(x,Ri)Ws)SR(x)=Norm(Reshape(x,R_i)W^s) SR(x)=Norm(Reshape(x,Ri)Ws)

不足

  • 随着输入图片的增大,PVT的资源消耗增长率比ResNet高

TNT

Transformer in Transformer,用于视觉识别

  • 对patch-level和pixel-level representation都进行建模

    • 解决VIT相关模型忽略了每个patch内部固有结构信息的问题
  • 堆叠TNT Block来构建TNT模型

基础模型

创新

  • Inner Transformer Block用于提取pixel embedding的局部结构信息
  • Outer Transformer Block用于提取patch embedding的全局信息
  • 通过Linear Projection将pixel embedding投影到patch embedding space
  • 将输入图像划分为一个个Patch,

  • Unfold & Linear:将输入的Patch,转换为对应的Patch Embedding,Pixel Embedding

    • Patch
      γ0=[Y01,...Y0n]∈Rn×p′×p′×cY0i∈Rp′×p′×c\gamma_0=[Y_0^1,...Y_0^n]\in\Bbb R^{n \times p^{'}\times p^{'} \times c}\\ Y_0^i\in \Bbb R^{p^{'}\times p^{'} \times c} γ0=[Y01,...Y0n]Rn×p×p×cY0iRp×p×c

    • Pixel
      Y0i=[y0i,1,...,y0i,m]m=p′2Y_0^i=[y_0^{i,1},...,y_0^{i,m}]\\ m=p^{'2} Y0i=[y0i,1,...,y0i,m]m=p2

  • TNT Block

    • inner transformer block
      Yl′i=Yl−1i+MSA(LN(Yl−1i))Yli=Yl′i+MLP(LN(Yl′i))Y_l^{'i}=Y_{l-1}^{i}+MSA(LN(Y_{l-1}^{i}))\\ Y_l^{i}=Y_{l}^{'i}+MLP(LN(Y_{l}^{'i}))\\ Yli=Yl1i+MSA(LN(Yl1i))Yli=Yli+MLP(LN(Yli))

    • outer transformer block 输入,Vec()Vec()Vec()将Pixel展开成向量,b为bias
      Zl−1i=Zl−1i+Vec(Yl−1i)Wl−1+bl−1Z_{l-1}^i=Z_{l-1}^i+Vec(Y_{l-1}^{i})W_{l-1}+b_{l-1} Zl1i=Zl1i+Vec(Yl1i)Wl1+bl1

    • outer transformer block
      Zl′i=Zl−1i+MSA(LN(Zl−1i))Zli=Zl′i+MLP(LN(Zl′i))Z_l^{'i}=Z_{l-1}^{i}+MSA(LN(Z_{l-1}^{i}))\\ Z_l^{i}=Z_{l}^{'i}+MLP(LN(Z_{l}^{'i}))\\ Zli=Zl1i+MSA(LN(Zl1i))Zli=Zli+MLP(LN(Zli))

    • TNT block
      γl,Zl=TNT(γl−1,Zl−1)\gamma_l,Z_l=TNT(\gamma_{l-1},Z_{l-1}) γl,Zl=TNT(γl1,Zl1)

position encoding

  • Pixel position encoding 在每一个Patch是共享的

ConvBERT

ConvBERT是一种基于span的动态卷积模型

  • 提出span-based dynamic convolution来代替一些冗余的self-attention head

    • 减少了预训练的计算花销,提高了local dependencies的建模能力
  • 提出mixed attention block
    • 结合span-based dynamic convolution和剩余的self-attention head
    • 更高效地学习 global and local context
  • 基于BERT结合mixed attention block,建立ConvBERT模型
    • 在各种downstream tasks中表现比BERT及其变体模型优越
    • 且更少的训练花销和更少的模型参数

基础模型

Span-based dynamic convolution

  • self-attention:使用所有token来捕获全局依赖关系,但由观察得知,BERT模型学习的更多是局部依赖关系

    • 随着序列增长,复杂度呈现二次增长

    • 故传统BERT模型存在大量冗余

      attention map

  • dynamic convolution:使用一个Kernel Generator来为每一个word embedding生成自己的Kernel

    • 但对于上下文相同的word生成的Kernel是相同的,无法解决一词多义的问题
  • span-based dynamic convolution:通过输入的word embedding和其周围的word embedding结合来生成Kernel

    • 解决一词多义的问题

Span-based dynamic convolution

Span-based dynamic convolution

生成dynamic Kernel
f(Q,Ks)=softmax(Wf(Q∘Ks))f(Q,K_s)=softmax(W_f(Q\circ K_s)) f(Q,Ks)=softmax(Wf(QKs))
span-based dynamic convolution公式
SDConv(Q,Ks,V;Wf,i)=LConv(V,softmax(Wf(Q∘Ks)),i)SDConv(Q,K_s,V;W_f,i)=LConv(V,softmax(W_f(Q \circ K_s)),i) SDConv(Q,Ks,V;Wf,i)=LConv(V,softmax(Wf(QKs)),i)
ConvBERT架构

  • Mixed Attention

    • 结合Self-Attention和Span-based dynamic convolution
    • Self-Attention捕获全局信息,Span-based dynamic convolution捕获局部信息

ConvBERT $$ MixedAttention(K,Q,K_s,V;W_f)=Cat(SelfAttention(Q,K,V),SDConv(Q,K_s,V;W_f)) $$

  • Bottleneck design for Self-Attention

    • 使用bottleneck structure来减少attention head
    • 将输入的embedding映射到更低维度
  • Grouped Feed-Forward module

    • 对Feed-Forward的改进,减少参数数量
      M=∏i=0g[fdg→mgi(H[:,i−1:i×dg])M′=GeLU(M)H′=∏i=0g[fdg→mgi(M[:,i−1:i×dg]′)]M=\prod_{i=0}^{g}[f_{\frac{d}{g} \rightarrow \frac{m}{g}}^i(H_{[:,i-1:i\times \frac{d}{g}]})\\ M^{'} = GeLU(M)\\ H^{'}=\prod_{i=0}^{g}[f_{\frac{d}{g} \rightarrow \frac{m}{g}}^i(M_{[:,i-1:i\times \frac{d}{g}]}^{'})] M=i=0g[fgdgmi(H[:,i1:i×gd])M=GeLU(M)H=i=0g[fgdgmi(M[:,i1:i×gd])]

Tokens-to-Token ViT

Tokens-to-Token Vit是基于Vit模型进行改进的

  • 提出Tokens-to-token

    • 解决Vit无法对图像相邻像素的局部结构信息(边缘,线条,纹理)进行建模
    • 对token进行局部建模,保留token局部结构信息并且减少token长度
  • 结合了deep-narrow结构的高效backone
    • 解决Vit冗余attention对feature richness的限制

基础模型

创新

  • Tokens to token Module:对图像的local structure 信息进行建模,并且减少了token的长度

    • re-structurization
    • soft split
  • T2T-ViT backone:应用了deep-narrow结构,减少了attention冗余,提高feature richness

Tokens to token Module

  • Re-structurization

    • 对输入进行传统Transformer处理
      T′=MLP(MSA(T))T^{'}=MLP(MSA(T)) T=MLP(MSA(T))

    • Reshape:将T′∈Rl×cT^{'}\in \Bbb R^{l\times c}TRl×c转为I∈Rh×w×cI \in \Bbb R^{h \times w\times c}IRh×w×c
      I=Reshape(T′)I=Reshape(T^{'}) I=Reshape(T)

  • Soft Split

    • 建立local structure信息,减少token长度lll

    • 避免re-structurization过程中信息丢失,在split中采取overlap机制,将每一个patch与其周围的patch联系起来,从而捕获周围pixel和patch的信息
      l0=⌊h+2p−kk−s+1⌋×⌊w+2p−kk−s+1⌋l_0=\lfloor\frac{h+2p-k}{k-s}+1 \rfloor \times \lfloor\frac{w+2p-k}{k-s}+1 \rfloor l0=ksh+2pk+1×ksw+2pk+1

      Ti+1=SS(Ti)T_{i+1}=SS(T_i) Ti+1=SS(Ti)

T2T-ViT backone

AutoTrans

  • comprehensive search space

    • 对layer-norm的位置设置
  • PL strategy and parameter sharing strategy
    • 对attention head数量的设置

参考文章

感谢下列文章提供的帮助,推荐大伙们阅读

用Transformer完全替代CNN

Swin Transformer对CNN的降维打击

大白话Pyramid Vision Transformer

Transformer in Transformer论文解读

ConvBERT:使用基于区间的动态卷积来提升BERT

Transformer系列论文阅读相关推荐

  1. 定位系列论文阅读-RoNIN(二)-Robust Neural Inertial Navigation in the Wild: Benchmark, Evaluations

    这里写目录标题 0.Abstract 0.1逐句翻译 0.2总结 1. Introduction 1.1逐句翻译 第一段(就是说惯性传感器十分重要有研究的必要) 第二段(惯性导航是非常理想的一个导航方 ...

  2. dqn系列梳理_系列论文阅读——DQN及其改进

    DQN 作为DRL的开山之作,DeepMind的DQN可以说是每一个入坑深度增强学习的同学必了解的第一个算法了吧.先前,将RL和DL结合存在以下挑战:1.deep learning算法需要大量的lab ...

  3. SSL for Medical Image Classification系列论文阅读笔记 -- ACPL

    ACPL: Anti-curriculum Pseudo-labelling for Semi-supervised Medical Image Classification(CVPR 2022) C ...

  4. 【定位系列论文阅读】-Indoor Visual Positioning Aided by CNN-Based Image Retrieval: Training-Free(一)

    文章目录 0.论文速览 0.1 文章信息 0.2 概述 0.2.1 研究什么东西 0.2.2 评价 1.Abstract 1.1 逐句翻译 1.2 总结 2.INTRODUCTION 2.1 逐句翻译 ...

  5. BIBM系列论文阅读笔记

    <Detecting Driver Sleepiness from EEG Alpha Wave during Daytime Driving> 数据集:8人5男3女EEG数据,采样率=1 ...

  6. 学习笔记:R-CNN系列论文阅读,用faster-rcnn实现交通标志牌的检测

    R-CNN,Fast-RCNN,Faster-RCNN都是基于候选区域(region proposal)的识别网络,在图片上寻找可能是目标存在的区域,对每个区域进行分类和检测框回归,实现目标检测. R ...

  7. 视觉注意力系列概念及论文阅读学习

    最近在看视觉注意力机制,看到比较好的博客或者公众号文章如下: 首先当然是要了解自然语言处理NLP里面的的注意力机制和Transformer基本概念: 1.Visualizing A Neural Ma ...

  8. 【论文阅读】Scene Text Image Super-Resolution in the Wild

    [论文阅读]Scene Text Image Super-Resolution in the Wild 摘要 引言 相关工作 TextZoom数据集 方法 pipeline SRB 中央对齐模块 梯度 ...

  9. [论文阅读](Transformer系列)

    文章目录 一.Video Transformer Network 摘要 引言 相关工作:Applying Transformers on long sequences Video Transforme ...

最新文章

  1. 2021年大数据ELK(三):Lucene全文检索库介绍
  2. 跨域调用报表展现页面的flash打印方法
  3. Java SE7新特性之try-with-resources语句
  4. 数据库分片教程mysql_简述MySQL分片中快速数据迁移
  5. python中 的用法_详解python中@的用法
  6. vue的html自动刷新,Vue页面刷新记住页面状态的实现
  7. jquery jeditable 多选插件 (checkbox or select)
  8. 基于深度学习的图像修补/完整方法分析
  9. pmwiki 安装和基本配置
  10. 物联网时代下,如何打造智慧新社区?
  11. 初次远程做Linux Iptables规则注意事项
  12. Jmeter下载安装详细步骤(2021)
  13. java applepay_java后端处理Apple Pay流程
  14. DAS、NAS、SAN、ISCSI的区别
  15. python-flask 设置网页保留缓存静态文件时间
  16. android开发split的方法在String中的特殊使用
  17. uefi+guid分区与legacy+mbr分区_对硬盘进行分区时,GPT和MBR有什么区别
  18. 【原创】PID控制算法模拟器
  19. CTF基础知识-Web
  20. 学习进度总结————王烁130201218

热门文章

  1. LCD1602液晶屏使用(51单片机七夕特别版)
  2. python 实现轨迹数据可视化
  3. C语言实现万年历系统
  4. 单片机(STC系列8051内核单片机)
  5. 「学习笔记」黑马面面布局开发
  6. 身份证合法性校验规则
  7. JS(JavaScript)验证身份证号码格式的合法性
  8. 普元 AppServer 6.5 业务应用连接mysql数据库报错:java.security.UnrecoverableKeyException: Password verification fai
  9. CentOS中 DNF 和 Yum 的区别
  10. 北京大学可视化暑期学校Day1总结