Ejemplo n.º 1
0
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
Ejemplo n.º 2
0
  /**
   * Evaluates a feature subset by cross validation
   *
   * @param feature_set the subset to be evaluated
   * @param num_atts the number of attributes in the subset
   * @return the estimated accuracy
   * @throws Exception if subset can't be evaluated
   */
  protected double estimatePerformance(BitSet feature_set, int num_atts) throws Exception {

    m_evaluation = new Evaluation(m_theInstances);
    int i;
    int[] fs = new int[num_atts];

    double[] instA = new double[num_atts];
    int classI = m_theInstances.classIndex();

    int index = 0;
    for (i = 0; i < m_numAttributes; i++) {
      if (feature_set.get(i)) {
        fs[index++] = i;
      }
    }

    // create new hash table
    m_entries = new Hashtable((int) (m_theInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (i = 0; i < m_numInstances; i++) {

      Instance inst = m_theInstances.instance(i);
      for (int j = 0; j < fs.length; j++) {
        if (fs[j] == classI) {
          instA[j] = Double.MAX_VALUE; // missing for the class
        } else if (inst.isMissing(fs[j])) {
          instA[j] = Double.MAX_VALUE;
        } else {
          instA[j] = inst.value(fs[j]);
        }
      }
      insertIntoTable(inst, instA);
    }

    if (m_CVFolds == 1) {

      // calculate leave one out error
      for (i = 0; i < m_numInstances; i++) {
        Instance inst = m_theInstances.instance(i);
        for (int j = 0; j < fs.length; j++) {
          if (fs[j] == classI) {
            instA[j] = Double.MAX_VALUE; // missing for the class
          } else if (inst.isMissing(fs[j])) {
            instA[j] = Double.MAX_VALUE;
          } else {
            instA[j] = inst.value(fs[j]);
          }
        }
        evaluateInstanceLeaveOneOut(inst, instA);
      }
    } else {
      m_theInstances.randomize(m_rr);
      m_theInstances.stratify(m_CVFolds);

      // calculate 10 fold cross validation error
      for (i = 0; i < m_CVFolds; i++) {
        Instances insts = m_theInstances.testCV(m_CVFolds, i);
        evaluateFoldCV(insts, fs);
      }
    }

    switch (m_evaluationMeasure) {
      case EVAL_DEFAULT:
        if (m_classIsNominal) {
          return m_evaluation.pctCorrect();
        }
        return -m_evaluation.rootMeanSquaredError();
      case EVAL_ACCURACY:
        return m_evaluation.pctCorrect();
      case EVAL_RMSE:
        return -m_evaluation.rootMeanSquaredError();
      case EVAL_MAE:
        return -m_evaluation.meanAbsoluteError();
      case EVAL_AUC:
        double[] classPriors = m_evaluation.getClassPriors();
        Utils.normalize(classPriors);
        double weightedAUC = 0;
        for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
          double tempAUC = m_evaluation.areaUnderROC(i);
          if (!Utils.isMissingValue(tempAUC)) {
            weightedAUC += (classPriors[i] * tempAUC);
          } else {
            System.err.println("Undefined AUC!!");
          }
        }
        return weightedAUC;
    }
    // shouldn't get here
    return 0.0;
  }