【解析】Vision Transformer 在图像分类中的应用
An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale
代码:https://github.com/google-research/vision_transformer
文章目录
- 小序
- 1、ViT原理分析:
- 1.1 Patch Embedding
- 为什么要追加这个向量?
- 1.2 Positional Encoding
- 1.3 Transformer Encoder的前向过程
- 1.4 训练方法:
- 1.5 最后,展示下ViT的动态过程:
- 1.6 Experiments:
- 2. ViT代码解读:
- 2.1 使用:
- 2.2 定义残差,FeedForward Layer 等:
- 2.3 Class ViT:
- 2.4 ViT 模型完整代码
小序
ViT(Vision Transformer) 是直接将Transformer直接应用在图像,经过微调:将图像拆成16x16 patch,然后将patch 的 the sequence of linear embedding 作为 Transformers 的输入。
- 数据处理部分:
先对图片作分块,再将每个图片块展平成一维向量 - 数据嵌入部分:
Patch Embedding:对每个向量都做一个线性变换
Positional Encoding:通过一个可学习向量加入序列的位置信息 - 编码部分:
class_token:额外的一个用于分类的可学习向量,与输入进行拼接 - 分类部分:
mlp_head:使用 LayerNorm 和两层全连接层实现的,采用的是GELU激活函数
但是实验表明,在中等尺寸数据集训练后,分类正确率相比于ResNet上往往降低几个百分点,这是由于transformer缺乏CNN的固有的inductive bias 如 translation equivariance and locality,因而在数据不充分情况时不能很好泛化。而在数据尺寸足够的情况下训练transfprmer,是能够应对这种inductive bias,实现对流行模型的性能逼近甚至超越。
1. 什么是CNN 的 inductive bisa?
表现为:transformers 在小数据上的预测正确率比 CNN 低,当采用混合结构时(即将CNN的输出特征作为输入序列时,尽在小数据上实现性能提升),这与我们预期有差,期望CNN的引入能够提升所有尺寸训练样本下的性能。就是凭借一些规律得出的偏好:如CNN天然的对图像处理的较好,天然的具有平移不变性等;
2. Patch 如何理解?
patch 是将 3 维图像 reshape 为2维之后进行切分,使用的 position embedding 是1维,将 patch作为一个小整体,然后对patch在整个图像中的位置进行编码,还是按照分割后的位置信息。
1、ViT原理分析:
这个工作本着尽可能少修改的原则,将原版的Transformer开箱即用地迁移到分类任务上面。并且作者认为没有必要总是依赖于CNN,只用Transformer也能够在分类任务中表现很好,尤其是在使用大规模训练集的时候。同时,在大规模数据集上预训练好的模型,在迁移到中等数据集或小数据集的分类任务上以后,也能取得比CNN更优的性能。下面看具体的方法:
这个工作首先把 x∈H×W×Cx\in H \times W \times Cx∈H×W×C 的图像,变成一个 xp∈N×(P2⋅C)x_p \in N \times (P^2 \cdot C)xp∈N×(P2⋅C) 的sequence of flattened 2D patches。它可以看做是一系列的展平的2D块的序列,这个序列中一共有 N=HW/P2N =HW/P^2N=HW/P2 个展平的2D块,每个块的维度是 (P2×C)(P^2\times C)(P2×C) 。其中 PPP 是块大小, CCC 是channel数。
注意作者做这步变化的意图:
根据之前的讲解,Transformer希望输入一个二维的矩阵 (N,D)(N,D)(N,D) ,其中 NNN 是sequence的长度,DDD 是sequence的每个向量的维度,常用256。所以这里也要设法把 H×W×CH\times W \times CH×W×C 的三维图片转化成 (N,D)(N,D)(N,D) 的二维输入。
所以有: H×W×C→N×(P2⋅C)H \times W \times C \to N \times (P^2 \cdot C)H×W×C→N×(P2⋅C),where N=HW/P2N=HW/P^2N=HW/P2 。
其中,NNN 是Transformer输入的sequence的长度。
代码是:
x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
具体是采用了einops库实现,具体可以参考这篇博客,科技猛兽:PyTorch 70.einops:优雅地操作张量维度
现在得到的向量维度是:xp∈N×(P2×C)x_p \in N \times (P^2 \times C)xp∈N×(P2×C) ,要转化成 (N,D)(N,D)(N,D) 的二维输入,我们还需要做一步叫做Patch Embedding的步骤。
1.1 Patch Embedding
方法是对每个向量都做一个线性变换(即全连接层),压缩后的维度为DDD ,这里我们称其为 Patch Embedding。
z0=[xclass;xp1E;xp2E;....;xpnE]+Epos(1)z_0 = [\color{green}x_{class}; \color{back} x_p^1E; x_p^2E; .... ; x_p^nE]+ E_{pos} \tag1z0=[xclass;xp1E;xp2E;....;xpnE]+Epos(1)
这个全连接层就是上式(5.1)中的 E\color{red}EE ,它的输入维度大小是 (P2⋅C)(P^2 \cdot C)(P2⋅C) ,输出维度大小是 DDD。
# 将3072变成dim,假设是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)
注意这里的绿色字体 xclass\color{green}x_{class}xclass ,假设切成9个块,但是最终到Transfomer输入是10个向量,这是人为增加的一个向量。
为什么要追加这个向量?
如果没有这个向量,假设 N=9N=9N=9 个向量输入Transformer Encoder,输出9个编码向量,然后呢?对于分类任务而言,我应该取哪个输出向量进行后续分类呢?
不知道。干脆就再来一个向量 xclass(vector,dim=D)\color{green}x_{class}(vector ,dim =D)xclass(vector,dim=D) ,这个向量是可学习的嵌入向量,它和那9个向量一并输入Transfomer Encoder,输出1+9个编码向量。然后就用第0个编码向量,即 xclass\color{green}x_{class}xclass 的输出进行分类预测即可。
这么做的原因可以理解为:ViT其实只用到了Transformer的Encoder,而并没有用到Decoder,而 xclass\color{green}x_{class}xclass 的作用有点类似于解码器中的 QueryQueryQuery 的作用,相对应的 Key,ValueKey, ValueKey,Value 就是其他9个编码向量的输出。xclass\color{green}x_{class}xclass 是一个可学习的嵌入向量,它的意义说通俗一点为:寻找其他9个输入向量对应的 imageimageimage 的类别。
代码为:
# dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))# forward前向代码
# 变成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 跟前面的分块进行concat
# 额外追加token,变成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)
1.2 Positional Encoding
按照Transformer的位置编码的习惯,这个工作也使用了位置编码。引入了一个 Positional encoding Epos\color{violet}E_{pos}Epos来加入序列的位置信息,同样在这里也引入了pos_embedding,是用一个可训练的变量。
z0=[xclass;xp1E;xp2E;....;xpnE]+Eposz_0 = [x_{class}; x_p^1E; x_p^2E; .... ; x_p^nE]+ \color{violet}E_{pos}z0=[xclass;xp1E;xp2E;....;xpnE]+Epos
没有采用原版Transformer的 sincossincossincos 编码,而是直接设置为可学习的Positional Encoding,效果差不多。对训练好的pos_embedding进行可视化,如下图所示。
我们发现,位置越接近,往往具有更相似的位置编码。此外,出现了行列结构;同一行/列中的patch具有相似的位置编码。
# num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
1.3 Transformer Encoder的前向过程
z0=[xclass;xp1E;xp2E;....;xpnE]+Epos,E∈RP2×C×D,Epos∈R(N+1)×D(2)z_0 = [x_{class}; x_p^1E; x_p^2E; .... ; x_p^nE]+ E_{pos}, \qquad \qquad E\in \mathbb{R}^{P^2 \times C\times D}, E_{pos} \in \mathbb{R}^{(N+1)\times D} \tag2z0=[xclass;xp1E;xp2E;....;xpnE]+Epos,E∈RP2×C×D,Epos∈R(N+1)×D(2)
zℓ′=MSA(LN(zℓ−1))+zℓ−1,ℓ=1...L(3){z}'_\ell = \color{violet}MSA(LN(z_{\ell-1}))+z_{\ell-1}, \qquad \qquad \color{back}\ell=1...L \qquad \qquad \qquad \tag3zℓ′=MSA(LN(zℓ−1))+zℓ−1,ℓ=1...L(3)
zℓ=MLP(LN(zℓ′))+zell′,ℓ=1...L(4)z_{\ell} = \color{blue}MLP(LN({z}'_\ell))+{z}'_{ell}, \qquad \qquad \color{back} \ell=1...L \qquad \qquad \quad \tag4zℓ=MLP(LN(zℓ′))+zell′,ℓ=1...L(4)
y=LN(zℓ0)(5)y = LN(z^0_{\ell}) \qquad\qquad\qquad \tag5y=LN(zℓ0)(5)
- 其中,第1个式子为上面讲到的 Patch Embedding 和 Positional Encoding 的过程。
- 第2个式子为Transformer Encoder的 Multi−headSelf−Attention,AddandNorm\color{violet}Multi-head \quad Self-Attention, Add and NormMulti−headSelf−Attention,AddandNorm 的过程,重复 LLL 次。
- 第3个式子为Transformer Encoder的 FeedForward,AddandNorm\color{blue}Feed Forward, AddandNormFeedForward,AddandNorm 的过程,重复 LLL 次。
作者采用的是没有任何改动的 Transformer。
最后是一个 MLPMLPMLP 的 Classfication−HeadClassfication - HeadClassfication−Head ,整个的结构只有这些,如下图所示,为了方便读者的理解,我把变量的维度变化过程标注在了图中。
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
1.4 训练方法:
先在大数据集上预训练,再迁移到小数据集上面。做法是把ViT的 prediction−head\color{violet}prediction-headprediction−head 去掉,换成一个 D×KD \times KD×K 的 FeedForwardLayer\color{violet}FeedForwardLayerFeedForwardLayer 。其中 KKK 为对应数据集的类别数。
当输入的图片是更大的shape时,patch size PPP 保持不变,则 N=HW/P2N=HW/P^2N=HW/P2 会增大。
ViT可以处理任意 NNN 的输入,但是Positional Encoding是按照预训练的输入图片的尺寸设计的,所以输入图片变大之后,Positional Encoding需要根据它们在原始图像中的位置做2D插值。
1.5 最后,展示下ViT的动态过程:
ViT的动态过程
整个流程:
- 一个图片256x256,分成了64个32x32的patch;
- 对这么多的patch做embedding,成64个1024向量;
- 再拼接一个cls_tokens,变成65个1024向量;
- 再加上pos_embedding,还是65个1024向量;
- 这些向量输入到transformer中进行自注意力的特征提取;
- 输出的是64个1024向量,然后对这个50个求均值,变成一个1024向量;
- 然后线性层把1024维变成 mlp_head维从而完成分类任务的transformer模型。
1.6 Experiments:
预训练模型使用到的数据集有:
- ILSVRC-2012 ImageNet dataset:1000 classes
- ImageNet-21k:21k classes
- JFT:18k High Resolution Images
将预训练迁移到的数据集有:
- CIFAR-10/100
- Oxford-IIIT Pets
- Oxford Flowers-102
- VTAB
作者设计了3种不同答小的ViT模型,它们分别是:
DModel | Layers | Hidden size | MLP size | Heads | Params |
---|---|---|---|---|---|
ViT-Base | 12 | 768 | 3072 | 12 | 86M |
ViT-Large | 24 | 1024 | 4096 | 16 | 307M |
ViT-Huge | 32 | 1280 | 5120 | 16 | 632M |
ViT-L/16代表ViT-Large + 16 patch size
评价指标 Metrics :
结果都是下游数据集上经过finetune之后的Accuracy,记录的是在各自数据集上finetune后的性能。
实验1:性能对比
实验结果如下图所示,整体模型还是挺大的,而经过大数据集的预训练后,性能也超过了当前CNN的一些SOTA结果。对比的CNN模型主要是:
2020年ECCV的Big Transfer (BiT)模型,它使用大的ResNet进行有监督转移学习。
2020年CVPR的Noisy Student模型,这是一个在ImageNet和JFT300M上使用半监督学习进行训练的大型高效网络,去掉了标签。
All models were trained on TPUv3 hardware。
在JFT-300M上预先训练的较小的ViT-L/16模型在所有任务上都优于BiT-L(在同一数据集上预先训练的),同时训练所需的计算资源要少得多。 更大的模型ViT-H/14进一步提高了性能,特别是在更具挑战性的数据集上——ImageNet, CIFAR-100和VTAB数据集。 与现有技术相比,该模型预训练所需的计算量仍然要少得多。
下图为VTAB数据集在Natural, Specialized, 和Structured子任务与CNN模型相比的性能,ViT模型仍然可以取得最优。
实验2:ViT对预训练数据的要求
ViT对于预训练数据的规模要求到底有多苛刻?
作者分别在下面这几个数据集上进行预训练:ImageNet, ImageNet-21k, 和JFT-300M。
结果如下图所示:
我们发现: 当在最小数据集ImageNet上进行预训练时,尽管进行了大量的正则化等操作,但ViT-大模型的性能不如ViT-Base模型。
但是有了稍微大一点的ImageNet-21k预训练,它们的表现也差不多。
只有到了JFT 300M,我们才能看到更大的ViT模型全部优势。 图3还显示了不同大小的BiT模型跨越的性能区域。BiT CNNs在ImageNet上的表现优于ViT(尽管进行了正则化优化),但在更大的数据集上,ViT超过了所有的模型,取得了SOTA。
作者还进行了一个实验: 在9M、30M和90M的随机子集以及完整的JFT300M数据集上训练模型,结果如下图所示。 ViT在较小数据集上的计算成本比ResNet高, ViT-B/32比ResNet50稍快;它在9M子集上表现更差, 但在90M+子集上表现更好。ResNet152x2和ViT-L/16也是如此。这个结果强化了一种直觉,即:
残差对于较小的数据集是有用的,但是对于较大的数据集,像attention一样学习相关性就足够了,甚至是更好的选择。
实验3:ViT的注意力机制Attention
作者还给了注意力观察得到的图片块, Self-attention使得ViT能够整合整个图像中的信息,甚至是最底层的信息。作者欲探究网络在多大程度上利用了这种能力。
具体来说,我们根据注意力权重计算图像空间中整合信息的平均距离,如下图所示。
注意这里我们只使用了attention,而没有使用CNN,所以这里的attention distance相当于CNN的receptive field的大小。
作者发现:在最底层, 有些head也已经注意到了图像的大部分,说明模型已经可以globally地整合信息了,说明它们负责global信息的整合。其他的head 只注意到图像的一小部分,说明它们负责local信息的整合。Attention Distance随深度的增加而增加。
整合局部信息的attention head在混合模型(有CNN存在)时,效果并不好,说明它可能与CNN的底层卷积有着类似的功能。
作者给出了attention的可视化,注意到了适合分类的位置:
2. ViT代码解读:
2.1 使用:
import torch
from vit_pytorch import ViTv = ViT(image_size = 256,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1
)img = torch.randn(1, 3, 256, 256)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend topreds = v(img, mask = mask) # (1, 1000)
- 传入参数的意义: image_size:输入图片大小。
- patch_size:论文中 patch size: 图片 的大小。
- num_classes:数据集类别数。
- dim:Transformer的隐变量的维度。
- depth:Transformer的Encoder,Decoder的Layer数。
- heads:Multi-head Attention
- layer的head数。
- mlp_dim:MLP层的hidden dim。
- dropout:Dropout rate。
- emb_dropout:Embedding dropout rate。
2.2 定义残差,FeedForward Layer 等:
class Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, **kwargs):return self.fn(x, **kwargs) + xclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)
Attention和Transformer,注释已标注在代码中:
class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head * headsself.heads = headsself.scale = dim ** -0.5self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout))def forward(self, x, mask = None):# b, 65, 1024, heads = 8b, n, _, h = *x.shape, self.heads# self.to_qkv(x): b, 65, 64*8*3# qkv: b, 65, 64*8qkv = self.to_qkv(x).chunk(3, dim = -1)# b, 65, 64, 8q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)# dots:b, 65, 64, 64dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scalemask_value = -torch.finfo(dots.dtype).maxif mask is not None:mask = F.pad(mask.flatten(1), (1, 0), value = True)assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'mask = mask[:, None, :] * mask[:, :, None]dots.masked_fill_(~mask, mask_value)del mask# attn:b, 65, 64, 64attn = dots.softmax(dim=-1)# 使用einsum表示矩阵乘法:# out:b, 65, 64, 8out = torch.einsum('bhij,bhjd->bhid', attn, v)# out:b, 64, 65*8out = rearrange(out, 'b h n d -> b n (h d)')# out:b, 64, 1024out = self.to_out(out)return outclass Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))]))def forward(self, x, mask = None):for attn, ff in self.layers:x = attn(x, mask = mask)x = ff(x)return x
2.3 Class ViT:
class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_size // patch_size) ** 2patch_dim = channels * patch_size ** 2assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.patch_size = patch_sizeself.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.patch_to_embedding = nn.Linear(patch_dim, dim)self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, num_classes))def forward(self, img, mask = None):p = self.patch_size# 图片分块x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)# 降维(b,N,d)x = self.patch_to_embedding(x)b, n, _ = x.shape# 多一个可学习的x_class,与输入concat在一起,一起输入Transformer的Encoder。(b,1,d)cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)x = torch.cat((cls_tokens, x), dim=1)# Positional Encoding:(b,N+1,d)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)# Transformer的输入维度x的shape是:(b,N+1,d)x = self.transformer(x, mask)# (b,1,d)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x) # (b,1,num_class)
2.4 ViT 模型完整代码
# !/usr/bin/env python
# -*- coding:utf-8 -*-
# @Time : 2021.
# @Author : 绿色羽毛
# @Email : lvseyumao@foxmail.com
# @Blog : https://blog.csdn.net/ViatorSun
# @Note :import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
# from einops.layers.torch import Rearrangeclass Residual(nn.Module):def __init__(self, fn):super().__init__()self.fn = fndef forward(self, x, **kwargs):return self.fn(x, **kwargs) + xclass PreNorm(nn.Module):def __init__(self, dim, fn):super().__init__()self.norm = nn.LayerNorm(dim)self.fn = fndef forward(self, x, **kwargs):return self.fn(self.norm(x), **kwargs)class FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential( nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout) )def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head * headsself.heads = headsself.scale = dim ** -0.5self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout) )def forward(self, x, mask = None):# b, 65, 1024, heads = 8b, n, _ = x.shapeh = self.heads# self.to_qkv(x): b, 65, 64*8*3# qkv: b, 65, 64*8qkv = self.to_qkv(x).chunk(3, dim = -1) # 沿-1轴分为3块# b, 65, 64, 8q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)# dots:b, 65, 64, 64dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scalemask_value = -torch.finfo(dots.dtype).maxif mask is not None:mask = F.pad(mask.flatten(1), (1, 0), value = True)assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'mask = mask[:, None, :] * mask[:, :, None]dots.masked_fill_(~mask, mask_value)del mask# attn:b, 65, 64, 64attn = dots.softmax(dim=-1)# 使用einsum表示矩阵乘法:# out:b, 65, 64, 8out = torch.einsum('bhij,bhjd->bhid', attn, v)# out:b, 64, 65*8out = rearrange(out, 'b h n d -> b n (h d)')# out:b, 64, 1024out = self.to_out(out)return outclass Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):super().__init__()self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([ Residual(PreNorm(dim, Attention( dim, heads = heads, dim_head = dim_head, dropout = dropout))),Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))) ]))def forward(self, x, mask = None):for attn, ff in self.layers:x = attn(x, mask = mask)x = ff(x)return xclass ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_size // patch_size) ** 2patch_dim = channels * patch_size ** 2# assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'self.patch_size = patch_sizeself.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.patch_to_embedding = nn.Linear(patch_dim, dim)self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, num_classes) )def forward(self, img, mask = None):p = self.patch_size# 图片分块# print(img.shape)x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) # 1,3,256,256 -> 1,64,3072# 降维(b,N,d)x = self.patch_to_embedding(x)b, n, _ = x.shape# 多一个可学习的x_class,与输入concat在一起,一起输入Transformer的Encoder。(b,1,d)cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)x = torch.cat((cls_tokens, x), dim=1)# Positional Encoding:(b,N+1,d)x += self.pos_embedding[:, :(n + 1)]x = self.dropout(x)# Transformer的输入维度x的shape是:(b,N+1,d)x = self.transformer(x, mask)# (b,1,d)x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]x = self.to_latent(x)return self.mlp_head(x) # (b,1,num_class)if __name__ == '__main__':v = ViT(image_size=256, patch_size=32, num_classes=10, dim=1024, depth=6, heads=16, mlp_dim=2048, dropout=0.1,emb_dropout=0.1)img = torch.randn(1, 3, 256, 256)mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend topreds = v(img, mask=mask) # (1, 1000)print(preds)
【解析】Vision Transformer 在图像分类中的应用相关推荐
- Vision Transformer 必读系列之图像分类综述(二): Attention-based
文 @ 000007 号外号外:awesome-vit 上新啦,欢迎大家 Star Star Star ~ https://github.com/open-mmlab/awesome-vitgith ...
- Vision Transformer发展现状
--------------- 声明 CSDN:越来越胖的GuanRunwei 知乎:无名之辈 / IDPT集萃感知 皆为本人 --------------- 背景 自 DETR 与 Vision T ...
- Vision Transformer(ViT)PyTorch代码全解析(附图解)
Vision Transformer(ViT)PyTorch代码全解析 最近CV领域的Vision Transformer将在NLP领域的Transormer结果借鉴过来,屠杀了各大CV榜单.本文将根 ...
- Vision Transformer中的自监督学习
任何接触过机器学习的人都肯定听说过监督学习和无监督学习.这些实际上是机器学习的两种重要的可能方法,已被广泛使用多年.然而,直到最近,一个新术语"自我监督学习"出现了爆炸式增长!但是 ...
- 【解析】Token to Token Vision Transformer
Vision Transformer 的提出颠覆了我们以往对图像处理的方式,也开阔了Transformer 在CV方向上的潜力,但其有一些缺点,如需要 超大型数据集(JFT)预训练,才能达到现在CNN ...
- ICCV2021 | Vision Transformer中相对位置编码的反思与改进
前言 在计算机视觉中,相对位置编码的有效性还没有得到很好的研究,甚至仍然存在争议,本文分析了相对位置编码中的几个关键因素,提出了一种新的针对2D图像的相对位置编码方法,称为图像RPE(IRPE). ...
- Vision Transformer在CV任务中的速度如何保证?
本文作者丨盘子正@知乎 编辑丨极市平台 来源丨https://zhuanlan.zhihu.com/p/569482746 我(盘子正@知乎)的PhD课题是Vision Transformer的 ...
- VIT(vision transformer)结构解析
文章目录 背景 网络结构 VIT简介 VIT模型概述 参考 transformer的出现彻底改变了自然语言处理的世界,然而在计算机视觉中,注意力机制保持原卷积网络整体结构,常与卷积网络结合.或是取代卷 ...
- Pytorch使用Vision Transformer做肺癌和结肠癌组织病理学图像分类
模型介绍 文章链接:https://arxiv.org/pdf/2010.11929.pdf github地址: 视频教程:https://www.bilibili.com/video/BV1Jh41 ...
最新文章
- SpringBoot四大核心组件,必知必会!
- Dynamo分布式系统——「RWN」协议解决多备份数据如何读写来保证数据一致性,而「向量时钟」来保证当读取到多个备份数据的时候,如何判断哪些数据是最新的这种情况...
- python爬取今日头条的文章_Python3爬取今日头条有关《人民的名义》文章
- FreeSwitch Lua编程接口(1)dialplan里的配置
- 反转链表与链表实现两数相加(简单思路)
- 凸优化系列二:确定步长一维搜索算法
- 配置高并发jdbc连接池
- 微信api接口调用-给微信好友或群聊发消息
- Flash builder 4.6 下载 破解 序列号【你懂的】
- android自动生成cardview,CardView
- 微生物组-扩增子16S分析和可视化(2022.7)
- Android 10 电池图标修改
- Tensorflow笔记4:Saver
- CTFHub 技能树web
- 基于百度地图API的WinForm地图
- 转载--三少三多技术开发
- 那些惊艳了我的第三方插件收集
- Monkey King-左偏树
- Redis各版本描述
- px(像素)与 dp, sp换算公式
热门文章
- jQuery-链接api实现星座运势和手机号归属地查询
- 论文阅读笔记:(2022) Delving into the Devils of Bird‘s-eye-view Perception: A Review, Evaluation and Recipe
- 硅谷归来,七大感触——You Only Live Once
- 微信小程序使用TDesign(TS版本)
- 一个大学教授在美国的生活
- 特斯拉蛇形充电机器人_特斯拉也造出蛇形机器人,专为充电使用!
- 如何更改Source Insight 4.0中Tab的宽度
- QueryByExampleExecutor接口的查询
- 99年人民币冠号大全
- 计算机上可用内存不足无法打开图,Windows照片查看器无法显示此图片因为计算机上的可用内存可能不足解决方法...