示例#1
0
 /**
  * Create DMatrix from Sparse matrix in CSR/CSC format.
  *
  * @param headers The row index of the matrix.
  * @param indices The indices of presenting entries.
  * @param data The data content.
  * @param st Type of sparsity.
  * @throws XGBoostError
  */
 public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError {
   long[] out = new long[1];
   if (st == SparseType.CSR) {
     JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out));
   } else if (st == SparseType.CSC) {
     JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out));
   } else {
     throw new UnknownError("unknow sparsetype");
   }
   handle = out[0];
 }
示例#2
0
 /**
  * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`.
  *
  * @param rowIndex row index
  * @return sliced new DMatrix
  * @throws XGBoostError native error
  */
 public DMatrix slice(int[] rowIndex) throws XGBoostError {
   long[] out = new long[1];
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out));
   long sHandle = out[0];
   DMatrix sMatrix = new DMatrix(sHandle);
   return sMatrix;
 }
示例#3
0
 /**
  * Create DMatrix by loading libsvm file from dataPath
  *
  * @param dataPath The path to the data.
  * @throws XGBoostError
  */
 public DMatrix(String dataPath) throws XGBoostError {
   if (dataPath == null) {
     throw new NullPointerException("dataPath: null");
   }
   long[] out = new long[1];
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out));
   handle = out[0];
 }
示例#4
0
 /**
  * Create DMatrix from iterator.
  *
  * @param iter The data iterator of mini batch to provide the data.
  * @param cacheInfo Cache path information, used for external memory setting, can be null.
  * @throws XGBoostError
  */
 public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError {
   if (iter == null) {
     throw new NullPointerException("iter: null");
   }
   // 32k as batch size
   int batchSize = 32 << 10;
   Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize);
   long[] out = new long[1];
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out));
   handle = out[0];
 }
示例#5
0
 public synchronized void dispose() {
   if (handle != 0) {
     XGBoostJNI.XGDMatrixFree(handle);
     handle = 0;
   }
 }
示例#6
0
 /** save DMatrix to filePath */
 public void saveBinary(String filePath) {
   XGBoostJNI.XGDMatrixSaveBinary(handle, filePath, 1);
 }
示例#7
0
 /**
  * get the row number of DMatrix
  *
  * @return number of rows
  * @throws XGBoostError native error
  */
 public long rowNum() throws XGBoostError {
   long[] rowNum = new long[1];
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum));
   return rowNum[0];
 }
示例#8
0
 /**
  * Set group sizes of DMatrix (used for ranking)
  *
  * @param group group size as array
  * @throws XGBoostError native error
  */
 public void setGroup(int[] group) throws XGBoostError {
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group));
 }
示例#9
0
 private int[] getIntInfo(String field) throws XGBoostError {
   int[][] infos = new int[1][];
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos));
   return infos[0];
 }
示例#10
0
 /**
  * if specified, xgboost will start from this init margin can be used to specify initial
  * prediction to boost from
  *
  * @param baseMargin base margin
  * @throws XGBoostError native error
  */
 public void setBaseMargin(float[] baseMargin) throws XGBoostError {
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin));
 }
示例#11
0
 /**
  * set weight of each instance
  *
  * @param weights weights
  * @throws XGBoostError native error
  */
 public void setWeight(float[] weights) throws XGBoostError {
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights));
 }
示例#12
0
 /**
  * set label of dmatrix
  *
  * @param labels labels
  * @throws XGBoostError native error
  */
 public void setLabel(float[] labels) throws XGBoostError {
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels));
 }
示例#13
0
 /**
  * create DMatrix from dense matrix
  *
  * @param data data values
  * @param nrow number of rows
  * @param ncol number of columns
  * @throws XGBoostError native error
  */
 public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError {
   long[] out = new long[1];
   JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out));
   handle = out[0];
 }
示例#14
0
 /**
  * Check the return value of C API.
  *
  * @param ret return valud of xgboostJNI C API call
  * @throws XGBoostError native error
  */
 static void checkCall(int ret) throws XGBoostError {
   if (ret != 0) {
     throw new XGBoostError(XGBoostJNI.XGBGetLastError());
   }
 }