public Model learn(ExampleSet exampleSet) throws OperatorException { double value = 0.0; double[] confidences = null; int method = getParameterAsInt(PARAMETER_METHOD); Attribute label = exampleSet.getAttributes().getLabel(); if ((label.isNominal()) && ((method == MEDIAN) || (method == AVERAGE))) { logWarning( "Cannot use method '" + METHODS[method] + "' for nominal labels: changing to 'mode'!"); method = MODE; } else if ((!label.isNominal()) && (method == MODE)) { logWarning( "Cannot use method '" + METHODS[method] + "' for numerical labels: changing to 'average'!"); method = AVERAGE; } switch (method) { case MEDIAN: double[] labels = new double[exampleSet.size()]; Iterator<Example> r = exampleSet.iterator(); int counter = 0; while (r.hasNext()) { Example example = r.next(); labels[counter++] = example.getValue(example.getAttributes().getLabel()); } java.util.Arrays.sort(labels); value = labels[exampleSet.size() / 2]; break; case AVERAGE: exampleSet.recalculateAttributeStatistics(label); value = exampleSet.getStatistics(label, Statistics.AVERAGE); break; case MODE: exampleSet.recalculateAttributeStatistics(label); value = exampleSet.getStatistics(label, Statistics.MODE); confidences = new double[label.getMapping().size()]; for (int i = 0; i < confidences.length; i++) { confidences[i] = exampleSet.getStatistics(label, Statistics.COUNT, label.getMapping().mapIndex(i)) / exampleSet.size(); } break; case CONSTANT: value = getParameterAsDouble(PARAMETER_CONSTANT); break; case ATTRIBUTE: return new AttributeDefaultModel( exampleSet, getParameterAsString(PARAMETER_ATTRIBUTE_NAME)); default: // cannot happen throw new OperatorException("DefaultLearner: Unknown default method '" + method + "'!"); } log( "Default value is '" + (label.isNominal() ? label.getMapping().mapIndex((int) value) : value + "") + "'."); return new DefaultModel(exampleSet, value, confidences); }
private RuleModel createNumericalRuleModel(ExampleSet trainingSet, Attribute attribute) { RuleModel model = new RuleModel(trainingSet); // split by best attribute int oldSize = -1; while ((trainingSet.size() > 0) && (trainingSet.size() != oldSize)) { ExampleSet exampleSet = (ExampleSet) trainingSet.clone(); Split bestSplit = splitter.getBestSplit(exampleSet, attribute, null); double bestSplitValue = bestSplit.getSplitPoint(); if (!Double.isNaN(bestSplitValue)) { SplittedExampleSet splittedSet = SplittedExampleSet.splitByAttribute(exampleSet, attribute, bestSplitValue); Attribute label = splittedSet.getAttributes().getLabel(); splittedSet.selectSingleSubset(0); SplitCondition condition = new LessEqualsSplitCondition(attribute, bestSplitValue); splittedSet.recalculateAttributeStatistics(label); int labelValue = (int) splittedSet.getStatistics(label, Statistics.MODE); String labelName = label.getMapping().mapIndex(labelValue); Rule rule = new Rule(labelName, condition); int[] frequencies = new int[label.getMapping().size()]; int counter = 0; for (String value : label.getMapping().getValues()) frequencies[counter++] = (int) splittedSet.getStatistics(label, Statistics.COUNT, value); rule.setFrequencies(frequencies); model.addRule(rule); oldSize = trainingSet.size(); trainingSet = rule.removeCovered(trainingSet); } else { break; } } // add default rule if some examples were not yet covered if (trainingSet.size() > 0) { Attribute label = trainingSet.getAttributes().getLabel(); trainingSet.recalculateAttributeStatistics(label); int index = (int) trainingSet.getStatistics(label, Statistics.MODE); String defaultLabel = label.getMapping().mapIndex(index); Rule defaultRule = new Rule(defaultLabel); int[] frequencies = new int[label.getMapping().size()]; int counter = 0; for (String value : label.getMapping().getValues()) frequencies[counter++] = (int) (trainingSet.getStatistics(label, Statistics.COUNT, value)); defaultRule.setFrequencies(frequencies); model.addRule(defaultRule); } return model; }
@Override public Model learn(ExampleSet exampleSet) throws OperatorException { Attribute label = exampleSet.getAttributes().getLabel(); RuleModel ruleModel = new RuleModel(exampleSet); double pureness = getParameterAsDouble(PARAMETER_PURENESS); TermDetermination termDetermination = new TermDetermination(new AccuracyCriterion(), 0.5d); ExampleSet trainingSet = (ExampleSet) exampleSet.clone(); for (String labelName : label.getMapping().getValues()) { trainingSet.recalculateAttributeStatistics(label); int oldSize = -1; while (trainingSet.size() > 0 && trainingSet.size() != oldSize && trainingSet.getStatistics(label, Statistics.COUNT, labelName) > 0) { Rule rule = new Rule(labelName); ExampleSet oldTrainingSet = (ExampleSet) trainingSet.clone(); // grow rule int growOldSize = -1; ExampleSet growSet = (ExampleSet) trainingSet.clone(); while (growSet.size() > 0 && growSet.size() != growOldSize && !rule.isPure(growSet, pureness) && growSet.getAttributes().size() > 0) { SplitCondition term = termDetermination.getBestTerm(growSet, labelName); if (term == null) { break; } rule.addTerm(term); Attribute splitAttribute = growSet.getAttributes().get(term.getAttributeName()); growSet.getAttributes().remove(splitAttribute); growOldSize = growSet.size(); growSet = rule.getCovered(growSet); } // add rule if not empty if (rule.getTerms().size() > 0) { growSet = rule.getCovered(trainingSet); growSet.recalculateAttributeStatistics(label); int[] frequencies = new int[label.getMapping().size()]; int counter = 0; for (String value : label.getMapping().getValues()) { frequencies[counter++] = (int) growSet.getStatistics(label, Statistics.COUNT, value); } rule.setFrequencies(frequencies); ruleModel.addRule(rule); oldSize = trainingSet.size(); trainingSet = rule.removeCovered(oldTrainingSet); } else { break; // no other terms found for this class --> next class } trainingSet.recalculateAttributeStatistics(label); } checkForStop(); } // training set not empty? add default rule if (trainingSet.size() > 0) { trainingSet.recalculateAttributeStatistics(label); int index = (int) trainingSet.getStatistics(label, Statistics.MODE); String defaultLabel = label.getMapping().mapIndex(index); Rule defaultRule = new Rule(defaultLabel); int[] frequencies = new int[label.getMapping().size()]; int counter = 0; for (String value : label.getMapping().getValues()) { frequencies[counter++] = (int) trainingSet.getStatistics(label, Statistics.COUNT, value); } defaultRule.setFrequencies(frequencies); ruleModel.addRule(defaultRule); } return ruleModel; }
@Override public void doWork() throws OperatorException { // sanity checks ExampleSet exampleSet = exampleSetInput.getData(ExampleSet.class); // checking preconditions Attribute label = exampleSet.getAttributes().getLabel(); if (label == null) { throw new UserError(this, 105); } if (!label.isNominal()) { throw new UserError(this, 101, label, "threshold finding"); } exampleSet.recalculateAttributeStatistics(label); NominalMapping mapping = label.getMapping(); if (mapping.size() != 2) { throw new UserError( this, 118, new Object[] {label, Integer.valueOf(mapping.getValues().size()), Integer.valueOf(2)}); } if (exampleSet.getAttributes().getPredictedLabel() == null) { throw new UserError(this, 107); } boolean useExplictLabels = getParameterAsBoolean(PARAMETER_DEFINE_LABELS); double secondCost = getParameterAsDouble(PARAMETER_MISCLASSIFICATION_COSTS_SECOND); double firstCost = getParameterAsDouble(PARAMETER_MISCLASSIFICATION_COSTS_FIRST); if (useExplictLabels) { String firstLabel = getParameterAsString(PARAMETER_FIRST_LABEL); String secondLabel = getParameterAsString(PARAMETER_SECOND_LABEL); if (mapping.getIndex(firstLabel) == -1) { throw new UserError(this, 143, firstLabel, label.getName()); } if (mapping.getIndex(secondLabel) == -1) { throw new UserError(this, 143, secondLabel, label.getName()); } // if explicit order differs from order in data: internally swap costs. if (mapping.getIndex(firstLabel) > mapping.getIndex(secondLabel)) { double temp = firstCost; firstCost = secondCost; secondCost = temp; } } // check whether the confidence attributes are available if (exampleSet.getAttributes().getConfidence(mapping.getPositiveString()) == null) { throw new UserError( this, 113, Attributes.CONFIDENCE_NAME + "_" + mapping.getPositiveString()); } if (exampleSet.getAttributes().getConfidence(mapping.getNegativeString()) == null) { throw new UserError( this, 113, Attributes.CONFIDENCE_NAME + "_" + mapping.getNegativeString()); } // create ROC data ROCDataGenerator rocDataGenerator = new ROCDataGenerator(firstCost, secondCost); ROCData rocData = rocDataGenerator.createROCData( exampleSet, getParameterAsBoolean(PARAMETER_USE_EXAMPLE_WEIGHTS), ROCBias.getROCBiasParameter(this)); // create plotter if (getParameterAsBoolean(PARAMETER_SHOW_ROC_PLOT)) { rocDataGenerator.createROCPlotDialog(rocData, true, true); } // create and return output exampleSetOutput.deliver(exampleSet); thresholdOutput.deliver( new Threshold( rocDataGenerator.getBestThreshold(), mapping.getNegativeString(), mapping.getPositiveString())); }