/** * Create DMatrix from Sparse matrix in CSR/CSC format. * * @param headers The row index of the matrix. * @param indices The indices of presenting entries. * @param data The data content. * @param st Type of sparsity. * @throws XGBoostError */ public DMatrix(long[] headers, int[] indices, float[] data, SparseType st) throws XGBoostError { long[] out = new long[1]; if (st == SparseType.CSR) { JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSR(headers, indices, data, out)); } else if (st == SparseType.CSC) { JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromCSC(headers, indices, data, out)); } else { throw new UnknownError("unknow sparsetype"); } handle = out[0]; }
/** * Slice the DMatrix and return a new DMatrix that only contains `rowIndex`. * * @param rowIndex row index * @return sliced new DMatrix * @throws XGBoostError native error */ public DMatrix slice(int[] rowIndex) throws XGBoostError { long[] out = new long[1]; JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSliceDMatrix(handle, rowIndex, out)); long sHandle = out[0]; DMatrix sMatrix = new DMatrix(sHandle); return sMatrix; }
/** * Create DMatrix by loading libsvm file from dataPath * * @param dataPath The path to the data. * @throws XGBoostError */ public DMatrix(String dataPath) throws XGBoostError { if (dataPath == null) { throw new NullPointerException("dataPath: null"); } long[] out = new long[1]; JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromFile(dataPath, 1, out)); handle = out[0]; }
/** * Create DMatrix from iterator. * * @param iter The data iterator of mini batch to provide the data. * @param cacheInfo Cache path information, used for external memory setting, can be null. * @throws XGBoostError */ public DMatrix(Iterator<LabeledPoint> iter, String cacheInfo) throws XGBoostError { if (iter == null) { throw new NullPointerException("iter: null"); } // 32k as batch size int batchSize = 32 << 10; Iterator<DataBatch> batchIter = new DataBatch.BatchIterator(iter, batchSize); long[] out = new long[1]; JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromDataIter(batchIter, cacheInfo, out)); handle = out[0]; }
public synchronized void dispose() { if (handle != 0) { XGBoostJNI.XGDMatrixFree(handle); handle = 0; } }
/** save DMatrix to filePath */ public void saveBinary(String filePath) { XGBoostJNI.XGDMatrixSaveBinary(handle, filePath, 1); }
/** * get the row number of DMatrix * * @return number of rows * @throws XGBoostError native error */ public long rowNum() throws XGBoostError { long[] rowNum = new long[1]; JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixNumRow(handle, rowNum)); return rowNum[0]; }
/** * Set group sizes of DMatrix (used for ranking) * * @param group group size as array * @throws XGBoostError native error */ public void setGroup(int[] group) throws XGBoostError { JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetGroup(handle, group)); }
private int[] getIntInfo(String field) throws XGBoostError { int[][] infos = new int[1][]; JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixGetUIntInfo(handle, field, infos)); return infos[0]; }
/** * if specified, xgboost will start from this init margin can be used to specify initial * prediction to boost from * * @param baseMargin base margin * @throws XGBoostError native error */ public void setBaseMargin(float[] baseMargin) throws XGBoostError { JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "base_margin", baseMargin)); }
/** * set weight of each instance * * @param weights weights * @throws XGBoostError native error */ public void setWeight(float[] weights) throws XGBoostError { JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "weight", weights)); }
/** * set label of dmatrix * * @param labels labels * @throws XGBoostError native error */ public void setLabel(float[] labels) throws XGBoostError { JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixSetFloatInfo(handle, "label", labels)); }
/** * create DMatrix from dense matrix * * @param data data values * @param nrow number of rows * @param ncol number of columns * @throws XGBoostError native error */ public DMatrix(float[] data, int nrow, int ncol) throws XGBoostError { long[] out = new long[1]; JNIErrorHandle.checkCall(XGBoostJNI.XGDMatrixCreateFromMat(data, nrow, ncol, 0.0f, out)); handle = out[0]; }
/** * Check the return value of C API. * * @param ret return valud of xgboostJNI C API call * @throws XGBoostError native error */ static void checkCall(int ret) throws XGBoostError { if (ret != 0) { throw new XGBoostError(XGBoostJNI.XGBGetLastError()); } }