CTR 系列文章:

  1. 广告点击率(CTR)预测经典模型 GBDT + LR 理解与实践(附数据 + 代码)
  2. CTR经典模型串讲:FM / FFM / 双线性 FFM 相关推导与理解
  3. CTR深度学习模型之 DeepFM 模型解读
  4. 【CTR模型】TensorFlow2.0 的 DeepFM 实现与实战(附代码+数据)
  5. CTR 模型之 Deep & Cross (DCN) 与 xDeepFM 解读
  6. 【CTR模型】TensorFlow2.0 的 DCN(Deep & Cross Network) 实现与实战(附代码+数据)
  7. 【CTR模型】TensorFlow2.0 的 xDeepFM 实现与实战(附代码+数据)

本篇文章讲解 DCN(Deep & Cross Network) 的 tensorflow2.0 实现,并使用 Criteo 数据集的子集加以实践。如果在看本文时有所困惑,可以看看 DCN(Deep & Cross Network) 的相关理论:CTR 模型之 Deep & Cross (DCN) 与 xDeepFM 解读。

本文使用的数据下载地址于代码获取地址在文末获取。

首先了解一下 Criteo数据集,它由有39个特征,1个label列,其中以I开头的为数值型特征,以C开头的为类别特征:

可以看到数据中有缺失值需要填充,并且类别变量需要进行类别编码(onehot 编码的任务交给模型),这部分预处理的代码不详细讲了。

为了方便后面建立模型,先将特征划分为 dense 特征与 sparse 特征两个类别:

# 数值型
dense_feats = [f for f in cols if f[0] == "I"]
# 类别型
sparse_feats = [f for f in cols if f[0] == "C"]

Deep & Cross Network

DCN网络结构如下:

构造模型输入

对于 dense 特征,按下面的代码构造输入:

# 构造每个 dense 特征的输入
dense_inputs = []
for f in dense_feats:_input = Input([1], name=f)dense_inputs.append(_input)
# 将输入拼接到一起
concat_dense_inputs = Concatenate(axis=1)(dense_inputs)  # ?, 13

对于 sparse 特征,按下面的代码构造输入:

# 这里单独对每一个 sparse 特征构造输入,
# 目的是方便后面构造二阶组合特征
sparse_inputs = []
for f in sparse_feats:_input = Input([1], name=f)sparse_inputs.append(_input)# embedding size
k = 8
# 对sparse特征进行embedding
sparse_kd_embed = []
for _input in sparse_inputs:f = _input.name.split(':')[0]voc_size = total_data[f].nunique()_embed = Flatten()(Embedding(voc_size, k, embeddings_regularizer=tf.keras.regularizers.l2(0.7))(_input))sparse_kd_embed.append(_embed)# 将sparse特征拼接在一起
concat_sparse_inputs = Concatenate(axis=1)(sparse_kd_embed)

然后将所有的 dense 输入和 sparse 输入拼接到一起:

embed_inputs = Concatenate(axis=1)([concat_sparse_inputs, concat_dense_inputs])

Cross Network

终于来到最核心的 Cross 部分,其中第 l+1l + 1l+1 层的计算过程为:

xl+1=x0xlTwl+bl+xlx_{l+1} =x_0x^T_lw_l + b_l + x_l xl+1=x0xlTwl+bl+xl
此公式的实现代码如下:

def cross_layer(x0, xl):"""实现一层cross layer@param x0: 特征embeddings@param xl: 前一层的输出结果"""# 1.获取xl层的embedding sizeembed_dim = xl.shape[-1]# 2.初始化当前层的W和bw = tf.Variable(tf.random.truncated_normal(shape=(embed_dim,), stddev=0.01))b = tf.Variable(tf.zeros(shape=(embed_dim,)))# 3.计算feature crossing# 下面的reshape操作相当于将列向量转换为行向量x1_T = tf.reshape(xl, [-1, 1, embed_dim])# 行向量与列向量的乘积结果是一个标量x_lw = tf.tensordot(x1_T, w, axes=1)cross = x0 * x_lw return cross + b + xl

这个代码的执行结果是第 l+1l+1l+1 层的输出,共分为三个步骤:

  1. 获取前一层的输出结果的嵌入维度
  2. 初始化本层的网络参数
  3. 进行特征交叉

关键在于第三步,计算 x0xlTwlx_0x^T_lw_lx0xlTwl 时,如果先计算 x0xTx_0x^Tx0xT则会得到一个矩阵,为了优化内存的使用,可以先计算 xlTwlx^T_lw_lxlTwl 得到标量,然后再与 x0x_0x0 相乘得到 feature crossing。

接下来可以利用循环构建多层 crossing layer:

def build_cross_layer(x0, num_layer=3):"""构建多层cross layer@param x0: 所有特征的embeddings@param num_layers: cross net的层数"""# 初始化xl为x0xl = x0# 构建多层cross netfor i in range(num_layer):xl = cross_layer(x0, xl)return xl# cross net
cross_layer_output = build_cross_layer(embed_inputs, 3)

DNN 部分

这部分好理解,直接上代码吧:

fc_layer = Dropout(0.5)(Dense(128, activation='relu')(embed_inputs))
fc_layer = Dropout(0.3)(Dense(128, activation='relu')(fc_layer))
fc_layer_output = Dropout(0.1)(Dense(128, activation='relu')(fc_layer))

输出部分

代码如下:

stack_layer = Concatenate()([cross_layer_output, fc_layer_output])
output_layer = Dense(1, activation='sigmoid', use_bias=True)(stack_layer)

完善模型

model = Model(dense_inputs+sparse_inputs, output_layer)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["binary_crossentropy", tf.keras.metrics.AUC(name='auc')])

训练模型

train_data = total_data.loc[:500000-1]
valid_data = total_data.loc[500000:]train_dense_x = [train_data[f].values for f in dense_feats]
train_sparse_x = [train_data[f].values for f in sparse_feats]
train_label = [train_data['label'].values]val_dense_x = [valid_data[f].values for f in dense_feats]
val_sparse_x = [valid_data[f].values for f in sparse_feats]
val_label = [valid_data['label'].values]model.fit(train_dense_x+train_sparse_x, train_label, epochs=5, batch_size=128,validation_data=(val_dense_x+val_sparse_x, val_label),)

最后,本文的代码链接在:https://github.com/zxxwin/tf2_DCN 。

数据下载地址为:链接:https://pan.baidu.com/s/1Qy3yemu1LYVtj0Wn47myHQ 提取码:pv7u

参考文章:

CTR预估模型:DeepFM/Deep&Cross/xDeepFM/AutoInt代码实战与讲解

NELSONZHAO/zhihu/ctr_models/DCN

【CTR模型】TensorFlow2.0 的 DCN(Deep Cross Network) 实现与实战(附代码+数据)相关推荐

  1. 【CTR模型】TensorFlow2.0 的 xDeepFM 实现与实战(附代码+数据)

    CTR 系列文章: 广告点击率(CTR)预测经典模型 GBDT + LR 理解与实践(附数据 + 代码) CTR经典模型串讲:FM / FFM / 双线性 FFM 相关推导与理解 CTR深度学习模型之 ...

  2. 【CTR模型】TensorFlow2.0 的 DeepFM 实现与实战(附代码+数据)

    CTR 系列文章: 广告点击率(CTR)预测经典模型 GBDT + LR 理解与实践(附数据 + 代码) CTR经典模型串讲:FM / FFM / 双线性 FFM 相关推导与理解 CTR深度学习模型之 ...

  3. tensorflow2.0实现Deep Cross Network(DCN)

    文章目录 预处理 模型的构建 part1-模型的输入 part2-Cross network part3-DNN部分 part4-输出部分 模型的训练 本文基于tensorflow2.0实现 Deep ...

  4. tensorflow 保存训练loss_tensorflow2.0保存和加载模型 (tensorflow2.0官方教程翻译)

    最新版本:https://www.mashangxue123.com/tensorflow/tf2-tutorials-keras-save_and_restore_models.html 英文版本: ...

  5. ctr 平滑_CTR平滑的原理,包懂!!!附代码

    为什么需要平滑? 某个物品CTR(click-Through-Rate)定义为"物品被点击的概率".CTR是某个物品在其他条件保持不变下自身的属性.但是概率我们不好确定,能确定的是 ...

  6. CTR 模型之 Deep Cross (DCN) 与 xDeepFM 解读

    CTR 系列文章: 广告点击率(CTR)预测经典模型 GBDT + LR 理解与实践(附数据 + 代码) CTR经典模型串讲:FM / FFM / 双线性 FFM 相关推导与理解 CTR深度学习模型之 ...

  7. 谷歌、阿里们的杀手锏:三大领域,十大深度学习CTR模型演化图谱

    作者 | 王喆 来源 | 转载自知乎专栏王喆的机器学习笔记 今天我们一起回顾一下近3年来的所有主流深度学习CTR模型,也是我工作之余的知识总结,希望能帮大家梳理推荐系统.计算广告领域在深度学习方面的前 ...

  8. 深度学习CTR模型粗略记录

    深度学习CTR模型粗略记录 RoadMap FM:Factorization Machines DNN:Embedding+MLP WND:Wide & Deep Learning for R ...

  9. 谷歌、阿里们的杀手锏:3大领域,10大深度学习CTR模型演化图谱(附论文)

    来源:知乎 作者:王喆 本文约4000字,建议阅读8分钟. 本文为你介绍近3年来的所有主流深度学习CTR模型. 今天我们一起回顾一下近3年来的所有主流深度学习CTR模型,也是我工作之余的知识总结,希望 ...

最新文章

  1. mysql 写入400_MySQL5.7运行CPU达百分之400处理方案
  2. 一份值得收藏的,互联网电商购物车架构演变案例
  3. 【C 语言】数组与指针操作 ( 数组符号 [] 与 指针 * 符号 的 联系 与 区别 | 数组符号 [] 与 指针 * 符号 使用效果 基本等价 | 数组首地址 与 指针 本质区别 )
  4. 2021-02-27 永磁同步电机 自抗扰控制 PI调节器 矢量控制 SVPWM
  5. 为什么数组是从0开始的
  6. 自动化测试在CI CD管道中的作用
  7. (并查集) Wireless Network --POJ --2236
  8. Java多线程系列--“JUC原子类”
  9. .NET重要技术思考
  10. 大牛直播SDK(android/iOS部分)最新功能列表
  11. 线段覆盖 java,南邮 OJ 1407 线段覆盖
  12. 基于Springboot外卖系统13:实现文件上传下载模块
  13. 语音合成和语音识别资料查询说明
  14. 【CF869E】The Untended Antiquity(哈希+二维树状数组)
  15. 新基建大热,关服务器什么事?
  16. “不务正业”斗地主?AI青年查道琛想做“被人看到”的研究
  17. html页面其中有添加员工的,编写一个添加员工信息的HTML页面,当用户点击添加按钮,请求AddEmpServlet,实现将用户提交的员工基本信息返回给客户端显示出来。...
  18. python移动文件到另一个文件夹若有同名文件更改文件名_Python 创建、复制、移动、删除和重命名文件和文件夹...
  19. 数学分析:一元实函数论----一元实函数基础
  20. flutter 中使用 WebView加载H5页面异常net:ERR_CLEARTEXT_NOT_PERMITTED

热门文章

  1. DataFrame挑选其中两列,带列名
  2. 消除ubuntu16.04自带的alt快捷键
  3. html5做一个皮卡丘,用css实现一个皮卡丘
  4. ie8 object param没有效果_如何用php实现分页效果
  5. RedHat虚拟机安装VMware Tools
  6. JavaScript基础总结(五)——Math对象
  7. TCP/IP之(四)Delay ack 和 Nagle算法
  8. listview滚动到底部
  9. Activity的常用方法和生命周期
  10. [Winform]一个简单的账户管理工具