/** * evaluate with given dmatrixs. * * @param evalMatrixs dmatrixs for evaluation * @param evalNames name for eval dmatrixs, used for check results * @param iter current eval iteration * @return eval information * @throws XGBoostError native error */ public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError { long[] handles = dmatrixsToHandles(evalMatrixs); String[] evalInfo = new String[1]; JNIErrorHandle.checkCall( XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo)); return evalInfo[0]; }
/** * update with give grad and hess * * @param dtrain training data * @param grad first order of gradient * @param hess seconde order of gradient * @throws XGBoostError native error */ public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { if (grad.length != hess.length) { throw new AssertionError( String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); } JNIErrorHandle.checkCall( XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess)); }
/** * get the dump of the model as a string array * * @param withStats Controls whether the split statistics are output. * @return dumped model information * @throws XGBoostError native error */ private String[] getDumpInfo(boolean withStats) throws XGBoostError { int statsFlag = 0; if (withStats) { statsFlag = 1; } String[][] modelInfos = new String[1][]; JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos)); return modelInfos[0]; }
private void init(DMatrix[] dMatrixs) throws XGBoostError { long[] handles = null; if (dMatrixs != null) { handles = dmatrixsToHandles(dMatrixs); } long[] out = new long[1]; JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out)); handle = out[0]; }
/** * base function for Predict * * @param data data * @param outPutMargin output margin * @param treeLimit limit number of trees * @param predLeaf prediction minimum to keep leafs * @return predict results */ private synchronized float[][] pred( DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError { int optionMask = 0; if (outPutMargin) { optionMask = 1; } if (predLeaf) { optionMask = 2; } float[][] rawPredicts = new float[1][]; JNIErrorHandle.checkCall( XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts)); int row = (int) data.rowNum(); int col = rawPredicts[0].length / row; float[][] predicts = new float[row][col]; int r, c; for (int i = 0; i < rawPredicts[0].length; i++) { r = i / col; c = i % col; predicts[r][c] = rawPredicts[0][i]; } return predicts; }
/** * set parameter * * @param key param name * @param value param value * @throws XGBoostError native error */ public final void setParam(String key, String value) throws XGBoostError { JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value)); }
public synchronized void dispose() { if (handle != 0L) { XgboostJNI.XGBoosterFree(handle); handle = 0; } }
/** * Save the booster model into thread-local rabit checkpoint. This is only used in distributed * training. * * @throws XGBoostError */ void saveRabitCheckpoint() throws XGBoostError { JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle)); }
/** * Load the booster model from thread-local rabit checkpoint. This is only used in distributed * training. * * @return the stored version number of the checkpoint. * @throws XGBoostError */ int loadRabitCheckpoint() throws XGBoostError { int[] out = new int[1]; JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out)); return out[0]; }
private void loadModel(String modelPath) { XgboostJNI.XGBoosterLoadModel(handle, modelPath); }
/** * save model to modelPath * * @param modelPath model path */ public void saveModel(String modelPath) throws XGBoostError { JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath)); }
/** * Update (one iteration) * * @param dtrain training data * @param iter current iteration number * @throws XGBoostError native error */ public void update(DMatrix dtrain, int iter) throws XGBoostError { JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); }