深度学习之图像分类(二十一)MLP-Mixer网络详解

目录

  • 深度学习之图像分类(二十一)MLP-Mixer网络详解
    • 1. 前言
    • 2. MLP-Mixer 网络结构
    • 3. 总结
    • 4. 代码

继 Transformer 之后,我们开启了一个新篇章,即无关卷积和注意力机制的最原始形态,全连接网络。在本章中我们学习全连接构成的 MLP-Mixer。(仔细发现,这个团队其实就是 ViT 团队…),作为一种“开创性”的工作,挖了很多很多的新坑,也引发了后续一系列工作。也许之后是 CNN、Transformer、MLP 三分天下了。

1. 前言

MLP-Mixer 是谷歌 AI 团队 2021 初的文章,论文为 MLP-Mixer: An all-MLP Architecture for Vision。卷积神经网络(CNN)是计算机视觉的首选模型。 最近,基于注意力的网络(例如Vision Transformer)也变得很流行。 在工作中,研究人员表明,尽管卷积和注意力都足以获得良好的性能,但它们都不是必需的,纯 MLP + 非线性激活函数 + Layer Normalization 也能取得 comparable 的性能,其预训练和推理成本可与最新模型相媲美。看了实验结果,只能说:有钱有钱有钱,常人即使有想法也跑不出来。

为什么要用全连接层有什么好处呢?它的归纳偏置(inductive bias)更低。归纳偏置可以看作学习算法自身在一个庞大的假设空间中对假设进行选择的启发式或者“价值观”。下图展示了机器学习中常用的归纳偏置。CNN 的归纳偏置在于卷积操作,只有感受野内部有相互作用,即图像的局部性特征。时序网络 RNN 的归纳偏置在于时间维度上的连续性和局部性。事实上,ViT 已经开始延续了一直以来想要在神经网络中移除手工视觉特征和归纳偏置的趋势,让模型只依赖于原始数据进行学习。MLP 则更进了一步。原文提到说:One could speculate and explain it again with the difference in inductive biases: self-attention layers in ViT lead to certain properties of the learned functions that are less compatible with the true underlying distribution than those discovered with Mixer architecture

2. MLP-Mixer 网络结构

MLP-Mixer 的其整体思路为:先将输入图片拆分成多个 patches(每个 patch 之间不重叠),通过 Per-patch Fully-connected 层的操作将每个 patch 转换成 feature embedding,然后送入N个Mixer Layer。最后,Mixer 将标准分类头与全局平均池化层配合使用,随后使用 Fully-connected 进行分类。细细看来,其第一步与 ViT 其实是一致的,Mixer Layer 替换了 Transformer Block,最后直接的输出接到全连接层而无需 class token。此外,Mixer 的输出基于输入的信息,因为因为全连接层,所以交换任意两个 token 会得到不同的结果(对应的权重不一样了),所以无需 position embedding。

假设我们有输入图像 224×224×3224 \times 224 \times 3224×224×3,首先我们切 patch,例如长宽都取 32,则我们可以切成 7×7=497 \times 7 = 497×7=49 个 patch,每个 patch 是 32×32×332 \times 32 \times 332×32×3。我们将每个 patch 展平就能成为 49 个 307230723072 维的向量。通过一个全连接层(Per-patch Fully-connected)进行降维,例如 512 维,就得到了 49 个 token,每个 token 的维度为 512。然后将他们馈入 Mixer Layer。

细看 Mixer Layer,Mixer 架构采用两种不同类型的 MLP 层:token-mixing MLP 和 channel-mixing MLP。token-mixing MLP 指的是 cross-location operation,即对于 49 个 512512512 维的 token,将每一个 token 内部进行自融合,将 49 维映射到 49 维,即“混合”空间信息;channel-mixing MLP 指的是 pre-location operation,即对于 49 个 512512512 维的 token,将每一维进行融合,将 512 维映射到 512 维,即“混合”每个位置特征。为了简单实现,其实将矩阵转置一下就可以了。这两种类型的层交替执行以促进两个维度间的信息交互。单个 MLP 是由两个全连接层和一个 GELU 激活函数组成的。此外,Mixer 还是用了跳跃连接(Skip-connection)和层归一化(Layer Norm),这里的跳跃连接其实不是 UNet 中的通道拼接,而是 ResNet 中的残差结构,将输入输出相加;而这里的层归一化与 DenseNet 等网络一致,作用于全连接层前面,进行前归一化(DenseNet 的 BN 层在卷积层前面,也有工作证明 LN 在 Multi-head Self-Attention 前会更好)。
U∗,i=X∗,i+W2σ(W1LayerNorm⁡(X)∗,i),for i=1…CYj,∗=Uj,∗+W4σ(W3LayerNorm⁡(U)j,∗),for j=1…S\begin{aligned} \mathbf{U}_{*, i} &=\mathbf{X}_{*, i}+\mathbf{W}_{2} \sigma\left(\mathbf{W}_{1} \operatorname{LayerNorm}(\mathbf{X})_{*, i}\right), & \text { for } i=1 \ldots C \\ \mathbf{Y}_{j, *} &=\mathbf{U}_{j, *}+\mathbf{W}_{4} \sigma\left(\mathbf{W}_{3} \operatorname{LayerNorm}(\mathbf{U})_{j, *}\right), & \text { for } j=1 \ldots S \end{aligned} U,iYj,=X,i+W2σ(W1LayerNorm(X),i),=Uj,+W4σ(W3LayerNorm(U)j,),fori=1Cforj=1S

为什么作者要采用 pre-location operation 和 cross-location operation 呢?因为信息融合无非就是分为固定位置的信息融合和位置间的信息融合。CNN 的卷积核大小为 K×K×CK \times K \times CK×K×C 其实是同时考虑了这两个,当 K=1K = 1K=1 时则是 pre-location operation,当 C=1C = 1C=1 时对应的 depth-wise convolution 其实就是考虑了 cross-location operation。有 CNN (MobileNet,EfficientNet…) 为了计算复杂度降低采用通道可分离卷积,先做 1×11 \times 11×1,再做 DW 卷积,效果也可,说明他们可以不一起做。所以 MLP-Mixer 也干脆将这两个正交开来处理。

注意到:这里的 channel-mixing MLP 其实就是卷积神经网络中的 1×11 \times 11×1 卷积,卷积核的个数和输入通道数一致。此外 token-mixing MLP 其实就是卷积神经网络中的 padding = 0 的超大型卷积,将 512 维映射到 512 维的全连接层对应的就是 512 个 32×3232 \times 3232×32 的卷积层。只不过在 Mixer Layer 中 49 个 token 是共用这个全连接层,也就是如果换成卷积层的话是 single-channel depth-wise convolutions(这个和组卷积还不一样,49个组卷积,每个组内有 512 个 32×3232 \times 3232×32 的卷积层,这样组间的卷积核可以不一样,但是 MLP-Mixer 中是同一个全连接层,所以组卷积每个组的卷积核是一样的)。所以作者说:MLP-Mixer 是 CNN 的特例,但 CNN 不是 MLP-Mixer 的特例

MLP-Mixer 有多个 Mixer Layer 拼接,Mixer 中的每一层 (除了初始块投影层外) 采用相同尺寸的输入(token 维度不变),这种“各向同性”设计类似于Transformer 和 RNN 中定宽;这与 CNN 中金字塔结构 (越深的层具有更低的分辨率、更多的通道数) 不同。

至于最后的全局池化和全连接,则是相对普通的操作了,按照 Channel 对每个 token 进行池化,最后得到 Channel 维度的向量,然后经过一个全连接层输出目标分类分数即可。

最终 MLP-Mixer 的网络配置如下表所示:

3. 总结

对于 MLP-Mixer 的一些反思:CV 和 Transformer 已经在目标检测和分割上展现了出色的性能,但是 MLP-Mixer 是否可以为分割、识别等方向提供太大的帮助呢?暂时还不得而知,因为 MLP-Mixer 的特征对分类有用,但是还有待下游其他任务验证。此外 MLP-Mixer 计算量很大,对于增大图像输入分辨率不太友好。还有一个重要的是,MLP 思想虽然看起来很哇塞,但是残差结构和 LN 的 block 设计感觉效用也很大,如果没有他们 MLP-Mixer 还能 work 吗?

对于 MLP-Mixer 的一些展望:MLP-Mixer 能不能也和 TNT 一样增加 patch 内的信息交流呢?虽然 token-mixing MLP 可以被视为 patch 内的信息交流,但是这也可以看起来和 Transformer Block 在 Multi-Head Self-Attention 之后的映射一样的,所以增加 TNT 一样增加 patch 内的信息交流应该会有帮助。CNN + MLP-Mixer 会比 CNN + Transformer 更好吗?我个人暂时看起来不一定,因为无论从可解释性还是训练速度等角度看起来还是持保有态度。

MLP-Mixer 有很多微弱修改和改进,例如 RepMLP,ResMLP,gMLP,CycleMLP 等有待后续学习。

关于实验部分,推荐大家去浏览原文,这里不做过多的描述。最后依然是一贯认识:只要算力和数据够,一切就都有可能

4. 代码

代码出处见 此处。

import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduceclass PreNormResidual(nn.Module):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.LayerNorm(dim)def forward(self, x):return self.fn(self.norm(x)) + xdef FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear):return nn.Sequential(dense(dim, dim * expansion_factor),nn.GELU(),nn.Dropout(dropout),dense(dim * expansion_factor, dim),nn.Dropout(dropout))def MLPMixer(*, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor = 4, dropout = 0.):assert (image_size % patch_size) == 0, 'image must be divisible by patch size'num_patches = (image_size // patch_size) ** 2chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linearreturn nn.Sequential(Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),nn.Linear((patch_size ** 2) * channels, dim),*[nn.Sequential(PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last))) for _ in range(depth)],nn.LayerNorm(dim),Reduce('b n c -> b c', 'mean'),nn.Linear(dim, num_classes))model = MLPMixer(image_size = 256,channels = 3,patch_size = 16,dim = 512,depth = 12,num_classes = 1000
)img = torch.randn(1, 3, 256, 256)
pred = model(img) # (1, 1000)

深度学习之图像分类(二十一)-- MLP-Mixer网络详解相关推荐

  1. Keras深度学习实战(22)——生成对抗网络详解与实现

    Keras深度学习实战(22)--生成对抗网络详解与实现 0. 前言 1. 生成对抗网络原理 2. 模型分析 3. 利用生成对抗网络生成手写数字图像 小结 系列链接 0. 前言 生成对抗网络 (Gen ...

  2. 笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解

    笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解 针对特定场景任务从模型选择.模型训练.超参优化.效果展示这四个方面进行模型开发. 一.模型选择 从任务类型出发,选择最合适的模型. ...

  3. 深度学习模型训练和关键参数调优详解

    深度学习模型训练和关键参数调优详解 一.模型选择 1.回归任务 人脸关键点检测 2.分类任务 图像分类 3.场景任务 目标检测 人像分割 文字识别 二.模型训练 1.基于高层API训练模型 加载数据集 ...

  4. python数据挖掘课程】二十一.朴素贝叶斯分类器详解及中文文本舆情分析

    #2018-04-06 13:52:30 April Friday the 14 week, the 096 day SZ SSMR python数据挖掘课程]二十一.朴素贝叶斯分类器详解及中文文本舆 ...

  5. Keras深度学习实战(26)——文档向量详解

    Keras深度学习实战(26)--文档向量详解 0. 前言 1. 文档向量基本概念 2. 神经网络模型与数据集分析 2.1 模型分析 2.2 数据集介绍 3. 利用 Keras 构建神经网络模型生成文 ...

  6. 【深度学习】扩散模型(Diffusion Model)详解

    [深度学习]扩散模型(Diffusion Model)详解 文章目录 [深度学习]扩散模型(Diffusion Model)详解 1. 介绍 2. 具体方法 2.1 扩散过程 2.2 逆扩散过程 2. ...

  7. (二十一)状态模式详解(DOTA版) - 转

    作者:zuoxiaolong8810(左潇龙),转载请注明出处. 本次LZ给各位介绍状态模式,之前在写设计模式的时候,引入了一些小故事,二十章职责连模式是故事版的最后一篇,之后还剩余四个设计模式,LZ ...

  8. Keras深度学习实战(21)——神经风格迁移详解

    Keras深度学习实战(21)--神经风格迁移详解 0. 前言 1. 神经风格迁移原理 2. 模型分析 3. 使用 Keras 实现神经风格迁移 小结 系列链接 0. 前言 在 DeepDream 图 ...

  9. 二、代码实现深度学习道路训练样本数据的制作(代码部分详解)——重复工作+多次返工的血泪史

    使用python读取文件夹对图片进行批量裁剪 通过第一部分操作arcgis制作了一部分样本数据 分辨率与原相片保持一致 为6060*6060 具体如图所示: 而我们深度学习一般使用的分辨率是1024和 ...

  10. 深度学习之卷积神经网络CNN理论与实践详解

    六月 北京 | 高性能计算之GPU CUDA培训 6月22-24日三天密集式学习  快速带你入门阅读全文> 正文共1416个字,6张图,预计阅读时间6分钟. 概括 大体上简单的卷积神经网络是下面 ...

最新文章

  1. Redis 高级特性(1)—— 事务 过期时间 排序
  2. Adam 又要“退休”了?耶鲁大学团队提出 AdaBelief,却引来网友质疑
  3. 矩阵的终极分解-奇异值分解 SVD
  4. linux 定时器_定时器: Nodejs 中的 timers
  5. pr抖动插件_2020最全的8000多款PR插件合集,一键安装
  6. word List 17
  7. VS的包含目录、库目录、引用目录、可执行目录解释
  8. Mac OSX 安装nvm(node.js版本管理器)
  9. 红橙Darren视频笔记 IOC注解框架 了解xUtils3与ButterKnife的原理
  10. [Luogu] 程序自动分析
  11. 学术资源不定期分享-【钱学森《工程控制论》英文原版】
  12. vc linux 中文版下载64位,vc2015运行库64位-vc++ 2015运行库64位下载 v14.0.23026官方版--pc6下载站...
  13. 清华大学计算机信息学院舒教授,清华大学出计算机与信息分社.ppt
  14. 网易云音乐推荐中的用户行为序列深度建模
  15. C++ 实现计算24点
  16. Android LayoutInflater源码分析
  17. 机器学习:二分类到多分类-ovr,ovo,mvm,sofmax
  18. 使用STM32CubeMX 快速生成 USB HID 工程 - STM32F107VCT6
  19. php 购物车 原理及实现,纯干货丨PHP实现购物车的构建
  20. 贴片电阻0603、1206之间的区别是什么

热门文章

  1. MSB3644 找不到 .NETFramework,Version=v4.7 的引用程序集。要解决此问题,请为此框架版本安装......
  2. 伍斯特理工学院计算机硕士怎么样,伍斯特理工学院硕士怎么样?
  3. 南京大学2019计算机学院复试名单,南京大学计算机科学与技术系2019考研复试名单...
  4. 新手兼职也能月入5000的副业项目,几乎零门槛
  5. Notepad++删除空白行
  6. Matlab导入数据(一定有用!!)
  7. FID指标复现踩坑避坑 文本生成图像FID定量实验全流程复现(Fréchet Inception Distance )定量评价实验踩坑避坑流程
  8. Ubuntu下将TinyOS移植到CC2430芯片
  9. c语言汉字utf8,C语言汉字gbk转utf-8
  10. UE4(VR)中3D世界内的UI模糊问题解决