作者|Yifan Jiang等

来源|机器之心

attention is really becoming『all you need』.

最近,CV 研究者对 transformer 产生了极大的兴趣并取得了不少突破。这表明,transformer 有可能成为计算机视觉任务(如分类、检测和分割)的强大通用模型。

我们都很好奇:在计算机视觉领域,transformer 还能走多远?对于更加困难的视觉任务,比如生成对抗网络 (GAN),transformer 表现又如何?

在这种好奇心的驱使下,德州大学奥斯汀分校的 Yifan Jiang、Zhangyang Wang,IBM Research 的 Shiyu Chang 等研究者进行了第一次试验性研究,构建了一个只使用纯 transformer 架构、完全没有卷积的 GAN,并将其命名为 TransGAN。与其它基于 transformer 的视觉模型相比,仅使用 transformer 构建 GAN 似乎更具挑战性,这是因为与分类等任务相比,真实图像生成的门槛更高,而且 GAN 训练本身具有较高的不稳定性。

论文链接:https://arxiv.org/pdf/2102.07074.pdf

代码链接:https://github.com/VITA-Group/TransGAN

从结构上来看,TransGAN 包括两个部分:一个是内存友好的基于 transformer 的生成器,该生成器可以逐步提高特征分辨率,同时降低嵌入维数;另一个是基于 transformer 的 patch 级判别器。

研究者还发现,TransGAN 显著受益于数据增强(超过标准的 GAN)、生成器的多任务协同训练策略和强调自然图像邻域平滑的局部初始化自注意力。这些发现表明,TransGAN 可以有效地扩展至更大的模型和具有更高分辨率的图像数据集。

实验结果表明,与当前基于卷积骨干的 SOTA GAN 相比,表现最佳的 TransGAN 实现了极具竞争力的性能。具体来说,TransGAN 在 STL-10 上的 IS 评分为 10.10,FID 为 25.32,实现了新的 SOTA。

该研究表明,对于卷积骨干以及许多专用模块的依赖可能不是 GAN 所必需的,纯 transformer 有足够的能力生成图像。

在该论文的相关讨论中,有读者调侃道,「attention is really becoming『all you need』.」

不过,也有部分研究者表达了自己的担忧:在 transformer 席卷整个社区的大背景下,势单力薄的小实验室要怎么活下去?

如果 transformer 真的成为社区「刚需」,如何提升这类架构的计算效率将成为一个棘手的研究问题。

基于纯 Transformer 的 GAN

作为基础块的 Transformer 编码器

研究者选择将 Transformer 编码器(Vaswani 等人,2017)作为基础块,并尽量进行最小程度的改变。编码器由两个部件组成,第一个部件由一个多头自注意力模块构造而成,第二个部件是具有 GELU 非线性的前馈 MLP(multiple-layer perceptron,多层感知器)。此外,研究者在两个部件之前均应用了层归一化(Ba 等人,2016)。两个部件也都使用了残差连接。

内存友好的生成器

NLP 中的 Transformer 将每个词作为输入(Devlin 等人,2018)。但是,如果以类似的方法通过堆叠 Transformer 编码器来逐像素地生成图像,则低分辨率图像(如 32×32)也可能导致长序列(1024)以及更高昂的自注意力开销。

所以,为了避免过高的开销,研究者受到了基于 CNN 的 GAN 中常见设计理念的启发,在多个阶段迭代地提升分辨率(Denton 等人,2015;Karras 等人,2017)。他们的策略是逐步增加输入序列,并降低嵌入维数

如下图 1 左所示,研究者提出了包含多个阶段的内存友好、基于 Transformer 的生成器:

每个阶段堆叠了数个编码器块(默认为 5、2 和 2)。通过分段式设计,研究者逐步增加特征图分辨率,直到其达到目标分辨率 H_T×W_T。具体来说,该生成器以随机噪声作为其输入,并通过一个 MLP 将随机噪声传递给长度为 H×W×C 的向量。该向量又变形为分辨率为 H×W 的特征图(默认 H=W=8),每个点都是 C 维嵌入。然后,该特征图被视为长度为 64 的 C 维 token 序列,并与可学得的位置编码相结合。

与 BERT(Devlin 等人,2018)类似,该研究提出的 Transformer 编码器以嵌入 token 作为输入,并递归地计算每个 token 之间的匹配。为了合成分辨率更高的图像,研究者在每个阶段之后插入了一个由 reshaping 和 pixelshuffle 模块组成的上采样模块。

具体操作上,上采样模块首先将 1D 序列的 token 嵌入变形为 2D 特征图,然后采用 pixelshuffle 模块对 2D 特征图的分辨率进行上采样处理,并下采样嵌入维数,最终得到输出。然后,2D 特征图 X’_0 再次变形为嵌入 token 的 1D 序列,其中 token 数为 4HW,嵌入维数为 C/4。所以,在每个阶段,分辨率(H, W)提升到两倍,同时嵌入维数 C 减少至输入的四分之一。这一权衡(trade-off)策略缓和了内存和计算量需求的激增。

研究者在多个阶段重复上述流程,直到分辨率达到(H_T , W_T )。然后,他们将嵌入维数投影到 3,并得到 RGB 图像

用于判别器的 tokenized 输入

与那些需要准确合成每个像素的生成器不同,该研究提出的判别器只需要分辨真假图像即可。这使得研究者可以在语义上将输入图像 tokenize 为更粗糙的 patch level(Dosovitskiy 等人,2020)。

如上图 1 右所示,判别器以图像的 patch 作为输入。研究者将输入图像分解为 8 × 8 个 patch,其中每个 patch 可被视为一个「词」。然后,8 × 8 个 patch 通过一个线性 flatten 层转化为 token 嵌入的 1D 序列,其中 token 数 N = 8 × 8 = 64,嵌入维数为 C。再之后,研究者在 1D 序列的开头添加了可学得位置编码和一个 [cls] token。在通过 Transformer 编码器后,分类 head 只使用 [cls] token 来输出真假预测。

实验

CIFAR-10 上的结果

研究者在 CIFAR-10 数据集上对比了 TransGAN 和近来基于卷积的 GAN 的研究,结果如下表 5 所示:

如上表 5 所示,TransGAN 优于 AutoGAN (Gong 等人,2019) ,在 IS 评分方面也优于许多竞争者,如 SN-GAN (Miyato 等人, 2018)、improving MMDGAN (Wang 等人,2018a)、MGAN (Hoang 等人,2018)。TransGAN 仅次于 Progressive GAN 和 StyleGAN v2。

对比 FID 结果,研究发现,TransGAN 甚至优于 Progressive GAN,而略低于 StyleGANv2 (Karras 等人,2020b)。在 CIFAR-10 上生成的可视化示例如下图 4 所示:

STL-10 上的结果

研究者将 TransGAN 应用于另一个流行的 48×48 分辨率的基准 STL-10。为了适应目标分辨率,该研究将第一阶段的输入特征图从(8×8)=64 增加到(12×12)=144,然后将提出的 TransGAN-XL 与自动搜索的 ConvNets 和手工制作的 ConvNets 进行了比较,结果下表 6 所示:

与 CIFAR-10 上的结果不同,该研究发现,TransGAN 优于所有当前的模型,并在 IS 和 FID 得分方面达到新的 SOTA 性能。

高分辨率生成

由于 TransGAN 在标准基准 CIFAR-10 和 STL-10 上取得不错的性能,研究者将 TransGAN 用于更具挑战性的数据集 CelebA 64 × 64,结果如下表 10 所示:

TransGAN-XL 的 FID 评分为 12.23,这表明 TransGAN-XL 可适用于高分辨率任务。可视化结果如图 4 所示。

局限性

虽然 TransGAN 已经取得了不错的成绩,但与最好的手工设计的 GAN 相比,它还有很大的改进空间。在论文的最后,作者指出了以下几个具体的改进方向:

  • 对 G 和 D 进行更加复杂的 tokenize 操作,如利用一些语义分组 (Wu et al., 2020)。

  • 使用代理任务(pretext task)预训练 Transformer,这样可能会改进该研究中现有的 MT-CT。

  • 更加强大的注意力形式,如 (Zhu 等人,2020)。

  • 更有效的自注意力形式 (Wang 等人,2020;Choromanski 等人,2020),这不仅有助于提升模型效率,还能节省内存开销,从而有助于生成分辨率更高的图像。

作者简介

本文一作 Yifan Jiang 是德州大学奥斯汀分校电子与计算机工程系的一年级博士生(此前在德克萨斯 A&M 大学学习过一年),本科毕业于华中科技大学,研究兴趣集中在计算机视觉、深度学习等方向。目前,Yifan Jiang 主要从事神经架构搜索、视频理解和高级表征学习领域的研究,师从德州大学奥斯汀分校电子与计算机工程系助理教授 Zhangyang Wang。

在本科期间,Yifan Jiang 曾在字节跳动 AI Lab 实习。今年夏天,他将进入 Google Research 实习。

一作主页:https://yifanjiang.net/

参考链接:https://www.reddit.com/r/MachineLearning/comments/ll30kf/r_transgan_two_transformers_can_make_one_strong/

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

不用卷积也能生成清晰图像,用两个Transformer构建一个GAN相关推荐

  1. 不用卷积,也能生成清晰图像,华人博士生首次尝试用两个Transformer构建一个GAN

    「attention is really becoming『all you need』.」 选自arXiv,作者:Yifan Jiang等,机器之心编译,机器之心编辑部 最近,CV 研究者对 tran ...

  2. 华人博士生首次尝试用两个Transformer构建一个GAN

    选自arXiv 作者:Yifan Jiang等 机器之心编译 机器之心编辑部 「attention is really becoming『all you need』.」 最近,CV 研究者对 tran ...

  3. LIVE 预告 | TransGAN:丢弃卷积,纯Transformer构建GAN网络

    自2014年Ian J. Goodfellow等人提出以来,生成对抗网络(GAN,Generative Adversarial Networks)便迅速成为人工智能领域中最有前景的研究方向之一. 而另 ...

  4. 不用卷积,也能生成清晰图像!Transformer再下一城

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 「attention is really becoming『all you need』.」 最 ...

  5. 多态指针访问虚函数不能被继承的类快速排序N皇后问题插入排序堆排序merge归并排序栈上生成对象两个栈实现一个队列...

    多态 /*1. 要想实现覆盖(重写)父类必须声明为virtual,子类可以不声明为virtual.-->FunB()2. 派生类重写基类的虚函数实现多态,要求函数名.参数列表.返回值完全相同.( ...

  6. 开源人工智能使用卷积网格自动编码器生成3D面部

    开源人工智能使用卷积网格自动编码器生成3D面部摘要:人脸的学习3D表示对于计算机视觉问题是有用的,例如3D面部跟踪和从图像重建,以及诸如角色生成和动画的图形应用.传统模型使用线性子空间或高阶张量概括来 ...

  7. 在word文档中如何自动生成目录,两种方法制作目录,总有一种适合你

    在word文档中如何自动生成目录,两种方法制作目录,总有一种适合你 目录 在word文档中如何自动生成目录,两种方法制作目录,总有一种适合你 1.文章中的标题较多,每个单独调整格式太费劲,这里我们用一 ...

  8. php 八字算法,南方排八字专业程序 php不用内置函数对数组排序的两个算法代码...

    一朋友找工作遇到的试题,备注一下. 极有可能今后我也会遇到的. 问题:php不用内置函数对数组排序,可能是降序或者升序 第一种方法:传说中的冒泡法 复制代码 代码如下: function arrays ...

  9. C++生成GUID的两种方法

    C++生成GUID的两种方法 C++生成GUID的两种方法 使用CoCreateGuid函数 使用Boost库 C++生成GUID的两种方法 GUID是软件开发中常用的组件,用于生成唯一的对象,在C# ...

最新文章

  1. Freetype library not found问题解决
  2. 4.vuex学习之getters、mapGetters
  3. JavaScript玩转机器学习:张量(Tensors) 和 操作(operations)
  4. jenkins运行web自动化测试找不到文件file not found
  5. 解决API中无法使用session问题
  6. 记一次网易云解锁灰色音乐代理异常
  7. Redis事务特性分析
  8. Python 玩转数据 19 - 数据操作 正则表达式 Regular Expressions 搜索模式匹配
  9. 模拟登陆115网盘(MFC版)
  10. 16 款基于jQuery的图片缩放效果插件推荐
  11. 三星14纳米EUV DDR5 DRAM量产;Amazfit推出三款智能手表;Whale帷幄获5000万美元融资 | 全球TMT...
  12. 【渝粤教育】电大中专跨境电子商务理论与实务 (12)作业 题库
  13. mysqli_connect(): (HY000/2002):
  14. Intellij IDEA插件--Key Promoter X
  15. NTP时间戳和UTC时间戳互转及其原理
  16. app文件上传到服务器教程,app上传文件到云服务器
  17. 之江汇空间如何加音乐背景_互动课堂的使用|之江汇互动课堂如何使用?之江汇互动课堂使用方法...
  18. Tesla技术方案深度剖析:自动标注/感知定位/决策规划/场景重建/场景仿真/数据引擎...
  19. x64长模式与段的纠葛
  20. 服务器盘符修改不了怎么办,服务器怎么修改盘符

热门文章

  1. int类型存小数 mysql_MySQL面试题-数据类型
  2. 测试dali协议的软件,基于DALI协议的数字照明控制软件的研发
  3. require与include+php,PHP中include与require有什么区别
  4. Alpha预乘-混合与不混合[转]
  5. 数据结构实验之查找七:线性之哈希表
  6. linux下Mysql命令
  7. 指定的命名连接在配置中找不到、非计划用于 EntityClient 提供程序或者无效
  8. HADOOP都升级到2.5啦~~~
  9. poj 1041(欧拉回路+输出字典序最小路径)
  10. 可视化Python设计工具