文章目录

  • 前言
  • 一、MLP-Mixer原理介绍
    • 1.网络结构
    • 2. Mixer Layer
    • 3.MLP-Mixer模型类型
  • 二、网络实现
    • 1.Pytorch复现
    • 2.Keras 复现
  • 总结

前言

MLP-Mixer,Google又提出的一种基于感知机的网络。尽管CNN 已经在计算机视觉上取得很好的效果,最近提出来的基于Attention,以Vision Transformer 为首的神经网络已经 “杀疯” CV界,但是Google的大佬们认为CNN 和Attention 也不是必须的,于是就提出了MLP-Mixer,在分类任务上也达到了很好的效果。但是网络的提出却早到了CNN之父LeCun的“教育”。因为网络的第一层(embedding时)却用到了卷积。

LeCun认为,这不过是 一个卷积核为1x1的卷积网络罢了。我在复现过程中,没有找到一个合适的数据集用该网络取得一个很好的效果。所以这里不附上训练的代码了。

项目 链接
论文 链接

(仅torch, 可训练,eval top1 and top5):https://github.com/jiantenggei/torch-classification
博客中给出的就是网络结构的全部代码。

一、MLP-Mixer原理介绍

在介绍MLP -Mixer 之前,先引入一下 ViT 结构,如下图所示:

图中Transformer Encoder 的红色方框中的MLP 结构(如下图所示),就是MLP -Mixer的结构。

对于MLP -Mixer 的提出,Google方面的大佬只是为了验证MLP 也能做很多事情。
但这篇文章 Do You Even Need Attention? A Stack of Feed-Forward Layers DoesSurprisingly Well on ImageNet 为了探寻ViT中起作用的部分是 Attention 还是MLP, 直接删除到ViT中的Attention部分做实验,发现效果也还行,网络如下图所示。它与本文介绍MLP-Mixer的不同点是,它保留了ViT 中Class Token 部分(粉红色方块) 最后用Class Token 来做预测


MLP- Mixer 是通过一个全局池化后连接全连接层来做预测。

1.网络结构


从上图我们可以看出,MLP -Mixer 首先使用图片分成很多个小正方形的patch,每个patch的大小定义为patch_size。论文中实现这一步骤使用的是前面提到的卷积,卷积核的大小和步长均patch_size。论文中给的参数,也是2的幂。
网络不再使用传统的RELU激活函数,而是使用了GELU激活函数。
将图片分成小块后,在将它转换为一维结构。如图:

然后将每一个patch进行转换,如下图所示:

通过这样一种方式呢,就将一张图片转换为了一个大矩阵,就可以输入到Mixer Layer 中进行计算。

2. Mixer Layer

MixerLayer的结构如下图所示:

我们看一下论文里给出的公式:

MLP 是两个全连接层的感知机,W1,W2,对应token_mixer中两个全连接的权重,W3,W4则表示channel_mixer两个全连接的权重。σ表示GELU激活函数。那么公示就很简单了,输入X经过Layer Normalize,再乘以W1,再经过激活函数后乘以W2,再加上X。第二个公式也是相同的计算过程。
将前面通过编码得到的矩阵经过Layer Norm 在将矩阵进行旋转(T 表示旋转)连接MLP1,MLP1 就是文章token_mixer 用来寻找像素与像素之间的关系,其中,MLP1中的权值共享。计算完之后,再将矩阵旋转回来,通过Layer Norm 后再接一个channel_mixer 用于寻找通道与通道之间的关系。其中MixerLayer 还启用了ResNet中的跨连结构,跨连结构的作用可以参考[ResNet原理讲解和复现],看到这里,是不是感觉它跟卷积的原理很类似。
从上图可以看出Mixer Layer的输入维度和输出维度相同,并且通过MLP的方式来寻找图片像素与像素,通道与通道的关系。
这就是MLP-MIXER的网络结构了,目前的了解,没有开源的pytorch或者TensorFlow 预训练的权重。官方给出的代码和权重是基于JAX的。

由于需要Patch embedding的网络 对于图片大小的依赖高,所以一般很难使用官方的权重进行迁移学习,如果想使用到自己的任务中,建议使用一个较大的数据集先预训练一下

3.MLP-Mixer模型类型

文章中给出的模型参数列表,Patch resolution 就是patch 的长宽。Hidden size 就是映射成前面提到的大矩阵的维度,Squence length 是计算后的结果。
以上图红色部分为例,输入图像大小为224*224,
然后分成的块大小为32*32.
那么 (224*224)\(32*32)=7*7,Squence length 就为 49。Dc和Ds分别表示token_mixer和channel_mixer 中全连接层节点的个数。
这就是MLP-Mixer的全部过程了。

二、网络实现

1.Pytorch复现

实现的难点在于,矩阵旋转,我们使用einops中的Rearrange实现矩阵旋转。还需要使用torchsummary 来查看网络结构。安装:

pip install einops
pip install torchsummary

首先我们来实现MLP 也就是FeedForward:

#定义多层感知机
import torch
import numpy as np
from torch import nn
from einops.layers.torch import Rearrange
from torchsummary import summary
import torch.nn.functional as F
class FeedForward(nn.Module):def __init__(self,dim,hidden_dim,dropout=0.):super().__init__()self.net=nn.Sequential(#由此可以看出 FeedForward 的输入和输出维度是一致的nn.Linear(dim,hidden_dim),#激活函数nn.GELU(),#防止过拟合nn.Dropout(dropout),#重复上述过程nn.Linear(hidden_dim,dim),nn.Dropout(dropout))def forward(self,x):x=self.net(x)return x
#测试多层感知机
# mlp=FeedForward(10,20,0.4).to(device)
# summary(mlp,input_size=(10,))

实现过程很简单,就是全连接结构
接着我们来实现Mixer Block,里面包含了 token_mixer 和channel_mixer,还有矩阵转置。

#使用Rearrange 实现旋转
Rearrange('b n d -> b d n') #这里是[batch_size, num_patch, dim] -> [batch_size, dim, num_patch]

实现如下:

class MixerBlock(nn.Module):def __init__(self,dim,num_patch,token_dim,channel_dim,dropout=0.):super().__init__()self.token_mixer=nn.Sequential(nn.LayerNorm(dim),Rearrange('b n d -> b d n'),FeedForward(num_patch,token_dim,dropout),Rearrange('b d n -> b n d'))self.channel_mixer=nn.Sequential(nn.LayerNorm(dim),FeedForward(dim,channel_dim,dropout))def forward(self,x):x=x+self.token_mixer(x)x=x+self.channel_mixer(x)return x#测试mixerblock
# x=torch.randn(1,196,512)
# mixer_block=MixerBlock(512,196,32,32)
# x=mixer_block(x)
# print(x.shape)

更具上述定义好的网络零件,就可以实现我们最终的主网络mlp-mixer:

class MLPMixer(nn.Module):def __init__(self,in_channels,dim,num_classes,patch_size,image_size,depth,token_dim,channel_dim,dropout=0.):super().__init__()assert image_size%patch_size==0self.num_patches=(image_size//patch_size)**2#embedding 操作,看见没用卷积来分成一小块一小块的self.to_embedding=nn.Sequential(         Conv2d(in_channels=in_channels,out_channels=dim,kernel_size=patch_size,stride=patch_size),Rearrange('b c h w -> b (h w) c'))self.mixer_blocks=nn.ModuleList([])for _ in range(depth):self.mixer_blocks.append(MixerBlock(dim,self.num_patches,token_dim,channel_dim,dropout))self.layer_normal=nn.LayerNorm(dim)self.mlp_head=nn.Sequential(nn.Linear(dim,num_classes))def forward(self,x):x=self.to_embedding(x)for mixer_block in self.mixer_blocks:x=mixer_block(x)x=self.layer_normal(x)x=x.mean(dim=1)x=self.mlp_head(x)return x
#测试Mlp-Mixer
if __name__ == '__main__':    model = MLPMixer(in_channels=3, dim=512, num_classes=1000, patch_size=16, image_size=224, depth=1, token_dim=256,channel_dim=2048).to(device)summary(model,(3,224,224))

depth=1,网络深度等于1 这样方便我们看整体结构,最后的运行结果:

2.Keras 复现

keras的复现过程与pytorch类似,但有几个注意的地方,网络中使用的GELU 激活函数在tensorflow>=2.4才可以使用,conda 国内镜像很难安装tensorflow 2.4 以上的GPU 版本。如果要使用的话tensorflow2.4一下版本的话,自定义GELU激活函数如下:

def gelu(x):cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))return x * cdf

在keras中使用类的方法定义自己的网络层时,需要重写 get_config 函数 不然模型无法保存。keras中借助Permute 层实现转置。
keras的完整实现如下:

from tensorflow import keras
import tensorflow as tf
import numpy as np
from keras import backend as K
from tensorflow.keras.layers import (Add,Dense,Conv2D,GlobalAveragePooling1D,Flatten,Layer,LayerNormalization,Permute,Softmax,Activation,
)
class MlpBlock(Layer):def __init__(self,dim: int,hidden_dim: int,activation=None,**kwargs):super(MlpBlock, self).__init__(**kwargs)if activation is None:activation = keras.activations.geluself.dim = dimself.hidden_dim = dimself.dense1 = Dense(hidden_dim)self.activation = Activation(activation)self.dense2 = Dense(dim)def call(self, inputs):x = inputsx = self.dense1(x)x = self.activation(x)x = self.dense2(x)return xdef compute_output_shape(self, input_signature):return (input_signature[0], self.dim)def get_config(self):config = super(MlpBlock, self).get_config()config.update({'dim': self.dim,'hidden_dim': self.hidden_dim})return configclass MixerBlock(Layer):def __init__(self,num_patches: int,channel_dim: int,token_mixer_hidden_dim: int,channel_mixer_hidden_dim: int = None,activation=None,**kwargs):super(MixerBlock, self).__init__(**kwargs)self.num_patches = num_patchesself.channel_dim = channel_dimself.token_mixer_hidden_dim = token_mixer_hidden_dimself.channel_mixer_hidden_dim = channel_mixer_hidden_dimself.activation = activationif activation is None:self.activation = keras.activations.geluif channel_mixer_hidden_dim is None:channel_mixer_hidden_dim = token_mixer_hidden_dimself.norm1 = LayerNormalization(axis=1)self.permute1 = Permute((2, 1))self.token_mixer = MlpBlock(num_patches, token_mixer_hidden_dim, name='token_mixer')self.permute2 = Permute((2, 1))self.norm2 = LayerNormalization(axis=1)self.channel_mixer = MlpBlock(channel_dim, channel_mixer_hidden_dim, name='channel_mixer')self.skip_connection1 = Add()self.skip_connection2 = Add()def call(self, inputs):x = inputsskip_x = xx = self.norm1(x)x = self.permute1(x)x = self.token_mixer(x)x = self.permute2(x)x = self.skip_connection1([x, skip_x])skip_x = xx = self.norm2(x)x = self.channel_mixer(x)x = self.skip_connection2([x, skip_x])  # TODO need 2?return xdef compute_output_shape(self, input_shape):return input_shapedef get_config(self):config = super(MixerBlock, self).get_config()config.update({'num_patches': self.num_patches,'channel_dim': self.channel_dim,'token_mixer_hidden_dim': self.token_mixer_hidden_dim,'channel_mixer_hidden_dim': self.channel_mixer_hidden_dim,'activation': self.activation,})return configdef MlpMixerModel(input_shape: int,num_classes: int,num_blocks: int,patch_size: int,hidden_dim: int,tokens_mlp_dim: int,channels_mlp_dim: int = None,use_softmax: bool = False,
):height, width, _ = input_shapeif channels_mlp_dim is None:channels_mlp_dim = tokens_mlp_dimnum_patches = (height*width)//(patch_size**2)  # TODO verify how this behaves with same paddinginputs = keras.Input(input_shape)x = inputsx = Conv2D(hidden_dim,kernel_size=patch_size,strides=patch_size,padding='same',name='projector')(x)x = keras.layers.Reshape([-1, hidden_dim])(x)for _ in range(num_blocks):x = MixerBlock(num_patches=num_patches,channel_dim=hidden_dim,token_mixer_hidden_dim=tokens_mlp_dim,channel_mixer_hidden_dim=channels_mlp_dim)(x)x = Flatten()(x)  # TODO verify this global average pool is correct choice herex = LayerNormalization(name='pre_head_layer_norm')(x)x = Dense(num_classes, name='head')(x)if use_softmax:x = Softmax()(x)return keras.Model(inputs, x)
  1. keras 配置训练代码:
    训练过程参数可以自行调试:
#学习率调试,首先我们设置一个较小的学习率 查看loss的变化情况 使用Tensorboard记录下来
import tensorflow as tf
from keras_preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.callbacks import ModelCheckpointfrom tensorflow.keras.callbacks import (EarlyStopping, ReduceLROnPlateau,TensorBoard)
from tensorflow.keras.optimizers import Adamdef train():log_dir = './log' #训练日志路劲train_dataset_path=r"your_tarin_data_path" #分类训练数据集路径test_dataset_path=r"your_test_data_path" #分类测试集路径batch_size = 64# 加载数据集lr= 1e-3epochs=20num_classes=1000 #你的分类数train_datagen = ImageDataGenerator( #数据集增强,这些参数查阅keras 官方文档 我前面的博客VGG 中 说明过也有介绍说rescale=1 / 255.0,rotation_range=20,zoom_range=0.05,width_shift_range=0.05,height_shift_range=0.05,shear_range=0.05,horizontal_flip=True,fill_mode="nearest",)train_generator = train_datagen.flow_from_directory(directory=train_dataset_path,target_size=(224, 224),color_mode="rgb",batch_size=batch_size,class_mode="categorical",shuffle=True,seed=42)test_datagen = ImageDataGenerator(rescale=1 / 255.0,)valid_generator = test_datagen.flow_from_directory(directory=test_dataset_path,target_size=(224, 224),color_mode="rgb",batch_size=batch_size,shuffle=True,seed=42)#你的模型,模型参数自己调试mlp_mixer_base = MlpMixerModel(input_shape=(224, 224, 3),num_classes=num_classes, num_blocks=2, patch_size=16,hidden_dim=64, tokens_mlp_dim=32,channels_mlp_dim=64,use_softmax=True)mlp_mixer_base.summary()
#     training_weights='./weights'  #这里是保存每次训练权重的  如果需要自己取消注释
#     checkpoint_period = ModelCheckpoint(training_weights + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
#                                         monitor='val_loss', save_weights_only=True, save_best_only=False, period=3)reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=4, verbose=1) #学习率衰减early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1) # val_loss 不下降时 停止训练 防止过拟合tensorboard = TensorBoard(log_dir=log_dir)  #训练日志optimizer=tf.keras.optimizers.Adam(learning_rate=lr)mlp_mixer_base.compile(loss=tf.keras.losses.categorical_crossentropy, metrics='acc',optimizer=optimizer)mlp_mixer_base.fit(train_generator,validation_data=valid_generator,epochs=epochs,callbacks=[tensorboard, reduce_lr, early_stopping])mlp_mixer_base.evaluate(valid_generator,verbose=1)mlp_mixer_base.save('./mlp_mixer_base.h5')
if __name__ == '__main__':train()

总结

代码复现过程,参考了论文地址给出的GitHub的链接,自己手撕代码的能力还是比较弱,不足之处就是没有使用数据集去训练它并达到一个不错的效果。后续如果有结果,会更新这篇博客,在训练模型时,调参过程是真的累,应该还是缺少理论知识的原因。后续会学习如何进行调参。
创作不易,点赞鼓励。
最后说一句,pytorch真香~~~~~~,前面的博客LeNet ,AlexNet和VGG已添加pytorch实现。
最新文章 ConvMixer认为 需要Patch embedding (图片分块)的网络表现之所以如此优越,是因为Patch embedding 操作就能完成神经网络的所有下采样过程,降低了图片的分辨率,增加了感受野,更容易找到远处的空间信息。从而模型表现良好

浅谈 Mlp-Mixer(pytorch and keras)相关推荐

  1. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

  2. python配置核_浅谈pytorch卷积核大小的设置对全连接神经元的影响

    3*3卷积核与2*5卷积核对神经元大小的设置 #这里kerner_size = 2*5 class CONV_NET(torch.nn.Module): #CONV_NET类继承nn.Module类 ...

  3. gather torch_浅谈Pytorch中的torch.gather函数的含义

    pytorch中的gather函数 pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验. 立个flag开始学习pytorch,新开一个分类整理学习pyt ...

  4. python模型保存save_浅谈keras保存模型中的save()和save_weights()区别

    今天做了一个关于keras保存模型的实验,希望有助于大家了解keras保存模型的区别. 我们知道keras的模型一般保存为后缀名为h5的文件,比如final_model.h5.同样是h5文件用save ...

  5. pytorch保存模型pth_浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式

    我们经常会看到后缀名为.pt, .pth, .pkl的pytorch模型文件,这几种模型文件在格式上有什么区别吗? 其实它们并不是在格式上有区别,只是后缀不同而已(仅此而已),在用torch.save ...

  6. 浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现)

    浅谈深度学习:LSTM对股票的收益进行预测(Sequential 序贯模型,Keras实现) 总包含文章: 一个完整的机器学习模型的流程 浅谈深度学习:了解RNN和构建并预测 浅谈深度学习:基于对LS ...

  7. python的matmul_浅谈keras中的batch_dot,dot方法和TensorFlow的matmul

    概述 在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如A,B,C,D,E,F,G,H,I,J,K,L都是二维矩阵,中间点表示 ...

  8. 浅谈Service Mesh体系中的Envoy

    摘要: 提到Envoy就不得不提Service Mesh,说到Service Mesh就一定要谈及微服务了,那么我们就先放下Envoy,简单了解下微服务.Service Mesh以及Envoy在Ser ...

  9. [深度学习-原理]浅谈Attention Model

    系列文章目录 深度学习NLP(一)之Attention Model; 深度学习NLP(二)之Self-attention, Muti-attention和Transformer; 深度学习NLP(三) ...

  10. 经验 | 清华大学计算机系教授~浅谈研究生学位论文选题方法

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 编辑:Sophia 计算机视觉联盟  报道  | 公众号 CVLianMeng 转载于 :清华大学,专知 AI博士笔 ...

最新文章

  1. 独家 | 手把手教你在试验中修正机器学习模型(附学习资源)
  2. mysql多表查询书籍_MySQL多表查询及子查询
  3. 反射中 BindingFlags标识
  4. 这个瑞士的项目没法在今年三月申请,因为我需要护照并且这个项目在人文社科学员下面,886
  5. 数字图像处理基础与应用 第四章
  6. 端口扫描 -- Masscan-Gui
  7. Java基础面试题40题
  8. GET请求淘宝H5页面获取商品信息
  9. 机器学习可解释性之shap模块的使用——基础用法(一)
  10. java.sql.SQLException: Incorrect string value: '\xF0\x9F\x90\x94
  11. echarts柱状图加上渐变色报错问题
  12. 【安全】【信息搜集】Google Hacking
  13. LeetCode——89.格雷编码
  14. Qt(c++)调用python一直报错slot、hypot等
  15. Linux驱动移植USB网卡r8156驱动(详细)总结
  16. HP LaserJet 1010 打印机 - 在 Win 7 下安装驱动
  17. 若依配置教程(九)若依前后端分离版部署到服务器Nginx(Windows版)
  18. 下载Linux ISO镜像的方法 (带你快速了解)
  19. 正则表达式(手机号前带区号)
  20. 【分享】解读时间同步(NTP网络授时服务器)的重要性

热门文章

  1. matlab的textscan与textread区别(转)
  2. 混淆电路的优化:PP、Free XOR、GRR
  3. Riemann问题精确解及程序实现
  4. word公式常用字体
  5. 腾讯云CDN常见问题
  6. overleaf插入参考文献
  7. 51单片机制作简易计算器(动态数码管、矩阵按键)
  8. SpringBoot+VUE项目启动方式
  9. 笔记本打印时出现打印机出现异常配置问题_笔记本电脑连接共享打印机出现错误怎么办...
  10. Chrome浏览器语音自动播放功能