一、前言

模型针对W&D的wide部分进行了改进, 因为Wide部分有一个不足就是需要人工进行特征的组合筛选, 过程繁琐且需要经验, 2阶的FM模型在线性的时间复杂度中自动进行特征交互,但是这些特征交互的表现能力并不够,并且随着阶数的上升,模型复杂度会大幅度提高。于是乎,作者用一个Cross Network替换掉了Wide部分,来自动进行特征之间的交叉,并且网络的时间和空间复杂度都是线性的。 通过与Deep部分相结合,构成了深度交叉网络(Deep & Cross Network),简称DCN。

二、Deep&Cross模型

下面就来看一下DCN的结构:模型的结构非常简洁,从下往上依次为:Embedding和Stacking层、Cross网络层与Deep网络层并列、输出合并层,得到最终的预测结果。

(1)Embedding and stacking layer

这里的作用依然是把稀疏离散的类别型特征变成低维密集型。

运用起来就是在训练得到的Embedding参数矩阵中找到属于当前样本对应的Embedding向量

最后,该层需要将所有的密集型特征与通过embedding转换后的特征进行联合(Stacking)

(2) Cross Network

设计该网络的目的是增加特征之间的交互力度。 交叉网络由多个交叉层组成

交叉层的操作的二阶部分非常类似PNN提到的外积操作, 在此基础上增加了外积操作的权重向量w_{1}, 以及原输入向量x_{1}和偏置向量b_{1}。 交叉层的可视化如下:

代码实现部分对应

 # x是(None, dim)的形状, 先扩展一个维度到(None, dim, 1)x_0 = torch.unsqueeze(x, dim=2)x = x_0.clone()  #32*221*1xT = x_0.clone().permute((0, 2, 1))  # (None, 1, dim)  32*1*221for i in range(self.layer_num):x = torch.matmul(torch.bmm(x_0, xT), self.cross_weights[i]) + self.cross_bias[i] + x  # (None, dim, 1)32*221*1            bmm(32*221*1,32*1*221), W=221*1, b=221*1xT = x.clone().permute((0, 2, 1))  # (None, 1, dim)

x_{0} 是最开始的输入,一直保持不变,x_{1}^{T}是不断更新的,与权重+偏差 做内积

x_{0}并且在每一层均保留了输入向量, 因此输入和输出之间的变化不会特别明显。

(3)Deep work

全连接层原理一样。

(4)组合层

负责将两个网络的输出进行拼接, 并且通过简单的Logistics回归完成最后的预测:

最后二分类的损失函数依然是交叉熵损失:

其核心部分就是Cross Network, 这个可以进行特征的自动交叉, 避免了更多基于业务理解的人工特征组合。 该模型相比于W&D,Cross部分表达能力更强, 使得模型具备了更强的非线性学习能力。

三、Deep&Cross模型的pytorch实现

(1)DNN网络

class Dnn(nn.Module):"""Dnn part"""def __init__(self, hidden_units, dropout=0.):"""hidden_units: 列表, 每个元素表示每一层的神经单元个数, 比如[256, 128, 64], 两层网络, 第一层神经单元128, 第二层64, 第一个维度是输入维度dropout: 失活率"""super(Dnn, self).__init__()self.dnn_network = nn.ModuleList([nn.Linear(layer[0], layer[1]) for layer in list(zip(hidden_units[:-1], hidden_units[1:]))])                       #221*128*64self.dropout = nn.Dropout(p=dropout)def forward(self, x):for linear in self.dnn_network:x = linear(x)x = F.relu(x)x = self.dropout(x)return x

(2)cross_network

class CrossNetwork(nn.Module):"""Cross Network"""def __init__(self, layer_num, input_dim):super(CrossNetwork, self).__init__()self.layer_num = layer_num# 定义网络层的参数   221*3       三个self.cross_weights = nn.ParameterList([nn.Parameter(torch.rand(input_dim, 1))for i in range(self.layer_num)])self.cross_bias = nn.ParameterList([nn.Parameter(torch.rand(input_dim, 1))for i in range(self.layer_num)])def forward(self, x):# x是(None, dim)的形状, 先扩展一个维度到(None, dim, 1)x_0 = torch.unsqueeze(x, dim=2)x = x_0.clone()  #32*221*1xT = x_0.clone().permute((0, 2, 1))  # (None, 1, dim)  32*1*221for i in range(self.layer_num):x = torch.matmul(torch.bmm(x_0, xT), self.cross_weights[i]) + self.cross_bias[i] + x  # (None, dim, 1)32*221*1            bmm(32*221*1,32*1*221), W=221*1, b=221*1xT = x.clone().permute((0, 2, 1))  # (None, 1, dim)x = torch.squeeze(x)  # (None, dim) 32*221  再降维return x

(3)DCN网络

class DCN(nn.Module):def __init__(self, feature_columns, hidden_units, layer_num, dnn_dropout=0.):super(DCN, self).__init__()self.dense_feature_cols, self.sparse_feature_cols = feature_columns# embeddingself.embed_layers = nn.ModuleDict({'embed_' + str(i): nn.Embedding(num_embeddings=feat['feat_num'], embedding_dim=feat['embed_dim'])for i, feat in enumerate(self.sparse_feature_cols)})hidden_units.insert(0,len(self.dense_feature_cols) + len(self.sparse_feature_cols) * self.sparse_feature_cols[0]['embed_dim'])self.dnn_network = Dnn(hidden_units)self.cross_network = CrossNetwork(layer_num, hidden_units[0])  # layer_num是交叉网络的层数, hidden_units[0]表示输入的整体维度大小self.final_linear = nn.Linear(hidden_units[-1] + hidden_units[0], 1)def forward(self, x):dense_input, sparse_inputs = x[:, :len(self.dense_feature_cols)], x[:, len(self.dense_feature_cols):]  #32*13 32*26sparse_inputs = sparse_inputs.long()sparse_embeds = [self.embed_layers['embed_' + str(i)](sparse_inputs[:, i]) for i inrange(sparse_inputs.shape[1])]sparse_embeds = torch.cat(sparse_embeds, axis=-1)  #32*208    208=(26*8)x = torch.cat([sparse_embeds, dense_input], axis=-1) #32*221# cross Networkcross_out = self.cross_network(x)  #32*221# Deep Networkdeep_out = self.dnn_network(x)  #32*32#  Concatenatetotal_x = torch.cat([cross_out, deep_out], axis=-1)  #32*253# outoutputs = F.sigmoid(self.final_linear(total_x))return outputs

模型训练

# 模型的相关设置
def auc(y_pred, y_true):pred = y_pred.datay = y_true.datareturn roc_auc_score(y, pred)loss_func = nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.001)
metric_func = auc
metric_name = 'auc'# 脚本训练风格
epochs = 10
log_step_freq = 10dfhistory = pd.DataFrame(columns=['epoch', 'loss', metric_name, 'val_loss', 'val_' + metric_name])print('start_training.........')
nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print('========' * 8 + '%s' % nowtime)for epoch in range(1, epochs + 1):# 训练阶段model.train()loss_sum = 0.0metric_sum = 0.0step = 1for step, (features, labels) in enumerate(dl_train, 1):# 梯度清零optimizer.zero_grad()# 正向传播predictions = model(features).squeeze()loss = loss_func(predictions, labels)try:metric = metric_func(predictions, labels)except ValueError:pass# 反向传播loss.backward()optimizer.step()# 打印batch级别日志loss_sum += loss.item()metric_sum += metric.item()if step % log_step_freq == 0:print(("[step=%d] loss: %.3f, " + metric_name + ": %.3f") % (step, loss_sum / step, metric_sum / step));# 验证阶段model.eval()val_loss_sum = 0.0val_metric_sum = 0.0val_step = 1for val_step, (features, labels) in enumerate(dl_val, 1):with torch.no_grad():predictions = model(features).squeeze()val_loss = loss_func(predictions, labels)try:val_metric = metric_func(predictions, labels)except ValueError:passval_loss_sum += val_loss.item()val_metric_sum += val_metric.item()# 记录日志info = (epoch, loss_sum / step, metric_sum / step, val_loss_sum / val_step, val_metric_sum / val_step)dfhistory.loc[epoch - 1] = info# 打印日志print(("\nEPOCH=%d, loss=%.3f, " + metric_name + " = %.3f, val_loss=%.3f, " + "val_" + metric_name + " = %.3f") % info)nowtime = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")print('\n' + '==========' * 8 + '%s' % nowtime)print('Finished Training')

测试集的预测

y_pred_probs = model(torch.tensor(test_x).float())
y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))print(y_pred.data)

总结

W&D开启了组合模型的探索之后,DCN也是替换掉了wide部分, 后面又出现了几个FM的演化版本模型, 比如FNN, DeepFM和NFM, 后面也会陆续整理!!

DeepCross(DCN)模型及torch实现相关推荐

  1. PyTorch:将模型转换为torch.jit.ScriptModule

    参见上一篇 C++调用PyTorch模型 import torch import torchvision# An instance of your model. model = torchvision ...

  2. dcn和dcnv2 torch

    dcn也叫可变形卷积 c++ cu版编译参考我的另一篇博客: DCNv2 windows编译 2021ok_jacke121的专栏-CSDN博客 pytorch版,输入x,参考我另一篇博客: 可变形卷 ...

  3. 60分钟吃掉嘎嘣脆的DeepCross模型

    谷歌在CTR预估和推荐排序模型方面有3篇相对重要的文章. 第1篇是2016年的Deep&Wide,第2篇是2017年的Deep&Cross(DCN),第3篇是2020年的DCN-V2. ...

  4. 推荐系统XDeepFM模型--DeepFM和DCN升级版

    xDeepFM模型 目标: 引言: xDeepFM模型介绍: Compressed Interaction Network(CIN) xDeepFM 复杂度分析 Experiment 产出: Conc ...

  5. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

  6. torch编程-加载预训练权重-模型冻结-解耦-梯度不反传

    1)加载预训练权重 net = torchvision.models.resnet50(pretrained=False) # 构建模型 pretrained_model = torch.load(p ...

  7. Diffusion 扩散模型(DDPM)详解及torch复现

    文章目录 torch复现 第1步:正向过程=噪声调度器 Step 2: 反向传播 = U-Net Step 3: 损失函数 采样 Training 我公众号文章目录综述: https://wanggu ...

  8. Pytorch两种模型保存方式

    以字典方式保存,更容易解析和可视化 Pytorch两种模型保存方式 大黑_7e1b关注 2019.02.12 17:49:35字数 13阅读 5,907 只保存模型参数 # 保存 torch.save ...

  9. TVM部署预定义模型

    TVM部署预定义模型 本文通过深度学习框架量化的模型加载到TVM中.预量化的模型导入是在TVM中提供的量化支持之一. 本文演示如何加载和运行由PyTorch,MXNet和TFLite量化的模型.加载后 ...

  10. 在C++中加载TorchScript模型

    在C++中加载TorchScript模型 本教程已更新为可与PyTorch 1.2一起使用 顾名思义,PyTorch的主要接口是Python编程语言.尽管Python是合适于许多需要动态性和易于迭代的 ...

最新文章

  1. 中国和英国的旅行的对比
  2. 连招 横版 flash 游戏_街机游戏中的无限连究竟有多变态?有种对决叫作没开始就结束了!...
  3. 笔记本电脑处理器_高通提示低成本5G芯片更强大的笔记本电脑处理器
  4. 华为机试HJ90:合法IP
  5. css border流光效果
  6. 100个java项目_我如何在100天内建立​​100个项目
  7. 计算机毕业设计Django毕业设计论文源代码服装展示平台电商商城购物系统
  8. 根据观测时间,经纬度,求太阳高度角
  9. 正点原子IMX6ULL开发板禁用出厂QT界面
  10. STM32+DAC8830驱动程序
  11. 计算机网络技术期末论文,计算机网络技术专业论文题目 计算机网络技术论文题目怎么定...
  12. 一文集齐几大硬核Linux技术公众号,不是精品不推荐
  13. Python-Flask入门,静态文件、页面跳转、错误信息、动态网页模板
  14. Hystrix之四种触发fallback情况
  15. 《遥感原理与应用》总结——遥感图像自动识别分类
  16. Oracle Data Recovery Advisor(DRA)
  17. sony直营店可以享受到什么体验和服务?
  18. 工程师不可不知:解决EMI之传导干扰的八大对策
  19. linux网卡e1000下载,linux安装主板自带网卡e1000 步骤
  20. 第九章 DLL文件 windows程序设计 王艳平版

热门文章

  1. 智商情商哪个重要_智商情商哪个更重要 辩论赛
  2. 如何从零开始搭建公司自动化测试框架?
  3. MATLAB神经网络工具箱的使用——Neural Net Fitting app
  4. Drug Discov. Today | 药物发现中的先进机器学习技术
  5. python pil_Python PIL composite()用法及代码示例
  6. 浏览器的邮件html编辑器无效,eWebEditor 辑器按钮失效 IE8下eWebEditor编辑器无法使用的解决方法...
  7. 2017年第22届中国国际涂料、油墨及粘合剂展览会会刊(参展商名录)
  8. 企业微信私聊安全吗?
  9. python的print与sys.stdout
  10. zhu的Oracle数据库笔记