@Test
  public void should_classify_instance_using_classifier() {
    // Given
    Double expectedPrediction1 = 2.5;
    Double expectedPrediction2 = 12.8;

    String regressorName = "TestLearner";
    RegressionQuery query = new RegressionQuery(regressorName);

    double[] features1 = new double[10];
    double[] features2 = new double[10];
    TridentTuple tuple1 = createMockedInstanceTuple(features1);
    TridentTuple tuple2 = createMockedInstanceTuple(features2);
    List<TridentTuple> tuples = Arrays.asList(tuple1, tuple2);

    Regressor expectedRegressor = mock(Regressor.class);
    given(expectedRegressor.predict(same(features1))).willReturn(expectedPrediction1);
    given(expectedRegressor.predict(same(features2))).willReturn(expectedPrediction2);

    List<List<Object>> expectedKeys = asList(asList((Object) regressorName));
    MapState<Regressor> state = mock(MapState.class);
    given(state.multiGet(expectedKeys)).willReturn(Arrays.asList(expectedRegressor));

    // When
    List<Double> actualPredictions = query.batchRetrieve(state, tuples);

    assertEquals(2, actualPredictions.size());
    assertEquals(expectedPrediction1, actualPredictions.get(0));
    assertEquals(expectedPrediction2, actualPredictions.get(1));
  }
 @Override
 public void train(Dataset dataset) {
   transformation.train(dataset);
   regressor.train(transformation.transform(dataset));
 }
 @Override
 public double regress(FeatureVector features) {
   return regressor.regress(transformation.transform(features));
 }
示例#4
0
  @Override
  public void train(final RegressionDataSet dataSet, final ExecutorService threadPool) {
    final PriorityQueue<RegressionModelEvaluation> bestModels =
        new PriorityQueue<RegressionModelEvaluation>(
            folds,
            new Comparator<RegressionModelEvaluation>() {
              @Override
              public int compare(RegressionModelEvaluation t, RegressionModelEvaluation t1) {
                double v0 = t.getScoreStats(regressionTargetScore).getMean();
                double v1 = t1.getScoreStats(regressionTargetScore).getMean();
                int order = regressionTargetScore.lowerIsBetter() ? 1 : -1;
                return order * Double.compare(v0, v1);
              }
            });

    /**
     * Use this to keep track of which parameter we are altering. Index correspondence to the
     * parameter, and its value corresponds to which value has been used. Increment and carry counts
     * to iterate over all possible combinations.
     */
    int[] setTo = new int[searchParams.size()];
    /**
     * Each model is set to have different combination of parameters. We then train each model to
     * determine the best one.
     */
    final List<Regressor> paramsToEval = new ArrayList<Regressor>();

    while (true) {
      setParameters(setTo);

      paramsToEval.add(baseRegressor.clone());

      if (incrementCombination(setTo)) break;
    }
    /*
     * This is the Executor used for training the models in parallel. If we
     * are not supposed to do that, it will be an executor that executes
     * them sequentually.
     */
    final ExecutorService modelService;
    if (trainModelsInParallel) modelService = threadPool;
    else modelService = new FakeExecutor();

    final CountDownLatch latch; // used for stopping in both cases

    // if we are doing our CV splits ahead of time, get them done now
    final List<RegressionDataSet> preFolded;

    /** Pre-combine our training combinations so that any caching can be re-used */
    final List<RegressionDataSet> trainCombinations;

    if (reuseSameCVFolds) {
      preFolded = dataSet.cvSet(folds);
      trainCombinations = new ArrayList<RegressionDataSet>(preFolded.size());
      for (int i = 0; i < preFolded.size(); i++)
        trainCombinations.add(RegressionDataSet.comineAllBut(preFolded, i));
    } else {
      preFolded = null;
      trainCombinations = null;
    }

    boolean considerWarm = useWarmStarts && baseRegressor instanceof WarmRegressor;
    /**
     * make sure we don't do a warm start if its only supported when trained on the same data but we
     * aren't reuse-ing the same CV splits So we get the truth table
     *
     * <p>a | b | (a&&b)||¬a T | T | T T | F | F F | T | T F | F | T
     *
     * <p>where a = warmFromSameDataOnly and b = reuseSameSplit So we can instead use ¬ a || b
     */
    if (considerWarm
        && (!((WarmRegressor) baseRegressor).warmFromSameDataOnly() || reuseSameCVFolds)) {
      /* we want all of the first parameter (which is the warm paramter,
       * taken care of for us) values done in a group. So We can get this
       * by just dividing up the larger list into sub lists, each sub list
       * is adjacent in the original and is the number of parameter values
       * we wanted to try
       */

      int stepSize = searchValues.get(0).size();
      int totalJobs = paramsToEval.size() / stepSize;
      latch = new CountDownLatch(totalJobs);
      for (int startPos = 0; startPos < paramsToEval.size(); startPos += stepSize) {
        final List<Regressor> subSet = paramsToEval.subList(startPos, startPos + stepSize);
        modelService.submit(
            new Runnable() {

              @Override
              public void run() {
                Regressor[] prevModels = null;
                for (Regressor r : subSet) {
                  RegressionModelEvaluation rme =
                      trainModelsInParallel
                          ? new RegressionModelEvaluation(r, dataSet)
                          : new RegressionModelEvaluation(r, dataSet, threadPool);
                  rme.setKeepModels(true); // we need these to do warm starts!
                  rme.setWarmModels(prevModels);
                  rme.addScorer(regressionTargetScore.clone());
                  if (reuseSameCVFolds) rme.evaluateCrossValidation(preFolded, trainCombinations);
                  else rme.evaluateCrossValidation(folds);
                  prevModels = rme.getKeptModels();
                  synchronized (bestModels) {
                    bestModels.add(rme);
                  }
                }
                latch.countDown();
              }
            });
      }
    } else // regular CV, train a new model from scratch at every step
    {
      latch = new CountDownLatch(paramsToEval.size());

      for (final Regressor toTrain : paramsToEval) {

        modelService.submit(
            new Runnable() {

              @Override
              public void run() {
                RegressionModelEvaluation rme =
                    trainModelsInParallel
                        ? new RegressionModelEvaluation(toTrain, dataSet)
                        : new RegressionModelEvaluation(toTrain, dataSet, threadPool);
                rme.addScorer(regressionTargetScore.clone());
                if (reuseSameCVFolds) rme.evaluateCrossValidation(preFolded, trainCombinations);
                else rme.evaluateCrossValidation(folds);
                synchronized (bestModels) {
                  bestModels.add(rme);
                }

                latch.countDown();
              }
            });
      }
    }

    try {
      latch.await();
      // Now we know the best classifier, we need to train one on the whole data set.
      Regressor bestRegressor =
          bestModels.peek().getRegressor(); // Just re-train it on the whole set
      if (trainFinalModel) {
        // try and warm start the final model if we can
        if (useWarmStarts
            && bestRegressor instanceof WarmRegressor
            && !((WarmRegressor) bestRegressor)
                .warmFromSameDataOnly()) // last line here needed to make sure we can do this warm
        // train
        {
          WarmRegressor wr = (WarmRegressor) bestRegressor;
          if (threadPool instanceof FakeExecutor) wr.train(dataSet, wr.clone());
          else wr.train(dataSet, wr.clone(), threadPool);
        } else {
          if (threadPool instanceof FakeExecutor) bestRegressor.train(dataSet);
          else bestRegressor.train(dataSet, threadPool);
        }
      }
      trainedRegressor = bestRegressor;

    } catch (InterruptedException ex) {
      Logger.getLogger(GridSearch.class.getName()).log(Level.SEVERE, null, ex);
    }
  }