Пример #1
0
  public TrainingSet buildTSet(TrainingSet tSet, double[] w) {

    WeightBasedRandom wRnd = new WeightBasedRandom(w);

    int n = w.length;

    Instance[] sample = new Instance[n];

    Map<Integer, Instance> instances = tSet.getInstances();

    for (int i = 0; i < n; i++) {
      int instanceIndex = wRnd.nextInt();
      sample[i] = instances.get(instanceIndex);
    }

    return new TrainingSet(sample);
  }
Пример #2
0
  @Override
  public boolean train() {

    baseClassifiers = new ArrayList<Classifier>();

    int size = originalTSet.getSize();

    /*
     * Weights that define sample selection
     */
    double[] w = new double[size];

    /*
     * Number of times instance was misclassified by classifiers that are
     * currently in ensemble.
     */
    int[] m = new int[size];

    double w0 = 1.0 / size;

    Arrays.fill(w, w0);
    Arrays.fill(m, 0);

    for (int i = 0; i < classifierPopulation; i++) {
      if (verbose) {
        System.out.println("Instance weights: " + Arrays.toString(w));
        System.out.println("Instance misclassifications: " + Arrays.toString(m));
      }

      TrainingSet tSet = buildTSet(originalTSet, w);

      Classifier baseClassifier = getClassifierForTraining(tSet);

      baseClassifier.train();

      updateWeights(originalTSet, w, m, baseClassifier);

      baseClassifiers.add(baseClassifier);
    }

    return true;
  }
Пример #3
0
  public void updateWeights(TrainingSet tSet, double[] w, int[] m, Classifier baseClassifier) {

    int n = w.length;

    // update misclassification counts with results from latest classifier
    for (int i = 0; i < n; i++) {
      Instance instance = tSet.getInstance(i);
      Concept actualConcept = baseClassifier.classify(instance);
      Concept expectedConcept = instance.getConcept();
      if (actualConcept == null || !(actualConcept.getName().equals(expectedConcept.getName()))) {
        m[i]++;
      }
    }

    // update weights
    double sum = 0.0;
    for (int i = 0; i < n; i++) {
      sum += (1.0 + Math.pow(m[i], 4));
    }

    for (int i = 0; i < n; i++) {
      w[i] = (1.0 + Math.pow(m[i], 4)) / sum;
    }
  }