目录

  • 1 介绍
  • 2 模型结构
  • 3 实验结果
  • 4 总结
  • 5 代码实践

1 介绍

DeepFM 是华为诺亚方舟实验室在 2017 年提出的模型。
论文传送门:

A Factorization-Machine based Neural Network for CTR Prediction

正如名称所示,DeepFM 是 Deep 与 FM 结合的产物,也是 Wide&Deep 的改进版,只是将其中的 LR 替换成了 FM,提升了模型 wide 侧提取信息的能力。学 DeepFM 之前建议先了解 FM 与 Wide&Deep 。

DeepFM 与Wide&Deep 只是 FM 与 LR 的区别么?并不是。

2 模型结构


2.1 Sparse Features

一般类别特征无法直接输入模型,所以需要先 onehot 处理得到的其稀疏01向量表示。该层即表示经过 onehot 编码的类别特征与数值特征的拼接。

2.2 Dense Embeddings

该层为嵌入层,用于对高维稀疏的01向量做嵌入,得到低维稠密的向量 e (每个01向量对应自己的嵌入层,不同向量的嵌入过程相互独立,如上图所示)。然后将每个稠密向量横向拼接,在拼接上原始的数值特征,然后作为 Deep 与 FM 的输入。

2.3 FM Layer


FM 有两部分,线性部分交叉部分。线性部分 (黑色线段) 是给与每个特征一个权重,然后进行加权和;交叉部分 (红色线段) 是对特征进行两两相乘,然后赋予权重加权求和。然后将两部分结果累加在一起即为 FM Layer 的输出。

2.4 Hidden Layer



Deep 部分的输入 a0 为所有稠密向量的横向拼接,然后经过多层线性映射+非线性转换得到 Hidden Layer 的输出,一般会映射到1维,因为需要与 FM 的结果进行累加。

2.5 Output Units


输出层为 FM Layer 的结果与 Hidden Layer 结果的累加,低阶与高阶特征交互的融合,然后经过 Sigmoid 非线性转换,得到预测的概率输出。

3 实验结果

DeepFM 在CTR预估任务上的表现,以及与其他推荐算法的对比如下:


作者通过实验证明了 relu 激活函数比 tanh 更适合 DeepFM (但在其他模型上效果不一)。

4 总结

与 Wide&Deep 的异同:

相同点:都是线性模型与深度模型的结合,低阶与高阶特征交互的融合。

不同点:DeepFM 两个部分共享输入,而 Wide&Deep 的 wide 侧是稀疏输入,deep 侧是稠密输入;DeepFM 无需加入人工特征,可端到端的学习,线上部署更方便,Wide&Deep 则需要在输入上加入人工特征提升模型表达能力。

DeepFM 优缺点:

优点:

1 两部分联合训练,无需加入人工特征,更易部署;

2 结构简单,复杂度低,两部分共享输入,共享信息,可更精确的训练学习。

缺点:

1 将类别特征对应的稠密向量拼接作为输入,然后对元素进行两两交叉。这样导致模型无法意识到域的概念,FM 与 Deep 两部分都不会考虑到域,属于同一个域的元素应该对应同样的计算。

面试可能会考察的问题:

FM 本来就可以在稀疏输入的场景中进行学习,为什么要跟 Deep 共享稠密输入?虽然 FM 具有线性复杂度 O(nk),其中 n 为特征数,k 为隐向量维度,可以随着输入的特征数线性增长。但是经过 onehot 处理的类别特征维度往往要比稠密向量高上一两个数量级,这样还是会给 FM 侧引入大量多于的计算,不可取。

5 代码实践

如果你搭建过 FM 或者 Wide&Deep 模型,那么对于 DeepFM 的搭建就很随意了。为了便于理解,我还是会放上每个部分全部的代码。

Layer 搭建:

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import Input, Denseclass FM_layer(Layer):def __init__(self, k, w_reg, v_reg):super().__init__()self.k = kself.w_reg = w_regself.v_reg = v_regdef build(self, input_shape):self.w0 = self.add_weight(name='w0', shape=(1,),initializer=tf.zeros_initializer(),trainable=True,)self.w = self.add_weight(name='w', shape=(input_shape[-1], 1),initializer=tf.random_normal_initializer(),trainable=True,regularizer=tf.keras.regularizers.l2(self.w_reg))self.v = self.add_weight(name='v', shape=(input_shape[-1], self.k),initializer=tf.random_normal_initializer(),trainable=True,regularizer=tf.keras.regularizers.l2(self.v_reg))def call(self, inputs, **kwargs):linear_part = tf.matmul(inputs, self.w) + self.w0   #shape:(batchsize, 1)inter_part1 = tf.pow(tf.matmul(inputs, self.v), 2)  #shape:(batchsize, self.k)inter_part2 = tf.matmul(tf.pow(inputs, 2), tf.pow(self.v, 2)) #shape:(batchsize, self.k)inter_part = 0.5*tf.reduce_sum(inter_part1 - inter_part2, axis=-1, keepdims=True) #shape:(batchsize, 1)output = linear_part + inter_partreturn outputclass Dense_layer(Layer):def __init__(self, hidden_units, output_dim, activation):super().__init__()self.hidden_units = hidden_unitsself.output_dim = output_dimself.activation = activation#全连接无需定义第一层维度self.hidden_layer = [Dense(i, activation=self.activation)for i in self.hidden_units]self.output_layer = Dense(self.output_dim, activation=None)def call(self, inputs):x = inputsfor layer in self.hidden_layer:x = layer(x)output = self.output_layer(x)return output

Model 搭建:

from layer import FM_layer, Dense_layerimport tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Embeddingclass DeepFM(Model):def __init__(self, feature_columns, k, w_reg, v_reg, hidden_units, output_dim, activation):super().__init__()self.dense_feature_columns, self.sparse_feature_columns = feature_columnsself.embed_layers = {'embed_' + str(i): Embedding(feat['feat_onehot_dim'], feat['embed_dim'])for i, feat in enumerate(self.sparse_feature_columns)}self.FM = FM_layer(k, w_reg, v_reg)self.Dense = Dense_layer(hidden_units, output_dim, activation)def call(self, inputs):dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]# embeddingsparse_embed = tf.concat([self.embed_layers['embed_{}'.format(i)](sparse_inputs[:, i])for i in range(sparse_inputs.shape[1])], axis=1)x = tf.concat([dense_inputs, sparse_embed], axis=-1)fm_output = self.FM(x)dense_output = self.Dense(x)output = tf.nn.sigmoid(0.5*(fm_output + dense_output))return output

完整训练代码可在文末仓库中查看。

写在最后

下一篇预告:推荐算法(五)——谷歌经典 DCN 原理及代码实践

完整的推荐算法复现代码可参考仓库: Recommend-System-tf2.0

希望看完此文的你能够有所收获…

推荐算法(四)——经典模型 DeepFM 模型详解及代码实践相关推荐

  1. 调包侠福音!机器学习经典算法开源教程(附参数详解及代码实现)

    Datawhale 作者:赵楠.杨开漠.谢文昕.张雨 寄语:本文针对5大机器学习经典算法,梳理了其模型.策略和求解等方面的内容,同时给出了其对应sklearn的参数详解和代码实现,帮助学习者入门和巩固 ...

  2. NLP | 自然语言处理经典seq2seq网络BERT详解及代码

    2019论文:BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding BERT:用于语言理解的 ...

  3. C语言wav格式详解,代码实践

    **[注意:此文章为笔者加强自己所学知识所写,如有疏漏请多多包涵,当然如果能帮助到其他人就更好了] 上篇提到了wav文件格式,但是只是简单了解,目的是读取wav数据写入pcm. 本篇重点详细介绍wav ...

  4. 【相机标定与三维重建原理及实现】学习笔记1——相机模型数学推导详解

    目录 前言 一.小孔成像模型 二.坐标系的变换 1.世界坐标系到相机坐标系的变换(刚体变换)[xw^→xc^\boldsymbol {\hat{x_{w}}}\rightarrow \boldsymb ...

  5. [深度学习概念]·实例分割模型Mask R-CNN详解

    实例分割模型Mask R-CNN详解 基础深度学习的目标检测技术演进解析 本文转载地址 Mask R-CNN是ICCV 2017的best paper,彰显了机器学习计算机视觉领域在2017年的最新成 ...

  6. 数学建模_随机森林分类模型详解Python代码

    数学建模_随机森林分类模型详解Python代码 随机森林需要调整的参数有: (1) 决策树的个数 (2) 特征属性的个数 (3) 递归次数(即决策树的深度)''' from numpy import ...

  7. Keras深度学习实战(1)——神经网络基础与模型训练过程详解

    Keras深度学习实战(1)--神经网络基础与模型训练过程详解 0. 前言 1. 神经网络基础 1.1 简单神经网络的架构 1.2 神经网络的训练 1.3 神经网络的应用 2. 从零开始构建前向传播 ...

  8. Diffusion Model (扩散生成模型)的基本原理详解(三)Stochastic Differential Equation(SDE)

    本篇是<Diffusion Model (扩散生成模型)的基本原理详解(二)Score-Based Generative Modeling(SGM)>的续写,继续介绍有关diffusion ...

  9. 神经网络学习小记录39——MobileNetV3(small)模型的复现详解

    神经网络学习小记录39--MobileNetV3(small)模型的复现详解 学习前言 什么是MobileNetV3 代码下载 large与small的区别 MobileNetV3(small)的网络 ...

最新文章

  1. mysql autocommit_【整理】MySQL 之 autocommit
  2. GIt版本回退还不会用?轻松学会不怕失误
  3. 制造机器人的现状和发展趋势
  4. 如何使用ArchUnit测试Java项目的体系结构
  5. 几种比较好看的滚动条样式及代码
  6. 聚类方法:DBSCAN算法研究
  7. java 标识符命名规则_java语言基础之标识符和命名规则详解
  8. Eygle力荐:Oracle 19c升级文档、视频、问答集锦
  9. YOLOv3使用自己数据集——Kmeans聚类计算anchor boxes
  10. Kaggle数据增强攻略来了!不氪金实现50种语言互译
  11. C# BackgroundWorker用法详解
  12. 涂抹mysql 完整_涂抹MYSQL-跟着三思一步一步学MySQL
  13. 阈的粤语发音_新编粤语读音字典 - 粤语 | Cantonese | 白话 - 声同小语种论坛 - Powered by phpwind...
  14. 社会网络分析工具—— Gephi 或 NetworkX的简单介绍和比较(源自GPTchat)
  15. PTA-莫尔斯码(字符串,模拟)
  16. 微服务架构设计实践之七:业务架构
  17. 怎样成为一名优秀的科学家
  18. 蓝桥杯 算法提高 聪明的美食家
  19. QAxObject来操作Excel的一些命令
  20. php-SER-libs【made by 这周末在做梦】

热门文章

  1. 直连获取串口服务器ip,能够进行串口modbusRTU和以太网modbus-TCP协议转换的串口服务器,并提供好用的MODBUS调试工具-专业自动化论坛-中国工控网论坛...
  2. CAD常见问题解答!恭喜你,又学会了几个重要的CAD操作!
  3. Python之变量作用域
  4. php代码覆盖工具(2)-phpunit-支持生成覆盖率报告
  5. 2020-11-30 PMP 群内练习题 - 光环
  6. DNA 13. SCI 文章肿瘤突变负荷计算方法(TMB)
  7. Eclipse查看文件的本地历史记录
  8. python的运行方式_详解python运行三种方式
  9. 如何自建企业服务器,如何将奥维账户添加到企业服务器(自建服务器)
  10. Python 制作微信全家福