@Override protected SB toJavaInit(SB sb, SB fileContextSB) { sb = super.toJavaInit(sb, fileContextSB); sb.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl(); sb.ip("public int nfeatures() { return " + _output.nfeatures() + "; }").nl(); sb.ip("public int nclasses() { return " + _parms._k + "; }").nl(); if (_output._nnums > 0) { JCodeGen.toStaticVar( sb, "NORMMUL", _output._normMul, "Standardization/Normalization scaling factor for numerical variables."); JCodeGen.toStaticVar( sb, "NORMSUB", _output._normSub, "Standardization/Normalization offset for numerical variables."); } JCodeGen.toStaticVar(sb, "CATOFFS", _output._catOffsets, "Categorical column offsets."); JCodeGen.toStaticVar(sb, "PERMUTE", _output._permutation, "Permutation index vector."); JCodeGen.toStaticVar(sb, "EIGVECS", _output._eigenvectors_raw, "Eigenvector matrix."); return sb; }
// Note: POJO scoring code doesn't support per-row offsets (the scoring API would need to be // changed to pass in offsets) @Override protected void toJavaUnifyPreds(SB body, SB file) { // Preds are filled in from the trees, but need to be adjusted according to // the loss function. if (_parms._distribution == Distributions.Family.bernoulli) { body.ip("preds[2] = preds[1] + ").p(_output._init_f).p(";").nl(); body.ip("preds[2] = " + _parms._distribution.linkInvString("preds[2]") + ";").nl(); body.ip("preds[1] = 1.0-preds[2];").nl(); if (_parms._balance_classes) body.ip( "hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);") .nl(); body.ip( "preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + defaultThreshold() + ");") .nl(); return; } if (_output.nclasses() == 1) { // Regression body.ip("preds[0] += ").p(_output._init_f).p(";").nl(); body.ip("preds[0] = " + _parms._distribution.linkInvString("preds[0]") + ";").nl(); return; } if (_output.nclasses() == 2) { // Kept the initial prediction for binomial body.ip("preds[1] += ").p(_output._init_f).p(";").nl(); body.ip("preds[2] = - preds[1];").nl(); } body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl(); if (_parms._balance_classes) body.ip( "hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);") .nl(); body.ip( "preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, " + defaultThreshold() + ");") .nl(); }