在计算广告和推荐系统中,CTR预估一直是一个核心问题。无论在工业界还是学术界都是一个热点研究问题,近年来也有若干相关的算法竞赛陆续举办。本文介绍一个使用PyTorch编写的深度学习的点击率预测算法库DeepCTR-Torch,具有简洁易用、模块化和可扩展的优点,非常适合初学者快速入门学习。

(本文作者:沈伟臣,阿里巴巴算法工程师)

点击率预估问题

点击率预估问题通常形式化描述为给定用户,物料,上下文的情况下,计算用户点击物料的概率即:pCTR = p(click=1|user,item,context)

简单来说,在广告业务中使用pCTR来计算广告的预期收益,在推荐业务中通过使用pCTR来确定候选物料的一个排序列表。

DeepCTR-Torch

人们通过构造有效的组合特征和使用复杂的模型来学习数据中的模式来提升效果。基于因子分解机的方法,可以通过向量乘积的形式学习特征的交互,并且泛化到那些没有出现过的组合上。

随着深度神经网络在若干领域的巨大发展,近年来研究者也提出了若干基于深度学习的分解模型来同时学习低阶和高阶的特征交互,如:

PNN,Wide&Deep,DeepFM,Attentional FM,Neural FM,DCN,xDeepFM,AutoInt,FiBiNET

以及基于用户历史行为序列建模的DIN,DIEN,DSIN等。

对于刚接触这方面的同学来说,可能对这些方法的细节还不太了解,虽然网上有很多介绍,但是代码却没有统一的形式,且当想要迁移到自己的数据集进行实验时也很不方便。本文介绍的一个使用PyTorch实现的基于深度学习的CTR模型包DeepCTR-PyTorch,无论是使用还是学习都很方便。

DeepCTR-PyTorch是一个简洁易用、模块化可扩展的基于深度学习的CTR模型包。除了近年来主流模型外,还包括许多可用于轻松构建您自己的自定义模型的核心组件层。

您简单的通过model.fit()model.predict()来使用这些复杂的模型执行训练和预测任务,以及在通过模型初始化列表的device参数来指定运行在cpu还是gpu上。

安装与使用

  • 安装

pip install -U deepctr-torch
  • 使用例子

下面用一个简单的例子告诉大家,如何快速的应用一个基于深度学习的CTR模型,代码地址在:

https://github.com/shenweichen/DeepCTR-Torch/blob/master/examples/run_classification_criteo.py。

The Criteo Display Ads datasetkaggle上的一个CTR预估竞赛数据集。里面包含13个数值特征I1-I13和26个类别特征C1-C26

# -*- coding: utf-8 -*-
# 使用pandas 读取上面介绍的数据,并进行简单的缺失值填充
import pandas as pd
from sklearn.metrics import log_loss, roc_auc_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, MinMaxScaler
from deepctr_torch.models import *
from deepctr_torch.inputs import SparseFeat, DenseFeat, get_fixlen_feature_names
import torch# 使用pandas 读取上面介绍的数据,并进行简单的缺失值填充
data = pd.read_csv('./criteo_sample.txt')
# 上面的数据在:https://github.com/shenweichen/DeepCTR-Torch/blob/master/examples/criteo_sample.txtsparse_features = ['C' + str(i) for i in range(1, 27)]
dense_features = ['I' + str(i) for i in range(1, 14)]data[sparse_features] = data[sparse_features].fillna('-1', )
data[dense_features] = data[dense_features].fillna(0, )
target = ['label']#这里我们需要对特征进行一些预处理,对于类别特征,我们使用LabelEncoder重新编码(或者哈希编码),对于数值特征使用MinMaxScaler压缩到0~1之间。for feat in sparse_features:lbe = LabelEncoder()data[feat] = lbe.fit_transform(data[feat])
mms = MinMaxScaler(feature_range=(0, 1))
data[dense_features] = mms.fit_transform(data[dense_features])# 这里是比较关键的一步,因为我们需要对类别特征进行Embedding,所以需要告诉模型每一个特征组有多少个embbedding向量,我们通过pandas的nunique()方法统计。fixlen_feature_columns = [SparseFeat(feat, data[feat].nunique())for feat in sparse_features] + [DenseFeat(feat, 1,)for feat in dense_features]dnn_feature_columns = fixlen_feature_columns
linear_feature_columns = fixlen_feature_columnsfixlen_feature_names = get_fixlen_feature_names(linear_feature_columns + dnn_feature_columns)#最后,我们按照上一步生成的特征列拼接数据train, test = train_test_split(data, test_size=0.2)
train_model_input = [train[name] for name in fixlen_feature_names]
test_model_input = [test[name] for name in fixlen_feature_names]# 检查是否可以使用gpudevice = 'cpu'
use_cuda = True
if use_cuda and torch.cuda.is_available():print('cuda ready...')device = 'cuda:0'# 初始化模型,进行训练和预测model = DeepFM(linear_feature_columns=linear_feature_columns, dnn_feature_columns=dnn_feature_columns, task='binary',l2_reg_embedding=1e-5, device=device)model.compile("adagrad", "binary_crossentropy",metrics=["binary_crossentropy", "auc"],)
model.fit(train_model_input, train[target].values,batch_size=256, epochs=10, validation_split=0.2, verbose=2)pred_ans = model.predict(test_model_input, 256)
print("")
print("test LogLoss", round(log_loss(test[target].values, pred_ans), 4))
print("test AUC", round(roc_auc_score(test[target].values, pred_ans), 4))

相关资料

  • DeepCTR-Torch代码主页

https://github.com/shenweichen/DeepCTR-Torch

  • DeepCTR-Torch文档:

    https://deepctr-torch.readthedocs.io/en/latest/index.html

  • DeepCTR(tensorflow版)代码主页 :

    https://github.com/shenweichen/DeepCTR

  • DeepCTR(tensorflow版)文档:

    https://deepctr-doc.readthedocs.io/en/latest/index.html

作者简介

沈伟臣,浙江大学计算机硕士,阿里巴巴集团算法工程师

沈伟臣曾经参与了《DeepLearning.ai深度学习》笔记的编写。

github主页:

https://github.com/shenweichen

知乎专栏 浅梦的学习笔记

https://zhuanlan.zhihu.com/weichennote

邮箱 wcshen1994@163.com

本站简介↓↓↓ 

“机器学习初学者”是帮助人工智能爱好者入门的个人公众号(创始人:黄海广)

初学者入门的道路上,最需要的是“雪中送炭”,而不是“锦上添花”。

本站的知识星球(黄博的机器学习圈子)ID:92416895

目前在机器学习方向的知识星球排名第一(上图二维码)

往期精彩回顾

  • 那些年做的学术公益-你不是一个人在战斗

  • 良心推荐:机器学习入门资料汇总及学习建议

  • 黄海广博士的github镜像下载(机器学习及深度学习笔记及资源)

  • 机器学习小抄-(像背托福单词一样理解机器学习)

  • 首发:深度学习入门宝典-《python深度学习》原文代码中文注释版及电子书

  • 机器学习必备宝典-《统计学习方法》的python代码实现、电子书及课件

  • 重磅 | 完备的 AI 学习路线,最详细的资源整理!

  • 图解word2vec(原文翻译)

  • 机器学习的相关数学资料下载

备注:加入本站微信群或者qq群,请回复“加群

【原创】推荐广告入门:DeepCTR-Torch,基于深度学习的CTR预测算法库相关推荐

  1. DeepCTR-Torch,基于深度学习的CTR预测算法库

    点击率预估问题 点击率预估问题通常形式化描述为给定用户,物料,上下文的情况下,计算用户点击物料的概率即:pCTR = p(click=1|user,item,context). 简单来说,在广告业务中 ...

  2. CTR预估系列:DeepCTR 一个基于深度学习的CTR模型包

    在计算广告和推荐系统中,CTR预估一直是一个核心问题.无论在工业界还是学术界都是一个热点研究问题,近年来也有若干相关的算法竞赛.本文介绍一个基于深度学习的CTR模型包DeepCTR,具有简洁易用.模块 ...

  3. DeepCTR:易用可扩展的深度学习点击率预测算法库

    本文首发于知乎专栏:https://zhuanlan.zhihu.com/p/53231955 这个项目主要是对目前的一些基于深度学习的点击率预测算法进行了实现,并且对外提供了一致的调用接口. 关于每 ...

  4. 综述 | 基于深度学习的目标检测算法

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:计算机视觉life 导读:目标检测(Object Det ...

  5. 基于深度学习的目标检测算法综述(一)

    基于深度学习的目标检测算法综述(一) 基于深度学习的目标检测算法综述(二) 基于深度学习的目标检测算法综述(三) 本文内容原创,作者:美图云视觉技术部 检测团队,转载请注明出处 目标检测(Object ...

  6. 学习笔记之——基于深度学习的目标检测算法

    国庆假期闲来无事~又正好打算入门基于深度学习的视觉检测领域,就利用这个时间来写一份学习的博文~本博文主要是本人的学习笔记与调研报告(不涉及商业用途),博文的部分来自我团队的几位成员的调研报告(由于隐私 ...

  7. 基于深度学习的股票预测(完整版,有代码)

    基于深度学习的股票预测 数据获取 数据转换 LSTM模型搭建 训练模型 预测结果 数据获取 采用tushare的数据接口(不知道tushare的筒子们自行百度一下,简而言之其免费提供各类金融数据 , ...

  8. 基于深度学习的目标检测算法综述(二)

    转自:https://zhuanlan.zhihu.com/p/40020809 基于深度学习的目标检测算法综述(一) 基于深度学习的目标检测算法综述(二) 基于深度学习的目标检测算法综述(三) 本文 ...

  9. 病虫害模型算法_基于深度学习的目标检测算法综述

    sigai 基于深度学习的目标检测算法综述 导言 目标检测的任务是找出图像中所有感兴趣的目标(物体),确定它们的位置和大小,是机器视觉领域的核心问题之一.由于各类物体有不同的外观,形状,姿态,加上成像 ...

最新文章

  1. Py中enumerate方法【转载】
  2. python 反转列表的3种方式
  3. pinpoint全链路监控系统安装配置
  4. android Notification的使用
  5. COCI CONTEST #3 29.11.2014 KAMIONI
  6. Scrum 冲刺 第一日
  7. 【移植Linux 3.4.2内核第三步】从0制作支持新内核的文件系统
  8. Python带参数的装饰器
  9. 浏览器和驱动版本对应关系
  10. Hadoop tutorial - 3 Hello MapReduce- 2015-3-30
  11. Sublime Text 3 如何配置Python环境及安装插件?
  12. Android定位地图导航——基于百度地图,实现自定义图标绘制并点击时弹出泡泡...
  13. Ubuntu 安装netstat网络工具
  14. (转)学习打印机,了解打印命令 .
  15. Flowchart流程图示例
  16. 光交删zone注意事项
  17. PCIE——第5章——Montevina 的 MCH 和 ICH
  18. 110配线架打法图解_「干货」图文并茂教会你110语音配线架线缆打法
  19. 杭漂多年是时候画个句号呢
  20. 《 初学 》 html5 制作简单时钟

热门文章

  1. ORA-12154/ORA-12560 可以尝试的解决办法
  2. redhat_yum install
  3. mongo connections url string 的问题
  4. JSBing-js自动绑定C++
  5. Nginx + PHP CGI的fix_pathinfo安全漏洞
  6. 在JavaScript中使用json.js:访问JSON编码的某个值
  7. [UML]UML系列——状态机图statechart diagram
  8. PHP起点 - PHP常量
  9. 你想的到想不到的 javascript 应用小技巧方法
  10. BLOG地址变更--博客园启用二级域名