public InstanceBased(OpenmlConnector openml, List<Integer> run_ids, Integer task_id) throws Exception { this.run_ids = run_ids; this.openml = openml; this.predictions = new HashMap<Integer, Map<Integer, Map<Integer, Map<Integer, Map<Integer, String>>>>>(); this.runs = new HashMap<Integer, Run>(); this.task_id = task_id; Task currentTask = openml.taskGet(task_id); if (currentTask.getTask_type().equals("Supervised Classification") == false && currentTask.getTask_type().equals("Supervised Data Stream Classification") == false) { // TODO: no string based comp. throw new RuntimeException( "Experimental function, only works with 'Supervised Classification' tasks for now (ttid / 1)"); } DataSetDescription dsd = openml.dataGet(TaskInformation.getSourceData(currentTask).getData_set_id()); dataset = new Instances( new BufferedReader( Input.getURL(openml.getOpenmlFileUrl(dsd.getFile_id(), dsd.getName())))); if (currentTask.getTask_type().equals("Supervised Data Stream Classification")) { // simulate task splits file. int numberOfInstances = dataset.numInstances(); ArrayList<Attribute> attributes = new ArrayList<Attribute>(); List<String> typeValues = new ArrayList<String>(); typeValues.add("TEST"); attributes.add(new Attribute("repeat")); attributes.add(new Attribute("fold")); attributes.add(new Attribute("rowid")); attributes.add(new Attribute("type", typeValues)); // don't need train task_splits = new Instances("task" + task_id + "splits-simulated", attributes, numberOfInstances); for (int i = 0; i < numberOfInstances; ++i) { double[] attValues = {0, 0, i, 0}; task_splits.add(new DenseInstance(1.0, attValues)); } } else { URL taskUrl = new URL(TaskInformation.getEstimationProcedure(currentTask).getData_splits_url()); task_splits = new Instances(new BufferedReader(Input.getURL(taskUrl))); } for (Integer run_id : run_ids) { Run current = this.openml.runGet(run_id); runs.put(run_id, current); Run.Data.File[] outputFiles = current.getOutputFile(); boolean found = false; for (Run.Data.File f : outputFiles) { if (f.getName().equals("predictions")) { found = true; URL predictionsURL = openml.getOpenmlFileUrl(f.getFileId(), f.getName()); Instances runPredictions = new Instances(new BufferedReader(Input.getURL(predictionsURL))); predictions.put(run_id, predictionsToHashMap(runPredictions)); } } if (found == false) { throw new RuntimeException("No prediction files associated with run. Id: " + run_id); } if (task_id != current.getTask_id()) { throw new RuntimeException( "Runs are not of the same task type: Should be: " + this.task_id + "; found " + current.getTask_id() + " (and maybe more)"); } } correct = datasetToHashMap(dataset, TaskInformation.getSourceData(currentTask).getTarget_feature()); }
public void evaluate(int run_id) throws Exception { Conversion.log("OK", "Process Run", "Start processing run: " + run_id); final Map<String, Integer> file_ids = new HashMap<String, Integer>(); final Task task; final DataSetDescription dataset; PredictionEvaluator predictionEvaluator; RunEvaluation runevaluation = new RunEvaluation(run_id); RunTrace trace = null; JSONArray runJson = (JSONArray) apiconnector .freeQuery("SELECT `task_id` FROM `run` WHERE `rid` = " + run_id) .get("data"); JSONArray filesJson = (JSONArray) apiconnector .freeQuery("SELECT `field`,`file_id` FROM `runfile` WHERE `source` = " + run_id) .get("data"); try { int task_id = ((JSONArray) runJson.get(0)).getInt(0); task = apiconnector.taskGet(task_id); Data_set source_data = TaskInformation.getSourceData(task); Estimation_procedure estimationprocedure = null; try { estimationprocedure = TaskInformation.getEstimationProcedure(task); } catch (Exception e) { } Integer dataset_id = source_data.getLabeled_data_set_id() != null ? source_data.getLabeled_data_set_id() : source_data.getData_set_id(); for (int i = 0; i < filesJson.length(); ++i) { String field = ((JSONArray) filesJson.get(i)).getString(0); int file_index = ((JSONArray) filesJson.get(i)).getInt(1); file_ids.put(field, file_index); } if (file_ids.get("description") == null) { runevaluation.setError("Run description file not present. "); File evaluationFile = Conversion.stringToTempFile( xstream.toXML(runevaluation), "run_" + run_id + "evaluations", "xml"); RunEvaluate re = apiconnector.runEvaluate(evaluationFile); Conversion.log("Error", "Process Run", "Run processed, but with error: " + re.getRun_id()); return; } if (file_ids.get("predictions") == null && file_ids.get("subgroups") == null) { // TODO: this is currently true, but later on we might have tasks that do // not require evaluations! runevaluation.setError("Required output files not present (e.g., arff predictions). "); File evaluationFile = Conversion.stringToTempFile( xstream.toXML(runevaluation), "run_" + run_id + "evaluations", "xml"); RunEvaluate re = apiconnector.runEvaluate(evaluationFile); Conversion.log("Error", "Process Run", "Run processed, but with error: " + re.getRun_id()); return; } if (file_ids.get("trace") != null) { trace = traceToXML(file_ids.get("trace"), task_id, run_id); } String description = OpenmlConnector.getStringFromUrl( apiconnector .getOpenmlFileUrl( file_ids.get("description"), "Run_" + run_id + "_description.xml") .toString()); Run run_description = (Run) xstream.fromXML(description); dataset = apiconnector.dataGet(dataset_id); Conversion.log("OK", "Process Run", "Start prediction evaluator. "); // TODO! no string comparisons, do something better String filename_prefix = "Run_" + run_id + "_"; if (task.getTask_type().equals("Supervised Data Stream Classification")) { predictionEvaluator = new EvaluateStreamPredictions( apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()), apiconnector.getOpenmlFileUrl( file_ids.get("predictions"), filename_prefix + "predictions.arff"), source_data.getTarget_feature()); } else if (task.getTask_type().equals("Survival Analysis")) { predictionEvaluator = new EvaluateSurvivalAnalysisPredictions( task, apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()), new URL(estimationprocedure.getData_splits_url()), apiconnector.getOpenmlFileUrl( file_ids.get("predictions"), filename_prefix + "predictions.arff")); } else if (task.getTask_type().equals("Subgroup Discovery")) { predictionEvaluator = new EvaluateSubgroups(); } else { predictionEvaluator = new EvaluateBatchPredictions( task, apiconnector.getOpenmlFileUrl(dataset.getFile_id(), dataset.getName()), new URL(estimationprocedure.getData_splits_url()), apiconnector.getOpenmlFileUrl( file_ids.get("predictions"), filename_prefix + "predictions.arff"), estimationprocedure .getType() .equals(EstimationProcedure.estimationProceduresTxt[6])); } runevaluation.addEvaluationMeasures(predictionEvaluator.getEvaluationScores()); if (run_description.getOutputEvaluation() != null) { Conversion.log( "OK", "Process Run", "Start consistency check with user defined measures. (x " + run_description.getOutputEvaluation().length + ")"); // TODO: This can be done so much faster ... String warningMessage = ""; boolean warningFound = false; for (EvaluationScore recorded : run_description.getOutputEvaluation()) { boolean foundSame = false; // important check: because of legacy (implementation_id), the flow id might be missing if (recorded.getFlow() != null && recorded.getFunction() != null) { for (EvaluationScore calculated : runevaluation.getEvaluation_scores()) { if (recorded.isSame(calculated)) { foundSame = true; if (recorded.sameValue(calculated) == false) { String offByStr = ""; try { double diff = Math.abs( Double.parseDouble(recorded.getValue()) - Double.parseDouble(calculated.getValue())); offByStr = " (off by " + diff + ")"; } catch (NumberFormatException nfe) { } warningMessage += "Inconsistent Evaluation score: " + recorded + offByStr; warningFound = true; } } } if (foundSame == false) { // give the record the correct sample size if (recorded.getSample() != null && recorded.getSample_size() == null) { recorded.setSample_size( predictionEvaluator .getPredictionCounter() .getShadowTypeSize( recorded.getRepeat(), recorded.getFold(), recorded.getSample())); } runevaluation.addEvaluationMeasure(recorded); } } } if (warningFound) runevaluation.setWarning(warningMessage); } else { Conversion.log("OK", "Process Run", "No local evaluation measures to compare to. "); } } catch (Exception e) { e.printStackTrace(); Conversion.log( "Warning", "Process Run", "Unexpected error, will proceed with upload process: " + e.getMessage()); runevaluation.setError(e.getMessage()); } Conversion.log("OK", "Process Run", "Start uploading results ... "); try { String runEvaluation = xstream.toXML(runevaluation); // System.out.println(runEvaluation); File evaluationFile = Conversion.stringToTempFile(runEvaluation, "run_" + run_id + "evaluations", "xml"); RunEvaluate re = apiconnector.runEvaluate(evaluationFile); if (trace != null) { String runTrace = xstream.toXML(trace); // System.out.println(runTrace); File traceFile = Conversion.stringToTempFile(runTrace, "run_" + run_id + "trace", "xml"); apiconnector.runTraceUpload(traceFile); } Conversion.log("OK", "Process Run", "Run processed: " + re.getRun_id()); } catch (Exception e) { Conversion.log("ERROR", "Process Run", "An error occured during API call: " + e.getMessage()); } }