/** * Score a frame with the given model and return just the metrics. * * <p>NOTE: ModelMetrics are now always being created by model.score. . . */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelMetricsListSchemaV3 score(int version, ModelMetricsListSchemaV3 s) { // parameters checking: if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model); if (null == DKV.get(s.model.name)) throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name); if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame); if (null == DKV.get(s.frame.name)) throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name); ModelMetricsList parms = s.createAndFillImpl(); parms ._model .score(parms._frame, parms._predictions_name) .remove(); // throw away predictions, keep metrics as a side-effect ModelMetricsListSchemaV3 mm = this.fetch(version, s); // TODO: for now only binary predictors write an MM object. // For the others cons one up here to return the predictions frame. if (null == mm) mm = new ModelMetricsListSchemaV3(); if (null == mm.model_metrics || 0 == mm.model_metrics.length) { Log.warn( "Score() did not return a ModelMetrics for model: " + s.model + " on frame: " + s.frame); } return mm; }
/** Score a frame with the given model and return the metrics AND the prediction frame. */ @SuppressWarnings("unused") // called through reflection by RequestServer public JobV3 predict2(int version, final ModelMetricsListSchemaV3 s) { // parameters checking: if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model); if (null == DKV.get(s.model.name)) throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name); if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame); if (null == DKV.get(s.frame.name)) throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name); final ModelMetricsList parms = s.createAndFillImpl(); // predict2 does not return modelmetrics, so cannot handle deeplearning: reconstruction_error // (anomaly) or GLRM: reconstruct and archetypes // predict2 can handle deeplearning: deepfeatures and predict if (s.deep_features_hidden_layer > 0) { if (null == parms._predictions_name) parms._predictions_name = "deep_features" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString(); } else if (null == parms._predictions_name) parms._predictions_name = "predictions" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString(); final Job<Frame> j = new Job(Key.make(parms._predictions_name), Frame.class.getName(), "prediction"); H2O.H2OCountedCompleter work = new H2O.H2OCountedCompleter() { @Override public void compute2() { if (s.deep_features_hidden_layer < 0) { parms._model.score(parms._frame, parms._predictions_name, j); } else { Frame predictions = ((Model.DeepFeatures) parms._model) .scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer, j); predictions = new Frame( Key.make(parms._predictions_name), predictions.names(), predictions.vecs()); DKV.put(predictions._key, predictions); } tryComplete(); } }; j.start(work, parms._frame.anyVec().nChunks()); return new JobV3().fillFromImpl(j); }
/** Score a frame with the given model and return the metrics AND the prediction frame. */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelMetricsListSchemaV3 predict(int version, ModelMetricsListSchemaV3 s) { // parameters checking: if (null == s.model) throw new H2OIllegalArgumentException("model", "predict", s.model); if (null == DKV.get(s.model.name)) throw new H2OKeyNotFoundArgumentException("model", "predict", s.model.name); if (null == s.frame) throw new H2OIllegalArgumentException("frame", "predict", s.frame); if (null == DKV.get(s.frame.name)) throw new H2OKeyNotFoundArgumentException("frame", "predict", s.frame.name); ModelMetricsList parms = s.createAndFillImpl(); Frame predictions; if (!s.reconstruction_error && !s.reconstruction_error_per_feature && s.deep_features_hidden_layer < 0 && !s.project_archetypes && !s.reconstruct_train && !s.leaf_node_assignment) { if (null == parms._predictions_name) parms._predictions_name = "predictions" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString(); predictions = parms._model.score(parms._frame, parms._predictions_name); } else { if (Model.DeepFeatures.class.isAssignableFrom(parms._model.getClass())) { if (s.reconstruction_error || s.reconstruction_error_per_feature) { if (s.deep_features_hidden_layer >= 0) throw new H2OIllegalArgumentException( "Can only compute either reconstruction error OR deep features.", ""); if (null == parms._predictions_name) parms._predictions_name = "reconstruction_error" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString(); predictions = ((Model.DeepFeatures) parms._model) .scoreAutoEncoder( parms._frame, Key.make(parms._predictions_name), parms._reconstruction_error_per_feature); } else { if (s.deep_features_hidden_layer < 0) throw new H2OIllegalArgumentException( "Deep features hidden layer index must be >= 0.", ""); if (null == parms._predictions_name) parms._predictions_name = "deep_features" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString(); predictions = ((Model.DeepFeatures) parms._model) .scoreDeepFeatures(parms._frame, s.deep_features_hidden_layer); } predictions = new Frame(Key.make(parms._predictions_name), predictions.names(), predictions.vecs()); DKV.put(predictions._key, predictions); } else if (Model.GLRMArchetypes.class.isAssignableFrom(parms._model.getClass())) { if (s.project_archetypes) { if (null == parms._predictions_name) parms._predictions_name = "reconstructed_archetypes_" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_of_" + parms._frame._key.toString(); predictions = ((Model.GLRMArchetypes) parms._model) .scoreArchetypes( parms._frame, Key.make(parms._predictions_name), s.reverse_transform); } else { assert s.reconstruct_train; if (null == parms._predictions_name) parms._predictions_name = "reconstruction_" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_of_" + parms._frame._key.toString(); predictions = ((Model.GLRMArchetypes) parms._model) .scoreReconstruction( parms._frame, Key.make(parms._predictions_name), s.reverse_transform); } } else if (Model.LeafNodeAssignment.class.isAssignableFrom(parms._model.getClass())) { assert (s.leaf_node_assignment); if (null == parms._predictions_name) parms._predictions_name = "leaf_node_assignement" + Key.make().toString().substring(0, 5) + "_" + parms._model._key.toString() + "_on_" + parms._frame._key.toString(); predictions = ((Model.LeafNodeAssignment) parms._model) .scoreLeafNodeAssignment(parms._frame, Key.make(parms._predictions_name)); } else throw new H2OIllegalArgumentException( "Requires a Deep Learning, GLRM, DRF or GBM model.", "Model must implement specific methods."); } ModelMetricsListSchemaV3 mm = this.fetch(version, s); // TODO: for now only binary predictors write an MM object. // For the others cons one up here to return the predictions frame. if (null == mm) mm = new ModelMetricsListSchemaV3(); mm.predictions_frame = new KeyV3.FrameKeyV3(predictions._key); if (parms._leaf_node_assignment) // don't show metrics in leaf node assignments are made mm.model_metrics = null; if (null == mm.model_metrics || 0 == mm.model_metrics.length) { // There was no response in the test set -> cannot make a model_metrics object } else { mm.model_metrics[0].predictions = new FrameV3(predictions, 0, 100); // TODO: Should call schema(version) } return mm; }
/** Delete one or more ModelMetrics. */ @SuppressWarnings("unused") // called through reflection by RequestServer public ModelMetricsListSchemaV3 delete(int version, ModelMetricsListSchemaV3 s) { ModelMetricsList m = s.createAndFillImpl(); s.fillFromImpl(m.delete()); return s; }