protected static Evaluation adjustForInputMappedClassifier( Evaluation eval, weka.classifiers.Classifier classifier, Instances inst, ClassifierErrorsPlotInstances plotInstances) throws Exception { if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) { Instances mappedClassifierHeader = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getModelHeader(new Instances(inst, 0)); eval = new Evaluation(new Instances(mappedClassifierHeader, 0)); if (!eval.getHeader().equalHeaders(inst)) { // When the InputMappedClassifier is loading a model, // we need to make a new dataset that maps the test instances to // the structure expected by the mapped classifier - this is only // to ensure that the ClassifierPlotInstances object is configured // in accordance with what the embeded classifier was trained with Instances mappedClassifierDataset = ((weka.classifiers.misc.InputMappedClassifier) classifier) .getModelHeader(new Instances(mappedClassifierHeader, 0)); for (int zz = 0; zz < inst.numInstances(); zz++) { Instance mapped = ((weka.classifiers.misc.InputMappedClassifier) classifier) .constructMappedInstance(inst.instance(zz)); mappedClassifierDataset.add(mapped); } eval.setPriors(mappedClassifierDataset); plotInstances.setInstances(mappedClassifierDataset); plotInstances.setClassifier(classifier); plotInstances.setClassIndex(mappedClassifierDataset.classIndex()); plotInstances.setEvaluation(eval); } } return eval; }
/** * Accept a classifier to be evaluated. * * @param ce a <code>BatchClassifierEvent</code> value */ public void acceptClassifier(BatchClassifierEvent ce) { if (ce.getTestSet() == null || ce.getTestSet().isStructureOnly()) { return; // can't evaluate empty/non-existent test instances } Classifier classifier = ce.getClassifier(); try { if (ce.getGroupIdentifier() != m_currentBatchIdentifier) { if (m_setsComplete > 0) { if (m_logger != null) { m_logger.statusMessage( statusMessagePrefix() + "BUSY. Can't accept data " + "at this time."); m_logger.logMessage( "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " BUSY. Can't accept data at this time."); } return; } if (ce.getTrainSet().getDataSet() == null || ce.getTrainSet().getDataSet().numInstances() == 0) { // we have no training set to estimate majority class // or mean of target from Evaluation eval = new Evaluation(ce.getTestSet().getDataSet()); m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances(); m_PlotInstances.setInstances(ce.getTestSet().getDataSet()); m_PlotInstances.setClassifier(ce.getClassifier()); m_PlotInstances.setClassIndex(ce.getTestSet().getDataSet().classIndex()); m_PlotInstances.setEvaluation(eval); eval = adjustForInputMappedClassifier( eval, ce.getClassifier(), ce.getTestSet().getDataSet(), m_PlotInstances); eval.useNoPriors(); m_eval = new AggregateableEvaluation(eval); } else { // we can set up with the training set here Evaluation eval = new Evaluation(ce.getTrainSet().getDataSet()); m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances(); m_PlotInstances.setInstances(ce.getTrainSet().getDataSet()); m_PlotInstances.setClassifier(ce.getClassifier()); m_PlotInstances.setClassIndex(ce.getTestSet().getDataSet().classIndex()); m_PlotInstances.setEvaluation(eval); eval = adjustForInputMappedClassifier( eval, ce.getClassifier(), ce.getTrainSet().getDataSet(), m_PlotInstances); m_eval = new AggregateableEvaluation(eval); } m_PlotInstances.setUp(); m_currentBatchIdentifier = ce.getGroupIdentifier(); m_setsComplete = 0; m_aggregatedPlotInstances = null; String msg = "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " starting executor pool (" + getExecutionSlots() + " slots)..."; // start the execution pool if (m_executorPool == null) { startExecutorPool(); } m_tasks = new ArrayList<EvaluationTask>(); if (m_logger != null) { m_logger.logMessage(msg); } else { System.out.println(msg); } } // if m_tasks == null then we've been stopped if (m_setsComplete < ce.getMaxSetNumber() && m_tasks != null) { EvaluationTask newTask = new EvaluationTask( classifier, ce.getTrainSet().getDataSet(), ce.getTestSet().getDataSet(), ce.getSetNumber(), ce.getMaxSetNumber()); String msg = "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " scheduling " + " evaluation of fold " + ce.getSetNumber() + " for execution..."; if (m_logger != null) { m_logger.logMessage(msg); } else { System.out.println(msg); } m_tasks.add(newTask); m_executorPool.execute(newTask); } } catch (Exception ex) { // stop everything stop(); } }
/** * Takes an evaluation object from a task and aggregates it with the overall one. * * @param eval the evaluation object to aggregate * @param classifier the classifier used by the task * @param testData the testData from the task * @param plotInstances the ClassifierErrorsPlotInstances object from the task * @param setNum the set number processed by the task * @param maxSetNum the maximum number of sets in this batch */ protected synchronized void aggregateEvalTask( Evaluation eval, Classifier classifier, Instances testData, ClassifierErrorsPlotInstances plotInstances, int setNum, int maxSetNum) { m_eval.aggregate(eval); if (m_aggregatedPlotInstances == null) { m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances()); m_aggregatedPlotShapes = plotInstances.getPlotShapes(); m_aggregatedPlotSizes = plotInstances.getPlotSizes(); } else { Instances temp = plotInstances.getPlotInstances(); for (int i = 0; i < temp.numInstances(); i++) { m_aggregatedPlotInstances.add(temp.get(i)); m_aggregatedPlotShapes.addElement(plotInstances.getPlotShapes().get(i)); m_aggregatedPlotSizes.addElement(plotInstances.getPlotSizes().get(i)); } } m_setsComplete++; // if (ce.getSetNumber() == ce.getMaxSetNumber()) { if (m_setsComplete == maxSetNum) { try { String textTitle = classifier.getClass().getName(); String textOptions = ""; if (classifier instanceof OptionHandler) { textOptions = Utils.joinOptions(((OptionHandler) classifier).getOptions()); } textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length()); String resultT = "=== Evaluation result ===\n\n" + "Scheme: " + textTitle + "\n" + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n" : "") + "Relation: " + testData.relationName() + "\n\n" + m_eval.toSummaryString(); if (testData.classAttribute().isNominal()) { resultT += "\n" + m_eval.toClassDetailsString() + "\n" + m_eval.toMatrixString(); } TextEvent te = new TextEvent(ClassifierPerformanceEvaluator.this, resultT, textTitle); notifyTextListeners(te); // set up visualizable errors if (m_visualizableErrorListeners.size() > 0) { PlotData2D errorD = new PlotData2D(m_aggregatedPlotInstances); errorD.setShapeSize(m_aggregatedPlotSizes); errorD.setShapeType(m_aggregatedPlotShapes); errorD.setPlotName(textTitle + " " + textOptions); /* PlotData2D errorD = m_PlotInstances.getPlotData( textTitle + " " + textOptions); */ VisualizableErrorEvent vel = new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD); notifyVisualizableErrorListeners(vel); m_PlotInstances.cleanUp(); } if (testData.classAttribute().isNominal() && m_thresholdListeners.size() > 0) { ThresholdCurve tc = new ThresholdCurve(); Instances result = tc.getCurve(m_eval.predictions(), 0); result.setRelationName(testData.relationName()); PlotData2D pd = new PlotData2D(result); String htmlTitle = "<html><font size=-2>" + textTitle; String newOptions = ""; if (classifier instanceof OptionHandler) { String[] options = ((OptionHandler) classifier).getOptions(); if (options.length > 0) { for (int ii = 0; ii < options.length; ii++) { if (options[ii].length() == 0) { continue; } if (options[ii].charAt(0) == '-' && !(options[ii].charAt(1) >= '0' && options[ii].charAt(1) <= '9')) { newOptions += "<br>"; } newOptions += options[ii]; } } } htmlTitle += " " + newOptions + "<br>" + " (class: " + testData.classAttribute().value(0) + ")" + "</font></html>"; pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")"); pd.setPlotNameHTML(htmlTitle); boolean[] connectPoints = new boolean[result.numInstances()]; for (int jj = 1; jj < connectPoints.length; jj++) { connectPoints[jj] = true; } pd.setConnectPoints(connectPoints); ThresholdDataEvent rde = new ThresholdDataEvent( ClassifierPerformanceEvaluator.this, pd, testData.classAttribute()); notifyThresholdListeners(rde); } if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() + "Finished."); } } catch (Exception ex) { if (m_logger != null) { m_logger.logMessage( "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " problem constructing evaluation results. " + ex.getMessage()); } ex.printStackTrace(); } finally { m_visual.setStatic(); // save memory m_PlotInstances = null; m_setsComplete = 0; m_tasks = null; m_aggregatedPlotInstances = null; } } }
public void execute() { if (m_stopped) { return; } if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix() + "Evaluating (" + m_setNum + ")..."); m_visual.setAnimated(); } try { ClassifierErrorsPlotInstances plotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances(); Evaluation eval = null; if (m_trainData == null || m_trainData.numInstances() == 0) { eval = new Evaluation(m_testData); plotInstances.setInstances(m_testData); plotInstances.setClassifier(m_classifier); plotInstances.setClassIndex(m_testData.classIndex()); plotInstances.setEvaluation(eval); eval = adjustForInputMappedClassifier(eval, m_classifier, m_testData, plotInstances); eval.useNoPriors(); } else { eval = new Evaluation(m_trainData); plotInstances.setInstances(m_trainData); plotInstances.setClassifier(m_classifier); plotInstances.setClassIndex(m_trainData.classIndex()); plotInstances.setEvaluation(eval); eval = adjustForInputMappedClassifier(eval, m_classifier, m_trainData, plotInstances); } plotInstances.setUp(); for (int i = 0; i < m_testData.numInstances(); i++) { if (m_stopped) { break; } Instance temp = m_testData.instance(i); plotInstances.process(temp, m_classifier, eval); } if (m_stopped) { return; } aggregateEvalTask(eval, m_classifier, m_testData, plotInstances, m_setNum, m_maxSetNum); } catch (Exception ex) { ClassifierPerformanceEvaluator.this.stop(); // stop all processing if (m_logger != null) { m_logger.logMessage( "[ClassifierPerformanceEvaluator] " + statusMessagePrefix() + " problem evaluating classifier. " + ex.getMessage()); } ex.printStackTrace(); } }