예제 #1
0
  private RunTrace traceToXML(int file_id, int task_id, int run_id) throws Exception {
    RunTrace trace = new RunTrace(run_id);
    URL traceURL = apiconnector.getOpenmlFileUrl(file_id, "Task_" + task_id + "_trace.arff");
    Instances traceDataset = new Instances(new BufferedReader(Input.getURL(traceURL)));
    List<Integer> parameterIndexes = new ArrayList<Integer>();

    if (traceDataset.attribute("repeat") == null
        || traceDataset.attribute("fold") == null
        || traceDataset.attribute("iteration") == null
        || traceDataset.attribute("evaluation") == null
        || traceDataset.attribute("selected") == null) {
      throw new Exception("trace file missing mandatory attributes. ");
    }

    for (int i = 0; i < traceDataset.numAttributes(); ++i) {
      if (traceDataset.attribute(i).name().startsWith("parameter_")) {
        parameterIndexes.add(i);
      }
    }
    if (parameterIndexes.size() == 0) {
      throw new Exception(
          "trace file contains no fields with prefix 'parameter_' (i.e., parameters are not registered). ");
    }
    if (traceDataset.numAttributes() > 6 + parameterIndexes.size()) {
      throw new Exception(
          "trace file contains illegal attributes (only allow for repeat, fold, iteration, evaluation, selected, setup_string and parameter_*). ");
    }

    for (int i = 0; i < traceDataset.numInstances(); ++i) {
      Instance current = traceDataset.get(i);
      Integer repeat = (int) current.value(traceDataset.attribute("repeat").index());
      Integer fold = (int) current.value(traceDataset.attribute("fold").index());
      Integer iteration = (int) current.value(traceDataset.attribute("iteration").index());
      Double evaluation = current.value(traceDataset.attribute("evaluation").index());
      Boolean selected =
          current.stringValue(traceDataset.attribute("selected").index()).equals("true");

      Map<String, String> parameters = new HashMap<String, String>();
      for (int j = 0; j < parameterIndexes.size(); ++j) {
        int attIdx = parameterIndexes.get(j);
        if (traceDataset.attribute(attIdx).isNumeric()) {
          parameters.put(traceDataset.attribute(attIdx).name(), current.value(attIdx) + "");
        } else {
          parameters.put(traceDataset.attribute(attIdx).name(), current.stringValue(attIdx));
        }
      }
      String setup_string = new JSONObject(parameters).toString();

      trace.addIteration(
          new RunTrace.Trace_iteration(
              repeat, fold, iteration, setup_string, evaluation, selected));
    }

    return trace;
  }
예제 #2
0
  /**
   * Parses a given list of options.
   *
   * <pre>
   * -T &lt;task_id&gt;
   *  The OpenML task to run the experiment on. (required)
   * </pre>
   *
   * <pre>
   * -C &lt;class name&gt;
   *  The full class name of the classifier.
   *  eg: weka.classifiers.bayes.NaiveBayes
   * </pre>
   *
   * <!-- options-end -->
   * All options after -- will be passed to the classifier.
   *
   * <p>
   *
   * @param options the list of options as an array of strings
   * @throws Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {

    Integer task_id = Integer.parseInt(Utils.getOption('T', options));
    String classifierName = Utils.getOption('C', options);
    String[] classifierOptions = Utils.partitionOptions(options);

    DefaultListModel<Task> tasks = new DefaultListModel<Task>();
    tasks.add(0, apiconnector.taskGet(task_id));
    setTasks(tasks);

    Classifier[] cArray = new Classifier[1];
    try {
      cArray[0] = (Classifier) Utils.forName(Classifier.class, classifierName, classifierOptions);
    } catch (Exception e) {
      // Try again, this time loading packages first
      weka.core.WekaPackageManager.loadPackages(false);
      cArray[0] = (Classifier) Utils.forName(Classifier.class, classifierName, classifierOptions);
    }
    setPropertyArray(cArray);
  }
예제 #3
0
  public Integer getRunId(boolean random) throws Exception {
    String sql =
        "SELECT `rid`,`start_time`,`processed`,`error` "
            + "FROM `run` WHERE `processed` IS NULL AND `error` IS NULL AND error_message IS NULL "
            + "ORDER BY `start_time` ASC LIMIT 0, 100; ";

    JSONArray runJson = (JSONArray) apiconnector.freeQuery(sql).get("data");

    int randomint = 0;

    if (random) {
      Random r = new Random();
      randomint = Math.abs(r.nextInt());
    }
    if (runJson.length() > 0) {
      int run_id = ((JSONArray) runJson.get(randomint % runJson.length())).getInt(0);
      return run_id;
    } else {
      return null;
    }
  }
예제 #4
0
  public InstanceBased(OpenmlConnector openml, List<Integer> run_ids, Integer task_id)
      throws Exception {
    this.run_ids = run_ids;
    this.openml = openml;

    this.predictions =
        new HashMap<Integer, Map<Integer, Map<Integer, Map<Integer, Map<Integer, String>>>>>();
    this.runs = new HashMap<Integer, Run>();

    this.task_id = task_id;
    Task currentTask = openml.taskGet(task_id);

    if (currentTask.getTask_type().equals("Supervised Classification") == false
        && currentTask.getTask_type().equals("Supervised Data Stream Classification")
            == false) { // TODO: no string based comp.
      throw new RuntimeException(
          "Experimental function, only works with 'Supervised Classification' tasks for now (ttid / 1)");
    }

    DataSetDescription dsd =
        openml.dataGet(TaskInformation.getSourceData(currentTask).getData_set_id());
    dataset =
        new Instances(
            new BufferedReader(
                Input.getURL(openml.getOpenmlFileUrl(dsd.getFile_id(), dsd.getName()))));

    if (currentTask.getTask_type().equals("Supervised Data Stream Classification")) {
      // simulate task splits file.

      int numberOfInstances = dataset.numInstances();

      ArrayList<Attribute> attributes = new ArrayList<Attribute>();
      List<String> typeValues = new ArrayList<String>();
      typeValues.add("TEST");
      attributes.add(new Attribute("repeat"));
      attributes.add(new Attribute("fold"));
      attributes.add(new Attribute("rowid"));
      attributes.add(new Attribute("type", typeValues)); // don't need train

      task_splits =
          new Instances("task" + task_id + "splits-simulated", attributes, numberOfInstances);

      for (int i = 0; i < numberOfInstances; ++i) {
        double[] attValues = {0, 0, i, 0};
        task_splits.add(new DenseInstance(1.0, attValues));
      }
    } else {
      URL taskUrl =
          new URL(TaskInformation.getEstimationProcedure(currentTask).getData_splits_url());
      task_splits = new Instances(new BufferedReader(Input.getURL(taskUrl)));
    }

    for (Integer run_id : run_ids) {
      Run current = this.openml.runGet(run_id);
      runs.put(run_id, current);
      Run.Data.File[] outputFiles = current.getOutputFile();

      boolean found = false;
      for (Run.Data.File f : outputFiles) {
        if (f.getName().equals("predictions")) {
          found = true;
          URL predictionsURL = openml.getOpenmlFileUrl(f.getFileId(), f.getName());
          Instances runPredictions =
              new Instances(new BufferedReader(Input.getURL(predictionsURL)));
          predictions.put(run_id, predictionsToHashMap(runPredictions));
        }
      }

      if (found == false) {
        throw new RuntimeException("No prediction files associated with run. Id: " + run_id);
      }
      if (task_id != current.getTask_id()) {
        throw new RuntimeException(
            "Runs are not of the same task type: Should be: "
                + this.task_id
                + "; found "
                + current.getTask_id()
                + " (and maybe more)");
      }
    }

    correct =
        datasetToHashMap(dataset, TaskInformation.getSourceData(currentTask).getTarget_feature());
  }
예제 #5
0
  public void evaluate(int run_id) throws Exception {
    Conversion.log("OK", "Process Run", "Start processing run: " + run_id);
    final Map<String, Integer> file_ids = new HashMap<String, Integer>();
    final Task task;
    final DataSetDescription dataset;

    PredictionEvaluator predictionEvaluator;
    RunEvaluation runevaluation = new RunEvaluation(run_id);
    RunTrace trace = null;

    JSONArray runJson =
        (JSONArray)
            apiconnector
                .freeQuery("SELECT `task_id` FROM `run` WHERE `rid` = " + run_id)
                .get("data");
    JSONArray filesJson =
        (JSONArray)
            apiconnector
                .freeQuery("SELECT `field`,`file_id` FROM `runfile` WHERE `source` = " + run_id)
                .get("data");

    try {
      int task_id = ((JSONArray) runJson.get(0)).getInt(0);
      task = apiconnector.taskGet(task_id);
      Data_set source_data = TaskInformation.getSourceData(task);
      Estimation_procedure estimationprocedure = null;
      try {
        estimationprocedure = TaskInformation.getEstimationProcedure(task);
      } catch (Exception e) {
      }
      Integer dataset_id =
          source_data.getLabeled_data_set_id() != null
              ? source_data.getLabeled_data_set_id()
              : source_data.getData_set_id();

      for (int i = 0; i < filesJson.length(); ++i) {
        String field = ((JSONArray) filesJson.get(i)).getString(0);
        int file_index = ((JSONArray) filesJson.get(i)).getInt(1);

        file_ids.put(field, file_index);
      }

      if (file_ids.get("description") == null) {
        runevaluation.setError("Run description file not present. ");
        File evaluationFile =
            Conversion.stringToTempFile(
                xstream.toXML(runevaluation), "run_" + run_id + "evaluations", "xml");

        RunEvaluate re = apiconnector.runEvaluate(evaluationFile);
        Conversion.log("Error", "Process Run", "Run processed, but with error: " + re.getRun_id());
        return;
      }

      if (file_ids.get("predictions") == null
          && file_ids.get("subgroups")
              == null) { // TODO: this is currently true, but later on we might have tasks that do
                         // not require evaluations!
        runevaluation.setError("Required output files not present (e.g., arff predictions). ");
        File evaluationFile =
            Conversion.stringToTempFile(
                xstream.toXML(runevaluation), "run_" + run_id + "evaluations", "xml");

        RunEvaluate re = apiconnector.runEvaluate(evaluationFile);
        Conversion.log("Error", "Process Run", "Run processed, but with error: " + re.getRun_id());
        return;
      }

      if (file_ids.get("trace") != null) {
        trace = traceToXML(file_ids.get("trace"), task_id, run_id);
      }

      String description =
          OpenmlConnector.getStringFromUrl(
              apiconnector
                  .getOpenmlFileUrl(
                      file_ids.get("description"), "Run_" + run_id + "_description.xml")
                  .toString());
      Run run_description = (Run) xstream.fromXML(description);
      dataset = apiconnector.dataGet(dataset_id);

      Conversion.log("OK", "Process Run", "Start prediction evaluator. ");
      // TODO! no string comparisons, do something better
      String filename_prefix = "Run_" + run_id + "_";
      if (task.getTask_type().equals("Supervised Data Stream Classification")) {
        predictionEvaluator =
            new EvaluateStreamPredictions(
                apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()),
                apiconnector.getOpenmlFileUrl(
                    file_ids.get("predictions"), filename_prefix + "predictions.arff"),
                source_data.getTarget_feature());
      } else if (task.getTask_type().equals("Survival Analysis")) {
        predictionEvaluator =
            new EvaluateSurvivalAnalysisPredictions(
                task,
                apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()),
                new URL(estimationprocedure.getData_splits_url()),
                apiconnector.getOpenmlFileUrl(
                    file_ids.get("predictions"), filename_prefix + "predictions.arff"));
      } else if (task.getTask_type().equals("Subgroup Discovery")) {
        predictionEvaluator = new EvaluateSubgroups();

      } else {
        predictionEvaluator =
            new EvaluateBatchPredictions(
                task,
                apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()),
                new URL(estimationprocedure.getData_splits_url()),
                apiconnector.getOpenmlFileUrl(
                    file_ids.get("predictions"), filename_prefix + "predictions.arff"),
                estimationprocedure
                    .getType()
                    .equals(EstimationProcedure.estimationProceduresTxt[6]));
      }
      runevaluation.addEvaluationMeasures(predictionEvaluator.getEvaluationScores());

      if (run_description.getOutputEvaluation() != null) {
        Conversion.log(
            "OK",
            "Process Run",
            "Start consistency check with user defined measures. (x "
                + run_description.getOutputEvaluation().length
                + ")");

        // TODO: This can be done so much faster ...
        String warningMessage = "";
        boolean warningFound = false;

        for (EvaluationScore recorded : run_description.getOutputEvaluation()) {
          boolean foundSame = false;

          // important check: because of legacy (implementation_id), the flow id might be missing
          if (recorded.getFlow() != null && recorded.getFunction() != null) {
            for (EvaluationScore calculated : runevaluation.getEvaluation_scores()) {
              if (recorded.isSame(calculated)) {
                foundSame = true;
                if (recorded.sameValue(calculated) == false) {
                  String offByStr = "";
                  try {
                    double diff =
                        Math.abs(
                            Double.parseDouble(recorded.getValue())
                                - Double.parseDouble(calculated.getValue()));
                    offByStr = " (off by " + diff + ")";
                  } catch (NumberFormatException nfe) {
                  }

                  warningMessage += "Inconsistent Evaluation score: " + recorded + offByStr;
                  warningFound = true;
                }
              }
            }
            if (foundSame == false) {
              // give the record the correct sample size
              if (recorded.getSample() != null && recorded.getSample_size() == null) {
                recorded.setSample_size(
                    predictionEvaluator
                        .getPredictionCounter()
                        .getShadowTypeSize(
                            recorded.getRepeat(), recorded.getFold(), recorded.getSample()));
              }
              runevaluation.addEvaluationMeasure(recorded);
            }
          }
        }
        if (warningFound) runevaluation.setWarning(warningMessage);
      } else {
        Conversion.log("OK", "Process Run", "No local evaluation measures to compare to. ");
      }
    } catch (Exception e) {
      e.printStackTrace();
      Conversion.log(
          "Warning",
          "Process Run",
          "Unexpected error, will proceed with upload process: " + e.getMessage());
      runevaluation.setError(e.getMessage());
    }

    Conversion.log("OK", "Process Run", "Start uploading results ... ");
    try {
      String runEvaluation = xstream.toXML(runevaluation);
      // System.out.println(runEvaluation);
      File evaluationFile =
          Conversion.stringToTempFile(runEvaluation, "run_" + run_id + "evaluations", "xml");
      RunEvaluate re = apiconnector.runEvaluate(evaluationFile);

      if (trace != null) {
        String runTrace = xstream.toXML(trace);
        // System.out.println(runTrace);
        File traceFile = Conversion.stringToTempFile(runTrace, "run_" + run_id + "trace", "xml");

        apiconnector.runTraceUpload(traceFile);
      }

      Conversion.log("OK", "Process Run", "Run processed: " + re.getRun_id());
    } catch (Exception e) {
      Conversion.log("ERROR", "Process Run", "An error occured during API call: " + e.getMessage());
    }
  }
예제 #6
0
  public int calculateDifference() {
    if (run_ids.size() != 2) {
      throw new RuntimeException("Too many runs to compare. Should be 2. ");
    }

    List<String> values = new ArrayList<String>();
    for (Integer run : run_ids) {
      values.add(run + "");
    }
    values.add("none");

    ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    attributes.add(new Attribute("repeat"));
    attributes.add(new Attribute("fold"));
    attributes.add(new Attribute("rowid"));
    attributes.add(new Attribute("whichCorrect", values));

    resultSet = new Instances("difference", attributes, task_splits.numInstances());

    for (int i = 0; i < task_splits.numInstances(); ++i) {
      Instance current = task_splits.get(i);
      boolean test = current.stringValue(task_splits.attribute("type")).equals("TEST");
      if (!test) {
        continue;
      }

      Integer row_id = (int) current.value(task_splits.attribute("rowid"));
      Integer repeat = (int) current.value(task_splits.attribute("repeat"));
      Integer fold = (int) current.value(task_splits.attribute("fold"));
      Integer sample = 0;
      try {
        sample = (int) current.value(task_splits.attribute("sample"));
      } catch (Exception e) {
      }

      String label = null;
      boolean difference = false;
      String correctLabel = correct.get(row_id);
      double whichCorrect = resultSet.attribute("whichCorrect").indexOfValue("none");

      for (Integer run_id : run_ids) {
        String currentLabel = predictions.get(run_id).get(repeat).get(fold).get(sample).get(row_id);
        // check for difference
        if (label == null) {
          label = currentLabel;
        } else if (label.equals(currentLabel) == false) {
          difference = true;
        }

        // check for correct label
        if (currentLabel.equals(correctLabel)) {
          whichCorrect = resultSet.attribute("whichCorrect").indexOfValue(run_id + "");
        }
      }

      if (difference) {
        double[] instance = {repeat, fold, row_id, whichCorrect};
        resultSet.add(new DenseInstance(1.0, instance));
      }
    }

    try { // put it in try catch, as admin rights are required.
      openml.setupDifferences(
          setup_ids.get(0), setup_ids.get(1), task_id, task_splits_size, resultSet.size());
    } catch (Exception e) {
    }

    return resultSet.size();
  }
예제 #7
0
  @Override
  public void nextIteration() throws Exception {

    if (m_CurrentTask == null) {
      m_CurrentTask = (Task) getTasks().elementAt(m_DatasetNumber);

      ((TaskResultProducer) m_ResultProducer).setTask(m_CurrentTask);
      this.setRunUpper(TaskInformation.getNumberOfRepeats(m_CurrentTask));

      // set classifier. Important, since by alternating between
      // regression and
      // classification tasks we possibly have resetted the splitevaluator

      System.err.println(
          ((TaskResultProducer) m_ResultProducer).getSplitEvaluator().getClass().toString());

      if (m_UsePropertyIterator) {
        setProperty(0, m_ResultProducer);
        m_CurrentProperty = m_PropertyNumber;
      }
    }

    if (openmlconfig.getAvoidDuplicateRuns()) {
      String classifierName =
          (String) ((TaskResultProducer) m_ResultProducer).getSplitEvaluatorKey(0);
      String classifierOptions =
          (String) ((TaskResultProducer) m_ResultProducer).getSplitEvaluatorKey(1);

      Integer setupId = WekaAlgorithm.getSetupId(classifierName, classifierOptions, apiconnector);

      if (setupId != null) {
        List<Integer> taskIds = new ArrayList<Integer>();
        taskIds.add(m_CurrentTask.getTask_id());
        List<Integer> setupIds = new ArrayList<Integer>();
        setupIds.add(setupId);

        try {
          RunList rl = apiconnector.runList(taskIds, setupIds, null, null);

          if (rl.getRuns().length > 0) {
            List<Integer> runIds = new ArrayList<Integer>();
            for (Run r : rl.getRuns()) {
              runIds.add(r.getRun_id());
            }

            Conversion.log(
                "INFO",
                "Skip",
                "Skipping run "
                    + classifierName
                    + " (setup #"
                    + setupId
                    + ") repeat "
                    + m_RunNumber
                    + ", already available. Run ids: "
                    + runIds);
            advanceCounters();
            return;
          }
        } catch (Exception e) {
        }
      }
    }

    m_ResultProducer.doRun(m_RunNumber);

    // before advancing the counters
    // check if we want to built a model over the full dataset.
    if (m_RunNumber == getRunUpper()) {
      if (openmlconfig.getModelFullDataset()) {
        ((TaskResultProducer) m_ResultProducer).doFullRun();
      }
    }

    advanceCounters();
  }