推荐系统遇上深度学习(八)--AFM模型理论和实践
预计阅读时间10分钟。
引言
在CTR预估中,为了解决稀疏特征的问题,学者们提出了FM模型来建模特征之间的交互关系。但是FM模型只能表达特征之间两两组合之间的关系,无法建模两个特征之间深层次的关系或者说多个特征之间的交互关系,因此学者们通过Deep Network来建模更高阶的特征之间的关系。
因此 FM和深度网络DNN的结合也就成为了CTR预估问题中主流的方法。有关FM和DNN的结合有两种主流的方法,并行结构和串行结构。两种结构的理解以及实现如下表所示:
今天介绍的AFM模型(Attentional Factorization Machine),便是串行结构中一种网络模型。
AFM模型介绍
我们首先来回顾一下FM模型,FM模型用n个隐变量来刻画特征之间的交互关系。这里要强调的一点是,n是特征的总数,是one-hot展开之后的,比如有三组特征,两个连续特征,一个离散特征有5个取值,那么n=7而不是n=3.
顺便回顾一下化简过程:
可以看到,不考虑最外层的求和,我们可以得到一个K维的向量。
不难发现,在进行预测时,FM会让一个特征固定一个特定的向量,当这个特征与其他特征做交叉时,都是用同样的向量去做计算。这个是很不合理的,因为不同的特征之间的交叉,重要程度是不一样的。如何体现这种重要程度,之前介绍的FFM模型是一个方案。另外,结合了attention机制的AFM模型,也是一种解决方案。
关于什么是attention model?本文不打算详细赘述,我们这里只需要知道的是,attention机制相当于一个加权平均,attention的值就是其中权重,判断不同特征之间交互的重要性。
刚才提到了,attention相等于加权的过程,因此我们的预测公式变为:
圆圈中有个点的符号代表的含义是element-wise product,即:
因此,我们在求和之后得到的是一个K维的向量,还需要跟一个向量p相乘,得到一个具体的数值。
可以看到,AFM的前两部分和FM相同,后面的一项经由如下的网络得到:
图中的前三部分:sparse iput,embedding layer,pair-wise interaction layer,都和FM是一样的。而后面的两部分,则是AFM的创新所在,也就是我们的Attention net。Attention背后的数学公式如下:
总结一下,不难看出AFM只是在FM的基础上添加了attention的机制,但是实际上,由于最后的加权累加,二次项并没有进行更深的网络去学习非线性交叉特征,所以AFM并没有发挥出DNN的优势,也许结合DNN可以达到更好的结果。
代码实现
终于到了激动人心的代码实战环节了,本文的代码有不对的的地方或者改进之处还望大家多多指正。
本文的github地址为:
https://github.com/princewen/tensorflow_practice/tree/master/recommendation/Basic-AFM-Demo
本文的代码根据之前DeepFM的代码进行改进,我们只介绍模型的实现部分,其他数据处理的细节大家可以参考我的github上的代码.
在介绍之前,我们先定义几个维度,方便下面的介绍:
Embedding Size:K
Batch Size:N
Attention Size :A
Field Size (这里是field size 不是feature size!!!!): F
模型输入
模型的输入主要有下面几个部分:
self.feat_index = tf.placeholder(tf.int32, shape=[None,None], name='feat_index')self.feat_value = tf.placeholder(tf.float32, shape=[None,None], name='feat_value')
self.label = tf.placeholder(tf.float32,shape=[None,1],name='label')self.dropout_keep_deep = tf.placeholder(tf.float32,shape=[None],name='dropout_deep_deep')
feat_index是特征的一个序号,主要用于通过embedding_lookup选择我们的embedding。feat_value是对应的特征值,如果是离散特征的话,就是1,如果不是离散特征的话,就保留原来的特征值。label是实际值。还定义了dropout来防止过拟合。
权重构建
权重主要分以下几部分,偏置项,一次项权重,embeddings,以及Attention部分的权重。除Attention部分的权重如下:
def _initialize_weights(self):weights = dict()
#embeddingsweights['feature_embeddings'] = tf.Variable(tf.random_normal([self.feature_size,self.embedding_size],0.0,0.01),name='feature_embeddings')weights['feature_bias'] = tf.Variable(tf.random_normal([self.feature_size,1],0.0,1.0),name='feature_bias')weights['bias'] = tf.Variable(tf.constant(0.1),name='bias')
Attention部分的权重我们详细介绍一下,这里共有四个部分,分别对应公式中的w,b,h和p。
weights['attention_w'] 的维度为 K * A,
weights['attention_b'] 的维度为 A,
weights['attention_h'] 的维度为 A,
weights['attention_p'] 的维度为 K * 1
Embedding Layer这个部分很简单啦,是根据feat_index选择对应的weights['feature_embeddings']中的embedding值,然后再与对应的feat_value相乘就可以了:
# attention partglorot = np.sqrt(2.0 / (self.attention_size + self.embedding_size))
weights['attention_w'] = tf.Variable(np.random.normal(loc=0,scale=glorot,size=(self.embedding_size,self.attention_size)), dtype=tf.float32,name='attention_w')
weights['attention_b'] = tf.Variable(np.random.normal(loc=0,scale=glorot,size=(self.attention_size,)), dtype=tf.float32,name='attention_b')
weights['attention_h'] = tf.Variable(np.random.normal(loc=0,scale=1,size=(self.attention_size,)), dtype=tf.float32,name='attention_h')
weights['attention_p'] = tf.Variable(np.ones((self.embedding_size,1)),dtype=np.float32)
Attention NetAttention部分的实现严格按照上面给出的数学公式:
这里我们一步步来实现。
对于得到的embedding向量,我们首先需要两两计算其element-wise-product。即:
通过嵌套循环的方式得到的结果需要通过stack将其变为一个tenser,此时的维度为(F * F - 1 / 2) * N* K,因此我们需要一个转置操作,来得到维度为 N * (F * F - 1 / 2) * K的element-wize-product结果。
element_wise_product_list = []for i in range(self.field_size):for j in range(i+1,self.field_size):element_wise_product_list.append(tf.multiply(self.embeddings[:,i,:],self.embeddings[:,j,:])) # None * Kself.element_wise_product = tf.stack(element_wise_product_list) # (F * F - 1 / 2) * None * Kself.element_wise_product = tf.transpose(self.element_wise_product,perm=[1,0,2],name='element_wise_product') # None * (F * F - 1 / 2) * K
得到了element-wise-product之后,我们接下来计算:
计算之前,我们需要先对element-wise-product进行reshape,将其变为二维的tensor,在计算完之后再变换回三维tensor,此时的维度为 N * (F * F - 1 / 2) * A:
self.attention_wx_plus_b = tf.reshape(tf.add(tf.matmul(tf.reshape(self.element_wise_product,shape=(-1,self.embedding_size)),self.weights['attention_w']),self.weights['attention_b']),shape=[-1,num_interactions,self.attention_size]) # N * ( F * F - 1 / 2) * A
然后我们计算:
此时的维度为 N * ( F * F - 1 / 2) * 1
self.attention_exp = tf.exp(tf.reduce_sum(tf.multiply(tf.nn.relu(self.attention_wx_plus_b),self.weights['attention_h']),axis=2,keep_dims=True)) # N * ( F * F - 1 / 2) * 1
然后计算:
这一层相当于softmax了,不过我们还是用基本的方式写出来:
self.attention_exp_sum = tf.reduce_sum(self.attention_exp,axis=1,keep_dims=True) # N * 1 * 1self.attention_out = tf.div(self.attention_exp,self.attention_exp_sum,name='attention_out') # N * ( F * F - 1 / 2) * 1
最后,我们计算得到经attention net加权后的二次项结果:
self.attention_x_product = tf.reduce_sum(tf.multiply(self.attention_out,self.element_wise_product),axis=1,name='afm') # N * Kself.attention_part_sum = tf.matmul(self.attention_x_product,self.weights['attention_p']) # N * 1
得到预测输出为了得到预测输出,除Attention part的输出外,我们还需要两部分,分别是偏置项和一次项:
# first order termself.y_first_order = tf.nn.embedding_lookup(self.weights['feature_bias'], self.feat_index)self.y_first_order = tf.reduce_sum(tf.multiply(self.y_first_order, feat_value), 2)
# biasself.y_bias = self.weights['bias'] * tf.ones_like(self.label)
而我们的最终输出如下:
# outself.out = tf.add_n([tf.reduce_sum(self.y_first_order,axis=1,keep_dims=True), self.attention_part_sum, self.y_bias],name='out_afm')
剩下的代码就不介绍啦!
好啦,本文只是提供一个引子,有关AFM的知识大家可以更多的进行学习呦。
参考文献
https://zhuanlan.zhihu.com/p/33540686
原文链接:https://mp.weixin.qq.com/s/1PmLkfd6CvI0d3owFdTYZA
查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:
www.leadai.org
请关注人工智能LeadAI公众号,查看更多专业文章
大家都在看
LSTM模型在问答系统中的应用
基于TensorFlow的神经网络解决用户流失概览问题
最全常见算法工程师面试题目整理(一)
最全常见算法工程师面试题目整理(二)
TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络
装饰器 | Python高级编程
今天不如来复习下Python基础
推荐系统遇上深度学习(八)--AFM模型理论和实践相关推荐
- 推荐系统遇上深度学习(七)--NFM模型理论和实践
预计阅读时间10分钟. 引言 在CTR预估中,为了解决稀疏特征的问题,学者们提出了FM模型来建模特征之间的交互关系.但是FM模型只能表达特征之间两两组合之间的关系,无法建模两个特征之间深层次的关系或者 ...
- 推荐系统遇上深度学习(一)--FM模型理论和实践
全文共2503字,15张图,预计阅读时间15分钟. FM背景 在计算广告和推荐系统中,CTR预估(click-through rate)是非常重要的一个环节,判断一个商品的是否进行推荐需要根据CTR预 ...
- 推荐系统遇上深度学习(六)--PNN模型理论和实践
全文共2621字,21张图,预计阅读时间15分钟. 原理 PNN,全称为Product-based Neural Network,认为在embedding输入到MLP之后学习的交叉特征表达并不充分,提 ...
- 推荐系统遇上深度学习(二)--FFM模型理论和实践
全文共1979字,6张图,预计阅读时间12分钟. FFM理论 在CTR预估中,经常会遇到one-hot类型的变量,one-hot类型变量会导致严重的数据特征稀疏的情况,为了解决这一问题,在上一讲中,我 ...
- 推荐系统遇上深度学习(八十七)-[阿里]基于搜索的用户终身行为序列建模
本文介绍的论文是<Search-based User Interest Modeling with Lifelong Sequential Behavior Data for Click-Thr ...
- 推荐系统遇上深度学习(九十二)-[腾讯]RecSys2020最佳长论文-多任务学习模型PLE
今天介绍的是腾讯提出的一种新的多任务学习个性化推荐模型,该论文荣获了RecSys2020最佳长论文奖,一起来学习下! 1.背景 多任务学习通过在一个模型中同时学习多个不同的目标,如CTR和CVR,最近 ...
- 知识图谱论文阅读(八)【转】推荐系统遇上深度学习(二十六)--知识图谱与推荐系统结合之DKN模型原理及实现
学习的博客: 推荐系统遇上深度学习(二十六)–知识图谱与推荐系统结合之DKN模型原理及实现 知识图谱特征学习的模型分类汇总 知识图谱嵌入(KGE):方法和应用的综述 论文: Knowledge Gra ...
- 推荐系统遇上深度学习,9篇阿里推荐论文汇总!
作者 | 石晓文 转载自小小挖掘机(ID: wAIsjwj) 业界常用的推荐系统主要分为两个阶段,召回阶段和精排阶段,当然有时候在最后还会接一些打散或者探索的规则,这点咱们就不考虑了. 前面九篇文章中 ...
- 推荐系统遇上深度学习(三十九)-推荐系统中召回策略演进!
推荐系统中的核心是从海量的商品库挑选合适商品最终展示给用户.由于商品库数量巨大,因此常见的推荐系统一般分为两个阶段,即召回阶段和排序阶段.召回阶段主要是从全量的商品库中得到用户可能感兴趣的一小部分候选 ...
最新文章
- 2021年大数据Spark(四十九):Structured Streaming 整合 Kafka
- 黄学东出任微软全球人工智能首席技术官!微软首位华人技术院士全面负责Azure云AI...
- WINCE6.0 chain.bin和xipkernel.bin解析
- 硬货 | 浅谈 CAP 和 Paxos 共识算法
- 设置vim打开文件光标指在上次退出位置
- AppendStream和RetractStream(没有弄完)
- xshell使用xftp传输文件和使用pure-ftpd搭建ftp服务
- php 删除单个文件大小,php删除指定大小的jpg文件
- ztree在刷新时第一个父节点消失_第一个关于中式菜谱的智能问答机器人小程序正式上线啦...
- hadoop的限制/不足
- golang |问题代码报go并发死锁
- vs qt 在linux运行,QT安装以及使用(QT支持linux和windows,也支持C/C++代码的编译运行,比vs简洁多)...
- thinkcmf5调用指定分类的二级_python机器学习API介绍11: 伯努利贝叶斯分类器
- movielens电影数据分析
- 【图像检测-缺陷检测】基于计算机视觉实现液晶显示器表面缺陷检测含Matlab源码
- Android Binder学习(四)之addService流程分析
- 毕业设计开题分析:MIPS指令集硬件化设计与实现
- 吃货必须知道的经验,收藏备用了!太全面了!
- 百度CarLife Android车机端黑屏问题
- 怎么隐藏CAD文件里的图层?
热门文章
- python wmi 重启网卡_python使用WMI检测windows系统信息、硬盘信息、网卡信息的方法...
- matlab 平滑曲线连接_【仪光学习】技能分享 | 前方高能:如何用Matlab轻松实现数学建模...
- php字符串以符号截取,PHP按符号截取字符串的指定部分的实现方法
- oracle edit历史,OGG-00952---oracle goldengate无法purge历史表和mark表处理一例
- 前端之BOM和DOM
- [Apple开发者帐户帮助]七、注册设备(3)禁用或启用设备
- supervisor linux下进程管理工具
- 转:Vim中显示不可见字符
- html css 深入理解float
- FIREDAC连接MSSQL 2000报不能支持连接MSSQL2000及更低版本的解决办法