文章目录

  • Item2Vec
    • 序列数据的处理
    • 模型训练
  • 随机游走的Graph Embedding算法
    • 数据准备
    • 随机游走采样过程

Item2Vec

序列数据的处理

  • Item2Vec是要处理的是类似文本句子、观影序列之类的序列数据。而在Item2Vec训练之前,还需要先为它准备好训练用的序列数据。在 MovieLens数据集中,有一张叫rating的数据表,里面包含了用户对看过电影的评分和评分的时间。观影序列自然就可以通过处理rating得到了。rating.csv文件中也包含userId、movieId、rating和timestamp。

  • 在使用观影序列编码之前,还有两个问题:

    • 一是MovieLens这个rating表只是一个评分表,不是真正的观影序列。对于用户来说,只有看过电影才能够评价它,所以,我们可以把评分序列当作是观影序列。
    • 二是我们是应该把所有电影都放到序列中,还是只放那些打分比较高的。
  • 这里建议对评分进行过滤,保留评分高的数据。因为我们希望Item2Vec能够学习到物品之间的近似性。当然是希望评分好的电影靠近一些,评分差的电影和评分好的电影不要在序列中结对出现。

  • 所以样本处理的思路就是:对于一个用户先过滤掉他评分低的电影,再把他评论过的电影按照时间戳排序,得到了一个用户的观影序列,所有用户的观影序列就组成了Item2Vec的训练样本集。

  • 处理步骤

    1. 读取ratings原始数据到Spark平台。
    2. 用where语句过滤评分低的评分记录。
    3. 用groupBy userId操作聚合每个用户的评分记录,DataFrame中每条记录是一个用户的评分序列。
    4. 定义一个自定义操作sortUdf,用它实现每个用户的评分记录按照时间戳进行排序。
    5. 把每个用户的评分记录处理成一个字符串的形式,供后续训练过程使用。
  • 代码展示

def processItemSequence(sparkSession: SparkSession): RDD[Seq[String]] ={//设定rating数据的路径并用spark载入数据val ratingsResourcesPath = this.getClass.getResource("/webroot/sampledata/ratings.csv")val ratingSamples = sparkSession.read.format("csv").option("header", "true").load(ratingsResourcesPath.getPath)//实现一个用户定义的操作函数(UDF),用于之后的排序val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {rows.map { case Row(movieId: String, timestamp: String) => (movieId, timestamp) }.sortBy { case (movieId, timestamp) => timestamp }.map { case (movieId, timestamp) => movieId }})//把原始的rating数据处理成序列数据val userSeq = ratingSamples.where(col("rating") >= 3.5)  //过滤掉评分在3.5一下的评分记录.groupBy("userId")            //按照用户id分组.agg(sortUdf(collect_list(struct("movieId", "timestamp"))) as "movieIds")     //每个用户生成一个序列并用刚才定义好的udf函数按照timestamp排序.withColumn("movieIdStr", array_join(col("movieIds"), " "))//把所有id连接成一个String,方便后续word2vec模型处理//把序列数据筛选出来,丢掉其他过程数据userSeq.select("movieIdStr").rdd.map(r => r.getAs[String]("movieIdStr").split(" ").toSeq)
  • 通过这段代码生成用户的评分序列样本中,每条样本的形式非常简单,就是电影ID组成的序列。

模型训练

  • 这里我们可以使用Spark MLlib机器学习工具包中调用的Word2Vec模型接口,来进行有效地训练。
  • 关键步骤:
    1. 第一步:创建Word2Vec模型并设定模型参数。关键参数有3个,分别是setVectorSize(设定生成的 Embedding 向量的维度)、setWindowSize(设定在序列数据上采样的滑动窗口大小)和setNumIterations(设定训练时的迭代次数)。这些超参数的具体选择就要根据实际的训练效果调整。
    2. 第二步:用模型的fit接口进行训练,完成之后,模型会返回一个包含了所有模型参数的对象。
    3. 最后一步:就是提取和保存Embedding向量,调用getVectors接口就可以提取出某个电影ID对应的Embedding向量,之后就可以把它们保存到文件或者其他数据库中,供其他模块使用了。
  • 具体代码:
def trainItem2vec(samples : RDD[Seq[String]]): Unit ={//设置模型参数val word2vec = new Word2Vec().setVectorSize(10).setWindowSize(5).setNumIterations(10)//训练模型val model = word2vec.fit(samples)//训练结束,用模型查找与item"592"最相似的20个itemval synonyms = model.findSynonyms("592", 20)for((synonym, cosineSimilarity) <- synonyms) {println(s"$synonym $cosineSimilarity")}//保存模型val embFolderPath = this.getClass.getResource("/webroot/sampledata/")val file = new File(embFolderPath.getPath + "embedding.txt")val bw = new BufferedWriter(new FileWriter(file))var id = 0//用model.getVectors获取所有Embedding向量for (movieId <- model.getVectors.keys){id+=1bw.write( movieId + ":" + model.getVectors(movieId).mkString(" ") + "\n")}bw.close()

随机游走的Graph Embedding算法

数据准备

  • Deep Walk 方法中,我们需要准备的最关键数据是物品之间的转移概率矩阵,实现代码如下:
//samples 输入的观影序列样本集
def graphEmb(samples : RDD[Seq[String]], sparkSession: SparkSession): Unit ={//通过flatMap操作把观影序列打碎成一个个影片对val pairSamples = samples.flatMap[String]( sample => {var pairSeq = Seq[String]()var previousItem:String = nullsample.foreach((element:String) => {if(previousItem != null){pairSeq = pairSeq :+ (previousItem + ":" + element)}previousItem = element})pairSeq})//统计影片对的数量val pairCount = pairSamples.countByValue()//转移概率矩阵的双层Map数据结构val transferMatrix = scala.collection.mutable.Map[String, scala.collection.mutable.Map[String, Long]]()val itemCount = scala.collection.mutable.Map[String, Long]()//求取转移概率矩阵pairCount.foreach( pair => {val pairItems = pair._1.split(":")val count = pair._2lognumber = lognumber + 1println(lognumber, pair._1)if (pairItems.length == 2){val item1 = pairItems.apply(0)val item2 = pairItems.apply(1)if(!transferMatrix.contains(pairItems.apply(0))){transferMatrix(item1) = scala.collection.mutable.Map[String, Long]()}transferMatrix(item1)(item2) = countitemCount(item1) = itemCount.getOrElse[Long](item1, 0) + count}
  • 生成转移概率矩阵的函数输入是在训练Item2Vec时处理好的观影序列数据。输出的是转移概率矩阵,由于转移概率矩阵比较稀疏,因此我没有采用比较浪费内存的二维数组的方法,而是采用了一个双层Map的结构去实现它。比如说,我们要得到物品A到物品B的转移概率,那么 transferMatrix(itemA)(itemB) 就是这一转移概率。在求取转移概率矩阵的过程中,先用flatMap操作把观影序列“打碎”成一个个影片对,再利用 countByValue操作统计这些影片对的数量,最后根据这些影片对的数量求取每两个影片之间的转移概率。在获得了物品之间的转移概率矩阵之后,就可以进行随机游走采样了。

随机游走采样过程

  • 随机游走采样的过程是利用转移概率矩阵生成新的序列样本的过程。首先,我们要根据物品出现次数的分布随机选择一个起始物品,之后就进入随机游走的过程。在每次游走时,我们根据转移概率矩阵查找到两个物品之间的转移概率,然后根据这个概率进行跳转。
  • 举个例子,当前的物品是A,从转移概率矩阵中查找到 A 可能跳转到物品B或物品C,转移概率分别是0.4和0.6,那么我们就按照这个概率来随机游走到B或C,依次进行下去,直到样本的长度达到了我们的要求。
//随机游走采样函数
//transferMatrix 转移概率矩阵
//itemCount 物品出现次数的分布
def randomWalk(transferMatrix : scala.collection.mutable.Map[String, scala.collection.mutable.Map[String, Long]], itemCount : scala.collection.mutable.Map[String, Long]): Seq[Seq[String]] ={//样本的数量val sampleCount = 20000//每个样本的长度val sampleLength = 10val samples = scala.collection.mutable.ListBuffer[Seq[String]]()//物品出现的总次数var itemTotalCount:Long = 0for ((k,v) <- itemCount) itemTotalCount += v//随机游走sampleCount次,生成sampleCount个序列样本for( w <- 1 to sampleCount) {samples.append(oneRandomWalk(transferMatrix, itemCount, itemTotalCount, sampleLength))}Seq(samples.toList : _*)
}//通过随机游走产生一个样本的过程
//transferMatrix 转移概率矩阵
//itemCount 物品出现次数的分布
//itemTotalCount 物品出现总次数
//sampleLength 每个样本的长度
def oneRandomWalk(transferMatrix : scala.collection.mutable.Map[String, scala.collection.mutable.Map[String, Long]], itemCount : scala.collection.mutable.Map[String, Long], itemTotalCount:Long, sampleLength:Int): Seq[String] ={val sample = scala.collection.mutable.ListBuffer[String]()//决定起始点val randomDouble = Random.nextDouble()var firstElement = ""var culCount:Long = 0//根据物品出现的概率,随机决定起始点breakable { for ((item, count) <- itemCount) {culCount += countif (culCount >= randomDouble * itemTotalCount){firstElement = itembreak}}}sample.append(firstElement)var curElement = firstElement//通过随机游走产生长度为sampleLength的样本breakable { for( w <- 1 until sampleLength) {if (!itemCount.contains(curElement) || !transferMatrix.contains(curElement)){break}//从curElement到下一个跳的转移概率向量val probDistribution = transferMatrix(curElement)val curCount = itemCount(curElement)val randomDouble = Random.nextDouble()var culCount:Long = 0//根据转移概率向量随机决定下一跳的物品breakable { for ((item, count) <- probDistribution) {culCount += countif (culCount >= randomDouble * curCount){curElement = itembreak}}}sample.append(curElement)}}Seq(sample.toList : _

推荐系统如何用spark训练得到Embedding向量相关推荐

  1. 推荐系统如何用TensorFlow实现经典的深度学习模型(Embedding+MLP)

    文章目录 Embedding+MLP模型的结构 最经典的模型Deep Crossing Embedding+MLP模型的实战 特征选择和模型设计 基于TensorFlow的模型实现 Embedding ...

  2. 推荐系统局部敏感哈希解决Embedding最近邻搜索问题

    文章目录 快速Embedding最近邻搜索问题 聚类.索引搜索最近邻 聚类搜索最近邻 索引搜索最近邻 局部敏感哈希及多桶策略 局部敏感哈希的基本原理 局部敏感哈希的多桶策略 局部敏感哈希代码实现 快速 ...

  3. 深度学习之循环神经网络(12)预训练的词向量

    深度学习之循环神经网络(12)预训练的词向量  在情感分类任务时,Embedding层是从零开始训练的.实际上,对于文本处理任务来说,领域知识大部分是共享的,因此我们能够利用在其它任务上训练好的词向量 ...

  4. github设置中文_【Github】100+ Chinese Word Vectors 上百种预训练中文词向量

    (给机器学习算法与Python学习加星标,提升AI技能) 该项目提供了不同表征(密集和稀疏)上下文特征(单词,ngram,字符等)和语料库训练的中文单词向量.开发者可以轻松获得具有不同属性的预先训练的 ...

  5. 比赛必备 ︱ 省力搞定三款词向量训练 + OOV词向量问题的可性方案

    本篇为资源汇总,一些NLP的比赛在抽取文本特征的时候会使用非常多的方式. 传统的有:TFIDF/LDA/LSI等 偏深度的有:word2vec/glove/fasttext等 还有一些预训练方式:el ...

  6. 如何用Spark进行数据分析

    小编和大家分享一下Spark是什么?如何用Spark进行数据分析,对大数据感兴趣的小伙伴就随着小编一起来了解一下吧. 如何用Spark进行数据分析 什么是Apache Spark? Apache Sp ...

  7. 预训练词向量中文维基百科,英文斯坦福glove预训练的词向量下载

    中文预训练词向量--基于中文维基百科语料训练 英文预训练词向量--斯坦福glove预训练的词向量 百度云分享:https://pan.baidu.com/s/1UpZeuqlNMl6XtTB5la53 ...

  8. NLP之word2vec:利用 Wikipedia Text(中文维基百科)语料+Word2vec工具来训练简体中文词向量

    NLP之word2vec:利用 Wikipedia Text(中文维基百科)语料+Word2vec工具来训练简体中文词向量 目录 输出结果 设计思路 1.Wikipedia Text语料来源 2.维基 ...

  9. pytorch Embedding模块,自动为文本加载预训练的embedding

    pytorch 提供了一个简便方法torch.nn.Embedding.from_pretrained,可以将文本与预训练的embedding对应起来: 词 embedding word1 0,2,3 ...

最新文章

  1. 二、OCR训练时,将txt文件和图片数据转为lmdb文件格式
  2. 关于百度分享——bdCustomStyle一点bug
  3. 【NLP】 理解NLP中网红特征抽取器Tranformer
  4. Luogu4022 CTSC2012 熟悉的文章 广义SAM、二分答案、单调队列
  5. 错误 1 类型“System.Web.UI.ScriptManager”同时存在于“c:\windows\assembly\GAC_MSIL\System.Web.Extensions\3.5.0.0
  6. sqlserver 建表指定主键_3-自增字段;主键约束
  7. Android中Bitmap和Drawable 总结
  8. car-like robot运动机构简析
  9. 每日算法系列【LeetCode 926】将字符串翻转到单调递增
  10. c语言用循环转换单词首字母,用c++实现将文本每个单词首字母转换为大写
  11. centos6安装wget
  12. 如何通过命令提示符进入MySQL服务器
  13. faststone capture注册码
  14. tensorflow学习笔记(二十七):leaky relu
  15. 2.5D(伪3D)站点可视化第一弹
  16. 磁盘性能--IOPS和吞吐量
  17. 精妙绝伦的脑洞acm题
  18. 牛客网QR46 字符集合
  19. 字符串逆序输出----Python
  20. php在线视频网站,GitHub - unkaer/olvideos: 简易 PHP 在线视频网站,搜索并播放资源站视频。...

热门文章

  1. https页面打不开
  2. 我的服务器我做主!搞定远程访问控制,轻松掌控全场
  3. 让phpcms支持https
  4. cnBeta 08年度精彩评论
  5. Virtual pc2007下安装ubuntu手记
  6. 使用phantomjs截图
  7. Python在子类中调用父类方法
  8. 1 安培3.2V磷酸铁锂电池充电方案
  9. 来一个腾讯开平的工具类
  10. YottaChain王东临:优塔令存储创造价值