Home 学习使用Xgboost4j
Post
Cancel

学习使用Xgboost4j

Xgboost4j使用Java训练rank(Learning to Rank)模型,跟一般算法不同, 这里数据有个组的概念,

可以通过DMatrix的setGroup()方法设置,参数是一个int数组,这里还是用demo中rank的

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package ml.dmlc.xgboost4j.java.example;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

public class L2rankTrain {
	
	public static void main(String[] args) {
		try {
			DMatrix trainMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.train");
			DMatrix testMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.test");

			DMatrix valiMat = new DMatrix(
					"xgboost/xgboost/demo/rank/mq2008.test");

			String trainGroupFile = "xgboost/xgboost/demo/rank/mq2008.train.group";
			String testGroupFile = "xgboost/xgboost/demo/rank/mq2008.test.group";


			String valiGroupFile = "xgboost/xgboost/demo/rank/mq2008.test.group";

			BufferedReader trainBr = new BufferedReader(new FileReader(
					trainGroupFile));
			BufferedReader testBr = new BufferedReader(new FileReader(
					testGroupFile));

			String line = null;
			
			List trainGroupValueList = new ArrayList<Integer>();
			List testGroupValueList = new ArrayList<Integer>();
			while ((line = trainBr.readLine()) != null) {
				trainGroupValueList.add(Integer.parseInt(line));
			}

			while ((line = testBr.readLine()) != null) {
				testGroupValueList.add(Integer.parseInt(line));
			}

			int[] trainGroupArr = new int[trainGroupValueList.size()];
			for (int i = 0; i < trainGroupArr.length; i++) {
				trainGroupArr[i] = (int) trainGroupValueList.get(i);
			}
			int[] testGroupArr = new int[testGroupValueList.size()];
			for (int i = 0; i < testGroupArr.length; i++) {
				testGroupArr[i] = (int) testGroupValueList.get(i);
			}

			HashMap<String, Object> params = new HashMap<String, Object>();
			params.put("eta", 0.1);
			params.put("gamma", 1.0);
			params.put("min_child_weight", 0.1);
			params.put("max_depth", 6);
			params.put("silent", 1);

			// rank:pairwise  rank:ndcg
			params.put("objective", "rank:pairwise");

			trainMat.setGroup(trainGroupArr);
			testMat.setGroup(testGroupArr);

			HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
			watches.put("train", trainMat);
			watches.put("test", testMat);
			// watches.put("val", value);

			int round = 4;

			Booster booster = XGBoost.train(trainMat, params, round, watches,
					null, null);

//			booster.saveModel(modelPath);
//			XGBoost.loadModel(modelPath);
//			XGBoost.loadModel(in);
			// predict

			float[][] predicts = booster.predict(testMat);
			System.out.println("predicts.length: " + predicts.length);
			for (int i = 0; i < predicts.length; i++) {
				float[] pred = predicts[i];
				for (int j = 0; j < pred.length; j++) {
					System.out.println(pred[j]);
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		}

	}

}

Java调用训练好的model文件, 其中模型文件是rank产生的0004.model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class LoadRankModelTest {
	
	public static void main(String[] args) throws XGBoostError {
	    //测试数据
		DMatrix testMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.test");
		//load model 
		Booster booster2 = XGBoost.loadModel("xgboost/xgboost/demo/rank/0004.model");
		//预测
		float[][] preds = booster2.predict(testMat, true);
		//testMat.setGroup();
		for (int i = 0; i < preds.length; i++) {
			float[] pred = preds[i];
			for (int j = 0; j < pred.length; j++) {
				System.out.println("i: "+i+" pred: "+pred[j]);
			}
		}
	}

}

看Api,模型文件也可以从InputStream得到

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
File f = new File("xgboost/xgboost/demo/rank/0004.model");
try {
	InputStream  input = new FileInputStream(f);
	
		Booster booster2 = XGBoost.loadModel(input);
		DMatrix testMat = new DMatrix("xgboost/xgboost/demo/rank/mq2008.test");
		//Booster booster2 = XGBoost.loadModel("xgboost/xgboost/demo/rank/0004.model");
		float[][] preds = booster2.predict(testMat, true);
		for (int i = 0; i < preds.length; i++) {
			float[] pred = preds[i];
			for (int j = 0; j < pred.length; j++) {
				System.out.println("i: "+i+" pred: "+pred[j]);
			}
		}
	
	
} catch (XGBoostError e) {
	System.out.println("XGBoostError: "+e);
} catch (FileNotFoundException e) {
	System.out.println("FileNotFoundException: "+e);
}  catch (IOException e) {
	System.out.println("e: "+e);
}

这里DMatrix很多时候不可能直接是文件,这时就不能通过DMatrix(String dataPath)来的得到测试的数据了。 这里可以先构造

1
List<LabeledPoint>

LabeledPoint可以认为是SVM格式一行数据,然后通过这个list的iter构造DMatrix

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
File svmFormatFile = new File(TEST_SVM_FORMAT_FILE_PATH);
try {
	BufferedReader bfReader = new BufferedReader(new FileReader(svmFormatFile));
	String line=null;
	//line=bfReader.readLine();
	java.util.List<LabeledPoint> blist = new java.util.LinkedList<LabeledPoint>();
	while((line=bfReader.readLine())!=null){
		String[] lines = line.split(" ");
		float label = Float.parseFloat(lines[0]);
		float[] fVals = new float[lines.length-2];
		int[] indices = new int[lines.length-2];
		
		for (int i = 2; i < lines.length; i++) {
			String[] futureAndVal = lines[i].split(":");
			fVals[i-2]= Float.parseFloat(futureAndVal[1]);
			indices[i-2] = Integer.parseInt(futureAndVal[0]);
		}
		LabeledPoint lp = LabeledPoint.fromSparseVector(label, indices, fVals);
		//LabeledPoint lp = LabeledPoint.fromDenseVector(label, fVals);
		blist.add(lp);
	}
	
	DMatrix testMat = new DMatrix(blist.iterator(), null);
	
	Booster booster2 = XGBoost.loadModel("xgboost/xgboost/demo/rank/0004.model");
	float[][] preds = booster2.predict(testMat, true);
	for (int i = 0; i < preds.length; i++) {
		float[] pred = preds[i];
		for (int j = 0; j < pred.length; j++) {
			System.out.println("i: "+i+" pred: "+pred[j]);
		}
	}
}catch(XGBoostError xGBoostError){
	System.out.println("xGBoostError: "+xGBoostError);
} catch(IOException e){

	System.out.println("e: "+e);

}
This post is licensed under CC BY 4.0 by the author.