예제 #1
0
  /**
   * Buildclassifier selects a classifier from the set of classifiers by minimising error on the
   * training data.
   *
   * @param data the training data to be used for generating the boosted classifier.
   * @exception Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    if (m_Classifiers.length == 0) {
      throw new Exception("No base classifiers have been set!");
    }
    Instances newData = new Instances(data);
    newData.deleteWithMissingClass();
    newData.randomize(new Random(m_Seed));
    if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
      newData.stratify(m_NumXValFolds);
    Instances train = newData; // train on all data by default
    Instances test = newData; // test on training data by default
    Classifier bestClassifier = null;
    int bestIndex = -1;
    double bestPerformance = Double.NaN;
    int numClassifiers = m_Classifiers.length;
    for (int i = 0; i < numClassifiers; i++) {
      Classifier currentClassifier = getClassifier(i);
      Evaluation evaluation;
      if (m_NumXValFolds > 1) {
        evaluation = new Evaluation(newData);
        for (int j = 0; j < m_NumXValFolds; j++) {
          train = newData.trainCV(m_NumXValFolds, j);
          test = newData.testCV(m_NumXValFolds, j);
          currentClassifier.buildClassifier(train);
          evaluation.setPriors(train);
          evaluation.evaluateModel(currentClassifier, test);
        }
      } else {
        currentClassifier.buildClassifier(train);
        evaluation = new Evaluation(train);
        evaluation.evaluateModel(currentClassifier, test);
      }

      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println(
            "Error rate: "
                + Utils.doubleToString(error, 6, 4)
                + " for classifier "
                + currentClassifier.getClass().getName());
      }

      if ((i == 0) || (error < bestPerformance)) {
        bestClassifier = currentClassifier;
        bestPerformance = error;
        bestIndex = i;
      }
    }
    m_ClassifierIndex = bestIndex;
    m_Classifier = bestClassifier;
    if (m_NumXValFolds > 1) {
      m_Classifier.buildClassifier(newData);
    }
  }
예제 #2
0
  public void split() throws IOException {
    FileSystem fs = FileSystem.get(new Configuration());

    fs.delete(splitsDir, true);

    instances.randomize(new Random(1));
    instances.stratify(numberOfSplits);

    for (int i = 0; i < numberOfSplits; i++) {
      BufferedWriter bw =
          new BufferedWriter(
              new OutputStreamWriter(fs.create(new Path(splitsDir, "train_split_" + i + ".arff"))));
      bw.write(instances.testCV(numberOfSplits, i).toString());
      bw.close();
    }
  }
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();

    if (!(m_Classifier instanceof OptionHandler)) {
      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
    }
    m_InitOptions = ((OptionHandler) m_Classifier).getOptions();
    m_BestPerformance = -99;
    m_NumAttributes = trainData.numAttributes();
    Random random = new Random(m_Seed);
    trainData.randomize(random);
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();

    // Check whether there are any parameters to optimize
    if (m_CVParams.size() == 0) {
      m_Classifier.buildClassifier(trainData);
      m_BestClassifierOptions = m_InitOptions;
      return;
    }

    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestClassifierOptions = null;

    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData, random);

    String[] options = (String[]) m_BestClassifierOptions.clone();
    ((OptionHandler) m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }
예제 #4
0
  /**
   * Cleanses the data based on misclassifications when performing cross-validation.
   *
   * @param data the data to train with and cleanse
   * @return the cleansed data
   * @throws Exception if something goes wrong
   */
  private Instances cleanseCross(Instances data) throws Exception {

    Instance inst;
    Instances crossSet = new Instances(data);
    Instances temp = new Instances(data, data.numInstances());
    Instances inverseSet = new Instances(data, data.numInstances());
    int count = 0;
    double ans;
    int iterations = 0;
    int classIndex = m_classIndex;
    if (classIndex < 0) {
      classIndex = data.classIndex();
    }
    if (classIndex < 0) {
      classIndex = data.numAttributes() - 1;
    }

    // loop until perfect
    while (count != crossSet.numInstances()
        && crossSet.numInstances() >= m_numOfCrossValidationFolds) {

      count = crossSet.numInstances();

      // check if hit maximum number of iterations
      iterations++;
      if (m_numOfCleansingIterations > 0 && iterations > m_numOfCleansingIterations) {
        break;
      }

      crossSet.setClassIndex(classIndex);

      if (crossSet.classAttribute().isNominal()) {
        crossSet.stratify(m_numOfCrossValidationFolds);
      }
      // do the folds
      temp = new Instances(crossSet, crossSet.numInstances());

      for (int fold = 0; fold < m_numOfCrossValidationFolds; fold++) {
        Instances train = crossSet.trainCV(m_numOfCrossValidationFolds, fold);
        m_cleansingClassifier.buildClassifier(train);
        Instances test = crossSet.testCV(m_numOfCrossValidationFolds, fold);
        // now test
        for (int i = 0; i < test.numInstances(); i++) {
          inst = test.instance(i);
          ans = m_cleansingClassifier.classifyInstance(inst);
          if (crossSet.classAttribute().isNumeric()) {
            if (ans >= inst.classValue() - m_numericClassifyThreshold
                && ans <= inst.classValue() + m_numericClassifyThreshold) {
              temp.add(inst);
            } else if (m_invertMatching) {
              inverseSet.add(inst);
            }
          } else { // class is nominal
            if (ans == inst.classValue()) {
              temp.add(inst);
            } else if (m_invertMatching) {
              inverseSet.add(inst);
            }
          }
        }
      }
      crossSet = temp;
    }

    if (m_invertMatching) {
      inverseSet.setClassIndex(data.classIndex());
      return inverseSet;
    } else {
      crossSet.setClassIndex(data.classIndex());
      return crossSet;
    }
  }
예제 #5
0
  /**
   * Evaluates a feature subset by cross validation
   *
   * @param feature_set the subset to be evaluated
   * @param num_atts the number of attributes in the subset
   * @return the estimated accuracy
   * @throws Exception if subset can't be evaluated
   */
  protected double estimatePerformance(BitSet feature_set, int num_atts) throws Exception {

    m_evaluation = new Evaluation(m_theInstances);
    int i;
    int[] fs = new int[num_atts];

    double[] instA = new double[num_atts];
    int classI = m_theInstances.classIndex();

    int index = 0;
    for (i = 0; i < m_numAttributes; i++) {
      if (feature_set.get(i)) {
        fs[index++] = i;
      }
    }

    // create new hash table
    m_entries = new Hashtable((int) (m_theInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (i = 0; i < m_numInstances; i++) {

      Instance inst = m_theInstances.instance(i);
      for (int j = 0; j < fs.length; j++) {
        if (fs[j] == classI) {
          instA[j] = Double.MAX_VALUE; // missing for the class
        } else if (inst.isMissing(fs[j])) {
          instA[j] = Double.MAX_VALUE;
        } else {
          instA[j] = inst.value(fs[j]);
        }
      }
      insertIntoTable(inst, instA);
    }

    if (m_CVFolds == 1) {

      // calculate leave one out error
      for (i = 0; i < m_numInstances; i++) {
        Instance inst = m_theInstances.instance(i);
        for (int j = 0; j < fs.length; j++) {
          if (fs[j] == classI) {
            instA[j] = Double.MAX_VALUE; // missing for the class
          } else if (inst.isMissing(fs[j])) {
            instA[j] = Double.MAX_VALUE;
          } else {
            instA[j] = inst.value(fs[j]);
          }
        }
        evaluateInstanceLeaveOneOut(inst, instA);
      }
    } else {
      m_theInstances.randomize(m_rr);
      m_theInstances.stratify(m_CVFolds);

      // calculate 10 fold cross validation error
      for (i = 0; i < m_CVFolds; i++) {
        Instances insts = m_theInstances.testCV(m_CVFolds, i);
        evaluateFoldCV(insts, fs);
      }
    }

    switch (m_evaluationMeasure) {
      case EVAL_DEFAULT:
        if (m_classIsNominal) {
          return m_evaluation.pctCorrect();
        }
        return -m_evaluation.rootMeanSquaredError();
      case EVAL_ACCURACY:
        return m_evaluation.pctCorrect();
      case EVAL_RMSE:
        return -m_evaluation.rootMeanSquaredError();
      case EVAL_MAE:
        return -m_evaluation.meanAbsoluteError();
      case EVAL_AUC:
        double[] classPriors = m_evaluation.getClassPriors();
        Utils.normalize(classPriors);
        double weightedAUC = 0;
        for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
          double tempAUC = m_evaluation.areaUnderROC(i);
          if (!Utils.isMissingValue(tempAUC)) {
            weightedAUC += (classPriors[i] * tempAUC);
          } else {
            System.err.println("Undefined AUC!!");
          }
        }
        return weightedAUC;
    }
    // shouldn't get here
    return 0.0;
  }