本节主要介绍的是libFM源码分析的第五部分之一——libFM的训练过程之SGD的方法。

5.1、基于梯度的模型训练方法

在libFM中,提供了两大类的模型训练方法,一类是基于梯度的训练方法,另一类是基于MCMC的模型训练方法。对于基于梯度的训练方法,其类为fm_learn_sgd类,其父类为fm_learn类,主要关系为:

fm_learn_sgd类是所有基于梯度的训练方法的父类,其具体的代码如下所示:

#include "fm_learn.h"
#include "../../fm_core/fm_sgd.h"// 继承自fm_learn
class fm_learn_sgd: public fm_learn {protected://DVector<double> sum, sum_sqr;public:int num_iter;// 迭代次数double learn_rate;// 学习率DVector<double> learn_rates;// 多个学习率        // 初始化virtual void init() {       fm_learn::init();   learn_rates.setSize(3);// 设置学习率//  sum.setSize(fm->num_factor);        //  sum_sqr.setSize(fm->num_factor);}       // 利用梯度下降法进行更新,具体的训练的过程在其子类中virtual void learn(Data& train, Data& test) { fm_learn::learn(train, test);// 该函数并没有具体实现// 输出运行时的参数,包括:学习率,迭代次数std::cout << "learnrate=" << learn_rate << std::endl;std::cout << "learnrates=" << learn_rates(0) << "," << learn_rates(1) << "," << learn_rates(2) << std::endl;std::cout << "#iterations=" << num_iter << std::endl;if (train.relation.dim > 0) {// 判断relationthrow "relations are not supported with SGD";}std::cout.flush();// 刷新}// SGD重新修正fm模型的权重void SGD(sparse_row<DATA_FLOAT> &x, const double multiplier, DVector<double> &sum) {fm_SGD(fm, learn_rate, x, multiplier, sum);// 调用fm_sgd中的fm_SGD函数} // debug函数,主要用于打印中间结果void debug() {std::cout << "num_iter=" << num_iter << std::endl;fm_learn::debug();          }// 对数据进行预测virtual void predict(Data& data, DVector<double>& out) {assert(data.data->getNumRows() == out.dim);// 判断样本个数是否相等for (data.data->begin(); !data.data->end(); data.data->next()) {double p = predict_case(data);// 得到线性项和交叉项的和,调用的是fm_learn中的方法if (task == TASK_REGRESSION ) {// 回归任务p = std::min(max_target, p);p = std::max(min_target, p);} else if (task == TASK_CLASSIFICATION) {// 分类任务p = 1.0/(1.0 + exp(-p));// Sigmoid函数处理} else {// 异常处理throw "task not supported";}out(data.data->getRowIndex()) = p;}               } };

fm_learn_sgd类中,主要包括五个函数,分别为:初始化init函数,训练learn函数,SGD训练SGD函数,debug的debug函数和预测predict函数。

5.1.1、初始化init函数

在初始化中,对学习率的大小进行了初始化,同时继承了父类中的初始化方法。

5.1.2、训练learn函数

learn函数中,没有具体的训练的过程,只是对训练中需要用到的参数进行输出,具体的训练的过程在其对应的子类中定义,如fm_learn_sgd_element类和fm_learn_sgd_element_adapt_reg类。

5.1.3、SGD训练SGD函数

SGD函数使用的是fm_sgd.h文件中的fm_SGD函数。fm_SGD函数是利用梯度下降法对模型中的参数进行调整,以得到最终的模型中的参数。在利用梯度下降法对模型中的参数进行调整的过程中,假设损失函数为ll,那么,对于回归问题来说,其损失函数为:

l=12(y^(i)−y(i))2l=12(y^(i)−y(i))2

对于二分类问题,其损失函数为:

l=−lnσ(y^(i)y(i))l=−lnσ(y^(i)y(i))

其中,σσ为Sigmoid函数:

σ(x)=11+e(−x)σ(x)=11+e(−x)

对于σ(x)σ(x),其导函数为:

σ′=σ(1−σ)σ′=σ(1−σ)

在可用SGD更新的过程中,首先需要计算损失函数的梯度,因此,对应于上述的回归问题和二分类问题,其中回归问题的损失函数的梯度为:

∂l∂θ=(y^(i)−y(i))⋅∂y^(i)∂θ∂l∂θ=(y^(i)−y(i))⋅∂y^(i)∂θ

分类问题的损失函数的梯度为:

∂l∂θ=(σ(y^(i)y(i))−1)⋅y(i)⋅∂y^(i)∂θ∂l∂θ=(σ(y^(i)y(i))−1)⋅y(i)⋅∂y^(i)∂θ

其中,λλ称为正则化参数,在具体的应用中,通常加上L2L2正则,即:

∂l∂θ+λθ∂l∂θ+λθ

在定义好上述的计算方法后,其核心的问题是如何计算∂y^(i)∂θ∂y^(i)∂θ,在“机器学习算法实现解析——libFM之libFM的模型处理部分”中已知:

y^:=w0+∑i=1nwixi+∑i=1n−1∑j=i+1n〈vi,vj〉xixjy^:=w0+∑i=1nwixi+∑i=1n−1∑j=i+1n〈vi,vj〉xixj

因此,当y^y^分别对w0w0,wiwi以及vi,fvi,f求偏导时,其结果分别为:

∂y^∂θ=⎧⎩⎨⎪⎪⎪⎪1xixi(∑j=1xjvj,f−xivi,f) if θ=w0 if θ=wi if θ=vi,f∂y^∂θ={1 if θ=w0xi if θ=wixi(∑j=1xjvj,f−xivi,f) if θ=vi,f

在利用梯度的方法中,其参数θθ的更新方法为:

θ=θ−η⋅(∂l∂θ+λθ)θ=θ−η⋅(∂l∂θ+λθ)

其中,ηη为学习率,在libFM中,其具体的代码如下所示:

// 利用SGD更新模型的参数
void fm_SGD(fm_model* fm, const double& learn_rate, sparse_row<DATA_FLOAT> &x, const double multiplier, DVector<double> &sum) {// 1、常数项的修正if (fm->k0) {double& w0 = fm->w0;w0 -= learn_rate * (multiplier + fm->reg0 * w0);}// 2、一次项的修正if (fm->k1) {for (uint i = 0; i < x.size; i++) {double& w = fm->w(x.data[i].id);w -= learn_rate * (multiplier * x.data[i].value + fm->regw * w);}}// 3、交叉项的修正for (int f = 0; f < fm->num_factor; f++) {for (uint i = 0; i < x.size; i++) {double& v = fm->v(f,x.data[i].id);double grad = sum(f) * x.data[i].value - v * x.data[i].value * x.data[i].value; v -= learn_rate * (multiplier * grad + fm->regv * v);}}
}

以上的更新的过程分别对应着上面的更新公式,其中multiplier变量分别对应着回归中的(y^(i)−y(i))(y^(i)−y(i))和分类中的(σ(y^(i)y(i))−1)⋅y(i)(σ(y^(i)y(i))−1)⋅y(i)。

5.1.4、预测predict函数

predict函数用于对样本进行预测,这里使用到了predict_case函数,该函数在“机器学习算法实现解析——libFM之libFM的训练过程概述”中有详细的说明,得到值后,分别对回归问题和分类问题做处理,在回归问题中,主要是防止超出最大值和最小值,在分类问题中,将其值放入Sigmoid函数,得到最终的结果。

5.2、SGD的训练方法

随机梯度下降法(Stochastic Gradient Descent ,SGD)是一种简单有效的优化方法。对于梯度下降法的更多内容,可以参见“梯度下降优化算法综述”。在利用SGD对FM模型训练的过程如下图所示:

在libFM中,SGD的实现在fm_learn_sgd_element.h文件中。在该文件中,定义了fm_learn_sgd_element类,fm_learn_sgd_element类继承自fm_learn_sgd类,主要实现了fm_learn_sgd类中的learn方法,具体的程序代码如下所示:

#include "fm_learn_sgd.h"// 继承了fm_learn_sgd
class fm_learn_sgd_element: public fm_learn_sgd {public:// 初始化virtual void init() {fm_learn_sgd::init();// 日志输出if (log != NULL) {log->addField("rmse_train", std::numeric_limits<double>::quiet_NaN());}}// 利用SGD训练FM模型virtual void learn(Data& train, Data& test) {fm_learn_sgd::learn(train, test);// 输出参数信息std::cout << "SGD: DON'T FORGET TO SHUFFLE THE ROWS IN TRAINING DATA TO GET THE BEST RESULTS." << std::endl; // SGDfor (int i = 0; i < num_iter; i++) {// 开始迭代,每一轮的迭代过程double iteration_time = getusertime();// 记录开始的时间for (train.data->begin(); !train.data->end(); train.data->next()) {// 对于每一个样本double p = fm->predict(train.data->getRow(), sum, sum_sqr);// 得到样本的预测值double mult = 0;// 损失函数的导数if (task == 0) {// 回归p = std::min(max_target, p);p = std::max(min_target, p);// loss=(y_ori-y_pre)^2mult = -(train.target(train.data->getRowIndex())-p);// 对损失函数求导} else if (task == 1) {// 分类// lossmult = -train.target(train.data->getRowIndex())*(1.0-1.0/(1.0+exp(-train.target(train.data->getRowIndex())*p)));}// 利用梯度下降法对参数进行学习SGD(train.data->getRow(), mult, sum);                   }               iteration_time = (getusertime() - iteration_time);// 记录时间差// evaluate函数是调用的fm_learn类中的方法double rmse_train = evaluate(train);// 对训练结果评估double rmse_test = evaluate(test);// 将模型应用在测试数据上std::cout << "#Iter=" << std::setw(3) << i << "\tTrain=" << rmse_train << "\tTest=" << rmse_test << std::endl;// 日志输出if (log != NULL) {log->log("rmse_train", rmse_train);log->log("time_learn", iteration_time);log->newLine();}}       }};

learn函数中,实现了SGD训练FM模型的主要过程,在实现的过程中,分别调用了SGD函数和evaluate函数,其中SGD函数如上面的5.1.3、SGD训练SGD函数小节所示,利用SGD函数对FM模型中的参数进行更新,evaluate函数如“机器学习算法实现解析——libFM之libFM的训练过程概述”中所示,evaluate函数用于评估学习出的模型的效果。其中mult变量分别对应着回归中的(y^(i)−y(i))(y^(i)−y(i))和分类中的(σ(y^(i)y(i))−1)⋅y(i)(σ(y^(i)y(i))−1)⋅y(i)。

参考文献

  • Rendle S. Factorization Machines[C]// IEEE International Conference on Data Mining. IEEE Computer Society, 2010:995-1000.
  • Rendle S. Factorization Machines with libFM[M]. ACM, 2012.

--------------------- 本文来自 zhiyong_will 的CSDN 博客 ,全文地址请点击:https://blog.csdn.net/google19890102/article/details/72866334?utm_source=copy

机器学习算法实现解析:libFM之libFM的训练过程之SGD的方法相关推荐

  1. 【机器学习算法】关联规则-3 关联规则的指标问题和关联规则的使用方法

    目录 关联规则的指标问题和关联规则的使用方法 再谈评估指标 支持度与置信度的问题 提升度指标 关联规则的生成 关联规则的延伸 虚拟产品 负向相关规则dissociation rules 相依性网络 总 ...

  2. bfgs算法c语言,机器学习算法实现解析——liblbfgs之L-BFGS算法

    在博文"优化算法--拟牛顿法之L-BFGS算法"中,已经对L-BFGS的算法原理做了详细的介绍,本文主要就开源代码liblbfgs重新回顾L-BFGS的算法原理以及具体的实现过程, ...

  3. 机器学习算法实现解析——word2vec源码解析

    在wrod2vec工具中,有如下的几个比较重要的概念: CBOW Skip-Gram Hierarchical Softmax Negative Sampling 其中CBOW和Skip-Gram是w ...

  4. 成为顶尖机器学习算法专家需要知道哪些算法?

    2019独角兽企业重金招聘Python工程师标准>>> 成为顶尖机器学习算法专家需要知道哪些算法? 摘要:顶尖的机器学习专家需要的算法,要不要? 机器学习算法简介 有两种方法可以对你 ...

  5. Pymetrics开源公平性感知机器学习算法Audit AI

    Pymetrics是一件专注于向企业提供招聘服务的初创企业.最近,Pymetrics在Github上开源了企业使用的偏差检测(bias detection)算法,称为"Audio AI&qu ...

  6. 机器学习算法中的过拟合与欠拟合(转载)

    在机器学习表现不佳的原因要么是过度拟合或欠拟合数据. 1.机器学习中的逼近目标函数过程 监督式机器学习通常理解为逼近一个目标函数(f)(f),此函数映射输入变量(X)到输出变量(Y). Y=f(X)Y ...

  7. 8个常见机器学习算法的计算复杂度总结!

    Datawhale干货 来源:DeepHub IMBA,编辑:数据派THU 计算的复杂度是一个特定算法在运行时所消耗的计算资源(时间和空间)的度量. 计算复杂度又分为两类: 一.时间复杂度 时间复杂度 ...

  8. 用于预测脊柱转移术后30天死亡率的机器学习算法的开发

    用于预测脊柱转移术后30天死亡率的机器学习算法的开发 Development of Machine Learning Algorithms for Prediction of 30-Day Morta ...

  9. 各种机器学习算法比较

    前言  简单介绍各种机器学习算法的优缺点,和用python中的一些相关库的用法 一.监督学习算法 1.k-NN近邻 1.1 简介  k-NN 算法可以说是最简单的机器学习算法.构建模型只需要保存训练数 ...

最新文章

  1. 属性驱动的架构设计方法图解【转载】
  2. mysql数据库扫描_使用nmap对mysql 数据库进行扫描
  3. 转载:JAVA日期处理
  4. Numpy 之 where理解
  5. 还在家隔离呢?没事写写这些程序吧!
  6. 为RedHat系统安装发布版的PostgreSQL数据库
  7. Atitit webdav的使用与配置总结attilax总结 目录 1. 支持的协议 2 1.1. http File unc 2 2. 应用场景 2 2.1. 远程文件管理实现功能 文件建立
  8. java interface作用是什么_关于Java反射原理:
  9. 小米air2se耳机只有一边有声音怎么办_别光盯着AirPods,这些无线蓝牙耳机,其实也很好用...
  10. 关于计算机团队名字大全集,好听的团队名字大全
  11. 牛客网暑期ACM多校训练营(第三场) J.Distance to Work 计算几何
  12. my python voyage
  13. matlab水下机器人,水下机器人路径控制与仿真
  14. spark RDD算子大全
  15. CJSON 使用介绍
  16. 用的五大bug管理工具的优缺点和下载地址
  17. docker开放的端口_docker容器怎么开端口
  18. 可变参数传递与不可变参数传递
  19. LaTeX快速入门(简易模板)
  20. html引用不了自定义字体,html5 – 自定义@ font-face不加载chrome(chrome自定义字体无法渲染)...

热门文章

  1. 服务器里怎么维修装备,教你在服务器加自己的装备
  2. 为进阶Linux大佬打牢地基
  3. url模糊匹配优化_详情页怎么做SEO优化?
  4. 用java正则表达式验证字符串(邮箱与网址)
  5. java 文件 递归_JAVA实现遍历文件夹下的所有文件(递归调用和非递归调用)
  6. php中的select case语句吗,VBS教程:VBScript 语句-Select Case 语句
  7. html中dir标签的作用是什么意思,htmldir标签是干啥的?dir标签的具体定义和属性介绍...
  8. 两个计算机系统安装,如何在一台电脑上同时重装两个系统|戴尔电脑怎么安装两个系统...
  9. 自定义权限 android,Android权限控制之自定义权限
  10. php 微信定位,微信企业号(服务号)坐标定位发生偏移解决方案记录( 附PHP代码)...