Xgboost4j使用Java训练rank(Learning to Rank)模型,跟一般算法不同, 这里数据有个组的概念,
可以通过DMatrix的setGroup()方法设置,参数是一个int数组,这里还是用demo中rank的
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package ml.dmlc.xgboost4j.java.example;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class L2rankTrain {
public static void main(String[] args) {
try {
DMatrix trainMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.train");
DMatrix testMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.test");
DMatrix valiMat = new DMatrix(
"xgboost/xgboost/demo/rank/mq2008.test");
String trainGroupFile = "xgboost/xgboost/demo/rank/mq2008.train.group";
String testGroupFile = "xgboost/xgboost/demo/rank/mq2008.test.group";
String valiGroupFile = "xgboost/xgboost/demo/rank/mq2008.test.group";
BufferedReader trainBr = new BufferedReader(new FileReader(
trainGroupFile));
BufferedReader testBr = new BufferedReader(new FileReader(
testGroupFile));
String line = null;
List trainGroupValueList = new ArrayList<Integer>();
List testGroupValueList = new ArrayList<Integer>();
while ((line = trainBr.readLine()) != null) {
trainGroupValueList.add(Integer.parseInt(line));
}
while ((line = testBr.readLine()) != null) {
testGroupValueList.add(Integer.parseInt(line));
}
int[] trainGroupArr = new int[trainGroupValueList.size()];
for (int i = 0; i < trainGroupArr.length; i++) {
trainGroupArr[i] = (int) trainGroupValueList.get(i);
}
int[] testGroupArr = new int[testGroupValueList.size()];
for (int i = 0; i < testGroupArr.length; i++) {
testGroupArr[i] = (int) testGroupValueList.get(i);
}
HashMap<String, Object> params = new HashMap<String, Object>();
params.put("eta", 0.1);
params.put("gamma", 1.0);
params.put("min_child_weight", 0.1);
params.put("max_depth", 6);
params.put("silent", 1);
// rank:pairwise rank:ndcg
params.put("objective", "rank:pairwise");
trainMat.setGroup(trainGroupArr);
testMat.setGroup(testGroupArr);
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
watches.put("test", testMat);
// watches.put("val", value);
int round = 4;
Booster booster = XGBoost.train(trainMat, params, round, watches,
null, null);
// booster.saveModel(modelPath);
// XGBoost.loadModel(modelPath);
// XGBoost.loadModel(in);
// predict
float[][] predicts = booster.predict(testMat);
System.out.println("predicts.length: " + predicts.length);
for (int i = 0; i < predicts.length; i++) {
float[] pred = predicts[i];
for (int j = 0; j < pred.length; j++) {
System.out.println(pred[j]);
}
}
} catch (Exception e) {
e.printStackTrace();
}
}
}
Java调用训练好的model文件, 其中模型文件是rank产生的0004.model
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class LoadRankModelTest {
public static void main(String[] args) throws XGBoostError {
//测试数据
DMatrix testMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.test");
//load model
Booster booster2 = XGBoost.loadModel("xgboost/xgboost/demo/rank/0004.model");
//预测
float[][] preds = booster2.predict(testMat, true);
//testMat.setGroup();
for (int i = 0; i < preds.length; i++) {
float[] pred = preds[i];
for (int j = 0; j < pred.length; j++) {
System.out.println("i: "+i+" pred: "+pred[j]);
}
}
}
}
看Api,模型文件也可以从InputStream得到
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
File f = new File("xgboost/xgboost/demo/rank/0004.model");
try {
InputStream input = new FileInputStream(f);
Booster booster2 = XGBoost.loadModel(input);
DMatrix testMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.test");
//Booster booster2 = XGBoost.loadModel("xgboost/xgboost/demo/rank/0004.model");
float[][] preds = booster2.predict(testMat, true);
for (int i = 0; i < preds.length; i++) {
float[] pred = preds[i];
for (int j = 0; j < pred.length; j++) {
System.out.println("i: "+i+" pred: "+pred[j]);
}
}
} catch (XGBoostError e) {
System.out.println("XGBoostError: "+e);
} catch (FileNotFoundException e) {
System.out.println("FileNotFoundException: "+e);
} catch (IOException e) {
System.out.println("e: "+e);
}
这里DMatrix很多时候不可能直接是文件,这时就不能通过DMatrix(String dataPath)来的得到测试的数据了。 这里可以先构造
1
List<LabeledPoint>
LabeledPoint可以认为是SVM格式一行数据,然后通过这个list的iter构造DMatrix
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
File svmFormatFile = new File(TEST_SVM_FORMAT_FILE_PATH);
try {
BufferedReader bfReader = new BufferedReader(new FileReader(svmFormatFile));
String line=null;
//line=bfReader.readLine();
java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
while((line=bfReader.readLine())!=null){
String[] lines = line.split(" ");
float label = Float.parseFloat(lines[0]);
float[] fVals = new float[lines.length-2];
int[] indices = new int[lines.length-2];
for (int i = 2; i < lines.length; i++) {
String[] futureAndVal = lines[i].split(":");
fVals[i-2]= Float.parseFloat(futureAndVal[1]);
indices[i-2] = Integer.parseInt(futureAndVal[0]);
}
LabeledPoint lp = LabeledPoint.fromSparseVector(label, indices, fVals);
//LabeledPoint lp = LabeledPoint.fromDenseVector(label, fVals);
blist.add(lp);
}
DMatrix testMat = new DMatrix(blist.iterator(), null);
Booster booster2 = XGBoost.loadModel("xgboost/xgboost/demo/rank/0004.model");
float[][] preds = booster2.predict(testMat, true);
for (int i = 0; i < preds.length; i++) {
float[] pred = preds[i];
for (int j = 0; j < pred.length; j++) {
System.out.println("i: "+i+" pred: "+pred[j]);
}
}
}catch(XGBoostError xGBoostError){
System.out.println("xGBoostError: "+xGBoostError);
} catch(IOException e){
System.out.println("e: "+e);
}