提到ALS相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),它是协同过滤的一种,并被集成到Spark的Mllib库中。本文就ALS的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。

完整实现代码请参考本人的p...哦不是...github:als.py​github.com

als_example.py​github.com

1. 原理篇

我们用人话而不是大段的数学公式来讲讲ALS是怎么一回事。

1.1 你听说过推荐算法么

假如我是豆瓣的CEO,很多豆瓣的用户在豆瓣电影上都会对电影进行评分。那么根据这个评分数据,我们有可能知道这些用户除了自己评过分的电影之外还喜欢或讨厌哪些电影吗?这就是一个典型的推荐问题,解决这一类问题的算法被称为推荐算法。

1.2 什么是协同过滤

协同过滤的英文全称是Collaborative Filtering,简称CF。注意,这不是一款游戏!从字面上分析,协同就是寻找共同点,过滤就是筛选出优质的内容。

1.3 协同过滤的分类

一般来说,协同过滤推荐分为三种类型:

1. 基于用户(user-based)的协同过滤,通过计算用户和用户的相似度找到跟用户A相似的用户B, C, D...再把这些用户喜欢的内容推荐给A;

2. 基于物品(item-based)的协同过滤,通过计算物品和物品的相似度找到跟物品1相似的物品2, 3, 4...再把这些物品推荐给看过物品1的用户们;

3. 基于模型(model based)的协同过滤。主流的方法可以分为:矩阵分解,关联算法,聚类算法,分类算法,回归算法,神经网络。

1.4 矩阵分解

矩阵分解 (decomposition, factorization)是将矩阵拆解为数个矩阵的乘积。比如豆瓣电影有m个用户,n个电影。那么用户对电影的评分可以形成一个m行n列的矩阵R,我们可以找到一个m行k列的矩阵U,和一个k行n列的矩阵I,通过U * I来得到矩阵R。

1.5 ALS

如果想通过矩阵分解的方法实现基于模型的协同过滤,ALS是一个不错的选择,其英文全称是Alternating Least Square,翻译过来是交替最小二乘法。假设用户为a,物品为b,评分矩阵为R(m, n),可分解为用户矩阵U(k, m)和物品矩阵I(k, n),其中m, n, k代表矩阵的维度。前方小段数学公式低能预警:根据矩阵分解的定义,有

用MSE作为损失函数,为了方便化简,加法符号左侧的常数改为-1/2

对损失函数求U_a的一阶偏导数,那么

令一阶偏导数等于0

同理,可证

1.6 求解用户矩阵U和物品矩阵I

矩阵R是已知的,我们随机生成用户矩阵U,

1. 利用1.5中的式5、R和U求出I

2. 利用1.5中的式6、R和I求出U

如此交替地执行步骤1和步骤2,直到算法收敛或者迭代次数超过了最大限制,最终我们用RMSE来评价模型的好坏。

2. 实现篇

本人用全宇宙最简单的编程语言——Python实现了ALS算法,没有依赖任何第三方库,便于学习和使用。简单说明一下实现过程,更详细的注释请参考本人github上的代码。

注:代码中用到的Matrix类是我写的一个矩阵类,可以取出矩阵的行或列,计算矩阵的乘法、转置和逆。代码链接:matrix.py

2.1 创建ALS类

初始化,存储用户ID、物品ID、用户ID与用户矩阵列号的对应关系、物品ID与物品矩阵列号的对应关系、用户已经看过哪些物品、评分矩阵的Shape以及RMSE。

class ALS(object):

def __init__(self):

self.user_ids = None

self.item_ids = None

self.user_ids_dict = None

self.item_ids_dict = None

self.user_matrix = None

self.item_matrix = None

self.user_items = None

self.shape = None

self.rmse = None

2.2 数据预处理

对训练数据进行处理,得到用户ID、物品ID、用户ID与用户矩阵列号的对应关系、物品ID与物品矩阵列号的对应关系、评分矩阵的Shape、评分矩阵及评分矩阵的转置。

def _process_data(self, X):

self.user_ids = tuple((set(map(lambda x: x[0], X))))

self.user_ids_dict = dict(map(lambda x: x[::-1],

enumerate(self.user_ids)))

self.item_ids = tuple((set(map(lambda x: x[1], X))))

self.item_ids_dict = dict(map(lambda x: x[::-1],

enumerate(self.item_ids)))

self.shape = (len(self.user_ids), len(self.item_ids))

ratings = defaultdict(lambda: defaultdict(int))

ratings_T = defaultdict(lambda: defaultdict(int))

for row in X:

user_id, item_id, rating = row

ratings[user_id][item_id] = rating

ratings_T[item_id][user_id] = rating

err_msg = "Length of user_ids%dand ratings%dnot match!" % (

len(self.user_ids), len(ratings))

assert len(self.user_ids) == len(ratings), err_msg

err_msg = "Length of item_ids%dand ratings_T%dnot match!" % (

len(self.item_ids), len(ratings_T))

assert len(self.item_ids) == len(ratings_T), err_msg

return ratings, ratings_T

2.3 用户矩阵乘以评分矩阵

实现稠密矩阵与稀疏矩阵的矩阵乘法,得到用户矩阵与评分矩阵的乘积。

def _users_mul_ratings(self, users, ratings_T):

def f(users_row, item_id):

user_ids = iter(ratings_T[item_id].keys())

scores = iter(ratings_T[item_id].values())

col_nos = map(lambda x: self.user_ids_dict[x], user_ids)

_users_row = map(lambda x: users_row[x], col_nos)

return sum(a * b for a, b in zip(_users_row, scores))

ret = [[f(users_row, item_id) for item_id in self.item_ids]

for users_row in users.data]

return Matrix(ret)

2.4 物品矩阵乘以评分矩阵

实现稠密矩阵与稀疏矩阵的矩阵乘法,得到物品矩阵与评分矩阵的乘积。

def _items_mul_ratings(self, items, ratings):

def f(items_row, user_id):

item_ids = iter(ratings[user_id].keys())

scores = iter(ratings[user_id].values())

col_nos = map(lambda x: self.item_ids_dict[x], item_ids)

_items_row = map(lambda x: items_row[x], col_nos)

return sum(a * b for a, b in zip(_items_row, scores))

ret = [[f(items_row, user_id) for user_id in self.user_ids]

for items_row in items.data]

return Matrix(ret)

2.5 生成随机矩阵

def _gen_random_matrix(self, n_rows, n_colums):

data = [[random() for _ in range(n_colums)] for _ in range(n_rows)]

return Matrix(data)

2.6 计算RMSE

def _get_rmse(self, ratings):

m, n = self.shape

mse = 0.0

n_elements = sum(map(len, ratings.values()))

for i in range(m):

for j in range(n):

user_id = self.user_ids[i]

item_id = self.item_ids[j]

rating = ratings[user_id][item_id]

if rating > 0:

user_row = self.user_matrix.col(i).transpose

item_col = self.item_matrix.col(j)

rating_hat = user_row.mat_mul(item_col).data[0][0]

square_error = (rating - rating_hat) ** 2

mse += square_error / n_elements

return mse ** 0.5

2.7 训练模型数据预处理

变量k合法性检查

生成随机矩阵U

交替计算矩阵U和矩阵I,并打印RMSE信息,直到迭代次数达到max_iter

保存最终的RMSE

def fit(self, X, k, max_iter=10):

ratings, ratings_T = self._process_data(X)

self.user_items = {k: set(v.keys()) for k, v in ratings.items()}

m, n = self.shape

error_msg = "Parameter k must be less than the rank of original matrix"

assert k < min(m, n), error_msg

self.user_matrix = self._gen_random_matrix(k, m)

for i in range(max_iter):

if i % 2:

items = self.item_matrix

self.user_matrix = self._items_mul_ratings(

items.mat_mul(items.transpose).inverse.mat_mul(items),

ratings

)

else:

users = self.user_matrix

self.item_matrix = self._users_mul_ratings(

users.mat_mul(users.transpose).inverse.mat_mul(users),

ratings_T

)

rmse = self._get_rmse(ratings)

print("Iterations:%d, RMSE:%.6f" % (i + 1, rmse))

self.rmse = rmse

2.8 预测一个用户

预测一个用户感兴趣的内容,剔除用户已看过的内容。然后按感兴趣分值排序,取出前n_items个内容。

def _predict(self, user_id, n_items):

users_col = self.user_matrix.col(self.user_ids_dict[user_id])

users_col = users_col.transpose

items_col = enumerate(users_col.mat_mul(self.item_matrix).data[0])

items_scores = map(lambda x: (self.item_ids[x[0]], x[1]), items_col)

viewed_items = self.user_items[user_id]

items_scores = filter(lambda x: x[0] not in viewed_items, items_scores)

return sorted(items_scores, key=lambda x: x[1], reverse=True)[:n_items]

2.9 预测多个用户

循环调用2.8,预测多个用户感兴趣的内容。

def predict(self, user_ids, n_items=10):

return [self._predict(user_id, n_items) for user_id in user_ids]

3 效果评估

3.1 main函数

使用电影评分数据集,训练模型并统计RMSE。

@run_time

def main():

print("Tesing the accuracy of ALS...")

X = load_movie_ratings()

model = ALS()

model.fit(X, k=3, max_iter=5)

print()

print("Showing the predictions of users...")

user_ids = range(1, 5)

predictions = model.predict(user_ids, n_items=2)

for user_id, prediction in zip(user_ids, predictions):

_prediction = [format_prediction(item_id, score)

for item_id, score in prediction]

print("User id:%drecommedation:%s" % (user_id, _prediction))

3.2 效果展示

设置k=3,迭代5次,并展示了前4个用户的推荐内容,最终RMSE为0.370,运行时间46.5秒,效果还算不错~

3.3 工具函数

本人自定义了一些工具函数,可以在github上查看:utils​github.com

1. run_time - 测试函数运行时间

2. load_movie_ratings - 加载电影评分数据

总结

ALS的原理:鸡生蛋、蛋生鸡

ALS的实现:基本上就是矩阵乘法

python协同过滤算法_协同过滤(ALS)的原理及Python实现相关推荐

  1. python协同过滤算法_协同过滤算法介绍及算法实现

    一.协同过滤算法简介 协同过滤算法是一种较为著名和常用的推荐算法,它基于对用户历史行为数据的挖掘发现用户的喜好偏向,并预测用户可能喜好的产品进行推荐.也就是常见的"猜你喜欢",和& ...

  2. 协同过滤算法_协同过滤在新闻推荐CTR预估中的应用

    概述协同过滤算法是推荐系统中的最基本的算法,该算法不仅在学术界得到了深入的研究,而且在工业界也得到了广泛的应用.本文介绍最基本的基于物品的和基于用户的协同过滤算法,并结合新闻推荐的CTR预估,介绍基于 ...

  3. python图像融合算法_图像融合质量评价方法的python代码实现——MS-SSIM

    图像融合质量评价方法的python代码实现--MS-SSIM 图像融合质量评价方法的python代码实现--MS-SSIM 文章目录 1 前言 2 MS-SSIM介绍 2 MS-SSIM的代码 2.1 ...

  4. python人物抠图算法_比PS还好用!Python 20行代码批量抠图

    抠图前 vs Python自动抠图后 在日常的工作和生活中,我们经常会遇到需要抠图的场景,即便是只有一张图片需要抠,也会抠得我们不耐烦,倘若遇到许多张图片需要抠,这时候你的表情应该会很有趣. Pyth ...

  5. python制作图像数据集_详细图像数据集增强原理的python代码

    导读 在深度学习时代,数据的规模越大.质量越高,模型就能够拥有更好的泛化能力,数据直接决定了模型学习的上限.然而在实际工程中,采集的数据很难覆盖全部的场景,比如图像的光照条件,同一场景拍摄的图片可能由 ...

  6. 协同过滤算法_机器学习 | 简介推荐场景中的协同过滤算法,以及SVD的使用

    本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第29篇文章,我们来聊聊SVD在上古时期的推荐场景当中的应用. 推荐的背后逻辑 有没有思考过一个问题,当我们在淘宝或者是 ...

  7. 协同过滤算法_基于Mahout的协同过滤推荐算法

    1协同过滤 基于物品的协同过滤 对物品进行相似度计算,然后再进行推荐. 基于用户的协同过滤 基于用户的协同过滤算法是通过用户的历史行为数据发现用户对商品或内容的喜欢,并对这些喜好进行度量和打分.根据不 ...

  8. 协同过滤算法_基于用户的协同过滤推荐算法原理和实现

    (给算法爱好者加星标,修炼编程内功) 来源:Create Chen's Blog https://www.cnblogs.com/technology/p/4467895.html 在推荐系统众多方法 ...

  9. 协同过滤算法_利用数据分析量化协同过滤算法的两大常见难题

    点击上方"蓝色字体",选择 "设为星标" 关键讯息,D1时间送达! 推荐系统自从问世以来解决了许多不同的商业产品问题,深受广大互联网从业者的喜爱.传统的互联网电 ...

最新文章

  1. Spring事务配置探究
  2. DCCI:Web3.0时代 微博未来有望突破4.6亿
  3. Linux C编程--进程间通信(IPC)5--System V IPC 机制3--共享内存
  4. C++三角函数用法错误error C2668: 'tan' : ambiguous call to overloaded function原因及解决方法
  5. django 框架中应用 redis 集群(亲测)
  6. python3字符串转数字_Python 3.6.0 正式版发布附新特性说明
  7. OpenJudge NOI 1.8 25:螺旋加密
  8. Matlab中错误使用mex
  9. Netty工作笔记0079---Log4j整合到Netty
  10. as的Enter_Frame与Timer
  11. Confluence 6 Windows 中以服务方式自动重启修改运行服务的用户
  12. mysql migration tool 64bit_Migration Tool支持MySQL业务SQL迁移
  13. ubuntu无法安装usb驱动
  14. Oracle中将小数转换成字符丢零.截取小数.除数为零解决法
  15. matlab cat函数_MATLAB的Categorical数组
  16. udp客户端 Java_java UDP通信客户端与服务器端实例分析
  17. 医院排队叫号系统源码
  18. js基础 js函数作用域链
  19. 【电源专题】什么是开关稳压器
  20. 基于Kinect体感器控制的机械臂项目记录

热门文章

  1. 什么是HIS,以及HIS的作用,特点,组成部分
  2. C++ Primer读书摘要(一)
  3. 我爱我专业计算机为主题的演讲稿,我爱我专业演讲稿
  4. 利用开源 SNI PROXY+DNSMASQ 工具链实战 Netflix 流媒体解锁
  5. 目标检测算法Faster R-CNN论文解读
  6. matlab单边带调制仿真,000单边带调制MATLAB模拟仿真
  7. 一个屌丝程序猿的人生(九十四)
  8. echart 三维可视化地图_可视化地图是什么?推荐3个工具!
  9. vs mysql 开发erp_ERP/MIS开发 Mindscape NHibernate + MySQL 快速开发入门
  10. 统信UOS提示系统磁盘空间不足,磁盘扩容