Ejemplo n.º 1
0
 /**
  * evaluate with given dmatrixs.
  *
  * @param evalMatrixs dmatrixs for evaluation
  * @param evalNames name for eval dmatrixs, used for check results
  * @param iter current eval iteration
  * @return eval information
  * @throws XGBoostError native error
  */
 public String evalSet(DMatrix[] evalMatrixs, String[] evalNames, int iter) throws XGBoostError {
   long[] handles = dmatrixsToHandles(evalMatrixs);
   String[] evalInfo = new String[1];
   JNIErrorHandle.checkCall(
       XgboostJNI.XGBoosterEvalOneIter(handle, iter, handles, evalNames, evalInfo));
   return evalInfo[0];
 }
Ejemplo n.º 2
0
 /**
  * update with give grad and hess
  *
  * @param dtrain training data
  * @param grad first order of gradient
  * @param hess seconde order of gradient
  * @throws XGBoostError native error
  */
 public void boost(DMatrix dtrain, float[] grad, float[] hess) throws XGBoostError {
   if (grad.length != hess.length) {
     throw new AssertionError(
         String.format("grad/hess length mismatch %s / %s", grad.length, hess.length));
   }
   JNIErrorHandle.checkCall(
       XgboostJNI.XGBoosterBoostOneIter(handle, dtrain.getHandle(), grad, hess));
 }
Ejemplo n.º 3
0
 /**
  * get the dump of the model as a string array
  *
  * @param withStats Controls whether the split statistics are output.
  * @return dumped model information
  * @throws XGBoostError native error
  */
 private String[] getDumpInfo(boolean withStats) throws XGBoostError {
   int statsFlag = 0;
   if (withStats) {
     statsFlag = 1;
   }
   String[][] modelInfos = new String[1][];
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterDumpModel(handle, "", statsFlag, modelInfos));
   return modelInfos[0];
 }
Ejemplo n.º 4
0
  private void init(DMatrix[] dMatrixs) throws XGBoostError {
    long[] handles = null;
    if (dMatrixs != null) {
      handles = dmatrixsToHandles(dMatrixs);
    }
    long[] out = new long[1];
    JNIErrorHandle.checkCall(XgboostJNI.XGBoosterCreate(handles, out));

    handle = out[0];
  }
Ejemplo n.º 5
0
 /**
  * base function for Predict
  *
  * @param data data
  * @param outPutMargin output margin
  * @param treeLimit limit number of trees
  * @param predLeaf prediction minimum to keep leafs
  * @return predict results
  */
 private synchronized float[][] pred(
     DMatrix data, boolean outPutMargin, int treeLimit, boolean predLeaf) throws XGBoostError {
   int optionMask = 0;
   if (outPutMargin) {
     optionMask = 1;
   }
   if (predLeaf) {
     optionMask = 2;
   }
   float[][] rawPredicts = new float[1][];
   JNIErrorHandle.checkCall(
       XgboostJNI.XGBoosterPredict(handle, data.getHandle(), optionMask, treeLimit, rawPredicts));
   int row = (int) data.rowNum();
   int col = rawPredicts[0].length / row;
   float[][] predicts = new float[row][col];
   int r, c;
   for (int i = 0; i < rawPredicts[0].length; i++) {
     r = i / col;
     c = i % col;
     predicts[r][c] = rawPredicts[0][i];
   }
   return predicts;
 }
Ejemplo n.º 6
0
 /**
  * set parameter
  *
  * @param key param name
  * @param value param value
  * @throws XGBoostError native error
  */
 public final void setParam(String key, String value) throws XGBoostError {
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSetParam(handle, key, value));
 }
Ejemplo n.º 7
0
 public synchronized void dispose() {
   if (handle != 0L) {
     XgboostJNI.XGBoosterFree(handle);
     handle = 0;
   }
 }
Ejemplo n.º 8
0
 /**
  * Save the booster model into thread-local rabit checkpoint. This is only used in distributed
  * training.
  *
  * @throws XGBoostError
  */
 void saveRabitCheckpoint() throws XGBoostError {
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveRabitCheckpoint(this.handle));
 }
Ejemplo n.º 9
0
 /**
  * Load the booster model from thread-local rabit checkpoint. This is only used in distributed
  * training.
  *
  * @return the stored version number of the checkpoint.
  * @throws XGBoostError
  */
 int loadRabitCheckpoint() throws XGBoostError {
   int[] out = new int[1];
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterLoadRabitCheckpoint(this.handle, out));
   return out[0];
 }
Ejemplo n.º 10
0
 private void loadModel(String modelPath) {
   XgboostJNI.XGBoosterLoadModel(handle, modelPath);
 }
Ejemplo n.º 11
0
 /**
  * save model to modelPath
  *
  * @param modelPath model path
  */
 public void saveModel(String modelPath) throws XGBoostError {
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterSaveModel(handle, modelPath));
 }
Ejemplo n.º 12
0
 /**
  * Update (one iteration)
  *
  * @param dtrain training data
  * @param iter current iteration number
  * @throws XGBoostError native error
  */
 public void update(DMatrix dtrain, int iter) throws XGBoostError {
   JNIErrorHandle.checkCall(XgboostJNI.XGBoosterUpdateOneIter(handle, iter, dtrain.getHandle()));
 }