Exemplo n.º 1
0
  private void testZeroTrainingError(
      GradientOracle<DynamicFactorGraph, Example<DynamicAssignment, DynamicAssignment>> oracle,
      boolean useLbfgs) {

    MapReduceConfiguration.setMapReduceExecutor(new LocalMapReduceExecutor(1, 1));

    SufficientStatistics parameters = null;
    if (useLbfgs) {
      RetryingLbfgs trainer = new RetryingLbfgs(50, 10, 0.1, new DefaultLogFunction(1, false));
      parameters = trainer.train(oracle, sequenceModel.getNewSufficientStatistics(), trainingData);
    } else {
      StochasticGradientTrainer trainer =
          StochasticGradientTrainer.createWithL2Regularization(
              100, 1, 1.0, true, false, 0.1, new DefaultLogFunction());
      parameters = trainer.train(oracle, sequenceModel.getNewSufficientStatistics(), trainingData);
    }
    DynamicFactorGraph trainedModel = sequenceModel.getModelFromParameters(parameters);

    // Should be able to get 0 training error.
    FactorGraphPredictor predictor =
        new FactorGraphPredictor(
            trainedModel,
            VariableNamePattern.fromTemplateVariables(y, VariableNumMap.EMPTY),
            new JunctionTree());
    for (Example<DynamicAssignment, DynamicAssignment> trainingDatum : trainingData) {
      DynamicAssignment prediction =
          predictor.getBestPrediction(trainingDatum.getInput()).getBestPrediction();
      assertEquals(trainingDatum.getOutput(), prediction);
    }
  }
Exemplo n.º 2
0
 private Example<DynamicAssignment, DynamicAssignment> getListVarAssignment(
     List<Assignment> xs, List<Assignment> ys) {
   Preconditions.checkArgument(xs.size() == ys.size());
   DynamicAssignment input = DynamicAssignment.createPlateAssignment("plateVar", xs);
   DynamicAssignment output = DynamicAssignment.createPlateAssignment("plateVar", ys);
   return Example.create(input, output);
 }