Ejemplo n.º 1
0
  public InstanceBased(OpenmlConnector openml, List<Integer> run_ids, Integer task_id)
      throws Exception {
    this.run_ids = run_ids;
    this.openml = openml;

    this.predictions =
        new HashMap<Integer, Map<Integer, Map<Integer, Map<Integer, Map<Integer, String>>>>>();
    this.runs = new HashMap<Integer, Run>();

    this.task_id = task_id;
    Task currentTask = openml.taskGet(task_id);

    if (currentTask.getTask_type().equals("Supervised Classification") == false
        && currentTask.getTask_type().equals("Supervised Data Stream Classification")
            == false) { // TODO: no string based comp.
      throw new RuntimeException(
          "Experimental function, only works with 'Supervised Classification' tasks for now (ttid / 1)");
    }

    DataSetDescription dsd =
        openml.dataGet(TaskInformation.getSourceData(currentTask).getData_set_id());
    dataset =
        new Instances(
            new BufferedReader(
                Input.getURL(openml.getOpenmlFileUrl(dsd.getFile_id(), dsd.getName()))));

    if (currentTask.getTask_type().equals("Supervised Data Stream Classification")) {
      // simulate task splits file.

      int numberOfInstances = dataset.numInstances();

      ArrayList<Attribute> attributes = new ArrayList<Attribute>();
      List<String> typeValues = new ArrayList<String>();
      typeValues.add("TEST");
      attributes.add(new Attribute("repeat"));
      attributes.add(new Attribute("fold"));
      attributes.add(new Attribute("rowid"));
      attributes.add(new Attribute("type", typeValues)); // don't need train

      task_splits =
          new Instances("task" + task_id + "splits-simulated", attributes, numberOfInstances);

      for (int i = 0; i < numberOfInstances; ++i) {
        double[] attValues = {0, 0, i, 0};
        task_splits.add(new DenseInstance(1.0, attValues));
      }
    } else {
      URL taskUrl =
          new URL(TaskInformation.getEstimationProcedure(currentTask).getData_splits_url());
      task_splits = new Instances(new BufferedReader(Input.getURL(taskUrl)));
    }

    for (Integer run_id : run_ids) {
      Run current = this.openml.runGet(run_id);
      runs.put(run_id, current);
      Run.Data.File[] outputFiles = current.getOutputFile();

      boolean found = false;
      for (Run.Data.File f : outputFiles) {
        if (f.getName().equals("predictions")) {
          found = true;
          URL predictionsURL = openml.getOpenmlFileUrl(f.getFileId(), f.getName());
          Instances runPredictions =
              new Instances(new BufferedReader(Input.getURL(predictionsURL)));
          predictions.put(run_id, predictionsToHashMap(runPredictions));
        }
      }

      if (found == false) {
        throw new RuntimeException("No prediction files associated with run. Id: " + run_id);
      }
      if (task_id != current.getTask_id()) {
        throw new RuntimeException(
            "Runs are not of the same task type: Should be: "
                + this.task_id
                + "; found "
                + current.getTask_id()
                + " (and maybe more)");
      }
    }

    correct =
        datasetToHashMap(dataset, TaskInformation.getSourceData(currentTask).getTarget_feature());
  }