Ranklib部分源码分析
- Ranklib部分源码分析LambdaMARTRandomForest
- 声明
- 关于Ranklib
- 主框架Evlauatorjava
- 程序主入口main函数
- 训练器evaluate函数
- 预测test函数
- 数据结构基础类DataPointRankList
- 数据格式
- DataPointjava
- RankListjava
- Rank基础类RANKER_TYPERankerFactoryRanker
- RANKER_TYPEjava
- RankerFactoryjava
- Rankerjava
- RankerTrainer
- LambdaMART基础类FeatureHistorgramRegressionTreeSplitEnsemble
- FeatureHistorgramjava
- RegressionTreejava
- Ensemblejava
- LambdaMART流程
Ranklib部分源码分析(LambdaMART+RandomForest)
声明
本文是Ranklib部分源码的分析,参考了RankLib源码分析——guoguo881218的专栏以及Learning to Rank——wowarsenal,在此对原博主表示感谢
关于Ranklib
在The Lemur Project可以下载到Ranklib程序,Ranklib2.1和Ranklib2.3有源码可以下载,Ranklib2.4和Ranklib2.5只有jar文件可以下载。通过jad反编译后可以看到源码,整体结构差别不大。本文以Ranklib2.3为标准进行说明。
主框架(Evlauator.java)
程序主入口main函数
Ranklib程序主入口为ciir.umass.edu.eval.Evaluator
类中main函数.
其中public static void main(String[] args)
函数接收命令行传入参数。
- 首先初始化一些变量并根据传入参数给变量赋值:
for(int i=0;i<args.length;i++){if(args[i].compareTo("-train")==0)trainFile = args[++i]; //训练集else if(args[i].compareTo("-ranker")==0)rankerType = Integer.parseInt(args[++i]); //Rank类型...else if(args[i].compareTo("-metric2t")==0)trainMetric = args[++i]; //训练集Metricelse if(args[i].compareTo("-metric2T")==0)testMetric = args[++i]; //测试集Metric...else if(args[i].compareTo("-validate")==0)validationFile = args[++i]; //验证集else if(args[i].compareTo("-test")==0){testFile = args[++i];testFiles.add(testFile);} //测试集...else if(args[i].compareTo("-save")==0)Evaluator.modelFile = args[++i]; //模型保存位置...else if(args[i].compareTo("-load")==0){savedModelFile = args[++i];savedModelFiles.add(args[i]);} //导入模型...else if(args[i].compareTo("-rank")==0)rankFile = args[++i]; //待排序数据... ... ...//MART / LambdaMART / Random forestelse if(args[i].compareTo("-tree")==0){LambdaMART.nTrees = Integer.parseInt(args[++i]);RFRanker.nTrees = Integer.parseInt(args[i]);} //树的棵树else if(args[i].compareTo("-leaf")==0){LambdaMART.nTreeLeaves = Integer.parseInt(args[++i]);RFRanker.nTreeLeaves = Integer.parseInt(args[i]);} //每棵树叶子结点数else if(args[i].compareTo("-shrinkage")==0){LambdaMART.learningRate = Float.parseFloat(args[++i]);RFRanker.learningRate = Float.parseFloat(args[i]);} //收缩系数...//Random forestelse if(args[i].compareTo("-bag")==0)RFRanker.nBag = Integer.parseInt(args[++i]); //bags数目
- 根据参数变量进行训练
if(nThread == -1)nThread = Runtime.getRuntime().availableProcessors();MyThreadPool.init(nThread); //线程池初始化...Evaluator e = new Evaluator(rType2[rankerType], trainMetric, testMetric); //根据Rank类型以及训练集、测试集上的评价函数生成Evaluator对象... ...RankerFactory rf = new RankerFactory();rf.createRanker(rType2[rankerType]).printParameters();//根据参数创建Rank对象...e.evaluate() //多个实现,针对不同情况进行evaluate...if(testFiles.size() > 1)e.test(savedModelFiles, testFiles, prpFile);elsee.test(savedModelFiles, testFile, prpFile) //利用已有模型在测试集上进行预测s
训练器evaluate函数
- Evaluator初始化
public Evaluator(RANKER_TYPE rType, String trainMetric, String testMetric){this.type = rType; //Ranke类型trainScorer = mFact.createScorer(trainMetric); //训练集上得分testScorer = mFact.createScorer(testMetric); //测试集上得分...}
- evaluate函数调用
//根据训练集验证集和测试集进行训练的evaluate()函数调用public void evaluate(String trainFile, String validationFile, String testFile, String featureDefFile){List<RankList> train = readInput(trainFile);...test = readInput(testFile);//读取训练、验证、测试文件... ...RankerTrainer trainer = new RankerTrainer();Ranker ranker = trainer.train(type, train, validation, features, trainScorer);//利用训练集和验证集训练模型... ...double rankScore = evaluate(ranker, test); //计算测试集得分... ...ranker.save(modelFile); //保存模型
预测test函数
test函数调用
//根据已有模型以及测试集进行训练的test()函数调用public void test(String modelFile, String testFile, String prpFile){Ranker ranker = rFact.loadRanker(modelFile); //导入模型List<RankList> test = readInput(testFile); //读取测试数据RankList l = ranker.rank(test.get(i)); //排序评分double score = testScorer.score(l); //取得评分
数据结构基础类(DataPoint、RankList)
数据格式
数据格式与SVM-Rank、libSVM、LETOR格式均相同。格式如下
<line> .=. <target> qid:<qid> <feature>:<value> <feature>:<value> ... <feature>:<value> # <info>
<target> .=. <positive integer> //正整数型评分
<qid> .=. <positive integer> //正整数型查询
<feature> .=. <positive integer> //正整数型特征序号
<value> .=. <float> //浮点型特征值
<info> .=. <string> //注释
DataPoint.java
ciir.umass.edu.learning.DataPoint
实现了需要评分的对象的数据结构。每个对象是一个待评分文档。
RankList.java
ciir.umass.edu.learning.RankList
实现了需要评分的对象组成的列表的数据结构。每个对象是一个包含对应于同一查询的不同文献的集合。
Rank基础类(RANKER_TYPE、RankerFactory、Ranker)
RANKER_TYPE.java
ciir.umass.edu.learning.RANKER_TYPE
枚举类型,包含各种Rank类型
RankerFactory.java
ciir.umass.edu.learning.RankerFactory
实现了RankerFactory,所有Rank方法都需要在此类中注册。
public Ranker createRanker(RANKER_TYPE type)
创建某种类型的Rank对象。
public Ranker loadRanker(String modelFile)
导入已有模型。
Ranker.java
ciir.umass.edu.learning.Ranker
Ranker类实现了一般的Rank接口,所有Rank类型都需要集成Ranker。
通用方法有:
public void setTrainingSet(List<RankList> samples) //设置训练集
public void setValidationSet(List<RankList> samples) //设置验证集
public double getScoreOnTrainingData() //训练集得分...
public void save(String modelFile) //保存模型
public RankList rank(RankList rl) //给出评分后的排序
public List<RankList> rank(List<RankList> l) //给出评分后的排序
必须在子类中实现的方法有:
public void init() //初始化
public void learn() //学习
public double eval(DataPoint p) //评价
public String toString() //模型转为字符串
public void load(String fn) //导入模型
RankerTrainer
ciir.umass.edu.learning.RankerTrainer
实现了对模型进行训练的函数:
public Ranker train(RANKER_TYPE type, List<RankList> train, List<RankList> validation, int[] features, MetricScorer scorer)
LambdaMART基础类(FeatureHistorgram、RegressionTree、Split、Ensemble)
FeatureHistorgram.java
ciir.umass.edu.learning.tree.FeatureHistogram
特征直方图类,对RankList对象进行特征的直方图统计,选择每次split时的最优feature和最优划分点。
- construct方法:
public void construct(DataPoint[] samples, double[] labels, int[][] sampleSortedIdx, int[] features, float[][] thresholds){...sum = new double[features.length][];count = new int[features.length][];...}
* sum[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的label(该算法里为lambda)之和。* count[i][j]:指定feature i 的所有值(训练数据中出现的值),每个j代表一个训练数据中出现的一个值,sum[i][j]的值为feature i的所有小于某个指定值(该值由threshold[j]提供)的训练数据datapoint的总数。
- update方法:
protected void update(double[] labels)
用新的label更新sum[i][j]。
- findBestSplit方法思路:
- 选取feature作为划分的备选(可全选,可选部分)。
- 选取最优feature和最优划分点
- 计算每个feature的每个划分点,doubleS=sumLeft∗sumLeft/countLeft+sumRight∗sumRight/countRightdouble S = sumLeft * sumLeft / countLeft + sumRight * sumRight / countRight,最小的S即为最优feature和最优划分点s(该s是feature的具体值)。
//findBestSplit方法:protected Config findBestSplit(int[] usedFeatures, int minLeafSupport, int start, int end){...int countLeft = count[i][t]; //countLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的总数int countRight = totalCount - countLeft; //countRight是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的总数...double sumLeft = sum[i][t]; //sumLeft是该节点下某个feature的值小于指定值(备选s)的所有训练数据的lambad之和double sumRight = sumResponse - sumLeft; //sumRight 是该节点下某个feature的值大于等于指定值(备选s)的所有训练数据的lambad之和...}
构建树的时候,输入为(xi,lambdai),其中lambdai代表着对xi的评分(影响排序结果,是增大还是减少)。最好的划分点,就是把增大的划分到一起(全部为正值,相加结果为sumA),减少的划分到一起(全部为负值,相加结果为sumb).此时的sumA*sumA/countA+sumB*sumB/countB为最大。
因此,这里的S的含义为:该划分点尽量把正值和负值区分开。正值表示:后续评分调大;负值表示:后续评分调小;lambdai就是si从newTree中获取的值,表示si的值如何调整才能满足C最大(类似梯度)。C表示的是排序后的NDCG,求其最大值。
RegressionTree.java
ciir.umass.edu.learning.tree.RegressionTree
回归树实现。
protected int nodes = 10; //控制分裂的次数,这个次数是按照节点来算的,而不是按照层数来计算的protected int minLeafSupport = 1;//控制分裂的次数,如果某个节点所包含的训练数据小于2*minLeafSupport ,则该节点不再分裂。...protected DataPoint[] trainingSamples = null; //训练的数据protected double[] trainingLabels = null; //这里的lables就是y值,在lambdaMART里为lambda值...
public void fit() //根据输入的数据以及lable值,生成回归树
Ensemble.java
ciir.umass.edu.learning.tree.Ensemble
LambdaMART流程
ciir.umass.edu.learning.tree.LambdaMART
- LambdaMART初始化函数
init()
a. 设置训练数据,为每个训练数据i设置初值(0),为每个训练数据的y设置初值(0),为每个训练数据的w设置初值(0)
b. 按照每个feature的大小重新排训练数据,为方便后面的计算。
c. 每个feature都设置一批值以供后续做回归树split时的切分点。
d. 初始化一个回归树(该树未进行分裂)
public void init(){...//将样本根绝特征排序,方便做树的分列时快速找出最优分列点sortedIdx = new int[features.length][];MyThreadPool p = MyThreadPool.getInstance();if(p.size() == 1)//single-threadsortSamplesByFeature(0, features.length-1);...//创建存放候选阈值(分列点)的表thresholds = new float[features.length][];for(int f=0;f<features.length;f++){...}//计算特征直方图,加速寻找分列点hist = new FeatureHistogram();hist.construct(martSamples, pseudoResponses, sortedIdx, features, thresholds);...}
- LambdaMART训练函数
learn()
生成指定数目的tree,以下为生成一个树的流程。- 清空以前生成的pseudoResponses(yi),weights(wi)
- computePseudoResponses函数中生成新的pseudoResponses(yi),weights(wi)
- 用新生成的pseudoResponses(yi)来更新回归树
- 生成一棵新的回归树,并保存结果
- 求得γlk。
- 重新计算modelScores,即每个训练数据的评分
- 通过early stop的方式校验数据和退出
public void learn(){//开始梯度提升训练过程for(int m=0; m<nTrees; m++){PRINT(new int[]{7}, new String[]{(m+1)+""});//计算lambdas (pseudo responses)computePseudoResponses();//根据新的label更新特征直方图hist.update(pseudoResponses);//回归决策树RegressionTree rt = new RegressionTree(nTreeLeaves, martSamples, pseudoResponses, hist, minLeafSupport);rt.fit();//将新生成的树加入模型ensemble.add(rt, learningRate);//更新树的输出(同时计算利用Newton-Raphson方法计算gamma)updateTreeOutput(rt);//更新所有训练样本的模型输出List<Split> leaves = rt.leaves();for(int i=0;i<leaves.size();i++){Split s = leaves.get(i);int[] idx = s.getSamples();for(int j=0;j<idx.length;j++)modelScores[idx[j]] += learningRate * s.getOutput();}//评价模型scoreOnTrainingData = computeModelScoreOnTraining();//检验是否应该提前结束if(m - bestModelOnValidation > nRoundToStopEarly)break;...//回滚到在验证集上最优的模型ensemble.remove(ensemble.treeCount()-1);...}
Ranklib部分源码分析相关推荐
- 【Golang源码分析】Go Web常用程序包gorilla/mux的使用与源码简析
目录[阅读时间:约10分钟] 一.概述 二.对比: gorilla/mux与net/http DefaultServeMux 三.简单使用 四.源码简析 1.NewRouter函数 2.HandleF ...
- SpringBoot-web开发(四): SpringMVC的拓展、接管(源码分析)
[SpringBoot-web系列]前文: SpringBoot-web开发(一): 静态资源的导入(源码分析) SpringBoot-web开发(二): 页面和图标定制(源码分析) SpringBo ...
- SpringBoot-web开发(二): 页面和图标定制(源码分析)
[SpringBoot-web系列]前文: SpringBoot-web开发(一): 静态资源的导入(源码分析) 目录 一.首页 1. 源码分析 2. 访问首页测试 二.动态页面 1. 动态资源目录t ...
- SpringBoot-web开发(一): 静态资源的导入(源码分析)
目录 方式一:通过WebJars 1. 什么是webjars? 2. webjars的使用 3. webjars结构 4. 解析源码 5. 测试访问 方式二:放入静态资源目录 1. 源码分析 2. 测 ...
- Yolov3Yolov4网络结构与源码分析
Yolov3&Yolov4网络结构与源码分析 从2018年Yolov3年提出的两年后,在原作者声名放弃更新Yolo算法后,俄罗斯的Alexey大神扛起了Yolov4的大旗. 文章目录 论文汇总 ...
- ViewGroup的Touch事件分发(源码分析)
Android中Touch事件的分发又分为View和ViewGroup的事件分发,View的touch事件分发相对比较简单,可参考 View的Touch事件分发(一.初步了解) View的Touch事 ...
- View的Touch事件分发(二.源码分析)
Android中Touch事件的分发又分为View和ViewGroup的事件分发,先来看简单的View的touch事件分发. 主要分析View的dispatchTouchEvent()方法和onTou ...
- MyBatis原理分析之四:一次SQL查询的源码分析
上回我们讲到Mybatis加载相关的配置文件进行初始化,这回我们讲一下一次SQL查询怎么进行的. 准备工作 Mybatis完成一次SQL查询需要使用的代码如下: Java代码 String res ...
- [转]slf4j + log4j原理实现及源码分析
slf4j + log4j原理实现及源码分析 转载于:https://www.cnblogs.com/jasonzeng888/p/6051080.html
最新文章
- PAT题解-1118. Birds in Forest (25)-(并查集模板题)
- 15个顶级Java多线程面试题及回答(高级java工程师)
- PowerPC VxWorks BSP分析(1)--PowerPC体系结构
- mysql快捷键设置_MySQL快捷命令
- 取消服务器系统,取消系统服务器默认共享通道的方法
- servlet实现mvc
- python range倒序_Python算法学习之堆和堆排序
- oracle call 存储过程 带out_详解oracle数据库存储过程调试方法
- Anaconda创建、激活、退出、删除虚拟环境
- UVa 621 - Secret Research
- matlab方程例子,Matlab求解超定方程组实例
- ftp上传工具 免安装,3款最好用的ftp上传工具 免安装
- 共议新时代的文化自信与守正创新,第十四届文化中国讲坛举办
- 苹果发布iOS9.1 Beta 3:新增壁纸和表情
- 70后.net老猿,尚能饭否?
- java高仿新浪微博短链接地址生成工具ShortUrlGenerator.java
- java的虚引用_java虚引用的使用说明
- 24种设计模式的定义和使用场合
- 什么是嵌入式人工智能,它的实际应用
- 机器人军团防护罩_冒险岛贴吧 - 1000A导轨保护罩Y轴保护盖板的简单介绍
热门文章
- XCTF Leaking
- VC驿站黑客编程(关机,重新启动,注销)
- 解题记录 P4017 最大食物链计数 拓扑排序
- 南通大学计算机科学与技术在哪个校区,南通大学各校区分别哪些专业
- 用python画花瓣-用python画花瓣
- Servlet[jsp]的Servlet.service()引发了具有根本原因的异常 (这个是什么情况?求解答)
- ACM/ICPC 2018亚洲区预选赛北京赛站网络赛 D. 80 Days
- hcaptcha 我是人类验证码怎么跳过怎么验证自动识别
- CNopendata空气质量站点监测数据
- windows重装系统之后,开机显示“An operating system wasn't found,Try disconnecting any drives that...”(亲身遇到+解决方法)