选自arXiv

作者:Yifan Jiang等

机器之心编译

机器之心编辑部

「attention is really becoming『all you need』.」

最近,CV 研究者对 transformer 产生了极大的兴趣并取得了不少突破。这表明,transformer 有可能成为计算机视觉任务(如分类、检测和分割)的强大通用模型。

我们都很好奇:在计算机视觉领域,transformer 还能走多远?对于更加困难的视觉任务,比如生成对抗网络 (GAN),transformer 表现又如何?

在这种好奇心的驱使下,德州大学奥斯汀分校的 Yifan Jiang、Zhangyang Wang,IBM Research 的 Shiyu Chang 等研究者进行了第一次试验性研究,构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。与其它基于 transformer 的视觉模型相比,仅使用 transformer 构建 GAN 似乎更具挑战性,这是因为与分类等任务相比,真实图像生成的门槛更高,而且 GAN 训练本身具有较高的不稳定性。

  • 论文链接:https://arxiv.org/pdf/2102.07074.pdf

  • 代码链接:https://github.com/VITA-Group/TransGAN

从结构上来看,TransGAN 包括两个部分:一个是内存友好的基于 transformer 的生成器,该生成器可以逐步提高特征分辨率,同时降低嵌入维数;另一个是基于 transformer 的 patch 级判别器。

研究者还发现,TransGAN 显著受益于数据增强(超过标准的 GAN)、生成器的多任务协同训练策略和强调自然图像邻域平滑的局部初始化自注意力。这些发现表明,TransGAN 可以有效地扩展至更大的模型和具有更高分辨率的图像数据集。

实验结果表明,与当前基于卷积骨干的 SOTA GAN 相比,表现最佳的 TransGAN 实现了极具竞争力的性能。具体来说,TransGAN 在 STL-10 上的 IS 评分为 10.10,FID 为 25.32,实现了新的 SOTA。

该研究表明,对于卷积骨干以及许多专用模块的依赖可能不是 GAN 所必需的,纯 transformer 有足够的能力生成图像。

在该论文的相关讨论中,有读者调侃道,「attention is really becoming『all you need』.」

不过,也有部分研究者表达了自己的担忧:在 transformer 席卷整个社区的大背景下,势单力薄的小实验室要怎么活下去?

如果 transformer 真的成为社区「刚需」,如何提升这类架构的计算效率将成为一个棘手的研究问题。

基于纯 Transformer 的 GAN

作为基础块的 Transformer 编码器

研究者选择将 Transformer 编码器(Vaswani 等人,2017)作为基础块,并尽量进行最小程度的改变。编码器由两个部件组成,第一个部件由一个多头自注意力模块构造而成,第二个部件是具有 GELU 非线性的前馈 MLP(multiple-layer perceptron,多层感知器)。此外,研究者在两个部件之前均应用了层归一化(Ba 等人,2016)。两个部件也都使用了残差连接。

内存友好的生成器

NLP 中的 Transformer 将每个词作为输入(Devlin 等人,2018)。但是,如果以类似的方法通过堆叠 Transformer 编码器来逐像素地生成图像,则低分辨率图像(如 32×32)也可能导致长序列(1024)以及更高昂的自注意力开销。

所以,为了避免过高的开销,研究者受到了基于 CNN 的 GAN 中常见设计理念的启发,在多个阶段迭代地提升分辨率(Denton 等人,2015;Karras 等人,2017)。他们的策略是逐步增加输入序列,并降低嵌入维数

如下图 1 左所示,研究者提出了包含多个阶段的内存友好、基于 Transformer 的生成器:

每个阶段堆叠了数个编码器块(默认为 5、2 和 2)。通过分段式设计,研究者逐步增加特征图分辨率,直到其达到目标分辨率 H_T×W_T。具体来说,该生成器以随机噪声作为其输入,并通过一个 MLP 将随机噪声传递给长度为 H×W×C 的向量。该向量又变形为分辨率为 H×W 的特征图(默认 H=W=8),每个点都是 C 维嵌入。然后,该特征图被视为长度为 64 的 C 维 token 序列,并与可学得的位置编码相结合。

与 BERT(Devlin 等人,2018)类似,该研究提出的 Transformer 编码器以嵌入 token 作为输入,并递归地计算每个 token 之间的匹配。为了合成分辨率更高的图像,研究者在每个阶段之后插入了一个由 reshaping 和 pixelshuffle 模块组成的上采样模块。

具体操作上,上采样模块首先将 1D 序列的 token 嵌入变形为 2D 特征图,然后采用 pixelshuffle 模块对 2D 特征图的分辨率进行上采样处理,并下采样嵌入维数,最终得到输出。然后,2D 特征图 X’_0 再次变形为嵌入 token 的 1D 序列,其中 token 数为 4HW,嵌入维数为 C/4。所以,在每个阶段,分辨率(H, W)提升到两倍,同时嵌入维数 C 减少至输入的四分之一。这一权衡(trade-off)策略缓和了内存和计算量需求的激增。

研究者在多个阶段重复上述流程,直到分辨率达到(H_T , W_T )。然后,他们将嵌入维数投影到 3,并得到 RGB 图像

用于判别器的 tokenized 输入

与那些需要准确合成每个像素的生成器不同,该研究提出的判别器只需要分辨真假图像即可。这使得研究者可以在语义上将输入图像 tokenize 为更粗糙的 patch level(Dosovitskiy 等人,2020)。

如上图 1 右所示,判别器以图像的 patch 作为输入。研究者将输入图像分解为 8 × 8 个 patch,其中每个 patch 可被视为一个「词」。然后,8 × 8 个 patch 通过一个线性 flatten 层转化为 token 嵌入的 1D 序列,其中 token 数 N = 8 × 8 = 64,嵌入维数为 C。再之后,研究者在 1D 序列的开头添加了可学得位置编码和一个 [cls] token。在通过 Transformer 编码器后,分类 head 只使用 [cls] token 来输出真假预测。

实验

CIFAR-10 上的结果

研究者在 CIFAR-10 数据集上对比了 TransGAN 和近来基于卷积的 GAN 的研究,结果如下表 5 所示:

如上表 5 所示,TransGAN 优于 AutoGAN (Gong 等人,2019) ,在 IS 评分方面也优于许多竞争者,如 SN-GAN (Miyato 等人, 2018)、improving MMDGAN (Wang 等人,2018a)、MGAN (Hoang 等人,2018)。TransGAN 仅次于 Progressive GAN 和 StyleGAN v2。

对比 FID 结果,研究发现,TransGAN 甚至优于 Progressive GAN,而略低于 StyleGANv2 (Karras 等人,2020b)。在 CIFAR-10 上生成的可视化示例如下图 4 所示:

STL-10 上的结果

研究者将 TransGAN 应用于另一个流行的 48×48 分辨率的基准 STL-10。为了适应目标分辨率,该研究将第一阶段的输入特征图从(8×8)=64 增加到(12×12)=144,然后将提出的 TransGAN-XL 与自动搜索的 ConvNets 和手工制作的 ConvNets 进行了比较,结果下表 6 所示:

与 CIFAR-10 上的结果不同,该研究发现,TransGAN 优于所有当前的模型,并在 IS 和 FID 得分方面达到新的 SOTA 性能。

高分辨率生成

由于 TransGAN 在标准基准 CIFAR-10 和 STL-10 上取得不错的性能,研究者将 TransGAN 用于更具挑战性的数据集 CelebA 64 × 64,结果如下表 10 所示:

TransGAN-XL 的 FID 评分为 12.23,这表明 TransGAN-XL 可适用于高分辨率任务。可视化结果如图 4 所示。

局限性

虽然 TransGAN 已经取得了不错的成绩,但与最好的手工设计的 GAN 相比,它还有很大的改进空间。在论文的最后,作者指出了以下几个具体的改进方向:

  • 对 G 和 D 进行更加复杂的 tokenize 操作,如利用一些语义分组 (Wu et al., 2020)。

  • 使用代理任务(pretext task)预训练 Transformer,这样可能会改进该研究中现有的 MT-CT。

  • 更加强大的注意力形式,如 (Zhu 等人,2020)。

  • 更有效的自注意力形式 (Wang 等人,2020;Choromanski 等人,2020),这不仅有助于提升模型效率,还能节省内存开销,从而有助于生成分辨率更高的图像。

作者简介

本文一作 Yifan Jiang 是德州大学奥斯汀分校电子与计算机工程系的一年级博士生(此前在德克萨斯 A&M 大学学习过一年),本科毕业于华中科技大学,研究兴趣集中在计算机视觉、深度学习等方向。目前,Yifan Jiang 主要从事神经架构搜索、视频理解和高级表征学习领域的研究,师从德州大学奥斯汀分校电子与计算机工程系助理教授 Zhangyang Wang。

在本科期间,Yifan Jiang 曾在字节跳动 AI Lab 实习。今年夏天,他将进入 Google Research 实习。

一作主页:https://yifanjiang.net/

参考链接:https://www.reddit.com/r/MachineLearning/comments/ll30kf/r_transgan_two_transformers_can_make_one_strong/

© THE END

转载请联系 机器之心 公众号获得授权

投稿或寻求报道:content@jiqizhixin.com

点个在看 paper不断!

华人博士生首次尝试用两个Transformer构建一个GAN相关推荐

  1. 不用卷积,也能生成清晰图像,华人博士生首次尝试用两个Transformer构建一个GAN

    「attention is really becoming『all you need』.」 选自arXiv,作者:Yifan Jiang等,机器之心编译,机器之心编辑部 最近,CV 研究者对 tran ...

  2. 不用卷积也能生成清晰图像,用两个Transformer构建一个GAN

    作者|Yifan Jiang等 来源|机器之心 attention is really becoming『all you need』. 最近,CV 研究者对 transformer 产生了极大的兴趣并 ...

  3. LIVE 预告 | TransGAN:丢弃卷积,纯Transformer构建GAN网络

    自2014年Ian J. Goodfellow等人提出以来,生成对抗网络(GAN,Generative Adversarial Networks)便迅速成为人工智能领域中最有前景的研究方向之一. 而另 ...

  4. 2021年Facebook博士生奖研金名单公布!一半获奖者是华人博士生

    作者 | 陈彩娴.贝爽 转自:AI科技评论 当地时间22日,Facebook公布了2021年博士生奖研金(2021 PhD Fellowship)名单!在来自全球100所大学的2163份申请中,Fac ...

  5. 摩根大通公布2021年AI研究博士生奖学金名单!获奖华人博士生占1/3

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 本文转载自:AI科技评论 作者 | 陈彩娴 编辑 | 刘冰一 不久前,摩根大通(J.P.Morgan Cha ...

  6. 脉冲神经网络在目标检测的首次尝试,性能堪比CNN | AAAI 2020

    译者 |  VincentLee 来源 | 晓飞的算法工程笔记 脉冲神经网络(Spiking neural network, SNN)将脉冲神经元作为计算单元,能够模仿人类大脑的信息编码和处理过程.不 ...

  7. Spiking-YOLO : 前沿!脉冲神经网络在目标检测的首次尝试 | AAAI 2020

    点击我爱计算机视觉标星,更快获取CVML新技术 论文提出Spiking-YOLO,是脉冲神经网络在目标检测领域的首次成功尝试,实现了与卷积神经网络相当的性能,而能源消耗极低.论文内容新颖,比较前沿,推 ...

  8. java如何给一个链表定义和传值_如何在CUDA中为Transformer编写一个PyTorch自定义层...

    如今,深度学习模型处于持续的演进中,它们正变得庞大而复杂.研究者们通常通过组合现有的 TensorFlow 或 PyTorch 操作符来发现新的架构.然而,有时候,我们可能需要通过自定义的操作符来实现 ...

  9. docker 解决php 502,Docker里两个php容器一个正常访问,一个出现502 Bad Gateway nginx/1.17.8。...

    问题描述 Docker里两个php容器一个镜像是phpfpm5.6,运行的程序是emlog,正常访问.一个镜像是phpfpm7.3,运行的程序是wordpress,运行的出现502 Bad Gatew ...

最新文章

  1. Git 头指针分离与 FETCH_HEAD
  2. autojs遍历当前页面所有控件_自定义控件(引入布局)
  3. Prism for WPF初探(构建简单的模块化开发框架)
  4. java转成图形界面_【转】java图形界面设计(AWT)
  5. leetcode-665-Non-decreasing Array
  6. 深入浅出讲解语言模型
  7. mysql中limit关键字_【JAVA】关于mysql的limit关键字使用。
  8. truffle unbox webpack报错
  9. Kafka 源码分析之网络层(一)
  10. Python官方软件包存储库成恶意软件大本营?
  11. Elasticsearch 备份数据到 AWS S3
  12. android局域网怎么传文件,两手机同一局域网怎么传文件
  13. Tone-Mapped Image Quality Assessment
  14. 巴特沃斯(Butterworth)滤波器(二)
  15. css旋转立方体教程,如何通过CSS3实现旋转立方体
  16. Teemo Attacking 提莫攻击
  17. IDEA 一直卡在Buil(编译 write classes)报错资源不足
  18. PHP7.2与apache环境安装部署详细流程
  19. cmd html 查找汉子字,cmd搜索字符串加换行 在cmd(命令提示符)中怎样换行
  20. 2020华数杯C题脱贫帮扶绩效评价你怕了吗?

热门文章

  1. .Net(c#) 通过 Fortran 动态链接库,实现混合编程
  2. LeetCode实战:删除链表中的节点
  3. Ivanti 洞察职场新趋势:71% 的员工宁愿放弃升职也要选择随处工作
  4. 程序员是复制粘贴的工具人?还是掌握“谜底”的魔术师?
  5. 改名 Meta,打元宇宙牌,老龄化的 Facebook 能否再换新颜
  6. WebDriver 识别反爬虫的原理和破解方法~
  7. 用 Python 动态可视化,看看比特币这几年
  8. 还缺30万人!程序员2020年要过好日子了……
  9. 我佛了!用KNN实现验证码识别,又 Get 到一招!
  10. 一次性掌握机器学习基础知识脉络 | 公开课笔记