コード例 #1
0
  protected static Evaluation adjustForInputMappedClassifier(
      Evaluation eval,
      weka.classifiers.Classifier classifier,
      Instances inst,
      ClassifierErrorsPlotInstances plotInstances)
      throws Exception {

    if (classifier instanceof weka.classifiers.misc.InputMappedClassifier) {
      Instances mappedClassifierHeader =
          ((weka.classifiers.misc.InputMappedClassifier) classifier)
              .getModelHeader(new Instances(inst, 0));

      eval = new Evaluation(new Instances(mappedClassifierHeader, 0));

      if (!eval.getHeader().equalHeaders(inst)) {
        // When the InputMappedClassifier is loading a model,
        // we need to make a new dataset that maps the test instances to
        // the structure expected by the mapped classifier - this is only
        // to ensure that the ClassifierPlotInstances object is configured
        // in accordance with what the embeded classifier was trained with
        Instances mappedClassifierDataset =
            ((weka.classifiers.misc.InputMappedClassifier) classifier)
                .getModelHeader(new Instances(mappedClassifierHeader, 0));
        for (int zz = 0; zz < inst.numInstances(); zz++) {
          Instance mapped =
              ((weka.classifiers.misc.InputMappedClassifier) classifier)
                  .constructMappedInstance(inst.instance(zz));
          mappedClassifierDataset.add(mapped);
        }

        eval.setPriors(mappedClassifierDataset);
        plotInstances.setInstances(mappedClassifierDataset);
        plotInstances.setClassifier(classifier);
        plotInstances.setClassIndex(mappedClassifierDataset.classIndex());
        plotInstances.setEvaluation(eval);
      }
    }

    return eval;
  }
コード例 #2
0
  /**
   * Accept a classifier to be evaluated.
   *
   * @param ce a <code>BatchClassifierEvent</code> value
   */
  public void acceptClassifier(BatchClassifierEvent ce) {
    if (ce.getTestSet() == null || ce.getTestSet().isStructureOnly()) {
      return; // can't evaluate empty/non-existent test instances
    }

    Classifier classifier = ce.getClassifier();

    try {
      if (ce.getGroupIdentifier() != m_currentBatchIdentifier) {
        if (m_setsComplete > 0) {
          if (m_logger != null) {
            m_logger.statusMessage(
                statusMessagePrefix() + "BUSY. Can't accept data " + "at this time.");
            m_logger.logMessage(
                "[ClassifierPerformanceEvaluator] "
                    + statusMessagePrefix()
                    + " BUSY. Can't accept data at this time.");
          }
          return;
        }
        if (ce.getTrainSet().getDataSet() == null
            || ce.getTrainSet().getDataSet().numInstances() == 0) {
          // we have no training set to estimate majority class
          // or mean of target from
          Evaluation eval = new Evaluation(ce.getTestSet().getDataSet());
          m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
          m_PlotInstances.setInstances(ce.getTestSet().getDataSet());
          m_PlotInstances.setClassifier(ce.getClassifier());
          m_PlotInstances.setClassIndex(ce.getTestSet().getDataSet().classIndex());
          m_PlotInstances.setEvaluation(eval);

          eval =
              adjustForInputMappedClassifier(
                  eval, ce.getClassifier(), ce.getTestSet().getDataSet(), m_PlotInstances);
          eval.useNoPriors();
          m_eval = new AggregateableEvaluation(eval);
        } else {
          // we can set up with the training set here
          Evaluation eval = new Evaluation(ce.getTrainSet().getDataSet());
          m_PlotInstances = ExplorerDefaults.getClassifierErrorsPlotInstances();
          m_PlotInstances.setInstances(ce.getTrainSet().getDataSet());
          m_PlotInstances.setClassifier(ce.getClassifier());
          m_PlotInstances.setClassIndex(ce.getTestSet().getDataSet().classIndex());
          m_PlotInstances.setEvaluation(eval);

          eval =
              adjustForInputMappedClassifier(
                  eval, ce.getClassifier(), ce.getTrainSet().getDataSet(), m_PlotInstances);
          m_eval = new AggregateableEvaluation(eval);
        }

        m_PlotInstances.setUp();

        m_currentBatchIdentifier = ce.getGroupIdentifier();
        m_setsComplete = 0;

        m_aggregatedPlotInstances = null;

        String msg =
            "[ClassifierPerformanceEvaluator] "
                + statusMessagePrefix()
                + " starting executor pool ("
                + getExecutionSlots()
                + " slots)...";
        // start the execution pool
        if (m_executorPool == null) {
          startExecutorPool();
        }
        m_tasks = new ArrayList<EvaluationTask>();

        if (m_logger != null) {
          m_logger.logMessage(msg);
        } else {
          System.out.println(msg);
        }
      }

      // if m_tasks == null then we've been stopped
      if (m_setsComplete < ce.getMaxSetNumber() && m_tasks != null) {
        EvaluationTask newTask =
            new EvaluationTask(
                classifier,
                ce.getTrainSet().getDataSet(),
                ce.getTestSet().getDataSet(),
                ce.getSetNumber(),
                ce.getMaxSetNumber());
        String msg =
            "[ClassifierPerformanceEvaluator] "
                + statusMessagePrefix()
                + " scheduling "
                + " evaluation of fold "
                + ce.getSetNumber()
                + " for execution...";
        if (m_logger != null) {
          m_logger.logMessage(msg);
        } else {
          System.out.println(msg);
        }
        m_tasks.add(newTask);
        m_executorPool.execute(newTask);
      }
    } catch (Exception ex) {
      // stop everything
      stop();
    }
  }
コード例 #3
0
  /**
   * Takes an evaluation object from a task and aggregates it with the overall one.
   *
   * @param eval the evaluation object to aggregate
   * @param classifier the classifier used by the task
   * @param testData the testData from the task
   * @param plotInstances the ClassifierErrorsPlotInstances object from the task
   * @param setNum the set number processed by the task
   * @param maxSetNum the maximum number of sets in this batch
   */
  protected synchronized void aggregateEvalTask(
      Evaluation eval,
      Classifier classifier,
      Instances testData,
      ClassifierErrorsPlotInstances plotInstances,
      int setNum,
      int maxSetNum) {

    m_eval.aggregate(eval);

    if (m_aggregatedPlotInstances == null) {
      m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances());
      m_aggregatedPlotShapes = plotInstances.getPlotShapes();
      m_aggregatedPlotSizes = plotInstances.getPlotSizes();
    } else {
      Instances temp = plotInstances.getPlotInstances();
      for (int i = 0; i < temp.numInstances(); i++) {
        m_aggregatedPlotInstances.add(temp.get(i));
        m_aggregatedPlotShapes.addElement(plotInstances.getPlotShapes().get(i));
        m_aggregatedPlotSizes.addElement(plotInstances.getPlotSizes().get(i));
      }
    }
    m_setsComplete++;

    //  if (ce.getSetNumber() == ce.getMaxSetNumber()) {
    if (m_setsComplete == maxSetNum) {
      try {
        String textTitle = classifier.getClass().getName();
        String textOptions = "";
        if (classifier instanceof OptionHandler) {
          textOptions = Utils.joinOptions(((OptionHandler) classifier).getOptions());
        }
        textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length());
        String resultT =
            "=== Evaluation result ===\n\n"
                + "Scheme: "
                + textTitle
                + "\n"
                + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n" : "")
                + "Relation: "
                + testData.relationName()
                + "\n\n"
                + m_eval.toSummaryString();

        if (testData.classAttribute().isNominal()) {
          resultT += "\n" + m_eval.toClassDetailsString() + "\n" + m_eval.toMatrixString();
        }

        TextEvent te = new TextEvent(ClassifierPerformanceEvaluator.this, resultT, textTitle);
        notifyTextListeners(te);

        // set up visualizable errors
        if (m_visualizableErrorListeners.size() > 0) {
          PlotData2D errorD = new PlotData2D(m_aggregatedPlotInstances);
          errorD.setShapeSize(m_aggregatedPlotSizes);
          errorD.setShapeType(m_aggregatedPlotShapes);
          errorD.setPlotName(textTitle + " " + textOptions);

          /*          PlotData2D errorD = m_PlotInstances.getPlotData(
          textTitle + " " + textOptions); */
          VisualizableErrorEvent vel =
              new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD);
          notifyVisualizableErrorListeners(vel);
          m_PlotInstances.cleanUp();
        }

        if (testData.classAttribute().isNominal() && m_thresholdListeners.size() > 0) {
          ThresholdCurve tc = new ThresholdCurve();
          Instances result = tc.getCurve(m_eval.predictions(), 0);
          result.setRelationName(testData.relationName());
          PlotData2D pd = new PlotData2D(result);
          String htmlTitle = "<html><font size=-2>" + textTitle;
          String newOptions = "";
          if (classifier instanceof OptionHandler) {
            String[] options = ((OptionHandler) classifier).getOptions();
            if (options.length > 0) {
              for (int ii = 0; ii < options.length; ii++) {
                if (options[ii].length() == 0) {
                  continue;
                }
                if (options[ii].charAt(0) == '-'
                    && !(options[ii].charAt(1) >= '0' && options[ii].charAt(1) <= '9')) {
                  newOptions += "<br>";
                }
                newOptions += options[ii];
              }
            }
          }

          htmlTitle +=
              " "
                  + newOptions
                  + "<br>"
                  + " (class: "
                  + testData.classAttribute().value(0)
                  + ")"
                  + "</font></html>";
          pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")");
          pd.setPlotNameHTML(htmlTitle);
          boolean[] connectPoints = new boolean[result.numInstances()];
          for (int jj = 1; jj < connectPoints.length; jj++) {
            connectPoints[jj] = true;
          }

          pd.setConnectPoints(connectPoints);

          ThresholdDataEvent rde =
              new ThresholdDataEvent(
                  ClassifierPerformanceEvaluator.this, pd, testData.classAttribute());
          notifyThresholdListeners(rde);
        }
        if (m_logger != null) {
          m_logger.statusMessage(statusMessagePrefix() + "Finished.");
        }

      } catch (Exception ex) {
        if (m_logger != null) {
          m_logger.logMessage(
              "[ClassifierPerformanceEvaluator] "
                  + statusMessagePrefix()
                  + " problem constructing evaluation results. "
                  + ex.getMessage());
        }
        ex.printStackTrace();
      } finally {
        m_visual.setStatic();
        // save memory
        m_PlotInstances = null;
        m_setsComplete = 0;
        m_tasks = null;
        m_aggregatedPlotInstances = null;
      }
    }
  }
コード例 #4
0
    public void execute() {
      if (m_stopped) {
        return;
      }

      if (m_logger != null) {
        m_logger.statusMessage(statusMessagePrefix() + "Evaluating (" + m_setNum + ")...");
        m_visual.setAnimated();
      }
      try {

        ClassifierErrorsPlotInstances plotInstances =
            ExplorerDefaults.getClassifierErrorsPlotInstances();
        Evaluation eval = null;

        if (m_trainData == null || m_trainData.numInstances() == 0) {
          eval = new Evaluation(m_testData);
          plotInstances.setInstances(m_testData);
          plotInstances.setClassifier(m_classifier);
          plotInstances.setClassIndex(m_testData.classIndex());
          plotInstances.setEvaluation(eval);
          eval = adjustForInputMappedClassifier(eval, m_classifier, m_testData, plotInstances);

          eval.useNoPriors();
        } else {
          eval = new Evaluation(m_trainData);
          plotInstances.setInstances(m_trainData);
          plotInstances.setClassifier(m_classifier);
          plotInstances.setClassIndex(m_trainData.classIndex());
          plotInstances.setEvaluation(eval);
          eval = adjustForInputMappedClassifier(eval, m_classifier, m_trainData, plotInstances);
        }

        plotInstances.setUp();

        for (int i = 0; i < m_testData.numInstances(); i++) {
          if (m_stopped) {
            break;
          }
          Instance temp = m_testData.instance(i);
          plotInstances.process(temp, m_classifier, eval);
        }

        if (m_stopped) {
          return;
        }

        aggregateEvalTask(eval, m_classifier, m_testData, plotInstances, m_setNum, m_maxSetNum);

      } catch (Exception ex) {
        ClassifierPerformanceEvaluator.this.stop(); // stop all processing
        if (m_logger != null) {
          m_logger.logMessage(
              "[ClassifierPerformanceEvaluator] "
                  + statusMessagePrefix()
                  + " problem evaluating classifier. "
                  + ex.getMessage());
        }
        ex.printStackTrace();
      }
    }