Пример #1
0
  @Override
  protected SB toJavaInit(SB sb, SB fileContextSB) {
    sb = super.toJavaInit(sb, fileContextSB);
    sb.ip("public boolean isSupervised() { return " + isSupervised() + "; }").nl();
    sb.ip("public int nfeatures() { return " + _output.nfeatures() + "; }").nl();
    sb.ip("public int nclasses() { return " + _parms._k + "; }").nl();

    if (_output._nnums > 0) {
      JCodeGen.toStaticVar(
          sb,
          "NORMMUL",
          _output._normMul,
          "Standardization/Normalization scaling factor for numerical variables.");
      JCodeGen.toStaticVar(
          sb,
          "NORMSUB",
          _output._normSub,
          "Standardization/Normalization offset for numerical variables.");
    }
    JCodeGen.toStaticVar(sb, "CATOFFS", _output._catOffsets, "Categorical column offsets.");
    JCodeGen.toStaticVar(sb, "PERMUTE", _output._permutation, "Permutation index vector.");
    JCodeGen.toStaticVar(sb, "EIGVECS", _output._eigenvectors_raw, "Eigenvector matrix.");
    return sb;
  }
Пример #2
0
 // Note: POJO scoring code doesn't support per-row offsets (the scoring API would need to be
 // changed to pass in offsets)
 @Override
 protected void toJavaUnifyPreds(SB body, SB file) {
   // Preds are filled in from the trees, but need to be adjusted according to
   // the loss function.
   if (_parms._distribution == Distributions.Family.bernoulli) {
     body.ip("preds[2] = preds[1] + ").p(_output._init_f).p(";").nl();
     body.ip("preds[2] = " + _parms._distribution.linkInvString("preds[2]") + ";").nl();
     body.ip("preds[1] = 1.0-preds[2];").nl();
     if (_parms._balance_classes)
       body.ip(
               "hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);")
           .nl();
     body.ip(
             "preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, "
                 + defaultThreshold()
                 + ");")
         .nl();
     return;
   }
   if (_output.nclasses() == 1) { // Regression
     body.ip("preds[0] += ").p(_output._init_f).p(";").nl();
     body.ip("preds[0] = " + _parms._distribution.linkInvString("preds[0]") + ";").nl();
     return;
   }
   if (_output.nclasses() == 2) { // Kept the initial prediction for binomial
     body.ip("preds[1] += ").p(_output._init_f).p(";").nl();
     body.ip("preds[2] = - preds[1];").nl();
   }
   body.ip("hex.genmodel.GenModel.GBM_rescale(preds);").nl();
   if (_parms._balance_classes)
     body.ip(
             "hex.genmodel.GenModel.correctProbabilities(preds, PRIOR_CLASS_DISTRIB, MODEL_CLASS_DISTRIB);")
         .nl();
   body.ip(
           "preds[0] = hex.genmodel.GenModel.getPrediction(preds, data, "
               + defaultThreshold()
               + ");")
       .nl();
 }