Esempio n. 1
0
  private BayBoostModel retrainLastWeight(
      BayBoostModel ensemble, ExampleSet exampleSet, Vector holdOutSet) throws OperatorException {
    this.prepareExtendedBatch(exampleSet); // method fits by chance
    int modelNum = ensemble.getNumberOfModels();
    Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>();
    double[] priors = ensemble.getPriors();
    for (int i = 0; i < modelNum - 1; i++) {
      Model model = ensemble.getModel(i);
      ContingencyMatrix cm = ensemble.getContingencyMatrix(i);
      modelInfo.add(new BayBoostBaseModelInfo(model, cm));
      exampleSet = model.apply(exampleSet);
      WeightedPerformanceMeasures.reweightExamples(exampleSet, cm, false);
    }
    Model latestModel = ensemble.getModel(modelNum - 1);
    exampleSet = latestModel.apply(exampleSet);

    // quite ugly:
    double[] weights = new double[holdOutSet.size()];
    Iterator it = holdOutSet.iterator();
    int index = 0;
    while (it.hasNext()) {
      Example example = (Example) it.next();
      weights[index++] = example.getWeight();
    }
    Iterator<Example> reader = exampleSet.iterator();
    while (reader.hasNext()) {
      reader.next().setWeight(0);
    }
    it = holdOutSet.iterator();
    index = 0;
    while (it.hasNext()) {
      Example example = (Example) it.next();
      example.setWeight(weights[index++]);
    }

    WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet);
    modelInfo.add(new BayBoostBaseModelInfo(latestModel, wp.getContingencyMatrix()));

    return new BayBoostModel(exampleSet, modelInfo, priors);
  }
Esempio n. 2
0
  /**
   * Constructs a <code>Model</code> repeatedly running a weak learner, reweighting the training
   * example set accordingly, and combining the hypothesis using the available weighted performance
   * values.
   */
  public Model learn(ExampleSet exampleSet) throws OperatorException {
    this.runVector = new RunVector();
    BayBoostModel ensembleNewBatch = null;
    BayBoostModel ensembleExtBatch = null;
    final Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>(); // for
    // models
    // and
    // their
    // probability
    // estimates
    Vector<BayBoostBaseModelInfo> modelInfo2 = new Vector<BayBoostBaseModelInfo>();
    this.currentIteration = 0;
    int firstOpenBatch = 1;

    // prepare the stream control attribute
    final Attribute streamControlAttribute;
    {
      Attribute attr = null;
      if ((attr = exampleSet.getAttributes().get(STREAM_CONTROL_ATTRIB_NAME)) == null)
        streamControlAttribute =
            com.rapidminer.example.Tools.createSpecialAttribute(
                exampleSet, STREAM_CONTROL_ATTRIB_NAME, Ontology.INTEGER);
      else {
        streamControlAttribute = attr;
        logWarning(
            "Attribute with the (reserved) name of the stream control attribute exists. It is probably an old version created by this operator. Trying to recycle it... ");
        // Resetting the stream control attribute values by overwriting
        // them with 0 avoids (unlikely)
        // problems in case the same ExampleSet is passed to this
        // operator over and over again:
        Iterator<Example> e = exampleSet.iterator();
        while (e.hasNext()) {
          e.next().setValue(streamControlAttribute, 0);
        }
      }
    }

    // and the weight attribute
    if (exampleSet.getAttributes().getWeight() == null) {
      this.prepareWeights(exampleSet);
    }

    boolean estimateFavoursExtBatch = true;
    // *** The main loop, one iteration per batch: ***
    Iterator<Example> reader = exampleSet.iterator();
    while (reader.hasNext()) {
      // increment batch number, collect batch and evaluate performance of
      // current model on batch
      double[] classPriors =
          this.prepareBatch(++this.currentIteration, reader, streamControlAttribute);

      ConditionedExampleSet trainingSet =
          new ConditionedExampleSet(
              exampleSet, new BatchFilterCondition(streamControlAttribute, this.currentIteration));

      final EstimatedPerformance estPerf;

      // Step 1: apply the ensemble model to the current batch (prediction
      // phase), evaluate and store result
      if (ensembleExtBatch != null) {
        // apply extended batch model first:
        trainingSet = (ConditionedExampleSet) ensembleExtBatch.apply(trainingSet);
        this.performance = evaluatePredictions(trainingSet); // unweighted
        // performance;

        // then apply new batch model:
        trainingSet = (ConditionedExampleSet) ensembleNewBatch.apply(trainingSet);
        double newBatchPerformance = evaluatePredictions(trainingSet);

        // heuristic: use extended batch model for predicting
        // unclassified instances
        if (estimateFavoursExtBatch == true)
          estPerf =
              new EstimatedPerformance("accuracy", this.performance, trainingSet.size(), false);
        else
          estPerf =
              new EstimatedPerformance("accuracy", newBatchPerformance, trainingSet.size(), false);

        // final double[] ensembleWeights;

        // continue with the better model:
        if (newBatchPerformance > this.performance) {
          this.performance = newBatchPerformance;
          firstOpenBatch = Math.max(1, this.currentIteration - 1);
          // ensembleWeights = ensembleNewBatch.getModelWeights();
        } else {
          modelInfo.clear();
          modelInfo.addAll(modelInfo2);
          // ensembleWeights = ensembleExtBatch.getModelWeights();
        }

      } else if (ensembleNewBatch != null) {
        trainingSet = (ConditionedExampleSet) ensembleNewBatch.apply(trainingSet);
        this.performance = evaluatePredictions(trainingSet);
        firstOpenBatch = Math.max(1, this.currentIteration - 1);
        estPerf = new EstimatedPerformance("accuracy", this.performance, trainingSet.size(), false);
      } else estPerf = null; // no model ==> no prediction performance

      if (estPerf != null) {
        PerformanceVector perf = new PerformanceVector();
        perf.addAveragable(estPerf);
        this.runVector.addVector(perf);
      }

      // *** retraining phase ***
      // Step 2: First reconstruct the initial weighting, if necessary.
      if (this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS) == true) {
        this.rescalePriors(trainingSet, classPriors);
      }

      estimateFavoursExtBatch = true;
      // Step 3: Find better weights for existing models and continue
      // training
      if (modelInfo.size() > 0) {

        modelInfo2 = new Vector<BayBoostBaseModelInfo>();
        for (BayBoostBaseModelInfo bbbmi : modelInfo) {
          modelInfo2.add(bbbmi); // BayBoostBaseModelInfo objects
          // cannot be changed, no deep copy
          // required
        }

        // separate hold out set
        final double holdOutRatio = this.getParameterAsDouble(PARAMETER_FRACTION_HOLD_OUT_SET);
        Vector<Example> holdOutExamples = new Vector<Example>();
        if (holdOutRatio > 0) {
          RandomGenerator random = RandomGenerator.getRandomGenerator(this);
          Iterator<Example> randBatchReader = trainingSet.iterator();
          while (randBatchReader.hasNext()) {
            Example example = randBatchReader.next();
            if (random.nextDoubleInRange(0, 1) <= holdOutRatio) {
              example.setValue(streamControlAttribute, 0);
              holdOutExamples.add(example);
            }
          }
          // TODO: create new example set
          // trainingSet.updateCondition();
        }

        // model 1: train one more base classifier
        boolean trainingExamplesLeft = this.adjustBaseModelWeights(trainingSet, modelInfo);
        if (trainingExamplesLeft) {
          // "trainingExamplesLeft" needs to be checked to avoid
          // exceptions.
          // Anyway, learning does not make sense, otherwise.
          if (!this.trainAdditionalModel(trainingSet, modelInfo)) {}
        }
        ensembleNewBatch = new BayBoostModel(exampleSet, modelInfo, classPriors);

        // model 2: remove last classifier, extend batch, train on
        // extended batch
        ExampleSet extendedBatch = // because of the ">=" condition it
            // is sufficient to remember the
            // opening batch
            new ConditionedExampleSet(
                exampleSet, new BatchFilterCondition(streamControlAttribute, firstOpenBatch));
        classPriors = this.prepareExtendedBatch(extendedBatch);
        if (this.getParameterAsBoolean(PARAMETER_RESCALE_LABEL_PRIORS) == true) {
          this.rescalePriors(extendedBatch, classPriors);
        }
        modelInfo2.remove(modelInfo2.size() - 1);
        trainingExamplesLeft = this.adjustBaseModelWeights(extendedBatch, modelInfo2);
        // If no training examples are left: no need and chance to
        // continue training.
        if (trainingExamplesLeft == false) {
          ensembleExtBatch = new BayBoostModel(exampleSet, modelInfo2, classPriors);
        } else {
          boolean success = this.trainAdditionalModel(extendedBatch, modelInfo2);
          if (success) {
            ensembleExtBatch = new BayBoostModel(exampleSet, modelInfo2, classPriors);
          } else {
            ensembleExtBatch = null;
            estimateFavoursExtBatch = false;
          }
        }

        if (holdOutRatio > 0) {
          Iterator hoEit = holdOutExamples.iterator();
          while (hoEit.hasNext()) {
            ((Example) hoEit.next()).setValue(streamControlAttribute, this.currentIteration);
          }
          // TODO: create new example set
          // trainingSet.updateCondition();

          if (ensembleExtBatch != null) {
            trainingSet = (ConditionedExampleSet) ensembleNewBatch.apply(trainingSet);
            hoEit = holdOutExamples.iterator();
            int errors = 0;
            while (hoEit.hasNext()) {
              Example example = (Example) hoEit.next();
              if (example.getPredictedLabel() != example.getLabel()) errors++;
            }
            double newBatchErr = ((double) errors) / holdOutExamples.size();

            trainingSet = (ConditionedExampleSet) ensembleExtBatch.apply(trainingSet);
            hoEit = holdOutExamples.iterator();
            errors = 0;
            while (hoEit.hasNext()) {
              Example example = (Example) hoEit.next();
              if (example.getPredictedLabel() != example.getLabel()) errors++;
            }
            double extBatchErr = ((double) errors) / holdOutExamples.size();

            estimateFavoursExtBatch = (extBatchErr <= newBatchErr);

            if (estimateFavoursExtBatch) {
              ensembleExtBatch =
                  this.retrainLastWeight(ensembleExtBatch, trainingSet, holdOutExamples);
            } else
              ensembleNewBatch =
                  this.retrainLastWeight(ensembleNewBatch, trainingSet, holdOutExamples);
          } else
            ensembleNewBatch =
                this.retrainLastWeight(ensembleNewBatch, trainingSet, holdOutExamples);
        }
      } else {
        this.trainAdditionalModel(trainingSet, modelInfo);
        ensembleNewBatch = new BayBoostModel(exampleSet, modelInfo, classPriors);
        ensembleExtBatch = null;
        estimateFavoursExtBatch = false;
      }
    }
    this.restoreOldWeights(exampleSet);
    return (ensembleExtBatch == null ? ensembleNewBatch : ensembleExtBatch);
  }