Ejemplo n.º 1
0
 @Override
 protected void toJavaPredictBody(
     SBPrintStream body,
     CodeGeneratorPipeline classCtx,
     CodeGeneratorPipeline fileCtx,
     final boolean verboseCode) {
   // Generate static fields
   classCtx.add(
       new CodeGenerator() {
         @Override
         public void generate(JCodeSB out) {
           JCodeGen.toClassWithArray(out, "static", "BETA", beta_internal()); // "The Coefficients"
           JCodeGen.toClassWithArray(
               out, "static", "NUM_MEANS", _output._dinfo._numMeans, "Imputed numeric values");
           JCodeGen.toClassWithArray(
               out,
               "static",
               "CAT_MODES",
               _output._dinfo._catModes,
               "Imputed categorical values.");
           JCodeGen.toStaticVar(out, "CATOFFS", dinfo()._catOffsets, "Categorical Offsets");
         }
       });
   body.ip("final double [] b = BETA.VALUES;").nl();
   if (_parms._missing_values_handling == MissingValuesHandling.MeanImputation) {
     body.ip(
             "for(int i = 0; i < "
                 + _output._dinfo._cats
                 + "; ++i) if(Double.isNaN(data[i])) data[i] = CAT_MODES.VALUES[i];")
         .nl();
     body.ip(
             "for(int i = 0; i < "
                 + _output._dinfo._nums
                 + "; ++i) if(Double.isNaN(data[i + "
                 + _output._dinfo._cats
                 + "])) data[i+"
                 + _output._dinfo._cats
                 + "] = NUM_MEANS.VALUES[i];")
         .nl();
   }
   if (_parms._family != Family.multinomial) {
     body.ip("double eta = 0.0;").nl();
     if (!_parms._use_all_factor_levels) { // skip level 0 of all factors
       body.ip("for(int i = 0; i < CATOFFS.length-1; ++i) if(data[i] != 0) {").nl();
       body.ip("  int ival = (int)data[i] - 1;").nl();
       body.ip(
               "  if(ival != data[i] - 1) throw new IllegalArgumentException(\"categorical value out of range\");")
           .nl();
       body.ip("  ival += CATOFFS[i];").nl();
       body.ip("  if(ival < CATOFFS[i + 1])").nl();
       body.ip("    eta += b[ival];").nl();
     } else { // do not skip any levels
       body.ip("for(int i = 0; i < CATOFFS.length-1; ++i) {").nl();
       body.ip("  int ival = (int)data[i];").nl();
       body.ip(
               "  if(ival != data[i]) throw new IllegalArgumentException(\"categorical value out of range\");")
           .nl();
       body.ip("  ival += CATOFFS[i];").nl();
       body.ip("  if(ival < CATOFFS[i + 1])").nl();
       body.ip("    eta += b[ival];").nl();
     }
     body.ip("}").nl();
     final int noff = dinfo().numStart() - dinfo()._cats;
     body.ip("for(int i = ").p(dinfo()._cats).p("; i < b.length-1-").p(noff).p("; ++i)").nl();
     body.ip("  eta += b[").p(noff).p("+i]*data[i];").nl();
     body.ip("eta += b[b.length-1]; // reduce intercept").nl();
     if (_parms._family != Family.tweedie)
       body.ip("double mu = hex.genmodel.GenModel.GLM_").p(_parms._link.toString()).p("Inv(eta");
     else
       body.ip(
           "double mu = hex.genmodel.GenModel.GLM_tweedieInv(eta," + _parms._tweedie_link_power);
     body.p(");").nl();
     if (_parms._family == Family.binomial) {
       body.ip("preds[0] = (mu > ")
           .p(defaultThreshold())
           .p(") ? 1 : 0")
           .p("; // threshold given by ROC")
           .nl();
       body.ip("preds[1] = 1.0 - mu; // class 0").nl();
       body.ip("preds[2] =       mu; // class 1").nl();
     } else {
       body.ip("preds[0] = mu;").nl();
     }
   } else {
     int P = _output._global_beta_multinomial[0].length;
     body.ip("preds[0] = 0;").nl();
     body.ip("for(int c = 0; c < " + _output._nclasses + "; ++c){").nl();
     body.ip("  preds[c+1] = 0;").nl();
     if (dinfo()._cats > 0) {
       if (!_parms._use_all_factor_levels) { // skip level 0 of all factors
         body.ip("  for(int i = 0; i < CATOFFS.length-1; ++i) if(data[i] != 0) {").nl();
         body.ip("    int ival = (int)data[i] - 1;").nl();
         body.ip(
                 "    if(ival != data[i] - 1) throw new IllegalArgumentException(\"categorical value out of range\");")
             .nl();
         body.ip("    ival += CATOFFS[i];").nl();
         body.ip("    if(ival < CATOFFS[i + 1])").nl();
         body.ip("      preds[c+1] += b[ival+c*" + P + "];").nl();
       } else { // do not skip any levels
         body.ip("  for(int i = 0; i < CATOFFS.length-1; ++i) {").nl();
         body.ip("    int ival = (int)data[i];").nl();
         body.ip(
                 "    if(ival != data[i]) throw new IllegalArgumentException(\"categorical value out of range\");")
             .nl();
         body.ip("    ival += CATOFFS[i];").nl();
         body.ip("    if(ival < CATOFFS[i + 1])").nl();
         body.ip("      preds[c+1] += b[ival+c*" + P + "];").nl();
       }
       body.ip("  }").nl();
     }
     final int noff = dinfo().numStart();
     body.ip("  for(int i = 0; i < " + dinfo()._nums + "; ++i)").nl();
     body.ip("    preds[c+1] += b[" + noff + "+i + c*" + P + "]*data[i];").nl();
     body.ip("  preds[c+1] += b[" + (P - 1) + " + c*" + P + "]; // reduce intercept").nl();
     body.ip("}").nl();
     body.ip("double max_row = 0;").nl();
     body.ip("for(int c = 1; c < preds.length; ++c) if(preds[c] > max_row) max_row = preds[c];")
         .nl();
     body.ip("double sum_exp = 0;").nl();
     body.ip(
             "for(int c = 1; c < preds.length; ++c) { sum_exp += (preds[c] = Math.exp(preds[c]-max_row));}")
         .nl();
     body.ip("sum_exp = 1/sum_exp;").nl();
     body.ip("double max_p = 0;").nl();
     body.ip(
             "for(int c = 1; c < preds.length; ++c) if((preds[c] *= sum_exp) > max_p){ max_p = preds[c]; preds[0] = c-1;};")
         .nl();
   }
 }