文章目录

  • 提出动机
  • 结构
  • 特征交叉池化层
  • 特点
  • 工程化结构
  • tf2实现

提出动机

结构

特征交叉池化层

特点

工程化结构

tf2实现

# coding:utf-8
# @Email: wangguisen@infinities.com.cn
# @Time: 2022/1/4 11:12 上午
# @File: ctr_NFM.py
'''
NFM
'''
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from sklearn.model_selection import train_test_split
import yamlfrom tools import *class NFM(Model):def __init__(self, spare_feature_columns, dense_feature_columns, hidden_units, output_dim, activation, droup_out, w_reg):super(NFM, self).__init__()self.spare_feature_columns = spare_feature_columnsself.dense_feature_columns = dense_feature_columnsself.spare_shape = len(spare_feature_columns)self.w_reg = w_reg# embeddingself.embedding_layer = {'embed_layer{}'.format(i): layers.Embedding(feat['vocabulary_size'], feat['embed_dim'])for i, feat in enumerate(self.spare_feature_columns)}# dnnself.DNN = tf.keras.Sequential()for hidden in hidden_units:self.DNN.add(layers.Dense(hidden))self.DNN.add(layers.BatchNormalization())self.DNN.add(layers.Activation(activation))self.DNN.add(layers.Dropout(droup_out))self.DNN.add(layers.Dense(output_dim, activation=None))  # output_dim=1def build(self, input_shape):self.b = self.add_weight(name='b', shape=(1,), initializer=tf.zeros_initializer(), trainable=True, )self.w = self.add_weight(name='w', shape=(self.spare_shape, 1), initializer=tf.random_normal_initializer(), trainable=True, regularizer=tf.keras.regularizers.l2(self.w_reg))def call(self, inputs, training=None, mask=None):# dense_inputs: 数值特征,13维# sparse_inputs: 类别特征,26维dense_inputs, sparse_inputs = inputs[:, :13], inputs[:, 13:]# # LR part# linear_part = tf.matmul(sparse_inputs, self.w) + self.b  # (batchsize, 1)# embeddingsparse_embeds = [self.embedding_layer['embed_layer{}'.format(i)](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]  # listsparse_embed = tf.convert_to_tensor(sparse_embeds)   # (26, batchsize, embed_dim)sparse_embed = tf.transpose(sparse_embed, [1, 0, 2])  # (batchsize, 26, embed_dim)# Bi-Interaction Layerbi_layer = 0.5 * (tf.pow(tf.reduce_sum(sparse_embed, axis=1), 2) - tf.reduce_sum(tf.pow(sparse_embed, 2), axis=1))  # (batchsize, embed_dim)# bi + densex = tf.concat([dense_inputs, bi_layer], axis=1)  # (batchsize, embed_dim + 13)# hidden layerhidden_part = self.DNN(x)   # (batchsize, 1)# output = tf.nn.sigmoid(linear_part + hidden_part)output = tf.nn.sigmoid(hidden_part)return outputif __name__ == '__main__':with open('config.yaml', 'r') as f:config = yaml.Loader(f).get_data()data = pd.read_csv('../../data/criteo_sampled_data_OK.csv')data_X = data.iloc[:, 1:]data_y = data['label'].values# I1-I13:总共 13 列数值型特征# C1-C26:共有 26 列类别型特征dense_features = ['I' + str(i) for i in range(1, 14)]sparse_features = ['C' + str(i) for i in range(1, 27)]dense_feature_columns = [denseFeature(feat) for feat in dense_features]spare_feature_columns = [sparseFeature(feat, data_X[feat].nunique(), config['NFM']['embed_dim']) for feat insparse_features]tmp_X, test_X, tmp_y, test_y = train_test_split(data_X, data_y, test_size=0.2, random_state=42, stratify=data_y)train_X, val_X, train_y, val_y = train_test_split(tmp_X, tmp_y, test_size=0.1, random_state=42, stratify=tmp_y)model = NFM(spare_feature_columns=spare_feature_columns,dense_feature_columns=dense_feature_columns,hidden_units=config['NFM']['hidden_units'],output_dim=config['NFM']['output_dim'],activation=config['NFM']['activation'],droup_out=config['NFM']['droup_out'],w_reg=config['NFM']['w_reg'],)adam = optimizers.Adam(lr=config['train']['adam_lr'], beta_1=0.95, beta_2=0.96,decay=config['train']['adam_lr'] / config['train']['epochs'])model.compile(optimizer=adam,loss='binary_crossentropy',metrics=[metrics.AUC(), metrics.Precision(), metrics.Recall()])model.fit(train_X.values, train_y,validation_data=(val_X.values, val_y),batch_size=config['train']['batch_size'],epochs=config['train']['epochs'],verbose=1,)model.summary()
Epoch 1/3
216/216 [==============================] - 9s 40ms/step - loss: 303.4375 - auc: 0.5100 - precision: 0.2705 - recall: 0.2608 - val_loss: 6.0378 - val_auc: 0.5098 - val_precision: 0.3144 - val_recall: 0.0579
Epoch 2/3
216/216 [==============================] - 8s 38ms/step - loss: 6.3794 - auc: 0.5833 - precision: 0.3619 - recall: 0.3607 - val_loss: 1.0595 - val_auc: 0.6949 - val_precision: 0.4574 - val_recall: 0.4347
Epoch 3/3
216/216 [==============================] - 8s 39ms/step - loss: 3.1246 - auc: 0.6743 - precision: 0.4674 - recall: 0.4705 - val_loss: 2.2340 - val_auc: 0.6381 - val_precision: 0.6591 - val_recall: 0.1596

CTR --- NFM论文阅读笔记,及tf2复现相关推荐

  1. CTR --- DIEN论文阅读笔记,及tf2复现

    文章目录 前言 DIN和DIEN的总体思路 DIN对兴趣建模的缺点 DIEN对兴趣建模的思路 结构 行为序列层(Behavior Layer) 兴趣抽取层(Interest Extractor Lay ...

  2. CTR --- AFM论文阅读笔记,及tf2复现

    文章目录 注意力机制 提出动机 解决方案 举例理解 结构 基于注意力机制的池化层 综合上述注意力机制的计算理解 tf2实现 注意力机制 提出动机 解决方案 把注意力机制引到里面去,来学习不同交叉特征对 ...

  3. 2019 sample-free(样本不平衡)目标检测论文阅读笔记

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自知乎,已获作者同意转载,请勿二次转载 (原文地址:https://zhuanlan.zhihu.com/p/100052168) 背景 < ...

  4. 论文阅读笔记:看完也许能进一步了解Batch Normalization

    提示:阅读论文时进行相关思想.结构.优缺点,内容进行提炼和记录,论文和相关引用会标明出处. 文章目录 前言 介绍 BN之前的一些减少Covariate Shift的方法 BN算法描述 Batch No ...

  5. 论文阅读笔记(2):Learning a Self-Expressive Network for Subspace Clustering,SENet,用于大规模子空间聚类的自表达网络

    论文阅读笔记(2):Learning a Self-Expressive Network for Subspace Clustering. SENet--用于大规模子空间聚类的自表达网络 前言 摘要 ...

  6. 点云配准论文阅读笔记--Comparing ICP variants on real-world data sets

    目录 写在前面 点云配准系列 摘要 1引言(Introduction) 2 相关研究(Related work) 3方法( Method) 3.1输入数据的敏感性 3.2评价指标 3.3协议 4 模块 ...

  7. Designing an optimal contest(博弈论+机制设计) 论文阅读笔记

    Designing an optimal contest 论文阅读笔记 一.基本信息 二.文章摘要 三.背景介绍 四.核心模型 五.核心结论 六.总结展望 一.基本信息 题目:设计一个最优竞赛 作者: ...

  8. 【论文阅读笔记】Myers的O(ND)时间复杂度的高效的diff算法

    前言 之前咱们三个同学做了个Simple-SCM,我负责那个Merge模块,也就是对两个不同分支的代码进行合并.当时为了简便起见,遇到文件冲突的时候,就直接按照文件的更改日期来存储,直接把更改日期较新 ...

  9. 论文阅读笔记——基于CNN-GAP可解释性模型的软件源码漏洞检测方法

    本论文相关内容 论文下载地址--Engineering Village 论文阅读笔记--基于CNN-GAP可解释性模型的软件源码漏洞检测方法 文章目录 本论文相关内容 前言 基于CNN-GAP可解释性 ...

最新文章

  1. ios - 使用@try、catch捕获异常:
  2. QT的QProgressDialog类的使用
  3. 学习笔记26_MVC前台强类型参数
  4. 索尼电脑娱乐(SCE)公司周一宣布
  5. oracle数据库初始化参数分类,oracle初始化参数设置
  6. Docker学习篇(一)Docker概述、安装和常用命令
  7. web电商、商城pc端、商城、购物车、订单、线上支付、web商城、pc商城、登录注册、人工客服、收货地址、现金券、优惠券、礼品卡、团购订单、评价晒单、消息通知、电子产品商城、手机商城、电脑商城
  8. 整理一些js中常见的问题
  9. error while loading shared libraries: libtinfo.so.5
  10. 编程必会的100个代码大全,建议收藏
  11. 怎么在ASP.NET中引用JS文件
  12. modbus功能码04实例_20种PLC元件编号和Modbus编号地址对应表
  13. MySQL over函数的用法
  14. 基于Arduino的密码+指纹智能锁(LCD1602显示器)
  15. K210频谱显示桌面摆件(Sipeed Maix Dock)
  16. JavaScript2谁刚开始学习应该知道4最佳实践文章(翻译)
  17. HTML网页内嵌入网页
  18. VUE 爬坑之旅 -- vue 项目中将简体转换为繁体
  19. 讲的真详细!花三分钟看完这篇文章你就懂了
  20. MATLAB 代码资料大全

热门文章

  1. python数据爬虫——如何爬取二级页面(三)
  2. [ 重 新 预 习 ] Node.js搭建服务
  3. Python 动态加载并下载梨视频短视频
  4. iOS开发-类似微信录音上滑取消功能
  5. 锦城学院计算机系考研,考研心得分享
  6. NOIP.COM账号注册以及密码找回
  7. ue4 后期处理景深_Unreal Engine4 后期处理特效 VOL1
  8. 【附源码】计算机毕业设计JAVA医院病历管理系统
  9. python文档相似性比较代码_Python使用gensim计算文档相似性
  10. 点钞机语音怎么打开_弱弱问一下验钞机怎么开声音