简介:本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。

多任务学习背景

目前工业中使用的推荐算法已不只局限在单目标(ctr)任务上,还需要关注后续的转换链路,如是否评论、收藏、加购、购买、观看时长等目标。

本文介绍的是阿里巴巴团队发表在 SIGIR’2018 的论文《Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate》。文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型,有效解决了真实场景中CVR预估面临的数据稀疏以及样本选择偏差这两个关键问题。后续还会陆续介绍MMoE,PLE,DBMTL等多任务学习模型。

论文介绍

CVR预估面临两个关键问题:

1. Sample Selection Bias (SSB)

转化是在点击之后才“有可能”发生的动作,传统CVR模型通常以点击数据为训练集,其中点击未转化为负例,点击并转化为正例。但是训练好的模型实际使用时,则是对整个空间的样本进行预估,而非只对点击样本进行预估。即训练数据与实际要预测的数据来自不同分布,这个偏差对模型的泛化能力构成了很大挑战,导致模型上线后,线上业务效果往往一般。

2. Data Sparsity (DS)

CVR预估任务的使用的训练数据(即点击样本)远小于CTR预估训练使用的曝光样本。仅使用数量较小的样本进行训练,会导致深度模型拟合困难。

一些策略可以缓解这两个问题,例如从曝光集中对unclicked样本抽样做负例缓解SSB,对转化样本过采样缓解DS等。但无论哪种方法,都没有从实质上解决上面任一个问题。

由于点击=>转化,本身是两个强相关的连续行为,作者希望在模型结构中显示考虑这种“行为链关系”,从而可以在整个空间上进行训练及预测。这涉及到CTR与CVR两个任务,因此使用多任务学习(MTL)是一个自然的选择,论文的关键亮点正在于“如何搭建”这个MTL。

首先需要重点区分下,CVR预估任务与CTCVR预估任务。

  • CVR = 转化数/点击数。是预测“假设item被点击,那么它被转化”的概率。CVR预估任务,与CTR没有绝对的关系。一个item的ctr高,cvr不一定同样会高,如标题党文章的浏览时长往往较低。这也是不能直接使用全部样本训练CVR模型的原因,因为无法确定那些曝光未点击的样本,假设他们被点击了,是否会被转化。如果直接使用0作为它们的label,会很大程度上误导CVR模型的学习。
  • CTCVR = 转换数/曝光数。是预测“item被点击,然后被转化”的概率。

其中x,y,z分别表示曝光,点击,转换。注意到,在全部样本空间中,CTR对应的label为click,而CTCVR对应的label为click & conversion,这两个任务是可以使用全部样本的。因此,ESMM通过学习CTR,CTCVR两个任务,再根据上式隐式地学习CVR任务。具体结构如下:

网络结构上有两点值得强调:

  1. 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
  2. 隐式学习pCVR。这里pCVR 仅是网络中的一个variable,没有显示的监督信号。

具体地,反映在目标函数中:

代码实现

基于EasyRec推荐算法框架,我们实现了ESMM算法,具体实现可移步至github:EasyRec-ESMM。

EasyRec介绍:EasyRec是阿里云计算平台机器学习PAI团队开源的大规模分布式推荐算法框架,EasyRec 正如其名字一样,简单易用,集成了诸多优秀前沿的推荐系统论文思想,并且有在实际工业落地中取得优良效果的特征工程方法,集成训练、评估、部署,与阿里云产品无缝衔接,可以借助 EasyRec 在短时间内搭建起一套前沿的推荐系统。作为阿里云的拳头产品,现已稳定服务于数百个企业客户。

模型前馈网络:

  def build_predict_graph(self):"""Forward function.Returns:self._prediction_dict: Prediction result of two tasks."""# 此处从Concatenate后的tensor(all_fea)开始,省略其生成逻辑cvr_tower_name = self._cvr_tower_cfg.tower_namednn_model = dnn.DNN(self._cvr_tower_cfg.dnn,self._l2_reg,name=cvr_tower_name,is_training=self._is_training)cvr_tower_output = dnn_model(all_fea)cvr_tower_output = tf.layers.dense(inputs=cvr_tower_output,units=1,kernel_regularizer=self._l2_reg,name='%s/dnn_output' % cvr_tower_name)ctr_tower_name = self._ctr_tower_cfg.tower_namednn_model = dnn.DNN(self._ctr_tower_cfg.dnn,self._l2_reg,name=ctr_tower_name,is_training=self._is_training)ctr_tower_output = dnn_model(all_fea)ctr_tower_output = tf.layers.dense(inputs=ctr_tower_output,units=1,kernel_regularizer=self._l2_reg,name='%s/dnn_output' % ctr_tower_name)tower_outputs = {cvr_tower_name: cvr_tower_output,ctr_tower_name: ctr_tower_output}self._add_to_prediction_dict(tower_outputs)return self._prediction_dict

loss计算:

注意:计算CVR的指标时需要mask掉曝光数据。

  def build_loss_graph(self):"""Build loss graph.Returns:self._loss_dict: Weighted loss of ctr and cvr."""cvr_tower_name = self._cvr_tower_cfg.tower_namectr_tower_name = self._ctr_tower_cfg.tower_namecvr_label_name = self._label_name_dict[cvr_tower_name]ctr_label_name = self._label_name_dict[ctr_tower_name]ctcvr_label = tf.cast(self._labels[cvr_label_name] * self._labels[ctr_label_name], tf.float32)cvr_loss = tf.keras.backend.binary_crossentropy(ctcvr_label, self._prediction_dict['probs_ctcvr'])cvr_loss = tf.reduce_sum(cvr_losses, name="ctcvr_loss")# The weight defaults to 1.self._loss_dict['weighted_cross_entropy_loss_%s' %cvr_tower_name] = self._cvr_tower_cfg.weight * cvr_lossctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.cast(self._labels[ctr_label_name], tf.float32),logits=self._prediction_dict['logits_%s' % ctr_tower_name]), name="ctr_loss")self._loss_dict['weighted_cross_entropy_loss_%s' %ctr_tower_name] = self._ctr_tower_cfg.weight * ctr_lossreturn self._loss_dict

note: 这里loss是 weighted_cross_entropy_loss_ctr + weighted_cross_entropy_loss_cvr, EasyRec框架会自动对self._loss_dict中的内容进行加和。

metric计算:

注意:计算CVR的指标时需要mask掉曝光数据。

  def build_metric_graph(self, eval_config):"""Build metric graph.Args:eval_config: Evaluation configuration.Returns:metric_dict: Calculate AUC of ctr, cvr and ctrvr."""metric_dict = {}cvr_tower_name = self._cvr_tower_cfg.tower_namectr_tower_name = self._ctr_tower_cfg.tower_namecvr_label_name = self._label_name_dict[cvr_tower_name]ctr_label_name = self._label_name_dict[ctr_tower_name]for metric in self._cvr_tower_cfg.metrics_set:# CTCVR metricctcvr_label_name = cvr_label_name + '_ctcvr'cvr_dtype = self._labels[cvr_label_name].dtypeself._labels[ctcvr_label_name] = self._labels[cvr_label_name] * tf.cast(self._labels[ctr_label_name], cvr_dtype)metric_dict.update(self._build_metric_impl(metric,loss_type=self._cvr_tower_cfg.loss_type,label_name=ctcvr_label_name,num_class=self._cvr_tower_cfg.num_class,suffix='_ctcvr'))# CVR metriccvr_label_masked_name = cvr_label_name + '_masked'ctr_mask = self._labels[ctr_label_name] > 0self._labels[cvr_label_masked_name] = tf.boolean_mask(self._labels[cvr_label_name], ctr_mask)pred_prefix = 'probs' if self._cvr_tower_cfg.loss_type == LossType.CLASSIFICATION else 'y'pred_name = '%s_%s' % (pred_prefix, cvr_tower_name)self._prediction_dict[pred_name + '_masked'] = tf.boolean_mask(self._prediction_dict[pred_name], ctr_mask)metric_dict.update(self._build_metric_impl(metric,loss_type=self._cvr_tower_cfg.loss_type,label_name=cvr_label_masked_name,num_class=self._cvr_tower_cfg.num_class,suffix='_%s_masked' % cvr_tower_name))for metric in self._ctr_tower_cfg.metrics_set:# CTR metricmetric_dict.update(self._build_metric_impl(metric,loss_type=self._ctr_tower_cfg.loss_type,label_name=ctr_label_name,num_class=self._ctr_tower_cfg.num_class,suffix='_%s' % ctr_tower_name))return metric_dict

实验及不足

我们基于开源AliCCP数据,进行了大量实验,实验部分请期待下一篇文章。实验发现,ESMM的跷跷板现象较为明显,CTR与CVR任务的效果较难同时提升。

参考文献

  1. Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate
  2. 阿里CVR预估模型之ESMM
  3. EasyRec-ESMM使用介绍多任务学习模型之ESMM介绍与实现

原文链接

本文为阿里云原创内容,未经允许不得转载。

多任务学习模型之ESMM介绍与实现相关推荐

  1. 多任务学习模型ESMM原理与实现(附代码)

    来源:DataFunTalk 本文约2500字,建议阅读5分钟 文章基于 Multi-Task Learning (MTL) 的思路,提出一种名为ESMM的CVR预估模型. [ 导读 ] 本文介绍的是 ...

  2. 推荐系统(十四)多任务学习:阿里ESMM(完整空间多任务模型)

    推荐系统(十四)多任务学习:阿里ESMM(完整空间多任务模型) 推荐系统系列博客: 推荐系统(一)推荐系统整体概览 推荐系统(二)GBDT+LR模型 推荐系统(三)Factorization Mach ...

  3. 推荐系统遇上深度学习(九十二)-[腾讯]RecSys2020最佳长论文-多任务学习模型PLE

    今天介绍的是腾讯提出的一种新的多任务学习个性化推荐模型,该论文荣获了RecSys2020最佳长论文奖,一起来学习下! 1.背景 多任务学习通过在一个模型中同时学习多个不同的目标,如CTR和CVR,最近 ...

  4. IJCAI 2019 | 为推荐系统生成高质量的文本解释:基于互注意力机制的多任务学习模型...

    编者按:在个性化推荐系统中,如果能在提高推荐准确性的同时生成高质量的文本解释,将更容易获得用户的"芳心".然而,现有方法通常将两者分开优化,或只优化其中一个目标.为了同时兼顾二者, ...

  5. 排序层-深度模型-2020:PLE【多任务学习模型】【腾讯】

    PLE模型是腾讯发表在RecSys '20上的文章,这篇paper获得了recsys'20的best paper award,也算为腾讯脱离技术贫民的大业添砖加瓦了.这篇文章号称极大的缓解了多任务学习 ...

  6. Sunny.Xia的深度学习(四)MMOE多任务学习模型实战演练

    本专栏文章会在本博客和知乎专栏--Sunny.Xia的深度学习同步更新,对于评论博主若未能够及时回复的,可以知乎私信.未经本人允许,请勿转载,谢谢. 一.什么是MMOE? 三张图分别是多任务模型的不同 ...

  7. 【推荐系统多任务学习MTL】ESMM 论文精读笔记(含代码实现)

    论文地址:Entire Space Multi-Task Model: An Effective Approach for Estimating Post-Click Conversion Rate ...

  8. MMOE——多任务学习模型

    摘要 对于多任务学习,我们的目标是建立一个单一的模型,同时学习这些多个目标和任务.然而,常用的多任务模型的预测质量往往对任务之间的关系比较敏感.因此,研究任务特定目标和任务间关系之间的建模权衡是很重要 ...

  9. 多任务学习模型MTL: MMoE、PLE

    常见的监督学习包括: 回归:预测值为连续值,如销售额: 二分类:预测值为离散值,且只有两种取值,如性别,要么是男,要么是女: 多分类:预测值为离散值,且多于两种取值,如动物分类,可能有猫.狗.狮子等等 ...

最新文章

  1. 查看控制文件的内容(oracle)
  2. DNA Sorting
  3. MOS管及MOS管的驱动电路设计
  4. Facebook的体系结构分析---外文转载
  5. MongoDB副本集学习(三):性能和优化相关
  6. CocoStuff—基于Deeplab训练数据的标定工具【二、用已提供的标注数据跑通项目】...
  7. 项目验收流程小TIPS
  8. 计算机辅助翻译与人工智能,2018年机器翻译行业概述与现状,人工智能让人人实现国际化交流...
  9. 南佛罗里达大学计算机科学硕士,去南佛罗里达大学读硕士好吗
  10. 苹果手机上如何设置qq邮箱服务器地址,iPhone手机如何添加qq邮箱
  11. 你居然不会狄杰斯特算法?惊了!
  12. LR---Loadrunner11破解方法
  13. css中outline,css中outline的解析(附示例)
  14. 机器学习:SOM聚类的实现
  15. Orc-Battle
  16. 使用Python和OpenCV标记超级像素的炫彩度
  17. BT——专门为大容量文件的共享而设计的网络协议
  18. 控油,真的可以缓解脂溢性脱发么?
  19. 微信公众号运营,如何编辑好的文案吸引粉丝
  20. 给对话框添加菜单 工具栏 状态栏简易方法

热门文章

  1. 如何在python中安装matplotlib模块_Windows下为Python安装Matplotlib模块
  2. 爬虫 页面元素变化_爬虫 基本知识 萌新
  3. python open写入_Python3 open() 函数详解 读取文件写入文件追加文件二进制文件
  4. Java的一些学习心得
  5. python get_len_Python类,特殊方法, __getitem__,__len__, __delitem__
  6. hdfs文件如何导出到服务器,[Hadoop] 如何将 HDFS 文件导出到 Windows文件系统
  7. alpine登陆mysql_如何构建一个php7-alpine的docker镜像
  8. 如何使用CNN进行物体识别和分类_RCNN物体识别
  9. go 连接服务器 并存放图片_基于 Go 语言开发在线论坛(二):通过模型类与MySQL数据库交互...
  10. mysql group_concat去重_mysql 数据库group_concat函数的一些用法