今天为大家带来阿里巴巴2021年的一篇文章:《One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction》。该文章提出的方法可以只使用一种模型,便可以服务于多种CTR业务场景。这些业务场景中可能会共享一些user和item,也有自己独立的user和item。相比于传统方法中的一个模型对应一种业务,该方法既可以减少多个模型带来的维护成本与计算资源,也可以共享不同业务场景下的数据。我们接下来将详细介绍。

1. 方法动机

针对不同的业务场景,例如图1所示的首页推荐和猜你喜欢,传统方法会针对每个业务场景建立不同的模型。这种方法会带来以下几种问题:

  • 一些业务场景的流量较少,相比于其他相似的业务场景,缺乏训练数据。

  • 维护多个模型会带来大量的成本。

因此,我们提出了一种使用单个模型服务于多种业务场景的任务。我们将其称之为 multi-domain CTR prediction,「即我们的模型需要同时预测在业务场景下的点击率。模型以作为输入,其中为输入特征,为点击标签,为不同业务场景的标识」。其中由不同业务场景下的分布得到。为了充分利用不同业务场景下的数据,该文章提出了以下3种模块:

  1. 「Partitioned Normalization (PN)」: 可以针对不同业务场景下不同的数据分布做定制化归一化。

  2. 「Star topology fully-connected neural network」: 文章提出了Star Topology Adaptive Recommender(STAR) 来解决多领域的CTR预估问题。该网络可以充分利用多个业务中的数据来提升各自业务的指标。

  3. 文章提出了一种「辅助网络」(auxiliary network),直接以业务场景的标识(domain indicator)作为输入,来使得网络更好的感知不同场景下的数据分布。

图1:首页推荐和猜你喜欢。

2. 方法介绍

方法总览

如图2(a)所示,之前单场景CTR预估的方法将输入经过embedding层后,通过pooling/concat操作得到一维的向量表示后,通过BN层,经过一系列FC层,输出最后的结果。这类方法一个模型对应一种业务,不能充分利用不同业务场景下相似的数据,也提升了多个模型带来的业务成本。本文提出的方法如图2(b)所示:相比于图2(a)所示的模型,该模型有以下几点不同:

  1. 将BN(Batch Normalization)层替换为PN(Partitioned Normalization)层。

  2. 将FCN替换为Star Topology FCN。

  3. 将domain indicator直接输入。

我们接下来将详细介绍这三个不同之处。

图2:(a)单场景CTR预估模型。(b):Multi-Domain CTR预估模型。

Partitioned Normalization

Batch Normalization (BN)是一个具有代表性的方法,该方法对于深度网络的训练有着关键的作用。具体来说,BN的公式如下:

其中为输出,和 为可学习的缩放系数与bias,和为mini-batch的均值和方差。在测试阶段,BN使用训练中滑动平均得到的均值和方差:

「BN假设所有的样本都是独立同分布(i.i.d),同时所有的训练样本都有着相同的统计规律。」

「然而在multi-domain CTR prediction任务下,样本只在一个domain里遵循i.i.d,不同领域之间并不独立同分布」。因此,文章提出了Partitioned Normalization(PN),具体公式如下:

其中为domain-specific scale 和 domain-specific bias,来捕捉不同domain之间的数据分布。在测试阶段,PN使用训练中每个领域滑动平均得到的均值和方差:

Star Topology FCN

如图2(b)所示,在经过PN层后,输出会作为Star Topology FCN的输入。Star Topology FCN包含一个所有领域共享的FCN和每个领域各自独立的FCN(如图3所示)。因此,所有的FCN数量为, 为domain的数量。对于共享的FCN,我们令为共享FCN的权重,为共享FCN的偏置。对于每个领域各自独立的FCN,我们令其权重为,偏置为。对于第个领域,其最后的权重和偏置表示为:

其中为逐点相乘。我们令为网络第个领域的输入,则输出可表示为:

图3:Star Topology FCN结构

所以,通俗来说,Star Topology FCN中每个领域网络的权重由共享FCN和其domain-specific FCN的权重共同决定。共享FCN来决定每个领域中数据的共性,而domain-specific FCN习得不同领域数据之间分布的差异性。

为了方便大家理解,我们提供了Star Topology FCN的tensorflow 代码实现,核心步骤实现如代码中注释所示:由于公众号显示代码会有遮挡,欢迎大家点击阅读原文,更方便地查看详细代码。

import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.initializers import Zeros, glorot_normal
from tensorflow.python.keras.layers import Layer
from tensorflow.python.keras.regularizers import l2def activation_layer(activation):if isinstance(activation, str):act_layer = tf.keras.layers.Activation(activation)elif issubclass(activation, Layer):act_layer = activation()else:raise ValueError("Invalid activation,found %s.You should use a str or a Activation Layer Class." % (activation))return act_layerclass STAR(Layer):def __init__(self, hidden_units, num_domains, activation='relu', l2_reg=0, dropout_rate=0, use_bn=False, output_activation=None,seed=1024, **kwargs):self.hidden_units = hidden_unitsself.num_domains = num_domainsself.activation = activationself.l2_reg = l2_regself.dropout_rate = dropout_rateself.use_bn = use_bnself.output_activation = output_activationself.seed = seedsuper(STAR, self).__init__(**kwargs)def build(self, input_shape):input_size = input_shape[-1]hidden_units = [int(input_size)] + list(self.hidden_units)## 共享FCN权重self.shared_kernels = [self.add_weight(name='shared_kernel_' + str(i),shape=(hidden_units[i], hidden_units[i + 1]),initializer=glorot_normal(seed=self.seed),regularizer=l2(self.l2_reg),trainable=True) for i in range(len(self.hidden_units))]self.shared_bias = [self.add_weight(name='shared_bias_' + str(i),shape=(self.hidden_units[i],),initializer=Zeros(),trainable=True) for i in range(len(self.hidden_units))]## domain-specific 权重self.domain_kernels = [[self.add_weight(name='domain_kernel_' + str(index) + str(i),shape=(hidden_units[i], hidden_units[i + 1]),initializer=glorot_normal(seed=self.seed),regularizer=l2(self.l2_reg),trainable=True) for i in range(len(self.hidden_units))] for index in range(self.num_domains)]self.domain_bias = [[self.add_weight(name='domain_bias_' + str(index) + str(i),shape=(self.hidden_units[i],),initializer=Zeros(),trainable=True) for i in range(len(self.hidden_units))] for index in range(self.num_domains)]self.activation_layers = [activation_layer(self.activation) for _ in range(len(self.hidden_units))]if self.output_activation:self.activation_layers[-1] = activation_layer(self.output_activation)super(STAR, self).build(input_shape)  # Be sure to call this somewhere!def call(self, inputs, domain_indicator, training=None, **kwargs):deep_input = inputsoutput_list = [inputs] * self.num_domains for i in range(len(self.hidden_units)):for j in range(self.num_domains):# 网络的权重由共享FCN和其domain-specific FCN的权重共同决定output_list[j] = tf.nn.bias_add(tf.tensordot(output_list[j], self.shared_kernels[i] * self.domain_kernels[j][i], axes=(-1, 0)), self.shared_bias[i] + self.domain_bias[j][i])try:output_list[j] = self.activation_layers[i](output_list[j], training=training)except TypeError as e:  # TypeError: call() got an unexpected keyword argument 'training'print("make sure the activation function use training flag properly", e)output_list[j] = self.activation_layers[i](output_list[j])output = tf.reduce_sum(tf.stack(output_list, axis=1) * tf.expand_dims(domain_indicator,axis=-1), axis=1)return outputdef compute_output_shape(self, input_shape):if len(self.hidden_units) > 0:shape = input_shape[:-1] + (self.hidden_units[-1],)else:shape = input_shapereturn tuple(shape)def get_config(self, ):config = {'activation': self.activation, 'hidden_units': self.hidden_units,'l2_reg': self.l2_reg, 'use_bn': self.use_bn, 'dropout_rate': self.dropout_rate,'output_activation': self.output_activation, 'seed': self.seed}base_config = super(STAR, self).get_config()return dict(list(base_config.items()) + list(config.items()))

Auxiliary Network

文章还提出了一个辅助网络来学习不同领域之间数据分布的差别。该网络和主干网络相比,参数量很小,仅为几层layer。该辅助网络以domain indicator 的embedding作为输入,同时连接了其他的特征。输出为,我们令主干网络的输出为。最终的CTR预测结果如下所示:

由此可见,domain indicator会直接影响到输出分数的变化,增强了网络捕捉不同领域数据分布的能力。

总结

本文首先提出了不同业务场景下,数据互相共享互补提升的思路,提出了一种新的任务:multi-domain CTR prediction。并针对这类任务设计了PN,Star Topology FCN,辅助网络等结构。笔者认为,该文章具有很好的借鉴价值,大家可以在自己的任务上或者业务中进行尝试,欢迎大家交流。

一起交流

想和你一起学习进步!『NewBeeNLP』目前已经建立了多个不同方向交流群(机器学习 / 深度学习 / 自然语言处理 / 搜索推荐 / 图网络 / 面试交流 / 等),名额有限,赶紧添加下方微信加入一起讨论交流吧!(注意一定要备注信息才能通过)

参考文献:

Sheng XR, Zhao L, Zhou G, Ding X, Luo Q, Yang S, Lv J, Zhang C, Zhu X. One Model to Serve All: Star Topology Adaptive Recommender for Multi-Domain CTR Prediction. arXiv preprint arXiv:2101.11427. 2021 Jan 27.

END -

从Google Scholar看各大科技公司科研水平

2021-10-30

广义多目标算法探索实践

2021-10-29

升级换代!Facebook全新电商搜索系统Que2Search

2021-10-28

字节AI LAB NLP算法二面凉+被捞后通过

2021-10-27

一个模型搞定多个CTR业务!阿里STAR网络介绍(附代码实现)相关推荐

  1. Meta AI推出“杂食者”:一个模型搞定图像、视频和3D数据三大分类任务,性能还不输独立模型...

    丰色 发自 凹非寺 量子位 | 公众号 QbitAI 最近,Meta AI推出了这样一个"杂食者" (Omnivore)模型,可以对不同视觉模态的数据进行分类,包括图像.视频和3D ...

  2. 一个模型搞定十大自然语言任务:NLP全能选手来了 | 论文+代码

    夏乙 发自 凹非寺 量子位 出品 | 公众号 QbitAI 所谓自然语言处理(NLP),其实涵盖了很多方面.比如有已经无处不在的机器翻译,各大公司总在冲击排行榜的机器问答,也有普通人不太熟悉的情感分析 ...

  3. NLP通用模型诞生?一个模型搞定十大自然语言常见任务

    翻译 | 于之涵 编辑 | Leo 出品 | AI科技大本营 (公众号ID:rgznai100) AI科技大本营按:目前的NLP领域有一个问题:即使是再厉害的算法也只能针对特定的任务,比如适用于机器翻 ...

  4. NLP通用模型decaNLP诞生,一个模型搞定十大自然语言常见任务

    然而近日,Salesforce发布了一项新的研究成果:decaNLP--一个可以同时处理机器翻译.问答.摘要.文本分类.情感分析等十项自然语言任务的通用模型. Salesforce的首席科学家Rich ...

  5. AI版「女娲」来了!文字生成图像、视频,8类任务一个模型搞定

    来源丨机器之心 作者丨陈萍.小舟 AI会是未来的「造物者」吗? 近来,视觉合成任务备受关注.几天前英伟达的 GauGAN 刚刚上新了 2.0 版本,现在一个新视觉合成模型 Nüwa(女娲)也火了. 相 ...

  6. 一个模型搞定图像标注、读图问答两件事,VQA准确率逼近人类水平 | Demo可玩...

    明敏 发自 凹非寺 量子位 | 公众号 QbitAI 现在,丢给AI一张图,它不仅能看图说话,还能应对人们提出的刁钻问题了. 比如,给它看一张经典卷福照. 它便能回答出: 一个穿着西服.正在比划手势的 ...

  7. python利器app怎么查文献-科研人必备:一个工具搞定文献查阅、数据分析、模型搭建...

    原标题:科研人必备:一个工具搞定文献查阅.数据分析.模型搭建 写论文有多难?这首诗形容得好: 进入学校先选题,踌躇满志万人敌:发现前辈都做过,满脸懵逼加惊奇. 终于找到大空白,我真是个小天才:左试右试 ...

  8. 通用人工智能最新突破!一个Transformer搞定一切

    Datawhale干货 编辑:梦晨 鱼羊,来源:量子位 通用人工智能,还得看DeepMind. 这回,只一个模型,使用相同的权重,不仅把看家本领雅达利游戏玩得飞起. 和人类聊聊天.看图写话也不在话下. ...

  9. 一个系列搞定校招——简历篇

    上一篇一个系列搞定校招--综合篇总体介绍了校招从简历到面试的各个环节,没看过的可以先看上一篇,接下来将分别从每一个环节详细介绍,本篇先说[简历篇]. 前面说过,简历是求职的敲门砖,一份好的简历必然会给 ...

最新文章

  1. django 设置外键_django2.0前后版本定义外键和一对一关系的差别
  2. tcp、udp协议连接的建立和释放
  3. setsockopt()使用方法(參数具体说明)
  4. eclipse打开过的工程信息保存路径
  5. python中argsort,sort 和 sorted,operator.itemgetter函数
  6. drop 很慢 物化视图_终于解决了物化视图复制的问题
  7. 翼支付和银行网络连通准备
  8. ubuntu的磁盘扩容
  9. 【JS第1期】深拷贝实现原理
  10. datatable中使用linq的条件或_条件格式中使用公式,请提前备好晕车药
  11. Matlab绘制圆饼统计图pie的用法详解
  12. html语言中alt,html中alt的用法
  13. VsCode文件屏蔽
  14. 案例:Java用面向对象的思想设计游戏中的角色
  15. Zemax-多重结构的公差分析
  16. 手机断触怎么办_手机触摸屏失灵了怎么办,五种方法自己就能修好它!
  17. lammps案例:分子自由落体运动模拟
  18. 关于【ROM制作工具】的那点事
  19. Pytorch环境下微调BERT以及调参教程
  20. 钉钉机器人实现打卡提醒定时任务

热门文章

  1. SAP License:SAP顾问心情随笔——点燃一支烟
  2. 评分卡建模工具scorecardpy全解读
  3. Mac 配置vscode调试PHP
  4. 使用WPF创建画图箭头
  5. SQL Server 阻塞原因分析
  6. Redis Cluster 伪集群的搭建
  7. java 动手动脑
  8. Android初级教程初谈自定义view自定义属性
  9. python连接redis002
  10. POJ 1014 Dividing