コード例 #1
0
  /** Score a frame with the given model and return the metrics AND the prediction frame. */
  @SuppressWarnings("unused") // called through reflection by RequestServer
  public ModelMetricsListSchemaV3 predict(int version, ModelMetricsListSchemaV3 s) {
    // parameters checking:
    if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model);
    if (null == DKV.get(s.model.name))
      throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name);

    if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame);
    if (null == DKV.get(s.frame.name))
      throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name);

    ModelMetricsList parms = s.createAndFillImpl();

    Frame predictions;
    if (!s.reconstruction_error
        && !s.reconstruction_error_per_feature
        && s.deep_features_hidden_layer < 0
        && !s.project_archetypes
        && !s.reconstruct_train
        && !s.leaf_node_assignment) {
      if (null == parms._predictions_name)
        parms._predictions_name =
            "predictions"
                + Key.make().toString().substring(0, 5)
                + "_"
                + parms._model._key.toString()
                + "_on_"
                + parms._frame._key.toString();
      predictions = parms._model.score(parms._frame, parms._predictions_name);
    } else {
      if (Model.DeepFeatures.class.isAssignableFrom(parms._model.getClass())) {
        if (s.reconstruction_error || s.reconstruction_error_per_feature) {
          if (s.deep_features_hidden_layer >= 0)
            throw new H2OIllegalArgumentException(
                "Can only compute either reconstruction error OR deep features.", "");
          if (null == parms._predictions_name)
            parms._predictions_name =
                "reconstruction_error"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_on_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.DeepFeatures) parms._model)
                  .scoreAutoEncoder(
                      parms._frame,
                      Key.make(parms._predictions_name),
                      parms._reconstruction_error_per_feature);
        } else {
          if (s.deep_features_hidden_layer < 0)
            throw new H2OIllegalArgumentException(
                "Deep features hidden layer index must be >= 0.", "");
          if (null == parms._predictions_name)
            parms._predictions_name =
                "deep_features"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_on_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.DeepFeatures) parms._model)
                  .scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer);
        }
        predictions =
            new Frame(Key.make(parms._predictions_name), predictions.names(), predictions.vecs());
        DKV.put(predictions._key, predictions);
      } else if (Model.GLRMArchetypes.class.isAssignableFrom(parms._model.getClass())) {
        if (s.project_archetypes) {
          if (null == parms._predictions_name)
            parms._predictions_name =
                "reconstructed_archetypes_"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_of_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.GLRMArchetypes) parms._model)
                  .scoreArchetypes(
                      parms._frame, Key.make(parms._predictions_name), s.reverse_transform);
        } else {
          assert s.reconstruct_train;
          if (null == parms._predictions_name)
            parms._predictions_name =
                "reconstruction_"
                    + Key.make().toString().substring(0, 5)
                    + "_"
                    + parms._model._key.toString()
                    + "_of_"
                    + parms._frame._key.toString();
          predictions =
              ((Model.GLRMArchetypes) parms._model)
                  .scoreReconstruction(
                      parms._frame, Key.make(parms._predictions_name), s.reverse_transform);
        }
      } else if (Model.LeafNodeAssignment.class.isAssignableFrom(parms._model.getClass())) {
        assert (s.leaf_node_assignment);
        if (null == parms._predictions_name)
          parms._predictions_name =
              "leaf_node_assignement"
                  + Key.make().toString().substring(0, 5)
                  + "_"
                  + parms._model._key.toString()
                  + "_on_"
                  + parms._frame._key.toString();
        predictions =
            ((Model.LeafNodeAssignment) parms._model)
                .scoreLeafNodeAssignment(parms._frame, Key.make(parms._predictions_name));
      } else
        throw new H2OIllegalArgumentException(
            "Requires a Deep Learning, GLRM, DRF or GBM model.",
            "Model must implement specific methods.");
    }

    ModelMetricsListSchemaV3 mm = this.fetch(version, s);

    // TODO: for now only binary predictors write an MM object.
    // For the others cons one up here to return the predictions frame.
    if (null == mm) mm = new ModelMetricsListSchemaV3();

    mm.predictions_frame = new KeyV3.FrameKeyV3(predictions._key);
    if (parms._leaf_node_assignment) // don't show metrics in leaf node assignments are made
    mm.model_metrics = null;

    if (null == mm.model_metrics || 0 == mm.model_metrics.length) {
      // There was no response in the test set -> cannot make a model_metrics object
    } else {
      mm.model_metrics[0].predictions =
          new FrameV3(predictions, 0, 100); // TODO: Should call schema(version)
    }
    return mm;
  }