/** * update with give grad and hess * * @param dtrain training data * @param grad first order of gradient * @param hess seconde order of gradient * @throws XGBoostError native error */ public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError { if (grad.length != hess.length) { throw new AssertionError( String.format("grad/hess length mismatch %s / %s", grad.length, hess.length)); } JNIErrorHandle.checkCall( XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess)); }
/** * base function for Predict * * @param data data * @param outPutMargin output margin * @param treeLimit limit number of trees * @param predLeaf prediction minimum to keep leafs * @return predict results */ private synchronized float[][] pred( DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError { int optionMask = 0; if (outPutMargin) { optionMask = 1; } if (predLeaf) { optionMask = 2; } float[][] rawPredicts = new float[1][]; JNIErrorHandle.checkCall( XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts)); int row = (int) data.rowNum(); int col = rawPredicts[0].length / row; float[][] predicts = new float[row][col]; int r, c; for (int i = 0; i < rawPredicts[0].length; i++) { r = i / col; c = i % col; predicts[r][c] = rawPredicts[0][i]; } return predicts; }
/** * Update (one iteration) * * @param dtrain training data * @param iter current iteration number * @throws XGBoostError native error */ public void update(DMatrix dtrain, int iter) throws XGBoostError { JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle())); }