ALS矩阵分解

一个 的打分矩阵 A 可以用两个小矩阵和的乘积来近似,描述一个人的喜好经常是在一个抽象的低维空间上进行的,并不需要把其喜欢的事物一一列出。再抽象一些,把人们的喜好和电影的特征都投到这个低维空间,一个人的喜好映射到了一个低维向量,一个电影的特征变成了纬度相同的向量,那么这个人和这个电影的相似度就可以表述成这两个向量之间的内积。

我们把打分理解成相似度,那么“打分矩阵A(m*n)”就可以由“用户喜好特征矩阵U(m*k)”和“产品特征矩阵V(n*k)”的乘积。

矩阵分解过程中所用的优化方法分为两种:交叉最小二乘法(alternative least squares)和随机梯度下降法(stochastic gradient descent)。

损失函数包括正则化项(setRegParam)。

参数选取

分块数:分块是为了并行计算,默认为10。 正则化参数:默认为1。 秩:模型中隐藏因子的个数显示偏好信息-false,隐式偏好信息-true,默认false(显示) alpha:只用于隐式的偏好数据,偏好值可信度底线。 非负限定 numBlocks is the number of blocks the users and items will be

partitioned into in order to parallelize computation (defaults to

10). rank is the number of latent factors in the model (defaults to 10). maxIter is the maximum number of iterations to run (defaults to 10). regParam specifies the regularization parameter in ALS (defaults to 1.0). implicitPrefs specifies whether to use the explicit feedback ALS variant or one adapted for implicit feedback data (defaults to false

which means using explicit feedback). alpha is a parameter applicable to the implicit feedback variant of ALS that governs the baseline confidence in preference

observations (defaults to 1.0). nonnegative specifies whether or not to use nonnegative constraints for least squares (defaults to false).

ALS als = newALS()

.setMaxIter(10)//最大迭代次数,设置太大发生java.lang.StackOverflowError

.setRegParam(0.16)//正则化参数

.setAlpha(1.0)

.setImplicitPrefs(false)

.setNonnegative(false)

.setNumBlocks(10)

.setRank(10)

.setUserCol("userId")

.setItemCol("movieId")

.setRatingCol("rating");

需要注意的问题:

对于用户和物品项ID ,基于DataFrame API 只支持integers,因此最大值限定在integers范围内。

The DataFrame-based API for ALS currently only supports integers for

user and item ids. Other numeric types are supported for the user and

item id columns, but the ids must be within the integer value range.

//循环正则化参数,每次由Evaluator给出RMSError

List RMSE=new ArrayList();//构建一个List保存所有的RMSE

for(int i=0;i<20;i++){//进行20次循环

double lambda=(i*5+1)*0.01;//RegParam按照0.05增加

ALS als = newALS()

.setMaxIter(5)//最大迭代次数

.setRegParam(lambda)//正则化参数

.setUserCol("userId")

.setItemCol("movieId")

.setRatingCol("rating");

ALSModel model=als.fit(training);//Evaluate the model by computing the RMSE on the test data

Dataset predictions =model.transform(test);//RegressionEvaluator.setMetricName可以定义四种评估器//"rmse" (default): root mean squared error//"mse": mean squared error//"r2": R^2^ metric//"mae": mean absolute error

RegressionEvaluator evaluator = newRegressionEvaluator()

.setMetricName("rmse")//RMS Error

.setLabelCol("rating")

.setPredictionCol("prediction");

Double rmse=evaluator.evaluate(predictions);

RMSE.add(rmse);

System.out.println("RegParam "+0.01*i+" RMSE " + rmse+"\n");

}//输出所有结果

for (int j = 0; j < RMSE.size(); j++) {

Double lambda=(j*5+1)*0.01;

System.out.println("RegParam= "+lambda+" RMSE= " + RMSE.get(j)+"\n");

}

通过设计一个循环,可以研究最合适的参数,部分结果如下:

RegParam= 0.01 RMSE= 1.956

RegParam= 0.06 RMSE= 1.166

RegParam= 0.11 RMSE= 0.977

RegParam= 0.16 RMSE= 0.962//具备最小的RMSE,参数最合适

RegParam= 0.21 RMSE= 0.985

RegParam= 0.26 RMSE= 1.021

RegParam= 0.31 RMSE= 1.061

RegParam= 0.36 RMSE= 1.102

RegParam= 0.41 RMSE= 1.144

RegParam= 0.51 RMSE= 1.228

RegParam= 0.56 RMSE= 1.267

RegParam= 0.61 RMSE= 1.300

//将RegParam固定在0.16,继续研究迭代次数的影响

输出如下的结果,在单机环境中,迭代次数设置过大,会出现一个java.lang.StackOverflowError异常。是由于当前线程的栈满了引起的。

numMaxIteration= 1 RMSE= 1.7325

numMaxIteration= 4 RMSE= 1.0695

numMaxIteration= 7 RMSE= 1.0563

numMaxIteration= 10 RMSE= 1.055

numMaxIteration= 13 RMSE= 1.053

numMaxIteration= 16 RMSE= 1.053

//测试Rank隐含语义个数

Rank =1 RMSErr = 1.1584

Rank =3 RMSErr = 1.1067

Rank =5 RMSErr = 0.9366

Rank =7 RMSErr = 0.9745

Rank =9 RMSErr = 0.9440

Rank =11 RMSErr = 0.9458

Rank =13 RMSErr = 0.9466

Rank =15 RMSErr = 0.9443

Rank =17 RMSErr = 0.9543

//可以用SPARK-SQL自己定义评估算法(如下面定义了一个平均绝对值误差计算过程)//Register the DataFrame as a SQL temporary view

predictions.createOrReplaceTempView("tmp_predictions");

Dataset absDiff=spark.sql("select abs(prediction-rating) as diff from tmp_predictions");

absDiff.createOrReplaceTempView("tmp_absDiff");

spark.sql("select mean(diff) as absMeanDiff from tmp_absDiff").show();

完整代码

public class Rating implements Serializable{...}

可以在 http://spark.apache.org/docs/latest/ml-collaborative-filtering.html找到:

packagemy.spark.ml.practice.classification;importorg.apache.spark.api.java.function.Function;importorg.apache.spark.ml.evaluation.RegressionEvaluator;importorg.apache.spark.ml.recommendation.ALS;importorg.apache.spark.ml.recommendation.ALSModel;importorg.apache.log4j.Level;importorg.apache.log4j.Logger;importorg.apache.spark.api.java.JavaRDD;importorg.apache.spark.sql.Dataset;importorg.apache.spark.sql.Row;importorg.apache.spark.sql.SparkSession;public classmyCollabFilter2 {public static voidmain(String[] args) {

SparkSession spark=SparkSession

.builder()

.appName("CoFilter")

.master("local[4]")

.config("spark.sql.warehouse.dir","file///:G:/Projects/Java/Spark/spark-warehouse")

.getOrCreate();

String path="G:/Projects/CgyWin64/home/pengjy3/softwate/spark-2.0.0-bin-hadoop2.6/"

+ "data/mllib/als/sample_movielens_ratings.txt";//屏蔽日志

Logger.getLogger("org.apache.spark").setLevel(Level.WARN);

Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF);//-------------------------------1.0 准备DataFrame----------------------------//..javaRDD()函数将DataFrame转换为RDD//然后对RDD进行Map 每一行String->Rating

JavaRDD ratingRDD=spark.read().textFile(path).javaRDD()

.map(newFunction() {

@Overridepublic Rating call(String str) throwsException {returnRating.parseRating(str);

}

});//System.out.println(ratingRDD.take(10).get(0).getMovieId());//由JavaRDD(每一行都是一个实例化的Rating对象)和Rating Class创建DataFrame

Dataset ratings=spark.createDataFrame(ratingRDD, Rating.class);//ratings.show(30);//将数据随机分为训练集和测试集

double[] weights=new double[] {0.8,0.2};long seed=1234;

Dataset [] split=ratings.randomSplit(weights, seed);

Dataset training=split[0];

Dataset test=split[1];//------------------------------2.0 ALS算法和训练数据集,产生推荐模型-------------

for(int rank=1;rank<20;rank++)

{//定义算法

ALS als=newALS()

.setMaxIter(5)最大迭代次数,设置太大发生java.lang.StackOverflowError

.setRegParam(0.16)

.setUserCol("userId")

.setRank(rank)

.setItemCol("movieId")

.setRatingCol("rating");//训练模型

ALSModel model=als.fit(training);//---------------------------3.0 模型评估:计算RMSE,均方根误差---------------------

Dataset predictions=model.transform(test);//predictions.show();

RegressionEvaluator evaluator=newRegressionEvaluator()

.setMetricName("rmse")

.setLabelCol("rating")

.setPredictionCol("prediction");

Double rmse=evaluator.evaluate(predictions);

System.out.println("Rank =" + rank+" RMSErr = " +rmse);

}

}

}

als算法参数_Spark2.0协同过滤与ALS算法介绍相关推荐

  1. 机器学习(十四)——协同过滤的ALS算法(2)、主成分分析

    http://antkillerfarm.github.io/ Kendall秩相关系数(Kendall rank correlation coefficient) 对于秩变量对(xi,yi),(xj ...

  2. 机器学习(十三)——机器学习中的矩阵方法(3)病态矩阵、协同过滤的ALS算法(1)

    http://antkillerfarm.github.io/ 向量的范数(续) 范数可用符号∥x∥λ\|x\|_\lambda∥x∥λ​表示.常用的有: ∥x∥1=∣x1∣+⋯+∣xn∣\|x\|_ ...

  3. 机器学习(十三)——机器学习中的矩阵方法(3)病态矩阵、协同过滤的ALS算法(1)...

    http://antkillerfarm.github.io/ 向量的范数(续) 范数可用符号∥x∥λ表示. 经常使用的有: ∥x∥1=|x1|+⋯+|xn| ∥x∥2=x21+⋯+x2n−−−−−− ...

  4. 推荐引擎算法学习导论:协同过滤、聚类、分类(2011年旧文)

    推荐引擎算法学习导论:协同过滤.聚类.分类 作者:July 出处:结构之法算法之道 引言 昨日看到几个关键词:语义分析,协同过滤,智能推荐,想着想着便兴奋了.于是昨天下午开始到今天凌晨3点,便研究了一 ...

  5. 92 推荐算法——相似性推荐和协同过滤

    1 基于相似性的推荐流程 用户偏好如何收集 用户偏好如何整合 大多数情况我们提取的用户行为都多于一种,如何组合这些不同的用户行为,基本上有以下两种方式: 不同的行为分组 一般可以分为"查看& ...

  6. 基于CF(协同过滤)推荐算法

    基于物品的CF(协同过滤)推荐算法 1.1算法简介 CF(协同过滤)简单来形容就是利用兴趣相投的原理进行推荐,协同过滤主要分两类,一类是基于物品的协同过滤算法,另一种是基于用户的协同过滤算法,这里主要 ...

  7. 推荐算法——基于用户的协同过滤算法(User-base CF)的java实现

    推荐算法--基于用户的协同过滤算法(User-base CF)的java实现 推荐系统 什么是推荐系统 为什么要有推荐系统 推荐算法 基于用户的协同过滤算法(User-base CF) 算法介绍 代码 ...

  8. 推荐算法-基于协同过滤的推荐算法

    推荐算法-基于协同过滤的推荐算法 在如今信息量呈爆炸式增长的时代,谷歌百度等搜索引擎为人们查找信息提供了便利,帮助用户快速查找有价值的信息.然而此类查询方式是大众化的,无法根据个人兴趣为用户展示相关的 ...

  9. 在线新闻推荐网 Python+Django+Mysql开发技术 基于用户、物品的协同过滤推荐算法 个性化新闻推荐系统 协同过滤推荐算法在新闻网站中的运用 个性化推荐算法、机器学习、分布式大数据、人工智

    在线新闻推荐网 Python+Django+Mysql开发技术 基于用户.物品的协同过滤推荐算法 个性化新闻推荐系统 协同过滤推荐算法在新闻网站中的运用 个性化推荐算法.机器学习.分布式大数据.人工智 ...

最新文章

  1. html表单php连接mysql数据库,PHP 连接MySQL数据库
  2. html盒子模型子元素怎么水平占满父元素_CSS3——弹性盒模型-flex——父级属性...
  3. Windows10搭建ftp服务
  4. nginx 正则 结尾 配置_nginx正则配置解释多用于伪静态规则
  5. 海龟绘图小动物_被解救海龟经野化训练后放归大海,这一幕让人惊叹不已
  6. android禁止锁屏保持常亮
  7. [转]window.location方法获取URL及window.location.assign(url)和replace(url)区别
  8. Spring Boot 阿里云短信平台手机验证码测试
  9. 超威主板关闭超线程教程
  10. 油/水溶性CdS-ZnS/InP-ZnS/ZnSe-ZnS/CdSe/ZnS量子点的应用
  11. SAP中 关于BAPI_MATERIAL_SAVEDATA创建物料报错:字段MARA-MEINS/BAPI_MARA-BASE_UOM(_ISO)被定义为必需的字段; 它不包含条目
  12. 杀戮尖塔是用java_杀戮尖塔修改class文件图文教程 杀戮尖塔怎么改class
  13. 【身体这些部位不舒服的时候,你知道意味着什么吗?】
  14. Webots+ROS学习记录(4)——六轮全地形移动机器人
  15. 论网络喷子的功力-关于全栈工程师那篇文
  16. 经典系列——鸡尾酒排序
  17. bzoj1787.紧急集合(倍增LCA)
  18. Ubuntu8.04安装Realplayer11
  19. Scala从入门到入土(从入门到放弃)
  20. 豆瓣用python写的网站_用python写一个豆瓣短评通用爬虫(登录、爬取、可视化)

热门文章

  1. tf.data.Dataset 用法
  2. elasticsearch python API
  3. 风控项目-收集基础知识1
  4. 深度学习数学基础(一)~卷积
  5. markdown 笔记
  6. 【数学建模】基于随机机会约束规划方法对旅行商问题TSP求解
  7. 基于MATLAB均值漂移图像分割技术
  8. Flink从入门到精通100篇(十五)-Flink SQL FileSystem Connector 分区提交与自定义小文件合并策略 ​
  9. python socket 说明
  10. eclipse下新建py文件的辅助信息设置