@Test public void prune_givenValidationDataSameAsTrainingData_expect100Percent() throws FileNotFoundException, UnsupportedEncodingException { Model model = data.getMushroomModel(); List<Sample> trainingSet = data.getAllMushroomSamples(); List<Prediction> originalPrediction = predictor.predict(model, trainingSet); Double originalAccuracy = accuracy.evaluate(originalPrediction, model.getTargetAttribute()); List<Rule> prunedRules = pruner.pruneRepeatedly(model, trainingSet); List<Prediction> prunedPrediction = predictor.predict(prunedRules, trainingSet); Double prunedAccuracy = accuracy.evaluate(prunedPrediction, model.getTargetAttribute()); log.debug( "Accuracy before pruning = {}, Accuracy after pruning = {}", originalAccuracy, prunedAccuracy); assertThat(originalAccuracy, is(prunedAccuracy)); }
@Test public void prune_given25PercentValidationSet_expectNoDecreaseInAccuracy() throws FileNotFoundException, UnsupportedEncodingException { data.loadData(25.0); Model model = data.getMushroomModel(); List<Sample> validationSet = data.getValidationSet(); List<Prediction> originalPrediction = predictor.predict(model, validationSet); Double originalAccuracy = accuracy.evaluate(originalPrediction, model.getTargetAttribute()); List<Rule> prunedRules = pruner.pruneRepeatedly(model, validationSet); List<Prediction> prunedPrediction = predictor.predict(prunedRules, validationSet); Double prunedAccuracy = accuracy.evaluate(prunedPrediction, model.getTargetAttribute()); log.debug( "Accuracy before pruning = {}%, Accuracy after pruning = {}%", originalAccuracy, prunedAccuracy); assertThat(prunedAccuracy, Matchers.greaterThanOrEqualTo(originalAccuracy)); }