예제 #1
0
  /**
   * Score a frame with the given model and return just the metrics.
   *
   * <p>NOTE: ModelMetrics are now always being created by model.score. . .
   */
  @SuppressWarnings("unused") // called through reflection by RequestServer
  public ModelMetricsListSchemaV3 score(int version, ModelMetricsListSchemaV3 s) {
    // parameters checking:
    if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
    if (null == DKV.get(s.model.name))
      throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);

    if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
    if (null == DKV.get(s.frame.name))
      throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);

    ModelMetricsList parms = s.createAndFillImpl();
    parms
        ._model
        .score(parms._frame, parms._predictions_name)
        .remove(); // throw away predictions, keep metrics as a side-effect
    ModelMetricsListSchemaV3 mm = this.fetch(version, s);

    // TODO: for now only binary predictors write an MM object.
    // For the others cons one up here to return the predictions frame.
    if (null == mm) mm = new ModelMetricsListSchemaV3();

    if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
      Log.warn(
          "Score() did not return a ModelMetrics for model: " + s.model + " on frame: " + s.frame);
    }

    return mm;
  }
예제 #2
0
  /** Score a frame with the given model and return the metrics AND the prediction frame. */
  @SuppressWarnings("unused") // called through reflection by RequestServer
  public JobV3 predict2(int version, final ModelMetricsListSchemaV3 s) {
    // parameters checking:
    if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
    if (null == DKV.get(s.model.name))
      throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);

    if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
    if (null == DKV.get(s.frame.name))
      throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);

    final ModelMetricsList parms = s.createAndFillImpl();

    // predict2 does not return modelmetrics, so cannot handle deeplearning: reconstruction_error
    // (anomaly) or GLRM: reconstruct and archetypes
    // predict2 can handle deeplearning: deepfeatures and predict

    if (s.deep_features_hidden_layer > 0) {
      if (null == parms._predictions_name)
        parms._predictions_name =
            "deep_features"
                + Key.make().toString().substring(0, 5)
                + "_"
                + parms._model._key.toString()
                + "_on_"
                + parms._frame._key.toString();
    } else if (null == parms._predictions_name)
      parms._predictions_name =
          "predictions"
              + Key.make().toString().substring(0, 5)
              + "_"
              + parms._model._key.toString()
              + "_on_"
              + parms._frame._key.toString();

    final Job<Frame> j =
        new Job(Key.make(parms._predictions_name), Frame.class.getName(), "prediction");

    H2O.H2OCountedCompleter work =
        new H2O.H2OCountedCompleter() {
          @Override
          public void compute2() {
            if (s.deep_features_hidden_layer < 0) {
              parms._model.score(parms._frame, parms._predictions_name, j);
            } else {
              Frame predictions =
                  ((Model.DeepFeatures) parms._model)
                      .scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer, j);
              predictions =
                  new Frame(
                      Key.make(parms._predictions_name), predictions.names(), predictions.vecs());
              DKV.put(predictions._key, predictions);
            }
            tryComplete();
          }
        };
    j.start(work, parms._frame.anyVec().nChunks());
    return new JobV3().fillFromImpl(j);
  }
예제 #3
0
  /** Score a frame with the given model and return the metrics AND the prediction frame. */
  @SuppressWarnings("unused") // called through reflection by RequestServer
  public ModelMetricsListSchemaV3 predict(int version, ModelMetricsListSchemaV3 s) {
    // parameters checking:
    if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
    if (null == DKV.get(s.model.name))
      throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);

    if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
    if (null == DKV.get(s.frame.name))
      throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);

    ModelMetricsList parms = s.createAndFillImpl();

    Frame predictions;
    if (!s.reconstruction_error
        && !s.reconstruction_error_per_feature
        && s.deep_features_hidden_layer < 0
        && !s.project_archetypes
        && !s.reconstruct_train
        && !s.leaf_node_assignment) {
      if (null == parms._predictions_name)
        parms._predictions_name =
            "predictions"
                + Key.make().toString().substring(0, 5)
                + "_"
                + parms._model._key.toString()
                + "_on_"
                + parms._frame._key.toString();
      predictions = parms._model.score(parms._frame, parms._predictions_name);
    } else {
      if (Model.DeepFeatures.class.isAssignableFrom(parms._model.getClass())) {
        if (s.reconstruction_error || s.reconstruction_error_per_feature) {
          if (s.deep_features_hidden_layer >= 0)
            throw new H2OIllegalArgumentException(
                "Can only compute either reconstruction error OR deep features.", "");
          if (null == parms._predictions_name)
            parms._predictions_name =
                "reconstruction_error"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_on_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.DeepFeatures) parms._model)
                  .scoreAutoEncoder(
                      parms._frame,
                      Key.make(parms._predictions_name),
                      parms._reconstruction_error_per_feature);
        } else {
          if (s.deep_features_hidden_layer < 0)
            throw new H2OIllegalArgumentException(
                "Deep features hidden layer index must be >= 0.", "");
          if (null == parms._predictions_name)
            parms._predictions_name =
                "deep_features"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_on_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.DeepFeatures) parms._model)
                  .scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer);
        }
        predictions =
            new Frame(Key.make(parms._predictions_name), predictions.names(), predictions.vecs());
        DKV.put(predictions._key, predictions);
      } else if (Model.GLRMArchetypes.class.isAssignableFrom(parms._model.getClass())) {
        if (s.project_archetypes) {
          if (null == parms._predictions_name)
            parms._predictions_name =
                "reconstructed_archetypes_"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_of_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.GLRMArchetypes) parms._model)
                  .scoreArchetypes(
                      parms._frame, Key.make(parms._predictions_name), s.reverse_transform);
        } else {
          assert s.reconstruct_train;
          if (null == parms._predictions_name)
            parms._predictions_name =
                "reconstruction_"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_of_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.GLRMArchetypes) parms._model)
                  .scoreReconstruction(
                      parms._frame, Key.make(parms._predictions_name), s.reverse_transform);
        }
      } else if (Model.LeafNodeAssignment.class.isAssignableFrom(parms._model.getClass())) {
        assert (s.leaf_node_assignment);
        if (null == parms._predictions_name)
          parms._predictions_name =
              "leaf_node_assignement"
                  + Key.make().toString().substring(0, 5)
                  + "_"
                  + parms._model._key.toString()
                  + "_on_"
                  + parms._frame._key.toString();
        predictions =
            ((Model.LeafNodeAssignment) parms._model)
                .scoreLeafNodeAssignment(parms._frame, Key.make(parms._predictions_name));
      } else
        throw new H2OIllegalArgumentException(
            "Requires a Deep Learning, GLRM, DRF or GBM model.",
            "Model must implement specific methods.");
    }

    ModelMetricsListSchemaV3 mm = this.fetch(version, s);

    // TODO: for now only binary predictors write an MM object.
    // For the others cons one up here to return the predictions frame.
    if (null == mm) mm = new ModelMetricsListSchemaV3();

    mm.predictions_frame = new KeyV3.FrameKeyV3(predictions._key);
    if (parms._leaf_node_assignment) // don't show metrics in leaf node assignments are made
    mm.model_metrics = null;

    if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
      // There was no response in the test set -> cannot make a model_metrics object
    } else {
      mm.model_metrics[0].predictions =
          new FrameV3(predictions, 0, 100); // TODO: Should call schema(version)
    }
    return mm;
  }
예제 #4
0
 /** Delete one or more ModelMetrics. */
 @SuppressWarnings("unused") // called through reflection by RequestServer
 public ModelMetricsListSchemaV3 delete(int version, ModelMetricsListSchemaV3 s) {
   ModelMetricsList m = s.createAndFillImpl();
   s.fillFromImpl(m.delete());
   return s;
 }