Ejemplo n.º 1
0
  private void testZeroTrainingError(
      GradientOracle<DynamicFactorGraph, Example<DynamicAssignment, DynamicAssignment>> oracle,
      boolean useLbfgs) {

    MapReduceConfiguration.setMapReduceExecutor(new LocalMapReduceExecutor(1, 1));

    SufficientStatistics parameters = null;
    if (useLbfgs) {
      RetryingLbfgs trainer = new RetryingLbfgs(50, 10, 0.1, new DefaultLogFunction(1, false));
      parameters = trainer.train(oracle, sequenceModel.getNewSufficientStatistics(), trainingData);
    } else {
      StochasticGradientTrainer trainer =
          StochasticGradientTrainer.createWithL2Regularization(
              100, 1, 1.0, true, false, 0.1, new DefaultLogFunction());
      parameters = trainer.train(oracle, sequenceModel.getNewSufficientStatistics(), trainingData);
    }
    DynamicFactorGraph trainedModel = sequenceModel.getModelFromParameters(parameters);

    // Should be able to get 0 training error.
    FactorGraphPredictor predictor =
        new FactorGraphPredictor(
            trainedModel,
            VariableNamePattern.fromTemplateVariables(y, VariableNumMap.EMPTY),
            new JunctionTree());
    for (Example<DynamicAssignment, DynamicAssignment> trainingDatum : trainingData) {
      DynamicAssignment prediction =
          predictor.getBestPrediction(trainingDatum.getInput()).getBestPrediction();
      assertEquals(trainingDatum.getOutput(), prediction);
    }
  }
  @Override
  public Predictor<I, O> train(Iterable<Example<I, O>> trainingData) {
    ParametricFactorGraph model = constructGraphicalModel(trainingData);

    Converter<I, DynamicAssignment> inputConverter = getInputConverter(model);
    Converter<O, DynamicAssignment> outputConverter = getOutputConverter(model);
    Converter<Example<I, O>, Example<DynamicAssignment, DynamicAssignment>> exampleConverter =
        Example.converter(inputConverter, outputConverter);
    List<Example<DynamicAssignment, DynamicAssignment>> trainingDataAssignments =
        Lists.newArrayList(Iterables.transform(trainingData, exampleConverter));

    SufficientStatistics initialParameters = model.getNewSufficientStatistics();
    initialParameters.perturb(parameterPerturbation);
    System.out.println(initialParameters);
    SufficientStatistics finalParameters =
        trainer.train(model, initialParameters, trainingDataAssignments);

    Predictor<DynamicAssignment, DynamicAssignment> assignmentPredictor =
        new FactorGraphPredictor(
            model.getModelFromParameters(finalParameters),
            getOutputVariables(model),
            new JunctionTree());

    return ForwardingPredictor.create(assignmentPredictor, inputConverter, outputConverter);
  }