/* * Uses the training set (current batch) to update current model stored in * modelInfo. */ private boolean trainAdditionalModel( ExampleSet trainingSet, Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException { Model model = this.trainBaseModel(trainingSet); trainingSet = model.apply(trainingSet); // get the weighted performance value of the example set with // respect to the new model WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(trainingSet); // debugMessage(wp); if (this.isModelUseful(wp.getContingencyMatrix()) == false) { // If the model is not considered to be useful then discard it. log("Discard model because of low advantage on training data."); return false; } else { // Add the new model and its weights to the collection of // models: modelInfo.add(new BayBoostBaseModelInfo(model, wp.getContingencyMatrix())); return true; } }
private BayBoostModel retrainLastWeight( BayBoostModel ensemble, ExampleSet exampleSet, Vector holdOutSet) throws OperatorException { this.prepareExtendedBatch(exampleSet); // method fits by chance int modelNum = ensemble.getNumberOfModels(); Vector<BayBoostBaseModelInfo> modelInfo = new Vector<BayBoostBaseModelInfo>(); double[] priors = ensemble.getPriors(); for (int i = 0; i < modelNum - 1; i++) { Model model = ensemble.getModel(i); ContingencyMatrix cm = ensemble.getContingencyMatrix(i); modelInfo.add(new BayBoostBaseModelInfo(model, cm)); exampleSet = model.apply(exampleSet); WeightedPerformanceMeasures.reweightExamples(exampleSet, cm, false); } Model latestModel = ensemble.getModel(modelNum - 1); exampleSet = latestModel.apply(exampleSet); // quite ugly: double[] weights = new double[holdOutSet.size()]; Iterator it = holdOutSet.iterator(); int index = 0; while (it.hasNext()) { Example example = (Example) it.next(); weights[index++] = example.getWeight(); } Iterator<Example> reader = exampleSet.iterator(); while (reader.hasNext()) { reader.next().setWeight(0); } it = holdOutSet.iterator(); index = 0; while (it.hasNext()) { Example example = (Example) it.next(); example.setWeight(weights[index++]); } WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet); modelInfo.add(new BayBoostBaseModelInfo(latestModel, wp.getContingencyMatrix())); return new BayBoostModel(exampleSet, modelInfo, priors); }
/** * This helper method takes as input the traing set and the set of models trained so far. It * re-estimates the model weights one by one, which means that it changes the contents of the * modelInfo container. Works with crisp base classifiers, only! * * @param exampleSet the training set to be used to tune the weights * @param modelInfo the <code>Vector</code> of <code>Model</code>s, each with its biasMatrix * @return <code>true</code> iff the <code>ExampleSet</code> contains at least one example that is * not yet explained deterministically (otherwise: nothing left to learn) */ private boolean adjustBaseModelWeights( ExampleSet exampleSet, Vector<BayBoostBaseModelInfo> modelInfo) throws OperatorException { for (int j = 0; j < modelInfo.size(); j++) { BayBoostBaseModelInfo consideredModelInfo = modelInfo.get(j); Model consideredModel = consideredModelInfo.getModel(); ContingencyMatrix cm = consideredModelInfo.getContingencyMatrix(); // double[][] oldBiasMatrix = (double[][]) consideredModelInfo[1]; BayBoostStream.createOrReplacePredictedLabelFor(exampleSet, consideredModel); exampleSet = consideredModel.apply(exampleSet); if (exampleSet.getAttributes().getPredictedLabel().isNominal() == false) { // Only the case of nominal base classifiers is supported! throw new UserError( this, 101, new Object[] {exampleSet.getAttributes().getLabel(), "BayBoostStream base learners"}); } WeightedPerformanceMeasures wp = new WeightedPerformanceMeasures(exampleSet); ContingencyMatrix cmNew = wp.getContingencyMatrix(); // double[][] newBiasMatrix = wp.createLiftRatioMatrix(); if (isModelUseful(cm) == false) { modelInfo.remove(j); j--; log("Discard base model because of low advantage."); } else { consideredModelInfo = new BayBoostBaseModelInfo(consideredModel, cmNew); modelInfo.set(j, consideredModelInfo); boolean stillUncoveredExamples = (WeightedPerformanceMeasures.reweightExamples(exampleSet, cmNew, false) > 0); if (stillUncoveredExamples == false) return false; } } return true; }