예제 #1
0
  public IOObject[] apply() throws OperatorException {
    ExampleSet exampleSet = getInput(ExampleSet.class);
    if (exampleSet.getAttributes().getLabel() == null) {
      throw new UserError(this, 105);
    }
    if (!exampleSet.getAttributes().getLabel().isNominal()) {
      throw new UserError(this, 101, "ROC Charts", exampleSet.getAttributes().getLabel());
    }
    if (exampleSet.getAttributes().getLabel().getMapping().getValues().size() != 2) {
      throw new UserError(this, 114, "ROC Charts", exampleSet.getAttributes().getLabel());
    }

    if (exampleSet.getAttributes().getPredictedLabel() != null
        && getParameterAsBoolean(PARAMETER_USE_MODEL)) {
      logWarning("Input example already has a predicted label which will be removed.");
      PredictionModel.removePredictedLabel(exampleSet);
    }
    if (exampleSet.getAttributes().getPredictedLabel() == null
        && !getParameterAsBoolean(PARAMETER_USE_MODEL)) {
      throw new UserError(this, 107);
    }
    Model model = null;
    if (getParameterAsBoolean(PARAMETER_USE_MODEL)) {
      model = getInput(Model.class);
      exampleSet = model.apply(exampleSet);
    }
    if (exampleSet.getAttributes().getPredictedLabel() == null) {
      throw new UserError(this, 107);
    }

    ROCDataGenerator rocDataGenerator = new ROCDataGenerator(1.0d, 1.0d);
    ROCData rocPoints =
        rocDataGenerator.createROCData(
            exampleSet, getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS));
    rocDataGenerator.createROCPlotDialog(rocPoints);

    PredictionModel.removePredictedLabel(exampleSet);
    if (getParameterAsBoolean(PARAMETER_USE_MODEL)) {
      return new IOObject[] {exampleSet, model};
    } else return new IOObject[] {exampleSet};
  }
  @Override
  public void doWork() throws OperatorException {
    // sanity checks
    ExampleSet exampleSet = exampleSetInput.getData(ExampleSet.class);

    // checking preconditions
    Attribute label = exampleSet.getAttributes().getLabel();
    if (label == null) {
      throw new UserError(this, 105);
    }
    if (!label.isNominal()) {
      throw new UserError(this, 101, label, "threshold finding");
    }
    exampleSet.recalculateAttributeStatistics(label);
    NominalMapping mapping = label.getMapping();
    if (mapping.size() != 2) {
      throw new UserError(
          this,
          118,
          new Object[] {label, Integer.valueOf(mapping.getValues().size()), Integer.valueOf(2)});
    }
    if (exampleSet.getAttributes().getPredictedLabel() == null) {
      throw new UserError(this, 107);
    }
    boolean useExplictLabels = getParameterAsBoolean(PARAMETER_DEFINE_LABELS);

    double secondCost = getParameterAsDouble(PARAMETER_MISCLASSIFICATION_COSTS_SECOND);
    double firstCost = getParameterAsDouble(PARAMETER_MISCLASSIFICATION_COSTS_FIRST);
    if (useExplictLabels) {
      String firstLabel = getParameterAsString(PARAMETER_FIRST_LABEL);
      String secondLabel = getParameterAsString(PARAMETER_SECOND_LABEL);

      if (mapping.getIndex(firstLabel) == -1) {
        throw new UserError(this, 143, firstLabel, label.getName());
      }
      if (mapping.getIndex(secondLabel) == -1) {
        throw new UserError(this, 143, secondLabel, label.getName());
      }

      // if explicit order differs from order in data: internally swap costs.
      if (mapping.getIndex(firstLabel) > mapping.getIndex(secondLabel)) {
        double temp = firstCost;
        firstCost = secondCost;
        secondCost = temp;
      }
    }

    // check whether the confidence attributes are available
    if (exampleSet.getAttributes().getConfidence(mapping.getPositiveString()) == null) {
      throw new UserError(
          this, 113, Attributes.CONFIDENCE_NAME + "_" + mapping.getPositiveString());
    }
    if (exampleSet.getAttributes().getConfidence(mapping.getNegativeString()) == null) {
      throw new UserError(
          this, 113, Attributes.CONFIDENCE_NAME + "_" + mapping.getNegativeString());
    }
    // create ROC data
    ROCDataGenerator rocDataGenerator = new ROCDataGenerator(firstCost, secondCost);
    ROCData rocData =
        rocDataGenerator.createROCData(
            exampleSet,
            getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS),
            ROCBias.getROCBiasParameter(this));

    // create plotter
    if (getParameterAsBoolean(PARAMETER_SHOW_ROC_PLOT)) {
      rocDataGenerator.createROCPlotDialog(rocData, true, true);
    }

    // create and return output
    exampleSetOutput.deliver(exampleSet);
    thresholdOutput.deliver(
        new Threshold(
            rocDataGenerator.getBestThreshold(),
            mapping.getNegativeString(),
            mapping.getPositiveString()));
  }