基于上一篇分析中协同过滤、逻辑回归及FM的比较,可以得出这样一个结论:

主流模型迭代的关键在于增强模型表达能力,而增强方式的主要脉络为:

  1. 引入其它可用特征信息(CF->LR)。
  2. 将现有特征进行组合(LR->POLY2->FM)。

更通俗的表达:

  • 模型迭代在于找到更多的有效信息。

本文想要回顾的FFM(Field-aware Factorization Macheines)模型可以看作是FM模型的增强版,其正是沿用在FM模型的特征组合思想,并将其发扬光大,曾在多项CTR预估赛中夺魁,并且被Criteo、美团等公司深度应用在推荐系统与CTR预估等领域。

相较于FM模型,FFM模型在FM隐向量特征交叉组合的基础上,进一步引入了特征感知(field-aware)这一概念,使得模型表达能力在理论上有了一个较大的提升。

FM模型表达式:

FFM的表达式如下所示:

从表达式中可以看出,FFM与FM的不同之处在于二阶特征组合部分的隐向量由

变成了
。而这便意味着

FFM模型中每一维特征对应的不是唯一一个隐向量,而是一组特征隐向量,这也具体的引出了FFM与FM特征组合不同之处与FFM模型的提升之处。

FFM这么做的原因是什么?

  1. FM模型中每一特征共用同一个特征隐向量,意味着每一特征与不同域特征进行组合时使用的是同一隐特征向量来学习隐向量参数,这样不够细致,存在明显的信息浪费

什么是特征域Fileds,FFM怎么去做特征域级的特征组合?

参考论文原文 ,“features” can be grouped into “fields” 。把这句话的主语和宾语置换一下就可以得到Fileds的定义: “fields” is grouped by "features",也即特征域由某类特征组成。

借用论文的示例,或许可以更直观的回答这个问题:

Table 1: An artificial CTR data set, where + (-) representsthe number of clicked (unclicked) impressions.
Table 2:a example of click event

在上面两个表中,Publisher(P)、Advertiser(A)和Gander(G)是三个Fileds特征域。其中Publisher特征域的中的feature有ESPN、Vogue和NBC,套用作者定义那么有:fields of Publisher is grouped by features<ESPN,Vogue,NBC>。Advertiser与Gander特征域同理。

那么对于Table 2的数据,FFM的二阶特征组合结果为:

而对于FM二阶特征组合结果而言:

可以看到在FFM模型中,Feature ESPN在与NIKE及Male形成特征组合ESPN,NIKE)、(ESPN,Male)时,使用了不同的潜向量

来学习参数。同理,Feature NIKE在与ESPN及Male特征组合时,也分别使用了两个不同的潜向量
来计算组合权重。

而对于FM而言,Feature ESPN在与NIKE及Male特征组合时,使用的是同一个潜向量

,同时NIKE与Male特征也只存在唯一一个隐向量。

至此,什么是特征域Fileds,FFM怎么去做特征域级的特征组合得到解释。

一个问题

以上解释了FFM模型相较于FM模型的优势在于,每个特征在与不同Field下的Feature进行组合时,会使用与之特征域对应的隐向量来学习组合权重。而这也带来了一个FFM不被大规模应用的问题--参数量暴增,解释如下:

假设一份数据集dataset的维度为(m,n),FM在进行模型训练时将隐向量特征维度设置为k维,那么FM的特征量便为nk。而对于FFM而言,假设dataset的n维特征对应着f个特征域,那么在隐向量维度同维k的情况下,FFM的特征量级为nfk(实际数量为n(f-1)k,由于特征不需要自我交叉,因此为f-1)。

nk与nfk之间的差别虽然只是线性级,但是由于互联网数据特征量n动辄百万级,虽然f的值会比n小若干个数量级,但这也足以使得FFM模型的参数量暴增到一个恐怖的级别。

Table 3:two CTR datasets Criteo and Avazu from Kaggle competitions

以Table3数据集Criteo为例,当维度k取为10时,FM的参数量为

,而FFM的量级为
。虽然只增加了一个数量级别,但参数量从千万级变为了亿级别,可谓非常恐怖了。

python实现

关于FFM的工程代码其实论文作者已经在github上给出C++版本,python版,或者Amazon AI的马超开源的XLearn。

本文引用Python implementation of FFM model (ctr, cvr)一文,来分析一下FFM的代码实现。

import  numpy as np
import  math
import  random
class ffm(object):def __init__(self, feature_num, fild_num, feature_dim_num, feat_fild_dic, learning_rate, regular_para, stop_threshold):#n features, m domains, each feature dimension kself.n = feature_num #特征数量self.m = fild_num #特征域数self.k = feature_dim_num #隐向量特征长度self.dic = feat_fild_dic #特征对应的域#Set hyperparameter, learning rate eta, regularization coefficient lamdaself.eta = learning_rateself.lamda = regular_paraself.threshold = stop_thresholdself.w = np.random.rand(self.n, self.m , self.k) / math.sqrt(self.k)#权重初试值self.G = np.ones(shape = (feature_num, fild_num, feature_dim_num), dtype = np.float64)def train(self, tr_l, val_l, train_y, val_y, max_echo):#这一部分计算模型的训练损失# tr_l, val_l, train_y, val_y, max_echo are# Training set, validation set, training set label, validation set label, maximum number of iterationsminloss = 0for i in range(max_echo):# Iterative training, max_echo is the maximum number of iterationsL_val = 0Logloss = 0order = range(len(train_y))# mess up the orderrandom.shuffle(order)for each_data_index in order:# Remove a recordtr_each_data = tr_l[each_data_index]# phi() is the model formulaphi = self.phi(tr_each_data)# y_i is the actual tag valuey_i = float(train_y[each_data_index])# Calculate the gradient belowg_phi = -y_i / (1 + math.exp(y_i * phi))# Begin to update the model parameters using the gradient descent methodself.sgd_para(tr_each_data, g_phi)# Next, check on the verification set, the basic process is the same as before.for each_vadata_index, each_va_y in enumerate(val_y):val_each_data = val_l[each_vadata_index]phi_v = self.phi(val_each_data)y_vai = float(each_va_y)Logloss += -(y_vai * math.log(phi_v) + (1 - y_vai) * math.log(1 - phi_v))Logloss = Logloss / len(val_y)# L_val += math.log(1+math.exp(-y_vai * phi_v))print("The %d iteration, LOGLOSS on the validation set: %f" % (i, Logloss))if minloss == 0:# minloss stores the smallest LOGLOSSminloss = Loglossif Logloss <= self.threshold:# It can also be considered that setting the threshold allows the program to jump, and personal needs can be removed.print('Less than the threshold!')breakif minloss < Logloss:# If the next iteration does not reduce LOGLOSS, break out (early stopping)print('early stopping')breakdef phi(self, tmp_dict):#这一部分计算这FFM二阶部分的对应的值#Samples are normalized here to prevent calculation overflowsum_v = sum(tmp_dict.values())#First find the index of the non-zero feature in each piece of data and put it in a listphi_tmp = 0key_list = tmp_dict.keys()for i in range(len(key_list)):#feat_index is the index of the feature, fild_index1 is the index of the domain, and value1 is the value corresponding to the featurefeat_index1 = key_list[i]fild_index1 = self.dic[feat_index1]#The purpose of dividing here by sum_v is to normalize this one (return all feature values ​​to between 0 and 1)#Of course, each feature has been normalized before (0-1)value1 = tmp_dict[feat_index1] / sum_v#Two non-zero features pairwise inner productfor j in range(i+1, len(key_list)):feat_index2 = key_list[j]fild_index2 = self.dic[feat_index2]value2 = tmp_dict[feat_index2] / sum_vw1 = self.w[feat_index1, fild_index2]w2 = self.w[feat_index2, fild_index1]#The final value is obtained by summing up multiple characteristic combinationsphi_tmp += np.dot(w1, w2) * value1 * value2return phi_tmpdef sgd_para(self, tmp_dict, g_phi):#这一部分梯度计算 ,参数的更新#学习率是用的AdaGrad算法sum_v = sum(tmp_dict.values())key_list = tmp_dict.keys()for i in range(len(key_list)):feat_index1 = key_list[i]fild_index1 = self.dic[feat_index1]value1 = tmp_dict[feat_index1] / sum_vfor j in range(i + 1, len(key_list)):feat_index2 = key_list[j]fild_index2 = self.dic[feat_index2]value2 = tmp_dict[feat_index2] / sum_vw1 = self.w[feat_index1, fild_index2]w2 = self.w[feat_index2, fild_index1]# Update g and Gg_feati_fildj = g_phi * value1 * value2 * w2 + self.lamda * w1g_featj_fildi = g_phi * value1 * value2 * w1 + self.lamda * w2self.G[feat_index1, fild_index2] += g_feati_fildj ** 2self.G[feat_index2, fild_index1] += g_featj_fildi ** 2# math.sqrt() can only accept one element, while np.sqrt() can root the entire vectorself.w[feat_index1, fild_index2] -= self.eta / np.sqrt(self.G[feat_index1, fild_index2]) * g_feati_fildjself.w[feat_index2, fild_index1] -= self.eta / np.sqrt(self.G[feat_index2, fild_index1]) * g_featj_fildi

参考资料

  • FFM原文
  • https://programmersought.com/article/15904469481/
  • 美团深入FFM原理与实践

推荐阅读

秋雨淅淅l:前深度学习时代--因子分解机模型FM的因与果。​zhuanlan.zhihu.com

python 主语_前深度学习时代--FFM模型的原理与Python实现相关推荐

  1. 前深度学习时代CTR预估模型的演化之路:从LR到FFM\n

    本文是王喆在 AI 前线 开设的原创技术专栏"深度学习 CTR 预估模型实践"的第二篇文章(以下"深度学习 CTR 预估模型实践"简称"深度 CTR ...

  2. 前深度学习时代CTR预估模型的演化之路 [王喆观点]

    毕业于清华大学计算机系的王喆学长梳理从传统机器学习时代到深度学习时代所有经典CTR(click through rate)模型的演化关系和模型特点.内容来源:https://zhuanlan.zhih ...

  3. MATLAB算法实战应用案例精讲-【深度学习】扩散模型(DM)(附python代码实现)

    目录 前言 广播模型 扩散模型 几个高频面试题目 GAN.VAE和基于流的生成模型之间的区别

  4. 深度学习时代的计算机视觉

    在上世纪50年代,数学家图灵提出判断机器是否具有人工智能的标准:图灵测试.图灵测试是指测试者在与被测试者(一个人和一台机器)隔开的情况下,通过一些装置(如键盘)向被测试者随意提问.进行多次测试后,如果 ...

  5. 漫谈深度学习时代点击率预估技术进展

    漫谈深度学习时代点击率预估技术进展(2019-1) 本文来源:[镶嵌在互联网技术上的明珠] (https://zhuanlan.zhihu.com/p/54822778) 下文是阅读后的一些笔记 在D ...

  6. 深度学习CTR预估模型凭什么成为互联网增长的关键?

    本文是王喆在InfoQ开设的原创技术专栏"深度学习CTR预估模型实践"的第一篇文章(以下"深度学习CTR预估模型实践"简称"深度CTR模型" ...

  7. 深度学习 图像分类_深度学习时代您应该阅读的10篇文章了解图像分类

    深度学习 图像分类 前言 (Foreword) Computer vision is a subject to convert images and videos into machine-under ...

  8. 深度学习之对象检测_深度学习时代您应该阅读的12篇文章,以了解对象检测

    深度学习之对象检测 前言 (Foreword) As the second article in the "Papers You Should Read" series, we a ...

  9. 使用lucce分词怎么_深度学习时代,分词真的有必要吗

    前言 中文数据集是我一直尽量避免的问题,但生活所迫,毕竟咱还是要在国内混江湖的,于是,最近开始研究研究深度学习模型在中文数据集上的各种表现, 但随之而来的一个问题是: 我真的需要分词吗? 香侬科技在 ...

最新文章

  1. RabbitMQ 高频考点
  2. HTTPS 工作原理和 TCP 握手机制
  3. linux中网卡的流量怎么通过c语言获取_用Python获取计算机网卡信息
  4. 微信小程序(mpvue)—解决视频播放bug的一种方式
  5. h5如何上传文件二进制流_Hadoop如何将TB级大文件的上传性能优化上百倍?
  6. ISAKMP主模式分析二
  7. 标准软件开发过程 文档
  8. VBScript详解(一)
  9. vue实现结算淘宝购物车效果
  10. 【CLion】新手使用之编译运行单个文件
  11. nginx 的基本概念
  12. redis搭建哨兵天坑
  13. 物理机无法ping通虚拟机,虚拟机能ping通物理机
  14. iOS-发布按钮动画(类似于闲鱼发布),弹出动画github开源
  15. 嘉信给你介绍新加坡10大特色美食
  16. 什么是百度转码?如何禁止网站百度转码?
  17. Linux系统下如果查看用户的UID和GID
  18. P4707 重返现世 扩展 MinMax 容斥+DP
  19. 【NOIP2006 普及组】T3 Jam 的计数法 题解
  20. send函数给FTP服务器发消息,send函数给FTP服务器发消息

热门文章

  1. actionscript3 事件类型
  2. php 重复区域,如何使用Mysql和PHP从重复区域单击缩略图后检索图像
  3. python解释器可以使用什么命令_python解释器用什么写的
  4. python读取ini文件编码格式_Python读取txt(.ini)文件BOM问题
  5. python基础案例教程_python基础教程 10-11例子如何执行
  6. 对齐方式有那些_字节对齐不慎引发的挂死问题
  7. access 记录集 filter find属性_《另一个伊甸》超越时空的猫时之塔阵容推荐 时之塔BOSS属性怎么打_另一个伊甸...
  8. android imap开发,企业邮箱在Android(安卓)系统手机上POP3/IMAP协议如何设置
  9. java对象的访问定位_JVM创建对象及访问定位过程详解
  10. 小学四年级下册计算机考试试题,四年级信息技术下学期测试题