예제 #1
0
  /**
   * Gets the classifier specification string, which contains the class name of the classifier and
   * any options to the classifier.
   *
   * @return the classifier string.
   */
  protected String getClassifierSpec() {

    Classifier c = getClassifier();
    if (c instanceof OptionHandler) {
      return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) c).getOptions());
    }
    return c.getClass().getName();
  }
예제 #2
0
  /**
   * Gets the classifier specification string, which contains the class name of the classifier and
   * any options to the classifier
   *
   * @param index the index of the classifier string to retrieve, starting from 0.
   * @return the classifier string, or the empty string if no classifier has been assigned (or the
   *     index given is out of range).
   */
  protected String getClassifierSpec(int index) {

    if (m_Classifiers.length < index) {
      return "";
    }
    Classifier c = getClassifier(index);
    if (c instanceof OptionHandler) {
      return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) c).getOptions());
    }
    return c.getClass().getName();
  }
예제 #3
0
  /**
   * Returns an enumeration describing the available options..
   *
   * @return an enumeration of all the available options.
   */
  public Enumeration listOptions() {

    Vector newVector = new Vector(1);

    newVector.addElement(
        new Option(
            "\tSkips the determination of sizes (train/test/classifier)\n"
                + "\t(default: sizes are determined)",
            "no-size",
            0,
            "-no-size"));
    newVector.addElement(
        new Option(
            "\tThe full class name of the classifier.\n"
                + "\teg: weka.classifiers.bayes.NaiveBayes",
            "W",
            1,
            "-W <class name>"));

    if ((m_Template != null) && (m_Template instanceof OptionHandler)) {
      newVector.addElement(
          new Option(
              "",
              "",
              0,
              "\nOptions specific to classifier " + m_Template.getClass().getName() + ":"));
      Enumeration enu = ((OptionHandler) m_Template).listOptions();
      while (enu.hasMoreElements()) {
        newVector.addElement(enu.nextElement());
      }
    }
    return newVector.elements();
  }
예제 #4
0
  /**
   * Gets the key describing the current SplitEvaluator. For example This may contain the name of
   * the classifier used for classifier predictive evaluation. The number of key fields must be
   * constant for a given SplitEvaluator.
   *
   * @return an array of objects containing the key.
   */
  public Object[] getKey() {

    Object[] key = new Object[KEY_SIZE];
    key[0] = m_Template.getClass().getName();
    key[1] = m_ClassifierOptions;
    key[2] = m_ClassifierVersion;
    return key;
  }
예제 #5
0
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
예제 #6
0
  /**
   * Returns a text description of the split evaluator.
   *
   * @return a text description of the split evaluator.
   */
  public String toString() {

    String result = "RegressionSplitEvaluator: ";
    if (m_Template == null) {
      return result + "<null> classifier";
    }
    return result
        + m_Template.getClass().getName()
        + " "
        + m_ClassifierOptions
        + "(version "
        + m_ClassifierVersion
        + ")";
  }
예제 #7
0
  /** Updates the options that the current classifier is using. */
  protected void updateOptions() {

    if (m_Template instanceof OptionHandler) {
      m_ClassifierOptions = Utils.joinOptions(((OptionHandler) m_Template).getOptions());
    } else {
      m_ClassifierOptions = "";
    }
    if (m_Template instanceof Serializable) {
      ObjectStreamClass obs = ObjectStreamClass.lookup(m_Template.getClass());
      m_ClassifierVersion = "" + obs.getSerialVersionUID();
    } else {
      m_ClassifierVersion = "";
    }
  }
예제 #8
0
 /**
  * Returns the value of the named measure
  *
  * @param additionalMeasureName the name of the measure to query for its value
  * @return the value of the named measure
  * @throws IllegalArgumentException if the named measure is not supported
  */
 public double getMeasure(String additionalMeasureName) {
   if (m_Template instanceof AdditionalMeasureProducer) {
     if (m_Classifier == null) {
       throw new IllegalArgumentException(
           "ClassifierSplitEvaluator: "
               + "Can't return result for measure, "
               + "classifier has not been built yet.");
     }
     return ((AdditionalMeasureProducer) m_Classifier).getMeasure(additionalMeasureName);
   } else {
     throw new IllegalArgumentException(
         "ClassifierSplitEvaluator: "
             + "Can't return value for : "
             + additionalMeasureName
             + ". "
             + m_Template.getClass().getName()
             + " "
             + "is not an AdditionalMeasureProducer");
   }
 }
예제 #9
0
  private static void evaluateClassifier(Classifier c, Instances trainData, Instances testData)
      throws Exception {
    System.err.println(
        "INFO: Starting split validation to predict '"
            + trainData.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' (#train="
            + trainData.numInstances()
            + ",#test="
            + testData.numInstances()
            + ") ...");

    if (trainData.classIndex() < 0) throw new IllegalStateException("class attribute not set");

    c.buildClassifier(trainData);
    Evaluation eval = new Evaluation(testData);
    eval.useNoPriors();
    double[] predictions = eval.evaluateModel(c, testData);

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));

    // write predictions to file
    {
      System.err.println("INFO: Writing predictions to file ...");
      Writer out = new FileWriter("prediction.trec");
      writePredictionsTrecEval(predictions, testData, 0, trainData.classIndex(), out);
      out.close();
    }

    // write predicted distributions to CSV
    {
      System.err.println("INFO: Writing predicted distributions to CSV ...");
      Writer out = new FileWriter("predicted_distribution.csv");
      writePredictedDistributions(c, testData, 0, out);
      out.close();
    }
  }
예제 #10
0
  /** outputs some data about the classifier */
  public String toString() {
    StringBuffer result;

    result = new StringBuffer();
    result.append("Weka - Demo\n===========\n\n");

    result.append(
        "Classifier...: "
            + m_Classifier.getClass().getName()
            + " "
            + Utils.joinOptions(m_Classifier.getOptions())
            + "\n");
    if (m_Filter instanceof OptionHandler)
      result.append(
          "Filter.......: "
              + m_Filter.getClass().getName()
              + " "
              + Utils.joinOptions(((OptionHandler) m_Filter).getOptions())
              + "\n");
    else result.append("Filter.......: " + m_Filter.getClass().getName() + "\n");
    result.append("Training file: " + m_TrainingFile + "\n");
    result.append("\n");

    result.append(m_Classifier.toString() + "\n");
    result.append(m_Evaluation.toSummaryString() + "\n");
    try {
      result.append(m_Evaluation.toMatrixString() + "\n");
    } catch (Exception e) {
      e.printStackTrace();
    }
    try {
      result.append(m_Evaluation.toClassDetailsString() + "\n");
    } catch (Exception e) {
      e.printStackTrace();
    }

    return result.toString();
  }
예제 #11
0
  static void evaluateClassifier(Classifier c, Instances data, int folds) throws Exception {
    System.err.println(
        "INFO: Starting crossvalidation to predict '"
            + data.classAttribute().name()
            + "' using '"
            + c.getClass().getCanonicalName()
            + ":"
            + Arrays.toString(c.getOptions())
            + "' ...");

    StringBuffer sb = new StringBuffer();
    Evaluation eval = new Evaluation(data);
    eval.crossValidateModel(c, data, folds, new Random(1), sb, new Range("first"), Boolean.FALSE);

    // write predictions to file
    {
      Writer out = new FileWriter("cv.log");
      out.write(sb.toString());
      out.close();
    }

    System.out.println(eval.toClassDetailsString());
    System.out.println(eval.toSummaryString("\nResults\n======\n", false));
  }
  /**
   * Takes an evaluation object from a task and aggregates it with the overall one.
   *
   * @param eval the evaluation object to aggregate
   * @param classifier the classifier used by the task
   * @param testData the testData from the task
   * @param plotInstances the ClassifierErrorsPlotInstances object from the task
   * @param setNum the set number processed by the task
   * @param maxSetNum the maximum number of sets in this batch
   */
  protected synchronized void aggregateEvalTask(
      Evaluation eval,
      Classifier classifier,
      Instances testData,
      ClassifierErrorsPlotInstances plotInstances,
      int setNum,
      int maxSetNum) {

    m_eval.aggregate(eval);

    if (m_aggregatedPlotInstances == null) {
      m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances());
      m_aggregatedPlotShapes = plotInstances.getPlotShapes();
      m_aggregatedPlotSizes = plotInstances.getPlotSizes();
    } else {
      Instances temp = plotInstances.getPlotInstances();
      for (int i = 0; i < temp.numInstances(); i++) {
        m_aggregatedPlotInstances.add(temp.get(i));
        m_aggregatedPlotShapes.addElement(plotInstances.getPlotShapes().get(i));
        m_aggregatedPlotSizes.addElement(plotInstances.getPlotSizes().get(i));
      }
    }
    m_setsComplete++;

    //  if (ce.getSetNumber() == ce.getMaxSetNumber()) {
    if (m_setsComplete == maxSetNum) {
      try {
        String textTitle = classifier.getClass().getName();
        String textOptions = "";
        if (classifier instanceof OptionHandler) {
          textOptions = Utils.joinOptions(((OptionHandler) classifier).getOptions());
        }
        textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length());
        String resultT =
            "=== Evaluation result ===\n\n"
                + "Scheme: "
                + textTitle
                + "\n"
                + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n" : "")
                + "Relation: "
                + testData.relationName()
                + "\n\n"
                + m_eval.toSummaryString();

        if (testData.classAttribute().isNominal()) {
          resultT += "\n" + m_eval.toClassDetailsString() + "\n" + m_eval.toMatrixString();
        }

        TextEvent te = new TextEvent(ClassifierPerformanceEvaluator.this, resultT, textTitle);
        notifyTextListeners(te);

        // set up visualizable errors
        if (m_visualizableErrorListeners.size() > 0) {
          PlotData2D errorD = new PlotData2D(m_aggregatedPlotInstances);
          errorD.setShapeSize(m_aggregatedPlotSizes);
          errorD.setShapeType(m_aggregatedPlotShapes);
          errorD.setPlotName(textTitle + " " + textOptions);

          /*          PlotData2D errorD = m_PlotInstances.getPlotData(
          textTitle + " " + textOptions); */
          VisualizableErrorEvent vel =
              new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD);
          notifyVisualizableErrorListeners(vel);
          m_PlotInstances.cleanUp();
        }

        if (testData.classAttribute().isNominal() && m_thresholdListeners.size() > 0) {
          ThresholdCurve tc = new ThresholdCurve();
          Instances result = tc.getCurve(m_eval.predictions(), 0);
          result.setRelationName(testData.relationName());
          PlotData2D pd = new PlotData2D(result);
          String htmlTitle = "<html><font size=-2>" + textTitle;
          String newOptions = "";
          if (classifier instanceof OptionHandler) {
            String[] options = ((OptionHandler) classifier).getOptions();
            if (options.length > 0) {
              for (int ii = 0; ii < options.length; ii++) {
                if (options[ii].length() == 0) {
                  continue;
                }
                if (options[ii].charAt(0) == '-'
                    && !(options[ii].charAt(1) >= '0' && options[ii].charAt(1) <= '9')) {
                  newOptions += "<br>";
                }
                newOptions += options[ii];
              }
            }
          }

          htmlTitle +=
              " "
                  + newOptions
                  + "<br>"
                  + " (class: "
                  + testData.classAttribute().value(0)
                  + ")"
                  + "</font></html>";
          pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")");
          pd.setPlotNameHTML(htmlTitle);
          boolean[] connectPoints = new boolean[result.numInstances()];
          for (int jj = 1; jj < connectPoints.length; jj++) {
            connectPoints[jj] = true;
          }

          pd.setConnectPoints(connectPoints);

          ThresholdDataEvent rde =
              new ThresholdDataEvent(
                  ClassifierPerformanceEvaluator.this, pd, testData.classAttribute());
          notifyThresholdListeners(rde);
        }
        if (m_logger != null) {
          m_logger.statusMessage(statusMessagePrefix() + "Finished.");
        }

      } catch (Exception ex) {
        if (m_logger != null) {
          m_logger.logMessage(
              "[ClassifierPerformanceEvaluator] "
                  + statusMessagePrefix()
                  + " problem constructing evaluation results. "
                  + ex.getMessage());
        }
        ex.printStackTrace();
      } finally {
        m_visual.setStatic();
        // save memory
        m_PlotInstances = null;
        m_setsComplete = 0;
        m_tasks = null;
        m_aggregatedPlotInstances = null;
      }
    }
  }
예제 #13
0
  @Override
  protected void notifyJobOutputListeners() {
    weka.classifiers.Classifier finalClassifier =
        ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassifier();
    Instances modelHeader =
        ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getTrainingHeader();
    String classAtt =
        ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassAttribute();
    try {
      weka.distributed.spark.WekaClassifierSparkJob.setClassIndex(classAtt, modelHeader, true);
    } catch (Exception ex) {
      if (m_log != null) {
        m_log.logMessage(statusMessagePrefix() + ex.getMessage());
      }
      ex.printStackTrace();
    }

    if (finalClassifier == null) {
      if (m_log != null) {
        m_log.logMessage(statusMessagePrefix() + "No classifier produced!");
      }
    }

    if (modelHeader == null) {
      if (m_log != null) {
        m_log.logMessage(statusMessagePrefix() + "No training header available for the model!");
      }
    }

    if (finalClassifier != null) {
      if (m_textListeners.size() > 0) {
        String textual = finalClassifier.toString();

        String title = "Spark: ";
        String classifierSpec = finalClassifier.getClass().getName();
        if (finalClassifier instanceof OptionHandler) {
          classifierSpec += " " + Utils.joinOptions(((OptionHandler) finalClassifier).getOptions());
        }
        title += classifierSpec;
        TextEvent te = new TextEvent(this, textual, title);
        for (TextListener t : m_textListeners) {
          t.acceptText(te);
        }
      }

      if (modelHeader != null) {
        // have to add a single bogus instance to the header to trick
        // the SerializedModelSaver into saving it (since it ignores
        // structure only DataSetEvents) :-)
        double[] vals = new double[modelHeader.numAttributes()];
        for (int i = 0; i < vals.length; i++) {
          vals[i] = Utils.missingValue();
        }
        Instance tempI = new DenseInstance(1.0, vals);
        modelHeader.add(tempI);
        DataSetEvent dse = new DataSetEvent(this, modelHeader);
        BatchClassifierEvent be = new BatchClassifierEvent(this, finalClassifier, dse, dse, 1, 1);
        for (BatchClassifierListener b : m_classifierListeners) {
          b.acceptClassifier(be);
        }
      }
    }
  }