public IOObject[] apply() throws OperatorException { ExampleSet exampleSet = getInput(ExampleSet.class); if (exampleSet.getAttributes().getLabel() == null) { throw new UserError(this, 105); } if (!exampleSet.getAttributes().getLabel().isNominal()) { throw new UserError(this, 101, "ROC Charts", exampleSet.getAttributes().getLabel()); } if (exampleSet.getAttributes().getLabel().getMapping().getValues().size() != 2) { throw new UserError(this, 114, "ROC Charts", exampleSet.getAttributes().getLabel()); } if (exampleSet.getAttributes().getPredictedLabel() != null && getParameterAsBoolean(PARAMETER_USE_MODEL)) { logWarning("Input example already has a predicted label which will be removed."); PredictionModel.removePredictedLabel(exampleSet); } if (exampleSet.getAttributes().getPredictedLabel() == null && !getParameterAsBoolean(PARAMETER_USE_MODEL)) { throw new UserError(this, 107); } Model model = null; if (getParameterAsBoolean(PARAMETER_USE_MODEL)) { model = getInput(Model.class); exampleSet = model.apply(exampleSet); } if (exampleSet.getAttributes().getPredictedLabel() == null) { throw new UserError(this, 107); } ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0d, 1.0d); ROCData rocPoints = rocDataGenerator.createROCData( exampleSet, getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS)); rocDataGenerator.createROCPlotDialog(rocPoints); PredictionModel.removePredictedLabel(exampleSet); if (getParameterAsBoolean(PARAMETER_USE_MODEL)) { return new IOObject[] {exampleSet, model}; } else return new IOObject[] {exampleSet}; }
@Override public void doWork() throws OperatorException { // sanity checks ExampleSet exampleSet = exampleSetInput.getData(ExampleSet.class); // checking preconditions Attribute label = exampleSet.getAttributes().getLabel(); if (label == null) { throw new UserError(this, 105); } if (!label.isNominal()) { throw new UserError(this, 101, label, "threshold finding"); } exampleSet.recalculateAttributeStatistics(label); NominalMapping mapping = label.getMapping(); if (mapping.size() != 2) { throw new UserError( this, 118, new Object[] {label, Integer.valueOf(mapping.getValues().size()), Integer.valueOf(2)}); } if (exampleSet.getAttributes().getPredictedLabel() == null) { throw new UserError(this, 107); } boolean useExplictLabels = getParameterAsBoolean(PARAMETER_DEFINE_LABELS); double secondCost = getParameterAsDouble(PARAMETER_MISCLASSIFICATION_COSTS_SECOND); double firstCost = getParameterAsDouble(PARAMETER_MISCLASSIFICATION_COSTS_FIRST); if (useExplictLabels) { String firstLabel = getParameterAsString(PARAMETER_FIRST_LABEL); String secondLabel = getParameterAsString(PARAMETER_SECOND_LABEL); if (mapping.getIndex(firstLabel) == -1) { throw new UserError(this, 143, firstLabel, label.getName()); } if (mapping.getIndex(secondLabel) == -1) { throw new UserError(this, 143, secondLabel, label.getName()); } // if explicit order differs from order in data: internally swap costs. if (mapping.getIndex(firstLabel) > mapping.getIndex(secondLabel)) { double temp = firstCost; firstCost = secondCost; secondCost = temp; } } // check whether the confidence attributes are available if (exampleSet.getAttributes().getConfidence(mapping.getPositiveString()) == null) { throw new UserError( this, 113, Attributes.CONFIDENCE_NAME + "_" + mapping.getPositiveString()); } if (exampleSet.getAttributes().getConfidence(mapping.getNegativeString()) == null) { throw new UserError( this, 113, Attributes.CONFIDENCE_NAME + "_" + mapping.getNegativeString()); } // create ROC data ROCDataGenerator rocDataGenerator = new ROCDataGenerator(firstCost, secondCost); ROCData rocData = rocDataGenerator.createROCData( exampleSet, getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS), ROCBias.getROCBiasParameter(this)); // create plotter if (getParameterAsBoolean(PARAMETER_SHOW_ROC_PLOT)) { rocDataGenerator.createROCPlotDialog(rocData, true, true); } // create and return output exampleSetOutput.deliver(exampleSet); thresholdOutput.deliver( new Threshold( rocDataGenerator.getBestThreshold(), mapping.getNegativeString(), mapping.getPositiveString())); }