Пример #1
0
 /**
  * update with give grad and hess
  *
  * @param dtrain training data
  * @param grad first order of gradient
  * @param hess seconde order of gradient
  * @throws XGBoostError native error
  */
 public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
   if (grad.length != hess.length) {
     throw new AssertionError(
         String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
   }
   JNIErrorHandle.checkCall(
       XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
 }
Пример #2
0
 /**
  * base function for Predict
  *
  * @param data data
  * @param outPutMargin output margin
  * @param treeLimit limit number of trees
  * @param predLeaf prediction minimum to keep leafs
  * @return predict results
  */
 private synchronized float[][] pred(
     DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError {
   int optionMask = 0;
   if (outPutMargin) {
     optionMask = 1;
   }
   if (predLeaf) {
     optionMask = 2;
   }
   float[][] rawPredicts = new float[1][];
   JNIErrorHandle.checkCall(
       XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
   int row = (int) data.rowNum();
   int col = rawPredicts[0].length / row;
   float[][] predicts = new float[row][col];
   int r, c;
   for (int i = 0; i < rawPredicts[0].length; i++) {
     r = i / col;
     c = i % col;
     predicts[r][c] = rawPredicts[0][i];
   }
   return predicts;
 }
Пример #3
0
 /**
  * Update (one iteration)
  *
  * @param dtrain training data
  * @param iter current iteration number
  * @throws XGBoostError native error
  */
 public void update(DMatrix dtrain, int iter) throws XGBoostError {
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
 }