(慢慢补充中…)
来源:王喆老师的github https://github.com/wzhe06/SparrowRecSys

Spark处理部分

Embedding

本部分代码的功能有:

  1. 制作用于模型训练的训练数据
  2. 支持item2vec和graph2vec两种对模型进行训练的方式
package com.sparrowrecsys.offline.spark.embedding
//相关库的导入
import java.io.{BufferedWriter, File, FileWriter}import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.ml.feature.BucketedRandomProjectionLSH
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.mllib.feature.{Word2Vec, Word2VecModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.{Row, SparkSession}
import redis.clients.jedis.Jedis
import redis.clients.jedis.params.SetParamsimport scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.util.Random//图embedding使用转移概率矩阵时会用到
import scala.util.control.Breaks.{break, breakable}
object Embedding {val redisEndpoint = "localhost"val redisPort = 6379//函数目的:将原始数据集转换为可以用于item2vec训练的数据集形式def processItemSequence(sparkSession: SparkSession, rawSampleDataPath: String): RDD[Seq[String]] ={//总体逻辑:1、读取 ratings 原始数据到 Spark 平台。2、用 where 语句过滤评分低的评分记录。//3、用 groupBy userId 操作聚合每个用户的评分记录,DataFrame 中每条记录是一个用户的评分序列。//4、定义一个自定义操作 sortUdf,用它实现每个用户的评分记录按照时间戳进行排序。//5、把每个用户的评分记录处理成一个字符串的形式,供后续训练过程使用val ratingsResourcesPath = this.getClass.getResource(rawSampleDataPath)val ratingSamples = sparkSession.read.format("csv").option("header", "true").load(ratingsResourcesPath.getPath)//设定rating数据的路径并用spark载入数据//rawSampleDataPath传入的是相对路径,用this.getClass.getResource来获取其绝对路径,这样的方式更易于扩展。//注意getResource返回的是file对象,获取路径还得用getPath来获取。文件的IO参考:https://www.liaoxuefeng.com/wiki/1252599548343744/1298069154955297//文件读取的学习可参考:https://blog.csdn.net/lcj8/article/details/3502849//读取文件既可以采用spark.read.csv(path)也可以用spark.read.format('csv').load(path)。spark2.0才开始源码支持CSV,所以网上会查到好多读取方法与这里不一样的,大都是早期的读取方法。//注意:这里读取得到的数据格式为sql.dataframe,不是RDD形式//这里千千万万要注意,文件路径必须是英文,并且没有空格,否则得到的path就是乱码的!!//实现一个用户定义的操作函数(UDF),用于之后的排序。//输入是每个用户的观影序列[[电影id1,时间戳],[电影id2,时间戳]...]//输出是按照时间戳进行了排序的观影序列[电影id1,电影id2...]val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {rows.map { case Row(movieId: String, timestamp: String) => (movieId, timestamp) }.sortBy { case (_, timestamp) => timestamp }.map { case (movieId, _) => movieId }})//case偏函数用法:https://blog.csdn.net/bluishglc/article/details/50995939//疑问:为什么rows: Seq[Row]和Row(movieId: String, timestamp: String)都必须//声明类型,不声明就会报错。ratingSamples.printSchema() //打印数据结构val userSeq = ratingSamples//把原始的rating数据处理成序列数据.where(col("rating") >= 3.5)//筛除小于3.5分的电影.groupBy("userId")//按userId进行分组.agg(sortUdf(collect_list(struct("movieId", "timestamp"))) as "movieIds")//每组数据中都按照时间戳排序,并只保留moiveId,然后将每组userId对应的moiveId聚合一个列表,取名为movieIds。其中collect_list表示把每个userId对应的movieId聚合到一个列表中,是一种聚合函数,更多的可以参考:https://www.cnblogs.com/shun7man/p/13195599.html.withColumn("movieIdStr", array_join(col("movieIds"), " "))//groupby分组、agg聚合、withColumn(名称,内容)新增列//array_join是spark2.4新增内容,参考https://www.iteblog.com/archives/2459.html,将指定列的内容拼接在一起。//array_join(col(“列名”),“分隔符”)userSeq.select("userId", "movieIdStr").show(10, truncate = false)userSeq.select("movieIdStr").rdd.map(r => r.getAs[String]("movieIdStr").split(" ").toSeq)//将拼接在一起的再拆分转换为Sequence//疑惑:为什么不能直接用之前的agg结束后的列表形式的movieIds,而是要转成字符串再转成序列。回答:因为spark word2vec的输入要求是Seq[String]格式的。新疑问:不能直接将array格式转换为Seq[String]格式吗,感觉从list转为字符串然后再拆分有点绕。}def generateUserEmb(sparkSession: SparkSession, rawSampleDataPath: String, word2VecModel: Word2VecModel, embLength:Int, embOutputFilename:String, saveToRedis:Boolean, redisKeyPrefix:String): Unit ={val ratingsResourcesPath = this.getClass.getResource(rawSampleDataPath)val ratingSamples = sparkSession.read.format("csv").option("header", "true").load(ratingsResourcesPath.getPath)ratingSamples.show(10, false)val userEmbeddings = new ArrayBuffer[(String, Array[Float])]()ratingSamples.collect().groupBy(_.getAs[String]("userId")).foreach(user => {val userId = user._1var userEmb = new Array[Float](embLength)userEmb = user._2.foldRight[Array[Float]](userEmb)((row, newEmb) => {val movieId = row.getAs[String]("movieId")val movieEmb = word2VecModel.getVectors.get(movieId)if(movieEmb.isDefined){newEmb.zip(movieEmb.get).map { case (x, y) => x + y }}else{newEmb}})userEmbeddings.append((userId,userEmb))})val embFolderPath = this.getClass.getResource("/webroot/modeldata/")val file = new File(embFolderPath.getPath + embOutputFilename)val bw = new BufferedWriter(new FileWriter(file))for (userEmb <- userEmbeddings) {bw.write(userEmb._1 + ":" + userEmb._2.mkString(" ") + "\n")}bw.close()if (saveToRedis) {val redisClient = new Jedis(redisEndpoint, redisPort)val params = SetParams.setParams()//set ttl to 24hsparams.ex(60 * 60 * 24)for (userEmb <- userEmbeddings) {redisClient.set(redisKeyPrefix + ":" + userEmb._1, userEmb._2.mkString(" "), params)}redisClient.close()}}def trainItem2vec(sparkSession: SparkSession, samples : RDD[Seq[String]], embLength:Int, embOutputFilename:String, saveToRedis:Boolean, redisKeyPrefix:String): Word2VecModel = {val word2vec = new Word2Vec().setVectorSize(embLength).setWindowSize(5).setNumIterations(10)val model = word2vec.fit(samples)val synonyms = model.findSynonyms("158", 20)for ((synonym, cosineSimilarity) <- synonyms) {println(s"$synonym $cosineSimilarity")}//BufferedWriter:https://blog.csdn.net/evilcry2012/article/details/83580907//https://blog.csdn.net/u010076574/article/details/101353532//https://c.lanmit.com/redianjishu/yunjisuan/23737.htmlval embFolderPath = this.getClass.getResource("/webroot/modeldata/")val file = new File(embFolderPath.getPath + embOutputFilename)val bw = new BufferedWriter(new FileWriter(file))for (movieId <- model.getVectors.keys) {bw.write(movieId + ":" + model.getVectors(movieId).mkString(" ") + "\n")}bw.close()if (saveToRedis) {val redisClient = new Jedis(redisEndpoint, redisPort)val params = SetParams.setParams()//set ttl to 24hsparams.ex(60 * 60 * 24)for (movieId <- model.getVectors.keys) {redisClient.set(redisKeyPrefix + ":" + movieId, model.getVectors(movieId).mkString(" "), params)}redisClient.close()}embeddingLSH(sparkSession, model.getVectors)model}def oneRandomWalk(transitionMatrix : mutable.Map[String, mutable.Map[String, Double]], itemDistribution : mutable.Map[String, Double], sampleLength:Int): Seq[String] ={val sample = mutable.ListBuffer[String]()//根据物品出现的概率,随机决定起始点val randomDouble = Random.nextDouble()//0到1之间var firstItem = ""var accumulateProb:Double = 0Dbreakable { for ((item, prob) <- itemDistribution) {//没有太明白为什么要用这种形式accumulateProb += probif (accumulateProb >= randomDouble){firstItem = itembreak}}}sample.append(firstItem)var curElement = firstItembreakable { for(_ <- 1 until sampleLength) {if (!itemDistribution.contains(curElement) || !transitionMatrix.contains(curElement)){break}val probDistribution = transitionMatrix(curElement)val randomDouble = Random.nextDouble()breakable { for ((item, prob) <- probDistribution) {if (randomDouble >= prob){curElement = itembreak}}}sample.append(curElement)}}Seq(sample.toList : _*)}def randomWalk(transitionMatrix : mutable.Map[String, mutable.Map[String, Double]], itemDistribution : mutable.Map[String, Double], sampleCount:Int, sampleLength:Int): Seq[Seq[String]] ={val samples = mutable.ListBuffer[Seq[String]]()for(_ <- 1 to sampleCount) {samples.append(oneRandomWalk(transitionMatrix, itemDistribution, sampleLength))}Seq(samples.toList : _*)}//generateTransitionMatrix用于产生转移矩阵,为图emmbedding做准备def generateTransitionMatrix(samples : RDD[Seq[String]]): (mutable.Map[String,   mutable.Map[String, Double]], mutable.Map[String, Double]) ={     //mutable.Map 可变的Map//输入多个sample,每个sample的形式都为(23,52,43,34),//将每个sample拆分为pairSeq{(23,52),(52,43),(43,34)}。然后flatmap将所有的sample平铺成一维。val pairSamples = samples.flatMap[(String, String)]( sample => {     //这里的flatMap是先对集合中每个元素进行操作然后再扁平化,而不是直接扁平化。var pairSeq = Seq[(String,String)]()var previousItem:String = nullsample.foreach((element:String) => {     //逻辑很好理解,关键是得习惯foreach, =>, :+这些自己不熟悉的操作if(previousItem != null){              //if(previousItem != null)其实就代表着从第二个元素开始循环pairSeq = pairSeq :+ (previousItem, element)      } //:+表示在序列的屁股后面进行拼接,在课程代码里,王老师用了“:”来连接影片,在这里取消了。previousItem = element})pairSeq})//countByValue()返回一个map,key就是电影对,value就是个数val pairCountMap = pairSamples.countByValue()var pairTotalCount = 0Lval transitionCountMatrix = mutable.Map[String, mutable.Map[String, Long]]()val itemCountMap = mutable.Map[String, Long]()pairCountMap.foreach( pair => {val pairItems = pair._1val count = pair._2if(!transitionCountMatrix.contains(pairItems._1)){transitionCountMatrix(pairItems._1) = mutable.Map[String, Long]()}//两层map的初始化,学习这里contains的用法transitionCountMatrix(pairItems._1)(pairItems._2) = countitemCountMap(pairItems._1) = itemCountMap.getOrElse[Long](pairItems._1, 0) + count//记录所有由item1跳转到其它item的次数的和。用做之后概率计算的分母。//getOrElse(key,default)主要就是防范措施,如果有值,那就可以得到这个值,如果没有就会得到一个默认值//但是注意这里的item.get得到的不是key,而是key对应的值pairTotalCount = pairTotalCount + count})val transitionMatrix = mutable.Map[String, mutable.Map[String, Double]]()//转移概率矩阵val itemDistribution = mutable.Map[String, Double]()//各item出现的概率//transitionCountMatrix形式:{(776:(5->1,595->1)),(590:(485->1,204->1))}transitionCountMatrix foreach {case (itemAId, transitionMap) =>transitionMatrix(itemAId) = mutable.Map[String, Double]()transitionMap foreach { case (itemBId, transitionCount) => transitionMatrix(itemAId)(itemBId) = transitionCount.toDouble / itemCountMap(itemAId) }}itemCountMap foreach { case (itemId, itemCount) => itemDistribution(itemId) = itemCount.toDouble / pairTotalCount }(transitionMatrix, itemDistribution)}//LSH:局部敏感哈希算法,是一种针对海量高维数据的快速最近邻查找算法def embeddingLSH(spark:SparkSession, movieEmbMap:Map[String, Array[Float]]): Unit ={val movieEmbSeq = movieEmbMap.toSeq.map(item => (item._1, Vectors.dense(item._2.map(f => f.toDouble))))val movieEmbDF = spark.createDataFrame(movieEmbSeq).toDF("movieId", "emb")//LSH bucket modelval bucketProjectionLSH = new BucketedRandomProjectionLSH().setBucketLength(0.1).setNumHashTables(3).setInputCol("emb").setOutputCol("bucketId")val bucketModel = bucketProjectionLSH.fit(movieEmbDF)val embBucketResult = bucketModel.transform(movieEmbDF)println("movieId, emb, bucketId schema:")embBucketResult.printSchema()println("movieId, emb, bucketId data result:")embBucketResult.show(10, truncate = false)println("Approximately searching for 5 nearest neighbors of the sample embedding:")val sampleEmb = Vectors.dense(0.795,0.583,1.120,0.850,0.174,-0.839,-0.0633,0.249,0.673,-0.237)bucketModel.approxNearestNeighbors(movieEmbDF, sampleEmb, 5).show(truncate = false)}def graphEmb(samples : RDD[Seq[String]], sparkSession: SparkSession, embLength:Int, embOutputFilename:String, saveToRedis:Boolean, redisKeyPrefix:String): Word2VecModel ={val transitionMatrixAndItemDis = generateTransitionMatrix(samples)println(transitionMatrixAndItemDis._1.size)println(transitionMatrixAndItemDis._2.size)val sampleCount = 20000val sampleLength = 10val newSamples = randomWalk(transitionMatrixAndItemDis._1, transitionMatrixAndItemDis._2, sampleCount, sampleLength)val rddSamples = sparkSession.sparkContext.parallelize(newSamples)trainItem2vec(sparkSession, rddSamples, embLength, embOutputFilename, saveToRedis, redisKeyPrefix)}def main(args: Array[String]): Unit = {Logger.getLogger("org").setLevel(Level.ERROR)val conf = new SparkConf().setMaster("local").setAppName("ctrModel").set("spark.submit.deployMode", "client")val spark = SparkSession.builder.config(conf).getOrCreate()val rawSampleDataPath = "/webroot/sampledata/ratings.csv"val embLength = 10val samples = processItemSequence(spark, rawSampleDataPath)val model = trainItem2vec(spark, samples, embLength, "item2vecEmb.csv", saveToRedis = false, "i2vEmb")//graphEmb(samples, spark, embLength, "itemGraphEmb.csv", saveToRedis = true, "graphEmb")//generateUserEmb(spark, rawSampleDataPath, model, embLength, "userEmb.csv", saveToRedis = false, "uEmb")}
}

电影推荐系统Sparrow Recsys源码解读相关推荐

  1. 电影推荐系统Sparrow Recsys源码解读——FeatureEngForRecModel部分

    小广告 (欢迎大家关注我的公众号"机器学习面试基地",之后将在公众号上持续记录本人从非科班转到算法路上的学习心得.笔经面经.心得体会.未来的重点也会主要放在机器学习面试上!) -- ...

  2. Sparrow RecSys 源码阅读

    https://github.com/wzhe06/SparrowRecSys 文章目录 根据接口进行调试 RecommendationService MovieService SimilarMovi ...

  3. java毕业设计电影推荐网站mybatis+源码+调试部署+系统+数据库+lw

    java毕业设计电影推荐网站mybatis+源码+调试部署+系统+数据库+lw java毕业设计电影推荐网站mybatis+源码+调试部署+系统+数据库+lw 本源码技术栈: 项目架构:B/S架构 开 ...

  4. java毕业生设计星星电影购票网站计算机源码+系统+mysql+调试部署+lw

    java毕业生设计星星电影购票网站计算机源码+系统+mysql+调试部署+lw java毕业生设计星星电影购票网站计算机源码+系统+mysql+调试部署+lw 本源码技术栈: 项目架构:B/S架构 开 ...

  5. 基于java图书个性化推荐系统计算机毕业设计源码+系统+lw文档+mysql数据库+调试部署

    基于java图书个性化推荐系统计算机毕业设计源码+系统+lw文档+mysql数据库+调试部署 基于java图书个性化推荐系统计算机毕业设计源码+系统+lw文档+mysql数据库+调试部署 本源码技术栈 ...

  6. RTC 月度小报 5 月 | WebRTC M83、SOLO 源码解读、实时美声……

    本月亮点速览 RTC开发者社区: CSDN专访RTC编程大赛获奖者 如何高效实现PSTN/SIP接入实时音视频网络 在线教育的创新模式及AI应用实践 开源与技术科普: WebRTC M83 Relea ...

  7. 基于JAVA健康饮食推荐系统计算机毕业设计源码+数据库+lw文档+系统+部署

    基于JAVA健康饮食推荐系统计算机毕业设计源码+数据库+lw文档+系统+部署 基于JAVA健康饮食推荐系统计算机毕业设计源码+数据库+lw文档+系统+部署 本源码技术栈: 项目架构:B/S架构 开发语 ...

  8. 基于JAVA大学生专业分配推荐系统计算机毕业设计源码+系统+lw文档+部署

    基于JAVA大学生专业分配推荐系统计算机毕业设计源码+系统+lw文档+部署 基于JAVA大学生专业分配推荐系统计算机毕业设计源码+系统+lw文档+部署 本源码技术栈: 项目架构:B/S架构 开发语言: ...

  9. 基于JAVA用户行为自动化书籍推荐系统计算机毕业设计源码+数据库+lw文档+系统+部署

    基于JAVA用户行为自动化书籍推荐系统计算机毕业设计源码+数据库+lw文档+系统+部署 基于JAVA用户行为自动化书籍推荐系统计算机毕业设计源码+数据库+lw文档+系统+部署 本源码技术栈: 项目架构 ...

最新文章

  1. mysql keepalived_mysql高可用+keepalived
  2. 两种富文本编辑器-ckeditor和ueditor
  3. Rust编程语言的核心部件
  4. Git之pull后回退版本
  5. n.html id=198,YPE htmlhtml lang=enhead data-n-head-ssrtitle data-n-=true小程序获取不到unionid 微信开放社区...
  6. python打包出现乱码_python解压zip包中文乱码解决方法
  7. python concurrent queue_Python的并发并行[2] - 队列[0] - queue 模块
  8. php清除h5格式,移动端H5页面端怎样除去input输入框的默认样式
  9. eclipse中配置server
  10. 放生大海的鱼,为什么要在鱼肚子上捅一个洞?
  11. 免费「模拟面试」福利反馈连载(20180128期)
  12. 【笔记】scp如何复制文件到带空格路径的server目录
  13. 智能机械按摩椅的改进设计
  14. QT D:\搜狗输入法\SogouInput\Components\ 13:53:42: 程序异常结束。 13:53:42: T
  15. 安装历史版本nvidia显卡驱动
  16. nginx错误502,503,504分析
  17. 京东VS淘宝:待付款订单-再次支付方案对比
  18. 苹果手机不支持网易云音乐服务器,为什么我的苹果手机总是打不开网易云音乐。...
  19. Qt MDI Window开发
  20. matlab均值量化函数_Matlab量化函数quantiz解析

热门文章

  1. 少样本 N-way K-shot
  2. 达林顿驱动器ULN2003,ULN2803使用注意要点
  3. 线段树 P3797 妖梦斩木棒
  4. ABAQUS中的碳纤维增强复合材料失效演化Hashin准则及参数详解
  5. 书虫小说在线阅读网站
  6. NSIS 中的$DOCUMENTS
  7. 基于http请求web打印组件,实现浏览器、移动端、服务端无预览打印
  8. IT项目管理小题计算总结【太原理工大学】
  9. html最全知识点(超级详细)
  10. Windows11系统中文乱码,软件中文路径打不开