代码

为了处理二维图像,我们将尺寸为 H×W×C的图像reshape为拉平的2维图块,尺寸为 (N×(P^2×C))。其中, (P,P)为图块的大小, N=HW/P^2 。 N 是图块的数量,会影响输入序列的长度。Transformer在所有图层上使用恒定的隐矢量D,因此我们将图块拉平,并使用可训练的线性投影映射到D的大小,将此投影的输出称为patch embedding。对应代码如下:直接暴力拉伸

# Transformer.n, h, w, c = x.shapex = jnp.reshape(x, [n, h * w, c])

类似BERT的[class] token,我们在可嵌入的补丁序列(z_0^0=x_class )之前准备了可学习的embedding向量,该序列在Transformer编码器的输出(z_L^0 )的状态用作图像表示y。 在预训练和微调期间,都将分类head连接到 z_L^0。

# If we want to add a class token, add it here.if self.classifier == 'token':cls = self.param('cls', nn.initializers.zeros, (1, 1, c))cls = jnp.tile(cls, [n, 1, 1])x = jnp.concatenate([cls, x], axis=1)

分类head是通过在预训练时具有一个隐藏层的MLP以及在微调时通过一个线性层的MLP来实现的。

class MlpBlock(nn.Module):"""Transformer MLP / feed-forward block."""mlp_dim: intdtype: Dtype = jnp.float32out_dim: Optional[int] = Nonedropout_rate: float = 0.1kernel_init: Callable[[PRNGKey, Shape, Dtype],Array] = nn.initializers.xavier_uniform()bias_init: Callable[[PRNGKey, Shape, Dtype],Array] = nn.initializers.normal(stddev=1e-6)@nn.compactdef __call__(self, inputs, *, deterministic):"""Applies Transformer MlpBlock module."""actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dimx = nn.Dense(features=self.mlp_dim,dtype=self.dtype,kernel_init=self.kernel_init,bias_init=self.bias_init)(  # pytype: disable=wrong-arg-typesinputs)x = nn.gelu(x)x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)output = nn.Dense(features=actual_out_dim,dtype=self.dtype,kernel_init=self.kernel_init,bias_init=self.bias_init)(  # pytype: disable=wrong-arg-typesx)output = nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)return

位置embedding会添加到patch embedding中,以保留位置信息。我们使用标准的可学习1D位置embedding,因为我们没有观察到使用更高级的2D感知位置embedding可显着提高性能。embedding向量的结果序列用作编码器的输入。

class AddPositionEmbs(nn.Module):"""Adds (optionally learned) positional embeddings to the inputs.Attributes:posemb_init: positional embedding initializer."""posemb_init: Callable[[PRNGKey, Shape, Dtype], Array]@nn.compactdef __call__(self, inputs):"""Applies AddPositionEmbs module.By default this layer uses a fixed sinusoidal embedding table. If alearned position embedding is desired, pass an initializer toposemb_init.Args:inputs: Inputs to the layer.Returns:Output tensor with shape `(bs, timesteps, in_dim)`."""# inputs.shape is (batch_size, seq_len, emb_dim).assert inputs.ndim == 3, ('Number of dimensions should be 3,'' but it is: %d' % inputs.ndim)pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])pe = self.param('pos_embedding', self.posemb_init, pos_emb_shape)return inputs + pe

transformer--ViT相关推荐

  1. Keras构建用于分类任务的Transformer(Vision Transformer/VIT)

    文章目录 一.Vision Transformer (ViT)详细信息 二.Vision Transformer结构 三.Keras实现 3.1 相关包 3.2 数据读取 3.3 声明超参数 3.4 ...

  2. 品论文:VISION TRANSFORMER (VIT)

    今天上午看了个论文,每当遇到全英文论文的时候,就会发现自己的英文水平属实是太一般,但是看完这篇论文确实是感触良多!!! 论文标题:<AN IMAGE IS WORTH 16X16 WORDS: ...

  3. Vision Transformer(ViT)解读

    Vision Transformer Transformer原本是用在NLP上的模型,直到Vision Transformer的出现,transformer开始了在视觉领域的应用. 论文:An Ima ...

  4. Vision Transformer(ViT) 2: 应用及代码讲解

    文章目录 1. 代码讲解 1.1 PatchEmbed类 1)`__init__ `函数 2) forward 过程 1.2 Attention类 1)`__init__ `函数 2)forward ...

  5. Vision Transformer(VIT)代码分析——保姆级教程

    目录 前言 一.代码分析 1.1.DropPath模块 1.2.Patch Embeding 1.3.Multi-Head Attention 1.4.MLP 1.5.Block 1.6.Vision ...

  6. Vision Transformer(ViT) 1: 理论详解

    Vison Transformer 介绍 Vison Transformer论文- An Image is Worth 16x16 Words: Transformers for Image Reco ...

  7. ViT(vision transformer)原理快速入门

    本专题需要具备的基础: 了解深度学习分类网络原理. 了解2017年的transformer. Transformer 技术里程碑: ViT简介 时间:2020年CVPR 论文全称:<An Ima ...

  8. 各类Transformer都得稍逊一筹,LV-ViT:探索多个用于提升ViT性能的高效Trick

    [导读]本文探索了用于提升ViT性能的各种训练技巧.通过一系列实验对比.改进与组合,本文所提方案取得了SOTA方案,超越了EfficientNet.T2TViT.DeiT.Swin Transform ...

  9. vision transformer(viT)教学视频【通俗易懂】

    11.1 Vision Transformer(vit)网络详解_哔哩哔哩_bilibili 文章地址:Vision Transformer详解_霹雳吧啦Wz-CSDN博客 其中两个关键的图

  10. ICCV2021-PiT-池化操作不是CNN的专属,ViT说:“我也可以”;南大提出池化视觉Transformer(PiT)...

    关注公众号,发现CV技术之美 本文分享一篇 ICCV2021 论文:『Rethinking Spatial Dimensions of Vision Transformers』. 详细信息如下: 论文 ...

最新文章

  1. 前后端分离的探索(三)
  2. 鸿蒙系统哪里的,华为“鸿蒙系统”IPFS/FIL:区块链的“鸿蒙系统”?
  3. 微软的Ajax库客户端Bug总结
  4. axis2生成客户端代码_利用ApiPost一键、快速生成接口文档!女猿也过38节!
  5. 李笑来登GitHub趋势榜第一,教你自学编程,含37%“硬核鸡汤”
  6. OPEN SQL中通配符的使用
  7. POJ 1191 棋盘分割【区间类DP】
  8. Tssd2019最新版下载地址和更新说明
  9. EndNoteX7中conference proceeding和conference paper的区别
  10. 小米浏览器保存的html文件怎么打开方式,怎么将小米手机浏览器中的网页设置为书签...
  11. 国内期刊 CCT 模板编译经验
  12. SNPS IP LPDDR4 调试
  13. GSOAP 在一个客户端内调用多个服务出现的问题解决
  14. unity3D 涂涂乐使用shader实现上色效果
  15. 工作和生活遇到的Windows常见需求 跨局域网共享文件 网页加载慢更换DNS
  16. 理想电压源的内阻是0,理想电流源的内阻是无穷大
  17. 快速排序(quickSort) 和 插入排序(insertSort)
  18. AsyncTask原理
  19. 单细胞转录组:Smart-seq 2还是10X Genomics Chromium?
  20. php 强制不换行,HTMLnobr强制不换行标签元素

热门文章

  1. P1106 删数问题【贪心】
  2. Linux下安装OpenOffice
  3. 铝碳化硅封装材料行业研究及十四五规划分析报告
  4. lumen php命令,Lumen创建自定义make命令
  5. 北洋大讲堂之“斯凯网络CEO宋涛-我的创业之路”感想
  6. Python数据分析之股票双均线策略制定
  7. 一个高速交警的忠告(转)
  8. Boosting算法与假设间隔
  9. 基于MediaPlayer的Android播放器控件
  10. 编写大并发高负载通讯程序