前言

最近看了2018年阿里在KDD上发表的论文《Deep Interest Network for Click-Through Rate Prediction》,想复现下,看了文章给出的github开源代码,发现环境是TF1.4的,并且注释太少,有些没大理解【还是太菜了】,因此准备参考原有代码使用TF2.0来对模型进行简单的复现。如果有些地方有些出入或者错误,请大佬们给我指出,感谢【因为现在没服务器,所以没像开源中跑完50个epoch】

数据分析

1、数据集为论文中的Amazon Dataset,下载并解压:

wget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Electronics_5.json.gzgzip -d reviews_Electronics_5.json.gzwget -c http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Electronics.json.gzgzip -d meta_Electronics.json.gz

其中reviews_Electronics_5.json为用户的行为数据,meta_Electronics为广告的元数据。

reviews某单个样本如下:

{"reviewerID": "A2SUAM1J3GNN3B","asin": "0000013714","reviewerName": "J. McDonald","helpful": [2, 3],"reviewText": "I bought this for my husband who plays the piano.  He is having a wonderful time playing these old hymns.  The music  is at times hard to read because we think the book was published for singing from more than playing from.  Great purchase though!","overall": 5.0,"summary": "Heavenly Highway Hymns","unixReviewTime": 1252800000,"reviewTime": "09 13, 2009"
}

各字段分别为:

  • reviewerID:用户ID;
  • asin: 物品ID;
  • reviewerName:用户姓名;
  • helpful :评论帮助程度,例如上述为 2/3
  • reviewText :文本信息;
  • overall :物品评分;
  • summary:评论总结
  • unixReviewTime :时间戳
  • reviewTime :时间

meta某样本如下:

{  "asin": "0000031852",  "title": "Girls Ballet Tutu Zebra Hot Pink",  "price": 3.17,  "imUrl": "http://ecx.images-amazon.com/images/I/51fAmVkTbyL._SY300_.jpg",  "related":  {    "also_bought": ["B00JHONN1S", "B002BZX8Z6", "B00D2K1M3O", "0000031909", "B00613WDTQ", "B00D0WDS9A", "B00D0GCI8S", "0000031895", "B003AVKOP2", "B003AVEU6G", "B003IEDM9Q", "B002R0FA24", "B00D23MC6W", "B00D2K0PA0", "B00538F5OK", "B00CEV86I6", "B002R0FABA", "B00D10CLVW", "B003AVNY6I", "B002GZGI4E", "B001T9NUFS", "B002R0F7FE", "B00E1YRI4C", "B008UBQZKU", "B00D103F8U", "B007R2RM8W"],    "also_viewed": ["B002BZX8Z6", "B00JHONN1S", "B008F0SU0Y", "B00D23MC6W", "B00AFDOPDA", "B00E1YRI4C", "B002GZGI4E", "B003AVKOP2", "B00D9C1WBM", "B00CEV8366", "B00CEUX0D8", "B0079ME3KU", "B00CEUWY8K", "B004FOEEHC", "0000031895", "B00BC4GY9Y", "B003XRKA7A", "B00K18LKX2", "B00EM7KAG6", "B00AMQ17JA", "B00D9C32NI", "B002C3Y6WG", "B00JLL4L5Y", "B003AVNY6I", "B008UBQZKU", "B00D0WDS9A", "B00613WDTQ", "B00538F5OK", "B005C4Y4F6", "B004LHZ1NY", "B00CPHX76U", "B00CEUWUZC", "B00IJVASUE", "B00GOR07RE", "B00J2GTM0W", "B00JHNSNSM", "B003IEDM9Q", "B00CYBU84G", "B008VV8NSQ", "B00CYBULSO", "B00I2UHSZA", "B005F50FXC", "B007LCQI3S", "B00DP68AVW", "B009RXWNSI", "B003AVEU6G", "B00HSOJB9M", "B00EHAGZNA", "B0046W9T8C", "B00E79VW6Q", "B00D10CLVW", "B00B0AVO54", "B00E95LC8Q", "B00GOR92SO", "B007ZN5Y56", "B00AL2569W", "B00B608000", "B008F0SMUC", "B00BFXLZ8M"],    "bought_together": ["B002BZX8Z6"]  },  "salesRank": {"Toys & Games": 211836},  "brand": "Coxlures",  "categories": [["Sports & Outdoors", "Other Sports", "Dance"]]}

各字段分别为:

  • asin :物品ID;
  • title :物品名称;
  • price :物品价格;
  • imUrl :物品图片的URL;
  • related :相关产品(也买,也看,一起买,看后再买);
  • salesRank: 销售排名信息;
  • brand :品牌名称;
  • categories :该物品属于的种类列表;

2、首先将原生数据存储的json格式转化为pickle数据流格式,方便读取:

def to_df(file_path):    """    转化为DataFrame结构    :param file_path: 文件路径    :return:    """    with open(file_path, 'r') as fin:        df = {}        i = 0        for line in fin:            df[i] = eval(line)            i += 1        df = pd.DataFrame.from_dict(df, orient='index')        return df

reviews_df = to_df('../raw_data/reviews_Electronics_5.json')

# 可以直接调用pandas的read_json方法,但会改变列的顺序# reviews2_df = pd.read_json('../raw_data/reviews_Electronics_5.json', lines=True)

# 序列化保存with open('../raw_data/reviews.pkl', 'wb') as f:    pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)

meta_df = to_df('../raw_data/meta_Electronics.json')# 只保留review_df出现过的广告meta_df = meta_df[meta_df['asin'].isin(reviews_df['asin'].unique())]meta_df = meta_df.reset_index(drop=True)

with open('../raw_data/meta.pkl', 'wb') as f:    pickle.dump(meta_df, f, pickle.HIGHEST_PROTOCOL)

3、对reaviewsmeta数据进行处理:

  • reviews选取 'reviewerID', 'asin', 'unixReviewTime'列,并将用户ID、物品ID【通过meta】映射为数值;
  • meta选取 'asin', 'categories'列,物品种类只选取列表最后一个,并将物品ID、种类ID进行映射;
  • 统计用户人数 user_count、物品总数 item_count,总样本数 sample_count
  • 保存reviews数据、物品种类列表、各个数值数据以及映射字典;
def build_map(df, col_name):    """    制作一个映射,键为列名,值为序列数字    :param df: reviews_df / meta_df    :param col_name: 列名    :return: 字典,键    """    key = sorted(df[col_name].unique().tolist())    m = dict(zip(key, range(len(key))))    df[col_name] = df[col_name].map(lambda x: m[x])    return m, key

# reviewsreviews_df = pd.read_pickle('../raw_data/reviews.pkl')reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]

# metameta_df = pd.read_pickle('../raw_data/meta.pkl')meta_df = meta_df[['asin', 'categories']]# 类别只保留最后一个meta_df['categories'] = meta_df['categories'].map(lambda x: x[-1][-1])

# meta_df文件的物品ID映射asin_map, asin_key = build_map(meta_df, 'asin')# meta_df文件物品种类映射cate_map, cate_key = build_map(meta_df, 'categories')# reviews_df文件的用户ID映射revi_map, revi_key = build_map(reviews_df, 'reviewerID')

# user_count: 192403 item_count: 63001 cate_count: 801 example_count: 1689188user_count, item_count, cate_count, example_count = \    len(revi_map), len(asin_map), len(cate_map), reviews_df.shape[0]# print('user_count: %d\titem_count: %d\tcate_count: %d\texample_count: %d' %#       (user_count, item_count, cate_count, example_count))

# 按物品id排序,并重置索引meta_df = meta_df.sort_values('asin')meta_df = meta_df.reset_index(drop=True)

# reviews_df文件物品id进行映射,并按照用户id、浏览时间进行排序,重置索引reviews_df['asin'] = reviews_df['asin'].map(lambda x: asin_map[x])reviews_df = reviews_df.sort_values(['reviewerID', 'unixReviewTime'])reviews_df = reviews_df.reset_index(drop=True)reviews_df = reviews_df[['reviewerID', 'asin', 'unixReviewTime']]

# 各个物品对应的类别cate_list = np.array(meta_df['categories'], dtype='int32')

# 保存所需数据为pkl文件with open('../raw_data/remap.pkl', 'wb') as f:    pickle.dump(reviews_df, f, pickle.HIGHEST_PROTOCOL)    pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)    pickle.dump((user_count, item_count, cate_count, example_count),                f, pickle.HIGHEST_PROTOCOL)    pickle.dump((asin_key, cate_key, revi_key), f, pickle.HIGHEST_PROTOCOL)

4、构建数据集

with open('raw_data/remap.pkl', 'rb') as f:    reviews_df = pickle.load(f)    cate_list = pickle.load(f)    user_count, item_count, cate_count, example_count = pickle.load(f)

train_set, test_set = [], []

# 最大的序列长度max_sl = 0

"""生成训练集、测试集,每个用户所有浏览的物品(共n个)前n-1个为训练集(正样本),并生成相应的负样本,每个用户共有n-2个训练集(第1个无浏览历史),第n个作为测试集。"""for reviewerID, hist in reviews_df.groupby('reviewerID'):    # 每个用户浏览过的物品,即为正样本    pos_list = hist['asin'].tolist()    max_sl = max(max_sl, len(pos_list))

    # 生成负样本    def gen_neg():        neg = pos_list[0]        while neg in pos_list:            neg = random.randint(0, item_count - 1)        return neg

    # 正负样本比例1:1    neg_list = [gen_neg() for i in range(len(pos_list))]

    for i in range(1, len(pos_list)):        # 生成每一次的历史记录,即之前的浏览历史        hist = pos_list[:i]        sl = len(hist)        if i != len(pos_list) - 1:            # 保存正负样本,格式:用户ID,正/负物品id,浏览历史,浏览历史长度,标签(1/0)            train_set.append((reviewerID, pos_list[i], hist, sl, 1))            train_set.append((reviewerID, neg_list[i], hist, sl, 0))        else:            # 最后一次保存为测试集            test_set.append((reviewerID, pos_list[i], hist, sl, 1))            test_set.append((reviewerID, neg_list[i], hist, sl, 0))

# 打乱顺序random.shuffle(train_set)random.shuffle(test_set)

assert len(test_set) == user_count

# 写入dataset.pkl文件with open('dataset/dataset.pkl', 'wb') as f:    pickle.dump(train_set, f, pickle.HIGHEST_PROTOCOL)    pickle.dump(test_set, f, pickle.HIGHEST_PROTOCOL)    pickle.dump(cate_list, f, pickle.HIGHEST_PROTOCOL)    pickle.dump((user_count, item_count, cate_count, max_sl), f, pickle.HIGHEST_PROTOCOL)

模型构建

1、定义模型所需的各种层

class DIN(tf.keras.Model):    def __init__(self, user_num, item_num, cate_num, cate_list, hidden_units):        """        :param user_num: 用户数量        :param item_num: 物品数量        :param cate_num: 物品种类数量        :param cate_list: 物品种类列表        :param hidden_units: 隐藏层单元        """        super(DIN, self).__init__()        self.cate_list = tf.convert_to_tensor(cate_list, dtype=tf.int32)        self.hidden_units = hidden_units        # self.user_embed = tf.keras.layers.Embedding(        #     input_dim=user_num, output_dim=hidden_units, embeddings_initializer='random_uniform',        #     embeddings_regularizer=tf.keras.regularizers.l2(0.01), name='user_embed')        self.item_embed = tf.keras.layers.Embedding(            input_dim=item_num, output_dim=self.hidden_units, embeddings_initializer='random_uniform',            embeddings_regularizer=tf.keras.regularizers.l2(0.01), name='item_embed')        self.cate_embed = tf.keras.layers.Embedding(            input_dim=cate_num, output_dim=self.hidden_units, embeddings_initializer='random_uniform',            embeddings_regularizer=tf.keras.regularizers.l2(0.01), name='cate_embed'        )        self.dense = tf.keras.layers.Dense(self.hidden_units)        self.bn1 = tf.keras.layers.BatchNormalization()        self.concat = tf.keras.layers.Concatenate(axis=-1)        self.att_dense1 = tf.keras.layers.Dense(80, activation='sigmoid')        self.att_dense2 = tf.keras.layers.Dense(40, activation='sigmoid')        self.att_dense3 = tf.keras.layers.Dense(1)        self.bn2 = tf.keras.layers.BatchNormalization()        self.concat2 = tf.keras.layers.Concatenate(axis=-1)        self.dense1 = tf.keras.layers.Dense(80, activation='sigmoid')        self.activation1 = tf.keras.layers.PReLU()        # self.activation1 = Dice()        self.dense2 = tf.keras.layers.Dense(40, activation='sigmoid')        self.activation2 = tf.keras.layers.PReLU()        # self.activation2 = Dice()        self.dense3 = tf.keras.layers.Dense(1, activation=None)

2、根据模型图,首先是对User BehaviorsCandidate Ad的embedding进行构建。在该数据集中,需要联合Goods ID和Cate ID。【因为User的gender、age信息不存在,并不需要进行User自身属性的embedding】

    def call(self, inputs):      # user为用户ID,item为物品id,hist为之前的历史记录,即物品id列表,sl为最大列表长度        user, item, hist, sl = inputs[0], tf.squeeze(inputs[1], axis=1), inputs[2], tf.squeeze(inputs[3], axis=1)        # user_embed = self.u_embed(user)        item_embed = self.concat_embed(item)        hist_embed = self.concat_embed(hist)        ......

    def concat_embed(self, item):        """        拼接物品embedding和物品种类embedding        :param item: 物品id        :return: 拼接后的embedding        """        # cate = tf.transpose(tf.gather_nd(self.cate_list, [item]))        cate = tf.gather(self.cate_list, item)        cate = tf.squeeze(cate, axis=1) if cate.shape[-1] == 1 else cate        item_embed = self.item_embed(item)        item_cate_embed = self.cate_embed(cate)        embed = self.concat([item_embed, item_cate_embed])        return embed

3、根据模型,再根据候选广告的内容对用户行为中的物品embedding做attention机制,即与候选广告相似的物品embedding赋予更大的权重。

      def call(self, inputs):      ......        # 经过attention的物品embedding        hist_att_embed = self.attention(item_embed, hist_embed, sl)        hist_att_embed = self.bn1(hist_att_embed)        hist_att_embed = tf.reshape(hist_att_embed, [-1, self.hidden_units * 2])        u_embed = self.dense(hist_att_embed)          ......

    def attention(self, queries, keys, keys_length):        """        activation unit        :param queries: 候选广告(物品)embedding        :param keys: 用户行为(历史记录)embedding        :param keys_length: 用户行为embedding中的有效长度        :return:        """        # 候选物品的隐藏向量维度,hidden_unit * 2        queries_hidden_units = queries.shape[-1]        # 每个历史记录的物品embed都需要与候选物品的embed拼接,故候选物品embed重复keys.shape[1]次        # keys.shape[1]为最大的序列长度,即431,为了方便矩阵计算        # [None, 431 * hidden_unit * 2]        queries = tf.tile(queries, [1, keys.shape[1]])        # 重塑候选物品embed的shape        # [None, 431, hidden_unit * 2]        queries = tf.reshape(queries, [-1, keys.shape[1], queries_hidden_units])        # 拼接候选物品embed与hist物品embed        # [None, 431, hidden * 2 * 4]        embed = tf.concat([queries, keys, queries - keys, queries * keys], axis=-1)        # 全连接, 得到权重W        d_layer_1 = self.att_dense1(embed)        d_layer_2 = self.att_dense2(d_layer_1)        # [None, 431, 1]        d_layer_3 = self.att_dense3(d_layer_2)        # 重塑输出权重类型, 每个hist物品embed有对应权重值        # [None, 1, 431]        outputs = tf.reshape(d_layer_3, [-1, 1, keys.shape[1]])

        # Mask        # 此处将为历史记录的物品embed令为True        # [None, 431]        key_masks = tf.sequence_mask(keys_length, keys.shape[1])        # 增添维度        # [None, 1, 431]        key_masks = tf.expand_dims(key_masks, 1)        # 填充矩阵        paddings = tf.ones_like(outputs) * (-2 ** 32 + 1)        # 构造输出矩阵,其实就是为了实现【sum pooling】。True即为原outputs的值,False为上述填充值,为很小的值,softmax后接近0        # [None, 1, 431] ----> 每个历史浏览物品的权重        outputs = tf.where(key_masks, outputs, paddings)        # Scale,keys.shape[-1]为hist_embed的隐藏单元数        outputs = outputs / (keys.shape[-1] ** 0.5)        # Activation,归一化        outputs = tf.nn.softmax(outputs)        # 对hist_embed进行加权        # [None, 1, 431] * [None, 431, hidden_unit * 2] = [None, 1, hidden_unit * 2]        outputs = tf.matmul(outputs, keys)        return outputs

4、对候选广告embedding、经过sum pooling的历史记录embedding进行拼接:

    def call(self, inputs):        ......        item_embed = tf.reshape(item_embed, [-1, item_embed.shape[-1]])        # 联合用户行为embedding、候选物品embedding、【用户属性、上下文内容特征】        embed = self.concat2([u_embed, item_embed])

5、进行MLP过程

    def call(self, inputs):        ......        x = self.bn2(embed)        x = self.dense1(x)        x = self.activation1(x)        x = self.dense2(x)        x = self.activation2(x)        x = self.dense3(x)        outputs = tf.nn.sigmoid(x)        return outputs

输入处理

这里我们对历史记录进行了处理【参考论文的开源代码】,因为每个用户的序列长度是不同的,在情感识别等NLP领域,输入RNN等模型时需要将句子进行截断或添加。这里作者也进行了类似的处理,不过在【开源代码】中,作者是取每个batch_size中的所有用户中最长的历史记录长度作为矩阵的列数,但这里我们是取所有用户的最长(max_sl),对长度不够的在最后进行添0处理【这样增加了内存消耗,但我不知道如何在TF2.0中如何处理】

def input_data(dataset, max_sl):    user = np.array(dataset[:, 0], dtype='int32')    item = np.array(dataset[:, 1], dtype='int32')    hist = dataset[:, 2]    hist_matrix = tf.keras.preprocessing.sequence.pad_sequences(hist, maxlen=max_sl, padding='post')

    sl = np.array(dataset[:, 3], dtype='int32')    y = np.array(dataset[:, 4], dtype='float32')

    return user, item, hist_matrix, sl, y

训练

然后就是正常的进行模型编译、训练。

Github

上传了自己的github:https://github.com/BlackSpaceGZY/Recommended-System

并还实现了NCF的TF2.0实现。【大佬给个star吧】

微信公众号

潜心的Python小屋

2018阿里广告点击率预估模型---DIN,Tensorflow2.0代码实践,并附上github相关推荐

  1. AutoInt 广告点击率预估模型

    AutoInt模型 代码请参考:autoint 内容 模型简介 数据准备 运行环境 快速开始 模型组网 效果复现 进阶使用 FAQ 模型简介 CTR(Click Through Rate),即点击率, ...

  2. 点击率预估模型汇总_CIKM20MiNet:阿里|跨域点击率预估混合兴趣模型

    " 本文介绍了阿里提出的一种利用跨域信息的CTR预估模型,基于UC头条的应用场景,将新闻feed流作为源域,广告作为目标域.跨域点击率预估的最大优势在于通过使用跨域数据,目标域中的数据稀疏和 ...

  3. 业余草推荐阿里妈妈自研广告点击率预估核心算法MLR

    业余草推荐阿里妈妈自研广告点击率预估核心算法MLR. 小编觉得CTR(广告点击率)预估的能力对于广告系统的意义和重要性,类似于在证券市场上预测股价的能力,优秀的CTR预测,通向美好和财富...(以下转 ...

  4. 效果广告点击率预估实践:在线学习

    效果广告点击率预估实践:在线学习 原创 2016-03-24 腾讯大数据 腾讯大数据 1.引言 技术钻研如逆水行舟,不进则退.公司的广告业务发展非常迅猛,有目共睹,激烈的外部竞争和客户越来越高的期望, ...

  5. 京东最新点击率预估模型论文学习和分享

    最近看了京东算法团队最新发表的一篇点击率预估模型的paper Telepath: Understanding Users from a Human Vision Perspective in Larg ...

  6. kaggle案例:广告点击率预估+LR

    一.业务背景 传统广告与在线广告区别? 传统广告: 类似电视广告,报纸媒体.杂志.广播.户外媒体等. 在线广告: 类似百度搜索广告,facebook页面展示广告. 区别:在线广告更多与用户相关,例,在 ...

  7. 对“科大讯飞2021丨广告点击率预估挑战赛 Top1方案(附完整代码)_Jack_Yang的博客-CSDN博客”的补充。

    这篇文章的初衷是针对科大讯飞2021丨广告点击率预估挑战赛 Top1方案(附完整代码)_Jack_Yang的博客-CSDN博客进行补充. 博客的信息量很少,对任务背景的介绍也不太对,说实话令人费解.我 ...

  8. 计算广告(一)【Ad Click Prediction: a View from the Trenches】工程实践视角下的广告点击率预估

    计算广告(一)Ad Click Prediction: a View from the Trenches --工程实践视角下的广告点击率预估 这是谷歌发表于KDD2013的一篇文章,从年份来看,已经有 ...

  9. 大规模推荐引擎和广告点击率预估引擎中的TopK计算

    转自:https://chuansongme.com/n/2035198 推荐引擎的研究结果成千上万,绝大部分工作都来自于矩阵分解或者类似,在针对用户和Item分别训练出特征向量之后,根据向量内积计算 ...

最新文章

  1. 1022 Digital Library
  2. mysql 优先队列_深入浅出 MySQL 优先队列(你一定会踩到的order by limit 问题)
  3. Hibernate学习5—Hibernate操作对象
  4. 我的一次被骗去培训班狗血的经历
  5. spring boot使用外置tomcat部署需增加如下初始化类
  6. boost::polygon模块voronoi相关的测试程序
  7. 第二周作业-影评、靶机和攻击机的安装与配置、kali的配置、DNS解析
  8. 升级glibc的影响_Java 11 升级:“债务”“危机”
  9. 剪映专业版PC端清理缓存与日志
  10. php算法结构,PHP中常用算法以及数据结构
  11. 极速稳定网络加速服务器,零点云 极速服务器
  12. Substance Painter TDR issue TDR问题
  13. linux中剪刀石头布的程序,简单模拟剪刀石头布js
  14. java中图片排版_Java实现第八届蓝桥杯图形排版
  15. 2分钟定制自己的专属桌面——win10仿Mac os风格美化!
  16. python之信用卡ATM(第五天)
  17. JS 删除对象(Object)中的键值对
  18. Go版GTK:环境搭建(windows)
  19. ElasticSearch分布式搜索引擎从入门到实战应用(入门篇-基本命令操作)
  20. Power Pivot按列排序

热门文章

  1. 凤凰涅磐 --- Phoenix 2 发布预览
  2. 微积分 | 常用等价无穷小的整理 | 清晰
  3. 调用雅虎API获取全球天气
  4. 带宽消耗以及Swap(上)
  5. nodejs request模块用法
  6. 上海经久生物任命田丰先生为首席执行官;歌礼在欧洲多个国家递交利托那韦上市许可申请 | 医药健闻...
  7. 小米盒子4拆解,看看盒子里面到底有啥
  8. android 自定义 snackbar,Android Snackbar 使用
  9. 深度增强学习(DRL)简单梳理
  10. 软件库Library之静态库\动态库和静态加载\动态加载