推荐系统(十五)多任务学习:谷歌MMoE(Multi-gate Mixture-of-Experts )

推荐系统系列博客:

  1. 推荐系统(一)推荐系统整体概览
  2. 推荐系统(二)GBDT+LR模型
  3. 推荐系统(三)Factorization Machines(FM)
  4. 推荐系统(四)Field-aware Factorization Machines(FFM)
  5. 推荐系统(五)wide&deep
  6. 推荐系统(六)Deep & Cross Network(DCN)
  7. 推荐系统(七)xDeepFM模型
  8. 推荐系统(八)FNN模型(FM+MLP=FNN)
  9. 推荐系统(九)PNN模型(Product-based Neural Networks)
  10. 推荐系统(十)DeepFM模型
  11. 推荐系统(十一)阿里深度兴趣网络(一):DIN模型(Deep Interest Network)
  12. 推荐系统(十二)阿里深度兴趣网络(二):DIEN模型(Deep Interest Evolution Network)
  13. 推荐系统(十三)阿里深度兴趣网络(三):DSIN模型(Deep Session Interest Network)
  14. 推荐系统(十四)多任务学习:阿里ESMM(完整空间多任务模型)

阿里的ESMM模型似乎是为广告ctr和cvr专门量身打造的,在ESMM模型结构中,两个塔具有明确的依赖关系。再早之前的MTL模型中,基本都是N个塔共享底座embedding,然后不同的任务分不同的塔,这种模式需要这些塔之间具有比较强的相关性,不然性能就很差,甚至会发生 『跷跷板』现象,即一个task性能的提升是通过损害另一个task性能作为代价换来的。因此,如果两个task都有足够的数据量,这种共享底座embedding的多塔设计的性能并没有分开单独建模效果来得好,原因是几乎必然出现负迁移(negative transfer)和“跷跷板”现象。因此在实际应用中,并不要盲目的为了MTL而MTL,这样只会弄巧成拙。(以上加黑部分为个人见解,欢迎讨论)

但如果又有多个目标,多个tower之间的相关性并不是很强,比如,CTR、点赞、时长、完播、分享等,并且有的目标的数据量并不是很足够,甚至无法单独训练一个DNN(当然,你如果说我单独建模用xgb,那我无话可说),在这种情况下,我们可能就要考虑MTL了,这时候MMoE就可以派上用场了。值得一提的是,MMoE是谷歌发表在KDD’18上的,和阿里的ESMM同年发表,所以相互之间应该独立的两个工作。

这篇博客将会从以下几个方面介绍MMoe(Multi-gate Mixture-of-Experts):

  1. 动机
  2. MMoe模型结构
  3. MMoe代码实现
  4. 总结

一、动机

说到动机,自然就要先说当前现状存在的问题。目前在MTL领域存在的问题:

  1. 工业界真实场景下,多个任务之间的相关性并不是很强,这个时候如果再用过去那种共享底座embedding的结构,往往会导致『跷跷板』现象。
  2. 当前学术界已经有很多工作意识到1中描述的问题并且尝试去解决,但大多数工作的套路都是『大力出奇迹』的路子,即加很多可学的参数去学习多个任务之间的difference,这在学术界跑跑数据,写写论文倒是没什么,但是在工业界场景下,增加这些参数会导致线上做infer时耗时增加,导致模型服务可用性大大下降,这是无法接受的。

二、MMoe模型结构

关于MMoe的模型结构,先上个论文中的原图:

图1. MMoE整体网络结构

图1中(a)展示了传统的MTL模型结构,即多个task共享底座(一般都是embedding向量),(b)则是论文中提到的一个gate的Mixture-of-Experts模型结构,(c)则是论文中的MMoE模型结构。

我们重点来看下MMoE结构,也就是图1 ( c ),这里每一个expert和gate都是一个全连接网络(MLP),层数由在实际的场景下自己决定。下面上一个我画的详细版本的MMoE模型结构图,有了这个图,公式都不用了,直接对着图代码实现就可以了。

图2. MMoE模型细节版

注:GateB那部分没画出来,因为画出来会显得很乱,参考GateA即可。

从上图2中,我们可以出来几个细节(一定要仔细看,非常重要!!!):

  1. Gate网络的数量取决于task数量,即与task数量相同。Gate网络最后一层全连接层的隐藏单元(即输出)size必须等于expert个数。另外,Gate网络最后的输出会经过softmax进行归一化。
  2. Gate网络最后一层全连接层经过softmax归一化后的输出,对应作用到每一个expert上(图2中GateA输出的红、紫、绿三条线分别作用与expert0,expert1,expert2),注意是通过广播机制作用到expert中的每一个隐藏单元,比如红线作用于expert0的2个隐藏单元。这里gate网络的作用非常类似于attention机制,提供了权重。
  3. 假设GateA的输出为[GA1,GA2,GA3][GA_1, GA_2, GA_3][GA1,GA2,GA3],expert0的输出为[E01,E02][E0_1, E0_2][E01,E02],expert1的输出为[E11,E12][E1_1, E1_2][E11,E12],expert2的输出为[E21,E22][E2_1, E2_2][E21,E22]。GateA分别与expert0、expert1、expert2作用,得到[GA1∗E01,GA1∗E02],[GA2∗E11,GA2∗E12],[GA3∗E21,GA3∗E22][GA_1*E0_1, GA_1*E0_2], [GA_2*E1_1, GA_2*E1_2], [GA_3*E2_1, GA_3*E2_2][GA1E01,GA1E02],[GA2E11,GA2E12],[GA3E21,GA3E22],然后对应位置求和得到towerA的输入,即towerA的输入size等于expert输出隐藏单元个数(在这个例子中,expert最后一层全连接层隐藏单元个数为2,因此towerA的输入维度也为2),所以towerA的输入为[GA1∗E01+GA2∗E11+GA3∗E21,GA1∗E02+GA2∗E12+GA3∗E22][GA_1*E0_1+GA_2*E1_1+GA_3*E2_1, GA_1*E0_2+GA_2*E1_2+GA_3*E2_2][GA1E01+GA2E11+GA3E21,GA1E02+GA2E12+GA3E22]
  4. expert每个网络的输入特征都是一样的,其网络结构也是一致的。
  5. 两个gate网络的输入也是一样的,gate网络结构也是一样的。

一直觉得举例子画图胜过任何繁琐复杂的解释,有了上面那个例子,相信大家基本上看完一遍就理解整个MMoE的精髓了。

三、MMoe代码实现

在实现的时候,expert网络和gate网络的全连接层数依据自己的实际场景设置,个人建议不要设置太深,通常2层就足够。
paddle给出了代码实现(但paddle这里犯了一个致命错误,详情参见我提的issue:关于MMoe网络一些疑问),paddle代码参见:
paddle MMoE

我这里给加了详细的注释,方便大家理解:

import paddle
import paddle.nn as nn
import paddle.nn.functional as Fclass MMoELayer(nn.Layer):def __init__(self, feature_size, expert_num, expert_size, tower_size,gate_num):super(MMoELayer, self).__init__()"""feature_size: 499"""self.expert_num = expert_num  # 8self.expert_size = expert_size  # 16self.tower_size = tower_size  # 8self.gate_num = gate_num  # 2self._param_expert = []for i in range(0, self.expert_num):# shape(499, 16)linear = self.add_sublayer(name='expert_' + str(i),sublayer=nn.Linear(feature_size,expert_size,weight_attr=nn.initializer.Constant(value=0.1),bias_attr=nn.initializer.Constant(value=0.1),#bias_attr=paddle.ParamAttr(learning_rate=1.0),name='expert_' + str(i)))# print("linear: ", linear.weight)self._param_expert.append(linear)self._param_gate = []self._param_tower = []self._param_tower_out = []# gate_num=2for i in range(0, self.gate_num):# shape(499, 8)linear = self.add_sublayer(name='gate_' + str(i),sublayer=nn.Linear(feature_size,expert_num,weight_attr=nn.initializer.Constant(value=0.1),bias_attr=nn.initializer.Constant(value=0.1),#bias_attr=paddle.ParamAttr(learning_rate=1.0),name='gate_' + str(i)))self._param_gate.append(linear)# shape(16, 8)linear = self.add_sublayer(name='tower_' + str(i),sublayer=nn.Linear(expert_size,tower_size,weight_attr=nn.initializer.Constant(value=0.1),bias_attr=nn.initializer.Constant(value=0.1),#bias_attr=paddle.ParamAttr(learning_rate=1.0),name='tower_' + str(i)))self._param_tower.append(linear)# shape(8, 2)linear = self.add_sublayer(name='tower_out_' + str(i),sublayer=nn.Linear(tower_size,2,weight_attr=nn.initializer.Constant(value=0.1),bias_attr=nn.initializer.Constant(value=0.1),name='tower_out_' + str(i)))self._param_tower_out.append(linear)def forward(self, input_data):"""input_data: Tensor(shape=[2, 499], 2--> batchsize"""expert_outputs = []# expert_num=8for i in range(0, self.expert_num):# Tensor(shape=[2, 16])linear_out = self._param_expert[i](input_data)expert_output = F.relu(linear_out)expert_outputs.append(expert_output)# Tensor(shape=[2, 128])  128=16*8expert_concat = paddle.concat(x=expert_outputs, axis=1)# Tensor(shape=[2, 8, 16]), 2-->batch_sizeexpert_concat = paddle.reshape(expert_concat, [-1, self.expert_num, self.expert_size])output_layers = []for i in range(0, self.gate_num):# Tensor(shape=[2, 8])cur_gate_linear = self._param_gate[i](input_data)# Tensor(shape=[2, 8]cur_gate = F.softmax(cur_gate_linear)# Tensor(shape=[2, 8, 1]cur_gate = paddle.reshape(cur_gate, [-1, self.expert_num, 1])# Tensor(shape=[2, 8, 16]) x Tensor(shape=[2, 8, 1]# = Tensor(shape=[2, 8, 16])cur_gate_expert = paddle.multiply(x=expert_concat, y=cur_gate)# Tensor(shape=[2, 16])cur_gate_expert = paddle.sum(x=cur_gate_expert, axis=1)# Tensor(shape=[2, 8])cur_tower = self._param_tower[i](cur_gate_expert)cur_tower = F.relu(cur_tower)# Tensor(shape=[2, 2])out = self._param_tower_out[i](cur_tower)out = F.softmax(out)out = paddle.clip(out, min=1e-15, max=1.0 - 1e-15)output_layers.append(out)return output_layers

四、总结

这里主要列出一些大家深入思考后可能遇到的疑问点及解释。

1. 疑问点1(呼应三中的4.5两个点)

【问】: expert网络结构一样,输入特征一样,是否会导致每个expert学出来的参数趋向于一致,从而失去了ensemble的意义?
【答】: 在网络参数随机初始化的情况下,不会发生问题中提到的问题。核心原因在于数据存在multi-view,只要每一个expert网络参数初始化是不一样的,就会导致每一个expert学到数据中不同的view(paddle官方实现就犯了这个致命错误)。微软的一篇论文中提到因为数据存在multi-view,训练多个DNN时,即使一样的特征,一样的超参数,只要简单的把参数初始化设置不一样, 这多个DNN也会有差异。论文参见:Towards Understanding Ensemble, Knowledge Distillation, and Self-Distillation in Deep Learning
所以大家在实现的时候,一定要注意这一个点,只需要简单的把参数初始化设置为随机即可。

2. 疑问点2

【问】: 是否应该强上MTL?
【答】: 如果task之间的相关性很弱,基本上都会发生negative transfer,所以MTL是绝对打不过single model的,不要盲目的为了显得高大上牛逼哄哄的一股脑MTL。还是那句话,模型不重要,重要的是对数据及场景的理解。

参考文献

[1] Ma J , Zhao Z , Yi X , et al. Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts. ACM, 2018.

推荐系统(十五)多任务学习:谷歌MMoE(Multi-gate Mixture-of-Experts )相关推荐

  1. 十五的学习日记20160926-你不知道的JS笔记/

    十五的学习日记20160926 JavaScript 一个用于检测正负值的函数,可以用它辨别-0值. 我觉得挺好用,以后可以写到自己的工具库里. //函数:检查传入参数是否为正数.Number=> ...

  2. 十五的学习日记20160925

    十五的学习日记20160925 CSS 学过盒子布局的人都知道,元素之间的上下margin会合并,保留数值较大的margin作为渲染依据. 但是今天在群里讨论发现: img元素和p元素的上下margi ...

  3. 第十五周学习周记——微信小程序开发初步

    第十五周学习周记 前言 一.小程序简介 二.小程序代码构成 1. JSON配置 2. WXML模板 3. WXSS样式 4. JS逻辑交互 总结 前言 这一周将开始微信小程序的学习. 一.小程序简介 ...

  4. 第十五周学习周报(20180611-20180617)

    第十五周学习周报 一.本周学习情况 1.本周主要学习李宏毅老师的机器学习课程 Backpropagation Convolutional Neural Network Semi-supervised ...

  5. 脉络梳理:推荐系统中的多任务学习

    © 作者|杨晨 机构|中国人民大学 研究方向|推荐系统 本文聚焦推荐系统中的一个研究方向 -- Multi-Task Recommendation,整理近五年内的研究工作,进行分类总结,并针对22年最 ...

  6. Google 多任务学习框架 MMoE

    2020-06-16 23:21:40 基于神经网络的多任务学习已经过成功应用内许多现实应用中,比如说之前我们介绍的阿里巴巴基于多任务联合学习的 ESMM 算法,其利用多任务学习解决了 CVR 中样本 ...

  7. 201771010137 赵栋《面向对象程序设计(java)》第十五周学习总结

    实验十五  GUI编程练习与应用程序部署 实验时间 2018-12-6 一:理论部分. 1.Java 程序的打包:编译完成后,程序员将.class 文件压缩打包为 .jar 文件后,GUI 界面序就可 ...

  8. 2017面向对象程序设计(Java)第十五周学习总结

    上周,老师要求同学们自学应用程序部署,并布置了相关的实验任务.此次实验的目的是掌握Java应用程序的打包操作:了解应用程序存储配置信息的两种方法: 了解Applet小应用程序的开发及应用方法:掌握基于 ...

  9. 【软件开发底层知识修炼】十五 快速学习GDB调试二 使用GDB进行断点调试

    上一篇文章我们学习了使用GDB的最基本方法:[软件开发底层知识修炼]十四 快速学习GDB调试一 入门使用 本篇文章将学习GDB的断点调试.断点调试是一种非常重要的调试方法. 文章目录 1 断点类型 2 ...

  10. 软件工程--第十五周学习进度

      第十五周 代码量  245 所花时间 6h  博客量  3篇 了解到的知识点  搭建基本web,了解了服务器的配置过程,也开始为自己的项目投入基金. 转载于:https://www.cnblogs ...

最新文章

  1. 卷起来了,写了一套计算机视觉学习笔记(20G/代码/PPT/视频)
  2. MySQL 性能跟踪语句
  3. 我整理了HMOV四大5G旗舰的参数,可依然没能拯救我的选择困难症
  4. vs目录(继承的值)配置
  5. Cuda编程学习(一)
  6. 数据通信技术_共建价值空间 共赢发展契机——2020华为贵数通新技术创享会在遵义市圆满举行...
  7. bim 模型web页面展示_BIM+装配式建筑工程师2020年必须拿下的技能证书
  8. android滑动开关框架,Android之实现滑动开关组件
  9. JAVA入门级教学之(什么是类加载)
  10. 城市运行一网统管_全国率先!“一屏观天下、一网管全城”,临港城市运行“一网统管”平台启动建设...
  11. CSS浮动(float)属性学习经验分享
  12. python ** 运算符_Python语法基础(2)运算符
  13. React Native系列文章
  14. 【智能制造】推进智能制造,他山之石可以攻玉!
  15. STM32 CAN通信协议详解—小白入门(一)
  16. 基于遥感解译与GIS技术环境影响评价图件制作(最新导则)
  17. C++ 实现matlab高斯滤波函数imgaussfilt
  18. .net之PDF合并(直接拼接,不改变尺寸和样式)
  19. Adobe Acrobat 如何批量删除PDF文件最后一页或倒数第二页?
  20. 表面缺陷检测的意义及现状

热门文章

  1. MathType不能正常右对齐解决方法
  2. VMware虚拟机ping不通主机,Destination Host Unreachable
  3. 教女朋友学Python(8)——排排坐吃果果
  4. 1.1.4 分支, if, if else, if elseif else, switch,循环,for,break,continue,双重for,while, do while
  5. LeetCode - 1002 - 查找常用字符(find-common-characters)
  6. 湖南独立学院计算机排名2015,2015年湖南独立学院高校名单
  7. E: dpkg was interrupted, you must manually run ‘dpkg –configure -a’ to correct the problem. 解决办法
  8. Ontonotes Release 5.0数据集的获取与处理
  9. 【smoj 1167】松果
  10. Transformer Fusion for Indoor RGB-D Semantic Segmentation非官方自己实现的代码