예제 #1
0
  public InstanceBased(OpenmlConnector openml, List<Integer> run_ids, Integer task_id)
      throws Exception {
    this.run_ids = run_ids;
    this.openml = openml;

    this.predictions =
        new HashMap<Integer, Map<Integer, Map<Integer, Map<Integer, Map<Integer, String>>>>>();
    this.runs = new HashMap<Integer, Run>();

    this.task_id = task_id;
    Task currentTask = openml.taskGet(task_id);

    if (currentTask.getTask_type().equals("Supervised Classification") == false
        && currentTask.getTask_type().equals("Supervised Data Stream Classification")
            == false) { // TODO: no string based comp.
      throw new RuntimeException(
          "Experimental function, only works with 'Supervised Classification' tasks for now (ttid / 1)");
    }

    DataSetDescription dsd =
        openml.dataGet(TaskInformation.getSourceData(currentTask).getData_set_id());
    dataset =
        new Instances(
            new BufferedReader(
                Input.getURL(openml.getOpenmlFileUrl(dsd.getFile_id(), dsd.getName()))));

    if (currentTask.getTask_type().equals("Supervised Data Stream Classification")) {
      // simulate task splits file.

      int numberOfInstances = dataset.numInstances();

      ArrayList<Attribute> attributes = new ArrayList<Attribute>();
      List<String> typeValues = new ArrayList<String>();
      typeValues.add("TEST");
      attributes.add(new Attribute("repeat"));
      attributes.add(new Attribute("fold"));
      attributes.add(new Attribute("rowid"));
      attributes.add(new Attribute("type", typeValues)); // don't need train

      task_splits =
          new Instances("task" + task_id + "splits-simulated", attributes, numberOfInstances);

      for (int i = 0; i < numberOfInstances; ++i) {
        double[] attValues = {0, 0, i, 0};
        task_splits.add(new DenseInstance(1.0, attValues));
      }
    } else {
      URL taskUrl =
          new URL(TaskInformation.getEstimationProcedure(currentTask).getData_splits_url());
      task_splits = new Instances(new BufferedReader(Input.getURL(taskUrl)));
    }

    for (Integer run_id : run_ids) {
      Run current = this.openml.runGet(run_id);
      runs.put(run_id, current);
      Run.Data.File[] outputFiles = current.getOutputFile();

      boolean found = false;
      for (Run.Data.File f : outputFiles) {
        if (f.getName().equals("predictions")) {
          found = true;
          URL predictionsURL = openml.getOpenmlFileUrl(f.getFileId(), f.getName());
          Instances runPredictions =
              new Instances(new BufferedReader(Input.getURL(predictionsURL)));
          predictions.put(run_id, predictionsToHashMap(runPredictions));
        }
      }

      if (found == false) {
        throw new RuntimeException("No prediction files associated with run. Id: " + run_id);
      }
      if (task_id != current.getTask_id()) {
        throw new RuntimeException(
            "Runs are not of the same task type: Should be: "
                + this.task_id
                + "; found "
                + current.getTask_id()
                + " (and maybe more)");
      }
    }

    correct =
        datasetToHashMap(dataset, TaskInformation.getSourceData(currentTask).getTarget_feature());
  }
예제 #2
0
  public void evaluate(int run_id) throws Exception {
    Conversion.log("OK", "Process Run", "Start processing run: " + run_id);
    final Map<String, Integer> file_ids = new HashMap<String, Integer>();
    final Task task;
    final DataSetDescription dataset;

    PredictionEvaluator predictionEvaluator;
    RunEvaluation runevaluation = new RunEvaluation(run_id);
    RunTrace trace = null;

    JSONArray runJson =
        (JSONArray)
            apiconnector
                .freeQuery("SELECT `task_id` FROM `run` WHERE `rid` = " + run_id)
                .get("data");
    JSONArray filesJson =
        (JSONArray)
            apiconnector
                .freeQuery("SELECT `field`,`file_id` FROM `runfile` WHERE `source` = " + run_id)
                .get("data");

    try {
      int task_id = ((JSONArray) runJson.get(0)).getInt(0);
      task = apiconnector.taskGet(task_id);
      Data_set source_data = TaskInformation.getSourceData(task);
      Estimation_procedure estimationprocedure = null;
      try {
        estimationprocedure = TaskInformation.getEstimationProcedure(task);
      } catch (Exception e) {
      }
      Integer dataset_id =
          source_data.getLabeled_data_set_id() != null
              ? source_data.getLabeled_data_set_id()
              : source_data.getData_set_id();

      for (int i = 0; i < filesJson.length(); ++i) {
        String field = ((JSONArray) filesJson.get(i)).getString(0);
        int file_index = ((JSONArray) filesJson.get(i)).getInt(1);

        file_ids.put(field, file_index);
      }

      if (file_ids.get("description") == null) {
        runevaluation.setError("Run description file not present. ");
        File evaluationFile =
            Conversion.stringToTempFile(
                xstream.toXML(runevaluation), "run_" + run_id + "evaluations", "xml");

        RunEvaluate re = apiconnector.runEvaluate(evaluationFile);
        Conversion.log("Error", "Process Run", "Run processed, but with error: " + re.getRun_id());
        return;
      }

      if (file_ids.get("predictions") == null
          && file_ids.get("subgroups")
              == null) { // TODO: this is currently true, but later on we might have tasks that do
                         // not require evaluations!
        runevaluation.setError("Required output files not present (e.g., arff predictions). ");
        File evaluationFile =
            Conversion.stringToTempFile(
                xstream.toXML(runevaluation), "run_" + run_id + "evaluations", "xml");

        RunEvaluate re = apiconnector.runEvaluate(evaluationFile);
        Conversion.log("Error", "Process Run", "Run processed, but with error: " + re.getRun_id());
        return;
      }

      if (file_ids.get("trace") != null) {
        trace = traceToXML(file_ids.get("trace"), task_id, run_id);
      }

      String description =
          OpenmlConnector.getStringFromUrl(
              apiconnector
                  .getOpenmlFileUrl(
                      file_ids.get("description"), "Run_" + run_id + "_description.xml")
                  .toString());
      Run run_description = (Run) xstream.fromXML(description);
      dataset = apiconnector.dataGet(dataset_id);

      Conversion.log("OK", "Process Run", "Start prediction evaluator. ");
      // TODO! no string comparisons, do something better
      String filename_prefix = "Run_" + run_id + "_";
      if (task.getTask_type().equals("Supervised Data Stream Classification")) {
        predictionEvaluator =
            new EvaluateStreamPredictions(
                apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()),
                apiconnector.getOpenmlFileUrl(
                    file_ids.get("predictions"), filename_prefix + "predictions.arff"),
                source_data.getTarget_feature());
      } else if (task.getTask_type().equals("Survival Analysis")) {
        predictionEvaluator =
            new EvaluateSurvivalAnalysisPredictions(
                task,
                apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()),
                new URL(estimationprocedure.getData_splits_url()),
                apiconnector.getOpenmlFileUrl(
                    file_ids.get("predictions"), filename_prefix + "predictions.arff"));
      } else if (task.getTask_type().equals("Subgroup Discovery")) {
        predictionEvaluator = new EvaluateSubgroups();

      } else {
        predictionEvaluator =
            new EvaluateBatchPredictions(
                task,
                apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()),
                new URL(estimationprocedure.getData_splits_url()),
                apiconnector.getOpenmlFileUrl(
                    file_ids.get("predictions"), filename_prefix + "predictions.arff"),
                estimationprocedure
                    .getType()
                    .equals(EstimationProcedure.estimationProceduresTxt[6]));
      }
      runevaluation.addEvaluationMeasures(predictionEvaluator.getEvaluationScores());

      if (run_description.getOutputEvaluation() != null) {
        Conversion.log(
            "OK",
            "Process Run",
            "Start consistency check with user defined measures. (x "
                + run_description.getOutputEvaluation().length
                + ")");

        // TODO: This can be done so much faster ...
        String warningMessage = "";
        boolean warningFound = false;

        for (EvaluationScore recorded : run_description.getOutputEvaluation()) {
          boolean foundSame = false;

          // important check: because of legacy (implementation_id), the flow id might be missing
          if (recorded.getFlow() != null && recorded.getFunction() != null) {
            for (EvaluationScore calculated : runevaluation.getEvaluation_scores()) {
              if (recorded.isSame(calculated)) {
                foundSame = true;
                if (recorded.sameValue(calculated) == false) {
                  String offByStr = "";
                  try {
                    double diff =
                        Math.abs(
                            Double.parseDouble(recorded.getValue())
                                - Double.parseDouble(calculated.getValue()));
                    offByStr = " (off by " + diff + ")";
                  } catch (NumberFormatException nfe) {
                  }

                  warningMessage += "Inconsistent Evaluation score: " + recorded + offByStr;
                  warningFound = true;
                }
              }
            }
            if (foundSame == false) {
              // give the record the correct sample size
              if (recorded.getSample() != null && recorded.getSample_size() == null) {
                recorded.setSample_size(
                    predictionEvaluator
                        .getPredictionCounter()
                        .getShadowTypeSize(
                            recorded.getRepeat(), recorded.getFold(), recorded.getSample()));
              }
              runevaluation.addEvaluationMeasure(recorded);
            }
          }
        }
        if (warningFound) runevaluation.setWarning(warningMessage);
      } else {
        Conversion.log("OK", "Process Run", "No local evaluation measures to compare to. ");
      }
    } catch (Exception e) {
      e.printStackTrace();
      Conversion.log(
          "Warning",
          "Process Run",
          "Unexpected error, will proceed with upload process: " + e.getMessage());
      runevaluation.setError(e.getMessage());
    }

    Conversion.log("OK", "Process Run", "Start uploading results ... ");
    try {
      String runEvaluation = xstream.toXML(runevaluation);
      // System.out.println(runEvaluation);
      File evaluationFile =
          Conversion.stringToTempFile(runEvaluation, "run_" + run_id + "evaluations", "xml");
      RunEvaluate re = apiconnector.runEvaluate(evaluationFile);

      if (trace != null) {
        String runTrace = xstream.toXML(trace);
        // System.out.println(runTrace);
        File traceFile = Conversion.stringToTempFile(runTrace, "run_" + run_id + "trace", "xml");

        apiconnector.runTraceUpload(traceFile);
      }

      Conversion.log("OK", "Process Run", "Run processed: " + re.getRun_id());
    } catch (Exception e) {
      Conversion.log("ERROR", "Process Run", "An error occured during API call: " + e.getMessage());
    }
  }