• Ranklib部分源码分析LambdaMARTRandomForest

    • 声明
    • 关于Ranklib
    • 主框架Evlauatorjava
      • 程序主入口main函数
      • 训练器evaluate函数
      • 预测test函数
    • 数据结构基础类DataPointRankList
      • 数据格式
      • DataPointjava
      • RankListjava
    • Rank基础类RANKER_TYPERankerFactoryRanker
      • RANKER_TYPEjava
      • RankerFactoryjava
      • Rankerjava
      • RankerTrainer
    • LambdaMART基础类FeatureHistorgramRegressionTreeSplitEnsemble
      • FeatureHistorgramjava
      • RegressionTreejava
      • Ensemblejava
    • LambdaMART流程

Ranklib部分源码分析(LambdaMART+RandomForest)

声明

本文是Ranklib部分源码的分析,参考了RankLib源码分析——guoguo881218的专栏以及Learning to Rank——wowarsenal,在此对原博主表示感谢

关于Ranklib

在The Lemur Project可以下载到Ranklib程序,Ranklib2.1和Ranklib2.3有源码可以下载,Ranklib2.4和Ranklib2.5只有jar文件可以下载。通过jad反编译后可以看到源码,整体结构差别不大。本文以Ranklib2.3为标准进行说明。

主框架(Evlauator.java)

程序主入口main函数

Ranklib程序主入口为ciir.umass.edu.eval.Evaluator类中main函数.
其中public static void main(String[] args)函数接收命令行传入参数。

  1. 首先初始化一些变量并根据传入参数给变量赋值:
        for(int i=0;i<args.length;i++){if(args[i].compareTo("-train")==0)trainFile = args[++i]; //训练集else if(args[i].compareTo("-ranker")==0)rankerType = Integer.parseInt(args[++i]); //Rank类型...else if(args[i].compareTo("-metric2t")==0)trainMetric = args[++i]; //训练集Metricelse if(args[i].compareTo("-metric2T")==0)testMetric = args[++i]; //测试集Metric...else if(args[i].compareTo("-validate")==0)validationFile = args[++i]; //验证集else if(args[i].compareTo("-test")==0){testFile = args[++i];testFiles.add(testFile);} //测试集...else if(args[i].compareTo("-save")==0)Evaluator.modelFile = args[++i]; //模型保存位置...else if(args[i].compareTo("-load")==0){savedModelFile = args[++i];savedModelFiles.add(args[i]);} //导入模型...else if(args[i].compareTo("-rank")==0)rankFile = args[++i]; //待排序数据... ... ...//MART / LambdaMART / Random forestelse if(args[i].compareTo("-tree")==0){LambdaMART.nTrees = Integer.parseInt(args[++i]);RFRanker.nTrees = Integer.parseInt(args[i]);} //树的棵树else if(args[i].compareTo("-leaf")==0){LambdaMART.nTreeLeaves = Integer.parseInt(args[++i]);RFRanker.nTreeLeaves = Integer.parseInt(args[i]);} //每棵树叶子结点数else if(args[i].compareTo("-shrinkage")==0){LambdaMART.learningRate = Float.parseFloat(args[++i]);RFRanker.learningRate = Float.parseFloat(args[i]);} //收缩系数...//Random forestelse if(args[i].compareTo("-bag")==0)RFRanker.nBag = Integer.parseInt(args[++i]); //bags数目
  1. 根据参数变量进行训练
        if(nThread == -1)nThread = Runtime.getRuntime().availableProcessors();MyThreadPool.init(nThread); //线程池初始化...Evaluator e = new Evaluator(rType2[rankerType], trainMetric, testMetric); //根据Rank类型以及训练集、测试集上的评价函数生成Evaluator对象... ...RankerFactory rf = new RankerFactory();rf.createRanker(rType2[rankerType]).printParameters();//根据参数创建Rank对象...e.evaluate() //多个实现,针对不同情况进行evaluate...if(testFiles.size() > 1)e.test(savedModelFiles, testFiles, prpFile);elsee.test(savedModelFiles, testFile, prpFile) //利用已有模型在测试集上进行预测s

训练器evaluate函数

  1. Evaluator初始化
    public Evaluator(RANKER_TYPE rType, String trainMetric, String testMetric){this.type = rType; //Ranke类型trainScorer = mFact.createScorer(trainMetric); //训练集上得分testScorer = mFact.createScorer(testMetric); //测试集上得分...}
  1. evaluate函数调用
//根据训练集验证集和测试集进行训练的evaluate()函数调用public void evaluate(String trainFile, String validationFile, String testFile, String featureDefFile){List<RankList> train = readInput(trainFile);...test = readInput(testFile);//读取训练、验证、测试文件... ...RankerTrainer trainer = new RankerTrainer();Ranker ranker = trainer.train(type, train, validation, features, trainScorer);//利用训练集和验证集训练模型... ...double rankScore = evaluate(ranker, test); //计算测试集得分... ...ranker.save(modelFile); //保存模型

预测test函数

test函数调用

//根据已有模型以及测试集进行训练的test()函数调用public void test(String modelFile, String testFile, String prpFile){Ranker ranker = rFact.loadRanker(modelFile); //导入模型List<RankList> test = readInput(testFile); //读取测试数据RankList l = ranker.rank(test.get(i)); //排序评分double score = testScorer.score(l); //取得评分

数据结构基础类(DataPoint、RankList)

数据格式

数据格式与SVM-Rank、libSVM、LETOR格式均相同。格式如下

<line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info>
<target> .=. <positive integer> //正整数型评分
<qid> .=. <positive integer> //正整数型查询
<feature> .=. <positive integer> //正整数型特征序号
<value> .=. <float> //浮点型特征值
<info> .=. <string> //注释

DataPoint.java

ciir.umass.edu.learning.DataPoint

实现了需要评分的对象的数据结构。每个对象是一个待评分文档。

RankList.java

ciir.umass.edu.learning.RankList
实现了需要评分的对象组成的列表的数据结构。每个对象是一个包含对应于同一查询的不同文献的集合。

Rank基础类(RANKER_TYPE、RankerFactory、Ranker)

RANKER_TYPE.java

ciir.umass.edu.learning.RANKER_TYPE

枚举类型,包含各种Rank类型

RankerFactory.java

ciir.umass.edu.learning.RankerFactory

实现了RankerFactory,所有Rank方法都需要在此类中注册。

public Ranker createRanker(RANKER_TYPE type)创建某种类型的Rank对象。
public Ranker loadRanker(String modelFile)导入已有模型。

Ranker.java

ciir.umass.edu.learning.Ranker

Ranker类实现了一般的Rank接口,所有Rank类型都需要集成Ranker。

通用方法有:

public void setTrainingSet(List<RankList> samples) //设置训练集
public void setValidationSet(List<RankList> samples) //设置验证集
public double getScoreOnTrainingData() //训练集得分...
public void save(String modelFile) //保存模型
public RankList rank(RankList rl) //给出评分后的排序
public List<RankList> rank(List<RankList> l) //给出评分后的排序

必须在子类中实现的方法有:

public void init() //初始化
public void learn() //学习
public double eval(DataPoint p) //评价
public String toString() //模型转为字符串
public void load(String fn) //导入模型

RankerTrainer

ciir.umass.edu.learning.RankerTrainer
实现了对模型进行训练的函数:
public Ranker train(RANKER_TYPE type, List<RankList> train, List<RankList> validation, int[] features, MetricScorer scorer)

LambdaMART基础类(FeatureHistorgram、RegressionTree、Split、Ensemble)

FeatureHistorgram.java

ciir.umass.edu.learning.tree.FeatureHistogram

特征直方图类,对RankList对象进行特征的直方图统计,选择每次split时的最优feature和最优划分点。

  • construct方法:
public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds){...sum = new double[features.length][];count = new int[features.length][];...}
 * sum[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的label(该算法里为lambda)之和。* count[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的总数。
  • update方法:
protected void update(double[] labels)

用新的label更新sum[i][j]。

  • findBestSplit方法思路:

    • 选取feature作为划分的备选(可全选,可选部分)。
    • 选取最优feature和最优划分点
      • 计算每个feature的每个划分点,doubleS=sumLeft∗sumLeft/countLeft+sumRight∗sumRight/countRightdouble S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight,最小的S即为最优feature和最优划分点s(该s是feature的具体值)。
//findBestSplit方法:protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end){...int countLeft = count[i][t]; //countLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的总数int countRight = totalCount - countLeft; //countRight是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的总数...double sumLeft = sum[i][t]; //sumLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的lambad之和double sumRight = sumResponse - sumLeft; //sumRight 是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的lambad之和...}

构建树的时候,输入为(xi,lambdai),其中lambdai代表着对xi的评分(影响排序结果,是增大还是减少)。最好的划分点,就是把增大的划分到一起(全部为正值,相加结果为sumA),减少的划分到一起(全部为负值,相加结果为sumb).此时的sumA*sumA/countA+sumB*sumB/countB为最大。
因此,这里的S的含义为:该划分点尽量把正值和负值区分开。正值表示:后续评分调大;负值表示:后续评分调小;lambdai就是si从newTree中获取的值,表示si的值如何调整才能满足C最大(类似梯度)。C表示的是排序后的NDCG,求其最大值。

RegressionTree.java

ciir.umass.edu.learning.tree.RegressionTree
回归树实现。

    protected int nodes = 10; //控制分裂的次数,这个次数是按照节点来算的,而不是按照层数来计算的protected int minLeafSupport = 1;//控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂。...protected DataPoint[] trainingSamples = null; //训练的数据protected double[] trainingLabels = null; //这里的lables就是y值,在lambdaMART里为lambda值...
public void fit() //根据输入的数据以及lable值,生成回归树

Ensemble.java

ciir.umass.edu.learning.tree.Ensemble

LambdaMART流程

ciir.umass.edu.learning.tree.LambdaMART

  • LambdaMART初始化函数init()
    a. 设置训练数据,为每个训练数据i设置初值(0),为每个训练数据的y设置初值(0),为每个训练数据的w设置初值(0)
    b. 按照每个feature的大小重新排训练数据,为方便后面的计算。
    c. 每个feature都设置一批值以供后续做回归树split时的切分点。
    d. 初始化一个回归树(该树未进行分裂)
    public void init(){...//将样本根绝特征排序,方便做树的分列时快速找出最优分列点sortedIdx = new int[features.length][];MyThreadPool p = MyThreadPool.getInstance();if(p.size() == 1)//single-threadsortSamplesByFeature(0, features.length-1);...//创建存放候选阈值(分列点)的表thresholds = new float[features.length][];for(int f=0;f<features.length;f++){...}//计算特征直方图,加速寻找分列点hist = new FeatureHistogram();hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);...}
  • LambdaMART训练函数learn()
    生成指定数目的tree,以下为生成一个树的流程。

    1. 清空以前生成的pseudoResponses(yi),weights(wi)
    2. computePseudoResponses函数中生成新的pseudoResponses(yi),weights(wi)
    3. 用新生成的pseudoResponses(yi)来更新回归树
    4. 生成一棵新的回归树,并保存结果
    5. 求得γlk。
    6. 重新计算modelScores,即每个训练数据的评分
    7. 通过early stop的方式校验数据和退出
    public void learn(){//开始梯度提升训练过程for(int m=0; m<nTrees; m++){PRINT(new int[]{7}, new String[]{(m+1)+""});//计算lambdas (pseudo responses)computePseudoResponses();//根据新的label更新特征直方图hist.update(pseudoResponses);//回归决策树RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);rt.fit();//将新生成的树加入模型ensemble.add(rt, learningRate);//更新树的输出(同时计算利用Newton-Raphson方法计算gamma)updateTreeOutput(rt);//更新所有训练样本的模型输出List<Split> leaves = rt.leaves();for(int i=0;i<leaves.size();i++){Split s = leaves.get(i);int[] idx = s.getSamples();for(int j=0;j<idx.length;j++)modelScores[idx[j]] += learningRate * s.getOutput();}//评价模型scoreOnTrainingData = computeModelScoreOnTraining();//检验是否应该提前结束if(m - bestModelOnValidation > nRoundToStopEarly)break;...//回滚到在验证集上最优的模型ensemble.remove(ensemble.treeCount()-1);...}

Ranklib部分源码分析相关推荐

  1. 【Golang源码分析】Go Web常用程序包gorilla/mux的使用与源码简析

    目录[阅读时间:约10分钟] 一.概述 二.对比: gorilla/mux与net/http DefaultServeMux 三.简单使用 四.源码简析 1.NewRouter函数 2.HandleF ...

  2. SpringBoot-web开发(四): SpringMVC的拓展、接管(源码分析)

    [SpringBoot-web系列]前文: SpringBoot-web开发(一): 静态资源的导入(源码分析) SpringBoot-web开发(二): 页面和图标定制(源码分析) SpringBo ...

  3. SpringBoot-web开发(二): 页面和图标定制(源码分析)

    [SpringBoot-web系列]前文: SpringBoot-web开发(一): 静态资源的导入(源码分析) 目录 一.首页 1. 源码分析 2. 访问首页测试 二.动态页面 1. 动态资源目录t ...

  4. SpringBoot-web开发(一): 静态资源的导入(源码分析)

    目录 方式一:通过WebJars 1. 什么是webjars? 2. webjars的使用 3. webjars结构 4. 解析源码 5. 测试访问 方式二:放入静态资源目录 1. 源码分析 2. 测 ...

  5. Yolov3Yolov4网络结构与源码分析

    Yolov3&Yolov4网络结构与源码分析 从2018年Yolov3年提出的两年后,在原作者声名放弃更新Yolo算法后,俄罗斯的Alexey大神扛起了Yolov4的大旗. 文章目录 论文汇总 ...

  6. ViewGroup的Touch事件分发(源码分析)

    Android中Touch事件的分发又分为View和ViewGroup的事件分发,View的touch事件分发相对比较简单,可参考 View的Touch事件分发(一.初步了解) View的Touch事 ...

  7. View的Touch事件分发(二.源码分析)

    Android中Touch事件的分发又分为View和ViewGroup的事件分发,先来看简单的View的touch事件分发. 主要分析View的dispatchTouchEvent()方法和onTou ...

  8. MyBatis原理分析之四:一次SQL查询的源码分析

    上回我们讲到Mybatis加载相关的配置文件进行初始化,这回我们讲一下一次SQL查询怎么进行的. 准备工作 Mybatis完成一次SQL查询需要使用的代码如下: Java代码   String res ...

  9. [转]slf4j + log4j原理实现及源码分析

    slf4j + log4j原理实现及源码分析 转载于:https://www.cnblogs.com/jasonzeng888/p/6051080.html

最新文章

  1. PAT题解-1118. Birds in Forest (25)-(并查集模板题)
  2. 15个顶级Java多线程面试题及回答(高级java工程师)
  3. PowerPC VxWorks BSP分析(1)--PowerPC体系结构
  4. mysql快捷键设置_MySQL快捷命令
  5. 取消服务器系统,取消系统服务器默认共享通道的方法
  6. servlet实现mvc
  7. python range倒序_Python算法学习之堆和堆排序
  8. oracle call 存储过程 带out_详解oracle数据库存储过程调试方法
  9. Anaconda创建、激活、退出、删除虚拟环境
  10. UVa 621 - Secret Research
  11. matlab方程例子,Matlab求解超定方程组实例
  12. ftp上传工具 免安装,3款最好用的ftp上传工具 免安装
  13. 共议新时代的文化自信与守正创新,第十四届文化中国讲坛举办
  14. 苹果发布iOS9.1 Beta 3:新增壁纸和表情
  15. 70后.net老猿,尚能饭否?
  16. java高仿新浪微博短链接地址生成工具ShortUrlGenerator.java
  17. java的虚引用_java虚引用的使用说明
  18. 24种设计模式的定义和使用场合
  19. 什么是嵌入式人工智能,它的实际应用
  20. 机器人军团防护罩_冒险岛贴吧 - 1000A导轨保护罩Y轴保护盖板的简单介绍

热门文章

  1. XCTF Leaking
  2. VC驿站黑客编程(关机,重新启动,注销)
  3. 解题记录 P4017 最大食物链计数 拓扑排序
  4. 南通大学计算机科学与技术在哪个校区,南通大学各校区分别哪些专业
  5. 用python画花瓣-用python画花瓣
  6. Servlet[jsp]的Servlet.service()引发了具有根本原因的异常 (这个是什么情况?求解答)
  7. ACM/ICPC 2018亚洲区预选赛北京赛站网络赛 D. 80 Days
  8. hcaptcha 我是人类验证码怎么跳过怎么验证自动识别
  9. CNopendata空气质量站点监测数据
  10. windows重装系统之后,开机显示“An operating system wasn't found,Try disconnecting any drives that...”(亲身遇到+解决方法)