@Override public double regress(DataPoint data) { Vec w = new DenseVector(baseRegressors.size()); for (int i = 0; i < baseRegressors.size(); i++) w.set(i, baseRegressors.get(i).regress(data)); return aggregatingRegressor.regress(new DataPoint(w)); }
/** * Copy constructor * * @param toCopy the object to copy */ public Stacking(Stacking toCopy) { this.folds = toCopy.folds; this.weightsPerModel = toCopy.weightsPerModel; if (toCopy.aggregatingClassifier != null) { this.aggregatingClassifier = toCopy.aggregatingClassifier.clone(); this.baseClassifiers = new ArrayList<Classifier>(toCopy.baseClassifiers.size()); for (Classifier bc : toCopy.baseClassifiers) this.baseClassifiers.add(bc.clone()); if (toCopy.aggregatingRegressor == toCopy.aggregatingClassifier) // supports both { aggregatingRegressor = (Regressor) aggregatingClassifier; baseRegressors = (List) baseClassifiers; // ugly type easure exploitation... } } else // we are doing with regressors only { this.aggregatingRegressor = toCopy.aggregatingRegressor.clone(); this.baseRegressors = new ArrayList<Regressor>(toCopy.baseRegressors.size()); for (Regressor br : toCopy.baseRegressors) this.baseRegressors.add(br.clone()); } }
@Override public void train(RegressionDataSet dataSet, ExecutorService threadPool) { final int models = baseRegressors.size(); weightsPerModel = 1; RegressionDataSet metaSet = new RegressionDataSet(models, new CategoricalData[0]); List<RegressionDataSet> dataFolds = dataSet.cvSet(folds); // iterate in the order of the folds so we get the right dataum weights for (RegressionDataSet rds : dataFolds) for (int i = 0; i < rds.getSampleSize(); i++) metaSet.addDataPoint( new DataPoint( new DenseVector(weightsPerModel * models), rds.getDataPoint(i).getWeight()), rds.getTargetValue(i)); // create the meta training set for (int c = 0; c < baseRegressors.size(); c++) { Regressor reg = baseRegressors.get(c); int pos = 0; for (int f = 0; f < dataFolds.size(); f++) { RegressionDataSet train = RegressionDataSet.comineAllBut(dataFolds, f); RegressionDataSet test = dataFolds.get(f); if (threadPool == null) reg.train(train); else reg.train(train, threadPool); for (int i = 0; i < test.getSampleSize(); i++) // evaluate and mark each point in the held out fold. { double pred = reg.regress(test.getDataPoint(i)); metaSet.getDataPoint(pos++).getNumericalValues().set(c, pred); } } } // train the meta model if (threadPool == null) aggregatingRegressor.train(metaSet); else aggregatingRegressor.train(metaSet, threadPool); // train the final classifiers, unless folds=1. In that case they are already trained if (folds != 1) { for (Regressor reg : baseRegressors) if (threadPool == null) reg.train(dataSet); else reg.train(dataSet, threadPool); } }
@Override public boolean supportsWeightedData() { if (aggregatingClassifier != null) return aggregatingClassifier.supportsWeightedData(); else return aggregatingRegressor.supportsWeightedData(); }
@Override public void train(final RegressionDataSet dataSet, final ExecutorService threadPool) { final PriorityQueue<RegressionModelEvaluation> bestModels = new PriorityQueue<RegressionModelEvaluation>( folds, new Comparator<RegressionModelEvaluation>() { @Override public int compare(RegressionModelEvaluation t, RegressionModelEvaluation t1) { double v0 = t.getScoreStats(regressionTargetScore).getMean(); double v1 = t1.getScoreStats(regressionTargetScore).getMean(); int order = regressionTargetScore.lowerIsBetter() ? 1 : -1; return order * Double.compare(v0, v1); } }); /** * Use this to keep track of which parameter we are altering. Index correspondence to the * parameter, and its value corresponds to which value has been used. Increment and carry counts * to iterate over all possible combinations. */ int[] setTo = new int[searchParams.size()]; /** * Each model is set to have different combination of parameters. We then train each model to * determine the best one. */ final List<Regressor> paramsToEval = new ArrayList<Regressor>(); while (true) { setParameters(setTo); paramsToEval.add(baseRegressor.clone()); if (incrementCombination(setTo)) break; } /* * This is the Executor used for training the models in parallel. If we * are not supposed to do that, it will be an executor that executes * them sequentually. */ final ExecutorService modelService; if (trainModelsInParallel) modelService = threadPool; else modelService = new FakeExecutor(); final CountDownLatch latch; // used for stopping in both cases // if we are doing our CV splits ahead of time, get them done now final List<RegressionDataSet> preFolded; /** Pre-combine our training combinations so that any caching can be re-used */ final List<RegressionDataSet> trainCombinations; if (reuseSameCVFolds) { preFolded = dataSet.cvSet(folds); trainCombinations = new ArrayList<RegressionDataSet>(preFolded.size()); for (int i = 0; i < preFolded.size(); i++) trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i)); } else { preFolded = null; trainCombinations = null; } boolean considerWarm = useWarmStarts && baseRegressor instanceof WarmRegressor; /** * make sure we don't do a warm start if its only supported when trained on the same data but we * aren't reuse-ing the same CV splits So we get the truth table * * <p>a | b | (a&&b)||¬a T | T | T T | F | F F | T | T F | F | T * * <p>where a = warmFromSameDataOnly and b = reuseSameSplit So we can instead use ¬ a || b */ if (considerWarm && (!((WarmRegressor) baseRegressor).warmFromSameDataOnly() || reuseSameCVFolds)) { /* we want all of the first parameter (which is the warm paramter, * taken care of for us) values done in a group. So We can get this * by just dividing up the larger list into sub lists, each sub list * is adjacent in the original and is the number of parameter values * we wanted to try */ int stepSize = searchValues.get(0).size(); int totalJobs = paramsToEval.size() / stepSize; latch = new CountDownLatch(totalJobs); for (int startPos = 0; startPos < paramsToEval.size(); startPos += stepSize) { final List<Regressor> subSet = paramsToEval.subList(startPos, startPos + stepSize); modelService.submit( new Runnable() { @Override public void run() { Regressor[] prevModels = null; for (Regressor r : subSet) { RegressionModelEvaluation rme = trainModelsInParallel ? new RegressionModelEvaluation(r, dataSet) : new RegressionModelEvaluation(r, dataSet, threadPool); rme.setKeepModels(true); // we need these to do warm starts! rme.setWarmModels(prevModels); rme.addScorer(regressionTargetScore.clone()); if (reuseSameCVFolds) rme.evaluateCrossValidation(preFolded, trainCombinations); else rme.evaluateCrossValidation(folds); prevModels = rme.getKeptModels(); synchronized (bestModels) { bestModels.add(rme); } } latch.countDown(); } }); } } else // regular CV, train a new model from scratch at every step { latch = new CountDownLatch(paramsToEval.size()); for (final Regressor toTrain : paramsToEval) { modelService.submit( new Runnable() { @Override public void run() { RegressionModelEvaluation rme = trainModelsInParallel ? new RegressionModelEvaluation(toTrain, dataSet) : new RegressionModelEvaluation(toTrain, dataSet, threadPool); rme.addScorer(regressionTargetScore.clone()); if (reuseSameCVFolds) rme.evaluateCrossValidation(preFolded, trainCombinations); else rme.evaluateCrossValidation(folds); synchronized (bestModels) { bestModels.add(rme); } latch.countDown(); } }); } } try { latch.await(); // Now we know the best classifier, we need to train one on the whole data set. Regressor bestRegressor = bestModels.peek().getRegressor(); // Just re-train it on the whole set if (trainFinalModel) { // try and warm start the final model if we can if (useWarmStarts && bestRegressor instanceof WarmRegressor && !((WarmRegressor) bestRegressor) .warmFromSameDataOnly()) // last line here needed to make sure we can do this warm // train { WarmRegressor wr = (WarmRegressor) bestRegressor; if (threadPool instanceof FakeExecutor) wr.train(dataSet, wr.clone()); else wr.train(dataSet, wr.clone(), threadPool); } else { if (threadPool instanceof FakeExecutor) bestRegressor.train(dataSet); else bestRegressor.train(dataSet, threadPool); } } trainedRegressor = bestRegressor; } catch (InterruptedException ex) { Logger.getLogger(GridSearch.class.getName()).log(Level.SEVERE, null, ex); } }