CTR --- NFM论文阅读笔记,及tf2复现
文章目录
- 提出动机
- 结构
- 特征交叉池化层
- 特点
- 工程化结构
- tf2实现
提出动机
结构
特征交叉池化层
特点
工程化结构
tf2实现
# coding:utf-8
# @Email: wangguisen@infinities.com.cn
# @Time: 2022/1/4 11:12 上午
# @File: ctr_NFM.py
'''
NFM
'''
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from sklearn.model_selection import train_test_split
import yamlfrom tools import *class NFM(Model):def __init__(self, spare_feature_columns, dense_feature_columns, hidden_units, output_dim, activation, droup_out, w_reg):super(NFM, self).__init__()self.spare_feature_columns = spare_feature_columnsself.dense_feature_columns = dense_feature_columnsself.spare_shape = len(spare_feature_columns)self.w_reg = w_reg# embeddingself.embedding_layer = {'embed_layer{}'.format(i): layers.Embedding(feat['vocabulary_size'], feat['embed_dim'])for i, feat in enumerate(self.spare_feature_columns)}# dnnself.DNN = tf.keras.Sequential()for hidden in hidden_units:self.DNN.add(layers.Dense(hidden))self.DNN.add(layers.BatchNormalization())self.DNN.add(layers.Activation(activation))self.DNN.add(layers.Dropout(droup_out))self.DNN.add(layers.Dense(output_dim, activation=None)) # output_dim=1def build(self, input_shape):self.b = self.add_weight(name='b', shape=(1,), initializer=tf.zeros_initializer(), trainable=True, )self.w = self.add_weight(name='w', shape=(self.spare_shape, 1), initializer=tf.random_normal_initializer(), trainable=True, regularizer=tf.keras.regularizers.l2(self.w_reg))def call(self, inputs, training=None, mask=None):# dense_inputs: 数值特征,13维# sparse_inputs: 类别特征,26维dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]# # LR part# linear_part = tf.matmul(sparse_inputs, self.w) + self.b # (batchsize, 1)# embeddingsparse_embeds = [self.embedding_layer['embed_layer{}'.format(i)](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])] # listsparse_embed = tf.convert_to_tensor(sparse_embeds) # (26, batchsize, embed_dim)sparse_embed = tf.transpose(sparse_embed, [1, 0, 2]) # (batchsize, 26, embed_dim)# Bi-Interaction Layerbi_layer = 0.5 * (tf.pow(tf.reduce_sum(sparse_embed, axis=1), 2) - tf.reduce_sum(tf.pow(sparse_embed, 2), axis=1)) # (batchsize, embed_dim)# bi + densex = tf.concat([dense_inputs, bi_layer], axis=1) # (batchsize, embed_dim + 13)# hidden layerhidden_part = self.DNN(x) # (batchsize, 1)# output = tf.nn.sigmoid(linear_part + hidden_part)output = tf.nn.sigmoid(hidden_part)return outputif __name__ == '__main__':with open('config.yaml', 'r') as f:config = yaml.Loader(f).get_data()data = pd.read_csv('../../data/criteo_sampled_data_OK.csv')data_X = data.iloc[:, 1:]data_y = data['label'].values# I1-I13:总共 13 列数值型特征# C1-C26:共有 26 列类别型特征dense_features = ['I' + str(i) for i in range(1, 14)]sparse_features = ['C' + str(i) for i in range(1, 27)]dense_feature_columns = [denseFeature(feat) for feat in dense_features]spare_feature_columns = [sparseFeature(feat, data_X[feat].nunique(), config['NFM']['embed_dim']) for feat insparse_features]tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.2, random_state=42, stratify=data_y)train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.1, random_state=42, stratify=tmp_y)model = NFM(spare_feature_columns=spare_feature_columns,dense_feature_columns=dense_feature_columns,hidden_units=config['NFM']['hidden_units'],output_dim=config['NFM']['output_dim'],activation=config['NFM']['activation'],droup_out=config['NFM']['droup_out'],w_reg=config['NFM']['w_reg'],)adam = optimizers.Adam(lr=config['train']['adam_lr'], beta_1=0.95, beta_2=0.96,decay=config['train']['adam_lr'] / config['train']['epochs'])model.compile(optimizer=adam,loss='binary_crossentropy',metrics=[metrics.AUC(), metrics.Precision(), metrics.Recall()])model.fit(train_X.values, train_y,validation_data=(val_X.values, val_y),batch_size=config['train']['batch_size'],epochs=config['train']['epochs'],verbose=1,)model.summary()
Epoch 1/3
216/216 [==============================] - 9s 40ms/step - loss: 303.4375 - auc: 0.5100 - precision: 0.2705 - recall: 0.2608 - val_loss: 6.0378 - val_auc: 0.5098 - val_precision: 0.3144 - val_recall: 0.0579
Epoch 2/3
216/216 [==============================] - 8s 38ms/step - loss: 6.3794 - auc: 0.5833 - precision: 0.3619 - recall: 0.3607 - val_loss: 1.0595 - val_auc: 0.6949 - val_precision: 0.4574 - val_recall: 0.4347
Epoch 3/3
216/216 [==============================] - 8s 39ms/step - loss: 3.1246 - auc: 0.6743 - precision: 0.4674 - recall: 0.4705 - val_loss: 2.2340 - val_auc: 0.6381 - val_precision: 0.6591 - val_recall: 0.1596
CTR --- NFM论文阅读笔记,及tf2复现相关推荐
- CTR --- DIEN论文阅读笔记,及tf2复现
文章目录 前言 DIN和DIEN的总体思路 DIN对兴趣建模的缺点 DIEN对兴趣建模的思路 结构 行为序列层(Behavior Layer) 兴趣抽取层(Interest Extractor Lay ...
- CTR --- AFM论文阅读笔记,及tf2复现
文章目录 注意力机制 提出动机 解决方案 举例理解 结构 基于注意力机制的池化层 综合上述注意力机制的计算理解 tf2实现 注意力机制 提出动机 解决方案 把注意力机制引到里面去,来学习不同交叉特征对 ...
- 2019 sample-free(样本不平衡)目标检测论文阅读笔记
点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自知乎,已获作者同意转载,请勿二次转载 (原文地址:https://zhuanlan.zhihu.com/p/100052168) 背景 < ...
- 论文阅读笔记:看完也许能进一步了解Batch Normalization
提示:阅读论文时进行相关思想.结构.优缺点,内容进行提炼和记录,论文和相关引用会标明出处. 文章目录 前言 介绍 BN之前的一些减少Covariate Shift的方法 BN算法描述 Batch No ...
- 论文阅读笔记(2):Learning a Self-Expressive Network for Subspace Clustering,SENet,用于大规模子空间聚类的自表达网络
论文阅读笔记(2):Learning a Self-Expressive Network for Subspace Clustering. SENet--用于大规模子空间聚类的自表达网络 前言 摘要 ...
- 点云配准论文阅读笔记--Comparing ICP variants on real-world data sets
目录 写在前面 点云配准系列 摘要 1引言(Introduction) 2 相关研究(Related work) 3方法( Method) 3.1输入数据的敏感性 3.2评价指标 3.3协议 4 模块 ...
- Designing an optimal contest(博弈论+机制设计) 论文阅读笔记
Designing an optimal contest 论文阅读笔记 一.基本信息 二.文章摘要 三.背景介绍 四.核心模型 五.核心结论 六.总结展望 一.基本信息 题目:设计一个最优竞赛 作者: ...
- 【论文阅读笔记】Myers的O(ND)时间复杂度的高效的diff算法
前言 之前咱们三个同学做了个Simple-SCM,我负责那个Merge模块,也就是对两个不同分支的代码进行合并.当时为了简便起见,遇到文件冲突的时候,就直接按照文件的更改日期来存储,直接把更改日期较新 ...
- 论文阅读笔记——基于CNN-GAP可解释性模型的软件源码漏洞检测方法
本论文相关内容 论文下载地址--Engineering Village 论文阅读笔记--基于CNN-GAP可解释性模型的软件源码漏洞检测方法 文章目录 本论文相关内容 前言 基于CNN-GAP可解释性 ...
最新文章
- ios - 使用@try、catch捕获异常:
- QT的QProgressDialog类的使用
- 学习笔记26_MVC前台强类型参数
- 索尼电脑娱乐(SCE)公司周一宣布
- oracle数据库初始化参数分类,oracle初始化参数设置
- Docker学习篇(一)Docker概述、安装和常用命令
- web电商、商城pc端、商城、购物车、订单、线上支付、web商城、pc商城、登录注册、人工客服、收货地址、现金券、优惠券、礼品卡、团购订单、评价晒单、消息通知、电子产品商城、手机商城、电脑商城
- 整理一些js中常见的问题
- error while loading shared libraries: libtinfo.so.5
- 编程必会的100个代码大全,建议收藏
- 怎么在ASP.NET中引用JS文件
- modbus功能码04实例_20种PLC元件编号和Modbus编号地址对应表
- MySQL over函数的用法
- 基于Arduino的密码+指纹智能锁(LCD1602显示器)
- K210频谱显示桌面摆件(Sipeed Maix Dock)
- JavaScript2谁刚开始学习应该知道4最佳实践文章(翻译)
- HTML网页内嵌入网页
- VUE 爬坑之旅 -- vue 项目中将简体转换为繁体
- 讲的真详细!花三分钟看完这篇文章你就懂了
- MATLAB 代码资料大全
热门文章
- python数据爬虫——如何爬取二级页面(三)
- [ 重 新 预 习 ] Node.js搭建服务
- Python 动态加载并下载梨视频短视频
- iOS开发-类似微信录音上滑取消功能
- 锦城学院计算机系考研,考研心得分享
- NOIP.COM账号注册以及密码找回
- ue4 后期处理景深_Unreal Engine4 后期处理特效 VOL1
- 【附源码】计算机毕业设计JAVA医院病历管理系统
- python文档相似性比较代码_Python使用gensim计算文档相似性
- 点钞机语音怎么打开_弱弱问一下验钞机怎么开声音