示例#1
0
  @Override
  public void trainC(ClassificationDataSet dataSet, ExecutorService threadPool) {
    final int models = baseClassifiers.size();
    final int C = dataSet.getClassSize();
    weightsPerModel = C == 2 ? 1 : C;
    ClassificationDataSet metaSet =
        new ClassificationDataSet(
            weightsPerModel * models, new CategoricalData[0], dataSet.getPredicting());

    List<ClassificationDataSet> dataFolds = dataSet.cvSet(folds);
    // iterate in the order of the folds so we get the right dataum weights
    for (ClassificationDataSet cds : dataFolds)
      for (int i = 0; i < cds.getSampleSize(); i++)
        metaSet.addDataPoint(
            new DenseVector(weightsPerModel * models),
            cds.getDataPointCategory(i),
            cds.getDataPoint(i).getWeight());

    // create the meta training set
    for (int c = 0; c < baseClassifiers.size(); c++) {
      Classifier cl = baseClassifiers.get(c);
      int pos = 0;
      for (int f = 0; f < dataFolds.size(); f++) {
        ClassificationDataSet train = ClassificationDataSet.comineAllBut(dataFolds, f);
        ClassificationDataSet test = dataFolds.get(f);
        if (threadPool == null) cl.trainC(train);
        else cl.trainC(train, threadPool);
        for (int i = 0;
            i < test.getSampleSize();
            i++) // evaluate and mark each point in the held out fold.
        {
          CategoricalResults pred = cl.classify(test.getDataPoint(i));
          if (C == 2)
            metaSet.getDataPoint(pos).getNumericalValues().set(c, pred.getProb(0) * 2 - 1);
          else {
            Vec toSet = metaSet.getDataPoint(pos).getNumericalValues();
            for (int j = weightsPerModel * c; j < weightsPerModel * (c + 1); j++)
              toSet.set(j, pred.getProb(j - weightsPerModel * c));
          }

          pos++;
        }
      }
    }

    // train the meta model
    if (threadPool == null) aggregatingClassifier.trainC(metaSet);
    else aggregatingClassifier.trainC(metaSet, threadPool);

    // train the final classifiers, unless folds=1. In that case they are already trained
    if (folds != 1) {
      for (Classifier cl : baseClassifiers)
        if (threadPool == null) cl.trainC(dataSet);
        else cl.trainC(dataSet, threadPool);
    }
  }
示例#2
0
  @Override
  public void trainC(final ClassificationDataSet dataSet, final ExecutorService threadPool) {
    final PriorityQueue<ClassificationModelEvaluation> bestModels =
        new PriorityQueue<ClassificationModelEvaluation>(
            folds,
            new Comparator<ClassificationModelEvaluation>() {
              @Override
              public int compare(
                  ClassificationModelEvaluation t, ClassificationModelEvaluation t1) {
                double v0 = t.getScoreStats(classificationTargetScore).getMean();
                double v1 = t1.getScoreStats(classificationTargetScore).getMean();
                int order = classificationTargetScore.lowerIsBetter() ? 1 : -1;
                return order * Double.compare(v0, v1);
              }
            });

    /**
     * Use this to keep track of which parameter we are altering. Index correspondence to the
     * parameter, and its value corresponds to which value has been used. Increment and carry counts
     * to iterate over all possible combinations.
     */
    int[] setTo = new int[searchParams.size()];

    /**
     * Each model is set to have different combination of parameters. We then train each model to
     * determine the best one.
     */
    final List<Classifier> paramsToEval = new ArrayList<Classifier>();

    while (true) {
      setParameters(setTo);

      paramsToEval.add(baseClassifier.clone());

      if (incrementCombination(setTo)) break;
    }
    /*
     * This is the Executor used for training the models in parallel. If we
     * are not supposed to do that, it will be an executor that executes
     * them sequentually.
     */
    final ExecutorService modelService;
    if (trainModelsInParallel) modelService = threadPool;
    else modelService = new FakeExecutor();

    final CountDownLatch latch; // used for stopping in both cases

    // if we are doing our CV splits ahead of time, get them done now
    final List<ClassificationDataSet> preFolded;

    /** Pre-combine our training combinations so that any caching can be re-used */
    final List<ClassificationDataSet> trainCombinations;

    if (reuseSameCVFolds) {
      preFolded = dataSet.cvSet(folds);
      trainCombinations = new ArrayList<ClassificationDataSet>(preFolded.size());
      for (int i = 0; i < preFolded.size(); i++)
        trainCombinations.add(ClassificationDataSet.comineAllBut(preFolded, i));
    } else {
      preFolded = null;
      trainCombinations = null;
    }

    boolean considerWarm = useWarmStarts && baseClassifier instanceof WarmClassifier;

    /**
     * make sure we don't do a warm start if its only supported when trained on the same data but we
     * aren't reuse-ing the same CV splits So we get the truth table
     *
     * <p>a | b | (a&&b)||¬a T | T | T T | F | F F | T | T F | F | T
     *
     * <p>where a = warmFromSameDataOnly and b = reuseSameSplit So we can instead use ¬ a || b
     */
    if (considerWarm
        && (!((WarmClassifier) baseClassifier).warmFromSameDataOnly() || reuseSameCVFolds)) {
      /* we want all of the first parameter (which is the warm paramter,
       * taken care of for us) values done in a group. So We can get this
       * by just dividing up the larger list into sub lists, each sub list
       * is adjacent in the original and is the number of parameter values
       * we wanted to try
       */

      int stepSize = searchValues.get(0).size();
      int totalJobs = paramsToEval.size() / stepSize;
      latch = new CountDownLatch(totalJobs);
      for (int startPos = 0; startPos < paramsToEval.size(); startPos += stepSize) {
        final List<Classifier> subSet = paramsToEval.subList(startPos, startPos + stepSize);
        modelService.submit(
            new Runnable() {

              @Override
              public void run() {
                Classifier[] prevModels = null;

                for (Classifier c : subSet) {
                  ClassificationModelEvaluation cme =
                      trainModelsInParallel
                          ? new ClassificationModelEvaluation(c, dataSet)
                          : new ClassificationModelEvaluation(c, dataSet, threadPool);
                  cme.setKeepModels(true); // we need these to do warm starts!
                  cme.setWarmModels(prevModels);
                  cme.addScorer(classificationTargetScore.clone());
                  if (reuseSameCVFolds) cme.evaluateCrossValidation(preFolded, trainCombinations);
                  else cme.evaluateCrossValidation(folds);
                  prevModels = cme.getKeptModels();
                  synchronized (bestModels) {
                    bestModels.add(cme);
                  }
                }
                latch.countDown();
              }
            });
      }
    } else // regular CV, train a new model from scratch at every step
    {
      latch = new CountDownLatch(paramsToEval.size());

      for (final Classifier toTrain : paramsToEval) {

        modelService.submit(
            new Runnable() {

              @Override
              public void run() {
                ClassificationModelEvaluation cme =
                    trainModelsInParallel
                        ? new ClassificationModelEvaluation(toTrain, dataSet)
                        : new ClassificationModelEvaluation(toTrain, dataSet, threadPool);
                cme.addScorer(classificationTargetScore.clone());
                if (reuseSameCVFolds) cme.evaluateCrossValidation(preFolded, trainCombinations);
                else cme.evaluateCrossValidation(folds);
                synchronized (bestModels) {
                  bestModels.add(cme);
                }

                latch.countDown();
              }
            });
      }
    }

    // now wait for everyone to finish
    try {
      latch.await();
      // Now we know the best classifier, we need to train one on the whole data set.
      Classifier bestClassifier =
          bestModels.peek().getClassifier(); // Just re-train it on the whole set
      if (trainFinalModel) {
        // try and warm start the final model if we can
        if (useWarmStarts
            && bestClassifier instanceof WarmClassifier
            && !((WarmClassifier) bestClassifier)
                .warmFromSameDataOnly()) // last line here needed to make sure we can do this warm
        // train
        {
          WarmClassifier wc = (WarmClassifier) bestClassifier;
          if (threadPool instanceof FakeExecutor) wc.trainC(dataSet, wc.clone());
          else wc.trainC(dataSet, wc.clone(), threadPool);
        } else {
          if (threadPool instanceof FakeExecutor) bestClassifier.trainC(dataSet);
          else bestClassifier.trainC(dataSet, threadPool);
        }
      }
      trainedClassifier = bestClassifier;

    } catch (InterruptedException ex) {
      Logger.getLogger(GridSearch.class.getName()).log(Level.SEVERE, null, ex);
    }
  }