DSSM原理解读与工程实践
推荐算法实践
DSSM原理解读与工程实践
一、原理
DSSM(Deep Structured Semantic Model),由微软研究院提出,利用深度神经网络将文本表示为低维度的向量,应用于文本相似度匹配场景下的一个算法。不仅局限于文本,在其他可以计算相似性计算的场景,例如推荐系统中。
其实我们现在来说一件事就是推荐系统和搜索引擎之间的关系。他们两者之间很相似,都是根据满足用户需求,根据用户喜好给出答案,但又不是完全相同,只不过推荐系统更难,因为推荐系统需要挖掘用户潜在喜好来推荐内容和物品给用户。这是因为搜索引擎和推荐系统的关系之间相似性,所以适用于文本匹配的模型也可以应用到推荐系统中。
二、模型结构
我们还是先看网络结果,网络结果比较简单,是一个由几层全连接组成网络,我们将要搜索文本(Query)和要匹配的文本(Document)的 embedding 输入到网络,网络输出为 128 维的向量,然后通过向量之间计算余弦相似度来计算向量之间距离,可以看作每一个 D 和 Q 之间相似分数,然后在做 softmax ,网络结构如下图
其中Q代表Query信息,D表示Document信息。
三、DSSM 模型在推荐召回环节的应用
1)DSSM 模型在推荐召回环节的结构
DSSM 模型的最大特点就是 Query 和 Document 是两个独立的子网络,后来这一特色被移植到推荐算法的召回环节,即对用户端(User)和物品端(Item)分别构建独立的子网络塔式结构。该方式对工业界十分友好,两个子网络产生的 Embedding 向量可以独自获取及缓存。目前工业界流行的 DSSM 双塔网络结构如图所示(美图DSSM架构图)。
双塔模型两侧分别对(用户,上下文)和(物品)进行建模,并在最后一层计算二者的内积。
2)候选集合召回
当模型训练完成时,物品的 Embedding 是可以保存成词表的,线上应用的时候只需要查找对应的 Embedding 即可。因此线上只需要计算 用户塔 一侧的 Embedding,基于 Milvus 或 Faiss 技术索引得到用户偏好的候选集。
四、DSSM召回实战
下面使用ml-1m数据集,实践一下DSSM召回模型。该模型的实现主要参考:python软件的DeepCtr和DeepMatch模块。
- u2i召回
DSSM模型训练完成可得到用户和物品的Embedding向量,再利用向量最近邻的方法(如局部敏感哈希LSH、kd树、annoy、milvus、faiss等)可计算出与每个用户最相似(向量相似度最高)的top-m个物品。线上召回时输入用户特征给模型,模型预测得到用户向量,利用向量检索工具召回M个相似物品作为候选物品作为该路召回的结果,进入后续的排序阶段。
- I2I召回
DSSM模型训练完成后输入物品特征会生成每个物品的Embedding向量,再利用向量最近邻的方法(如局部敏感哈希LSH、kd树、annoy、milvus、faiss等)可计算出与每个物品最相似(向量相似度最高)的top-m个物品。线上召回时可根据用户最近操作(如点击)过的N个物品,分别召回k个相似物品,一共N*k个作为候选物品作为该路召回的结果,进入后续的排序阶段。
3)两种召回方式效果对比
u2i召回:
对用户行为预测为一个向量后再召回用户向量的topN个物品
i2i召回:
用户最近L个行为物品一一召回k个物品,总体再求topN个物品
用开源数据集ml-1m测试得到的结果如下:
由上述结果可知,对ml-1m数据集u2i召回方式效果要好于i2i召回方式。
完整代码如下:
import faiss
import pandas as pd
from deepctr.feature_column import SparseFeat, VarLenSparseFeat
from preprocess import gen_data_set, gen_model_input
from sklearn.preprocessing import LabelEncoder
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import Model
from deepmatch.models import *import numpy as np
from tqdm import tqdm
from deepmatch.utils import recall_N
import os
data_path = "../"unames = ['user_id','gender','age','occupation','zip']
user = pd.read_csv(data_path+'ml-1m/users.dat',sep='::',header=None,names=unames,engine='python')
rnames = ['user_id','item_id','rating','timestamp']
ratings = pd.read_csv(data_path+'ml-1m/ratings.dat',sep='::',header=None,names=rnames,engine='python')
mnames = ['item_id','title','genres']
movies = pd.read_csv(data_path+'ml-1m/movies.dat',sep='::',header=None,names=mnames,engine='python')data = pd.merge(pd.merge(ratings,movies),user)sparse_features = ["item_id", "user_id", "gender", "age", "occupation", "zip", ]
SEQ_LEN = 50
negsample = 3# 1.稀疏特征编码,生成训练和测试集features = ['user_id', 'item_id', 'gender', 'age', 'occupation', 'zip']
feature_max_idx = {}
for feature in features:lbe = LabelEncoder()data[feature] = lbe.fit_transform(data[feature]) + 1feature_max_idx[feature] = data[feature].max() + 1user_profile = data[["user_id", "gender", "age", "occupation", "zip"]].drop_duplicates('user_id')item_profile = data[["item_id"]].drop_duplicates('item_id')user_profile.set_index("user_id", inplace=True)user_item_list = data.groupby("user_id")['item_id'].apply(list)train_set, test_set = gen_data_set(data, negsample)train_model_input, train_label = gen_model_input(train_set, user_profile, SEQ_LEN)
test_model_input, test_label = gen_model_input(test_set, user_profile, SEQ_LEN)# 2.配置稀疏特征维度及emb维度embedding_dim = 32user_feature_columns = [SparseFeat('user_id', feature_max_idx['user_id'], 16),SparseFeat("gender", feature_max_idx['gender'], 16),SparseFeat("age", feature_max_idx['age'], 16),SparseFeat("occupation", feature_max_idx['occupation'], 16),SparseFeat("zip", feature_max_idx['zip'], 16),VarLenSparseFeat(SparseFeat('hist_item_id', feature_max_idx['item_id'], embedding_dim,embedding_name="item_id"), SEQ_LEN, 'mean', 'hist_len'),]item_feature_columns = [SparseFeat('item_id', feature_max_idx['item_id'], embedding_dim)]# 3.定义模型并训练K.set_learning_phase(True)import tensorflow as tf
if tf.__version__ >= '2.0.0':tf.compat.v1.disable_eager_execution()model = DSSM(user_feature_columns, item_feature_columns)
model.compile(optimizer='adagrad', loss="binary_crossentropy", metrics=['accuracy'])
history = model.fit(train_model_input, train_label, # train_label,batch_size=256, epochs=10, verbose=1, validation_split=0.2, )# 4. 生成用户emb和物品emb,用于召回
test_user_model_input = test_model_input
all_item_model_input = {"item_id": item_profile['item_id'].values,}user_embedding_model = Model(inputs=model.user_input, outputs=model.user_embedding)
item_embedding_model = Model(inputs=model.item_input, outputs=model.item_embedding)user_embs = user_embedding_model.predict(test_user_model_input, batch_size=2 ** 12)
item_embs = item_embedding_model.predict(all_item_model_input, batch_size=2 ** 12)test_user_np = test_user_model_input['user_id']
all_item_np = all_item_model_input['item_id']test_user_emb_all = np.hstack((test_user_np.reshape(-1, 1),user_embs))
all_item_all = np.hstack((all_item_np.reshape(-1, 1),item_embs))np.savetxt('user_embs.csv', test_user_emb_all, delimiter = ',')
np.savetxt('item_embs.csv', all_item_all, delimiter = ',')idex = np.lexsort([all_item_all[:, 0]])
sorted_item_embs = all_item_all[idex, :]
sorted_item_embs2 = sorted_item_embs[:,1:]
np.savetxt('sorted_item_embs.csv', sorted_item_embs, delimiter = ',')
np.savetxt('sorted_item_embs2.csv', sorted_item_embs2, delimiter = ',')print("sorted_item_emb.shape = ",sorted_item_embs.shape)
print("sorted_item_emb2.shape = ",sorted_item_embs2.shape)print(test_user_emb_all.shape)
print(all_item_all.shape)test_true_label = {line[0]:[line[2]] for line in test_set}# 5、faiss 创建索引 插入item_embs
index = faiss.IndexFlatIP(embedding_dim)
# faiss.normalize_L2(item_embs)
index.add(item_embs)# 6、根据user_emb 检索物品列表
# faiss.normalize_L2(user_embs)
D, I = index.search(np.ascontiguousarray(user_embs), 1000)
s1000 = []
s500 = []
s100 = []
s50 = []
s10 = []
hit = 0filename = 'user_emb_rec_list.txt'
if os.path.exists(filename):os.remove(filename)
with open(filename, 'a') as f:for i, uid in tqdm(enumerate(test_user_model_input['user_id'])):pred = [item_profile['item_id'].values[x] for x in I[i]]item_list = ",".join('%s' %x for x in pred)filter_item = Nonerecall_score_1000 = recall_N(test_true_label[uid], pred, N=1000)recall_score_500 = recall_N(test_true_label[uid], pred, N=500)recall_score_100 = recall_N(test_true_label[uid], pred, N=100)recall_score_50 = recall_N(test_true_label[uid], pred, N=50)recall_score_10 = recall_N(test_true_label[uid], pred, N=10)s1000.append(recall_score_1000)s500.append(recall_score_500)s100.append(recall_score_100)s50.append(recall_score_50)s10.append(recall_score_10)# if test_true_label[uid] in pred:# hit += 1f.write("{} {}\n".format(uid, item_list))print("recall1000", np.mean(s1000))
print("recall500", np.mean(s500))
print("recall100", np.mean(s100))
print("recall50", np.mean(s50))
print("recall10", np.mean(s10))
# print("hit rate", hit / len(test_user_model_input['user_id']))# 7、根据item_emb 检索物品列表生成I2I倒排索引
# faiss.normalize_L2(item_embs)
D, I = index.search(np.ascontiguousarray(item_embs), 50)
s1000 = []
s500 = []
s100 = []
s50 = []
s10 = []
hit = 0
i2i_dict = {}
filename = 'item_item_list.txt'
if os.path.exists(filename):os.remove(filename)
with open(filename, 'a') as f:for i, item_id in tqdm(enumerate(all_item_model_input['item_id'])):pred = [item_profile['item_id'].values[x] for x in I[i] ]pred2 = [x for x in pred if x != item_id]item_list = ",".join('%s' % x for x in pred2)# i2i倒排索引i2i_dict[item_id] = [x for x in pred2]f.write("{} {}\n".format(item_id, item_list))# 不改变顺序去重
def dupe(items):seen = set()for item in items:if item not in seen:yield itemseen.add(item)# 8、根据用户最近操作的50个物品检索物品列表
data.sort_values("timestamp", inplace=True, ascending=False)
filename = 'user_action_rec_list.txt'
if os.path.exists(filename):os.remove(filename)filename2 = 'user_action_rec_list2.txt'
if os.path.exists(filename2):os.remove(filename2)filename3 = 'user_action_list.txt'
if os.path.exists(filename3):os.remove(filename3)with open(filename, 'a') as f, open(filename2,'a') as f2, open(filename3,'a') as f3:for uid, hist in tqdm(data.groupby('user_id')):pred = []result = []# 截取最近50个物品pos_list = hist['item_id'].tolist()[1:51]act_list = ",".join('%s' % x for x in pos_list)# 根据这50个物品检索物品列表for item_id in pos_list:pred = pred + i2i_dict[item_id]result.append(str(item_id) + ":[" + ",".join('%s' % x for x in i2i_dict[item_id]) + "]")pred = list(dupe(pred))pred = pred[:1000]item_list = ",".join('%s' % x for x in result)item_list2 = ",".join('%s' % x for x in pred)filter_item = Nonerecall_score_1000 = recall_N(test_true_label[uid], pred, N=1000)recall_score_500 = recall_N(test_true_label[uid], pred, N=500)recall_score_100 = recall_N(test_true_label[uid], pred, N=100)recall_score_50 = recall_N(test_true_label[uid], pred, N=50)recall_score_10 = recall_N(test_true_label[uid], pred, N=10)s1000.append(recall_score_1000)s500.append(recall_score_500)s100.append(recall_score_100)s50.append(recall_score_50)s10.append(recall_score_10)# if test_true_label[uid] in pred:# hit += 1f.write("{} {}\n".format(uid, item_list))f2.write("{} {}\n".format(uid, item_list2))f3.write("{} {}\n".format(uid, act_list))print("recall1000", np.mean(s1000))
print("recall500", np.mean(s500))
print("recall100", np.mean(s100))
print("recall50", np.mean(s50))
print("recall10", np.mean(s10))
五、相关思考
1)上下文特征是放到用户塔还是物品塔?
2)新物品如何计算其Embeding?
3) 在此基础上如何进一步优化其效果?有哪些思路?
4)线上如何部署?
DSSM原理解读与工程实践相关推荐
- SDM原理解读与工程实践
SDM原理解读与工程实践 本文主要介绍的是阿里在召回阶段使用的深度召回模型SDM,paper名称为<SDM: Sequential Deep Matching Model for Online ...
- 副本放置策略Copysets论文解读及工程实践
副本放置策略Copysets论文解读及工程实践 概述 CopySet论文解读 术语定义 Random Replication Copyset Replication Premutation Repli ...
- 深入理解 ProtoBuf 原理与工程实践(概述)
ProtoBuf 作为一种跨平台.语言无关.可扩展的序列化结构数据的方法,已广泛应用于网络数据交换及存储.随着互联网的发展,系统的异构性会愈发突出,跨语言的需求会愈加明显,同时 gRPC 也大有取代R ...
- 分级加权评分算法 java_荐书|智能风控:原理、算法与工程实践
图书简介 风控领域是新兴的机器学习应用场景之一,其特点包括了负样本占比极少.业务对模型解释性要求偏高.业务模型多样.风控数据源丰富等. <智能风控:原理.算法与工程实践>一书共 8 章,包 ...
- 人脸识别技术原理与工程实践
1人脸识别应用场景(验证) 我们先来看看人脸识别的几个应用.第一个是苹果的FACE ID,自从苹果推出FaceID后,业界对人脸识别的应用好像信心大增,各种人脸识别的应用从此开始"野蛮生长& ...
- 人脸识别技术原理与工程实践(10个月人脸识别领域实战总结)
1人脸识别应用场景(验证) 我们先来看看人脸识别的几个应用.第一个是苹果的FACE ID,自从苹果推出FaceID后,业界对人脸识别的应用好像信心大增,各种人脸识别的应用从此开始"野蛮生长& ...
- 深入理解 ProtoBuf 原理与工程实践
ProtoBuf 作为一种跨平台.语言无关.可扩展的序列化结构数据的方法,已广泛应用于网络数据交换及存储.随着互联网的发展,系统的异构性会愈发突出,跨语言的需求会愈加明显,同时 gRPC 也大有取代R ...
- 关于概率分布理论的原理分析的一些讨论,以及经典概率分布的应用场景,以及概率统计其在工程实践中的应用...
1. 随机变量定义 0x1:为什么要引入随机变量这个数学概念 在早期的古典概率理论研究中,人们基于随机试验的样本空间去研究随机事件,也发展出了非常多辉煌的理论,包括著名的贝叶斯估计在内. 但是随着研究 ...
- 从入门到深入:移动平台模型裁剪与优化的技术探索与工程实践
可以看到,通过机器学习技术,软件或服务的功能和体验得到了质的提升.比如,我们甚至可以通过启发式引擎智能地预测并调节云计算分布式系统的节点压力,以此改善服务的弹性和稳定性,这是多么美妙. 而对移动平台来 ...
最新文章
- 九零后的五年七次工作经历
- 如何高效的做机器学习项目
- HTML5手机端几秒钟自动跳转
- js实现语音播报功能
- 快做这 15点,让 SpringBoot 启动更快一点!
- oracle创建包 和调用,oracle创建函数和调用存储过程和调用函数的例子(区别)...
- 64位WINDOWS 使用PL SQL DEVELOPER 连接ORACLE 出错问题解决
- 使用CLONE TABLE方式实现同region不同可用区的MaxCompute
- 面趣 | 据说这道烧脑的微软面试题很奇葩,你来试试?
- 【LeetCode】剑指 Offer 20. 表示数值的字符串
- 【开源.NET】 分享一个前后端分离的轻量级内容管理框架
- 生科实验室仪器维护保养,一篇就够了!
- 万网域名转向指定URL地址
- 用英文给领导写建议信
- Java instead of 用法_实例讲解instance of 运算符用法
- 服装erp软件如何提高企业利润
- 廊坊金彩教育:如何进行选品
- 基于OAI协议元数据收割的.NET资源
- 【第三章 有限自动机与右线性文法】形式语言与自动机第三章个人总结复习笔记分享!(含文件、持续更新...)
- 斩获数亿元B轮融资,这家Tier 1抢跑「L2/L2+」主战场
热门文章
- CentOS7安装MySQL8.0
- BP神经网络中的BP是指,bp神经网络属于什么
- Ubuntu安装Nginx和正确卸载Nginx Nginx相关
- Android 自定义ImageView实现圆角图片
- Apache中 RewriteRule说明
- 斗鱼上市,腾讯坐“快”观“虎斗”
- VUE element-ui 之table表格表头插入输入框
- 数据代码分享|PYTHON用NLP自然语言处理LSTM神经网络TWITTER推特灾难文本数据、词云可视化...
- IDC发布《企业数据智能实施部署指南》,巨杉数据库获评数字化平台代表供应商
- 信息系统项目管理论文范文(二)