コード例 #1
0
  private void runTestAllFeatures(DenseIndicatorLogLinearFactor parametricFactor) {
    SufficientStatistics parameters = parametricFactor.getNewSufficientStatistics();
    Factor initial = parametricFactor.getModelFromParameters(parameters);
    assertEquals(0.0, initial.getUnnormalizedLogProbability("A", "T"), TOLERANCE);
    assertEquals(0.0, initial.getUnnormalizedLogProbability("B", "T"), TOLERANCE);
    assertEquals(0.0, initial.getUnnormalizedLogProbability("C", "F"), TOLERANCE);

    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("B", "F"), 2.0);
    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("B", "F"), -3.0);
    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("A", "T"), 1.0);
    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("C", "T"), 2.0);

    Factor factor = parametricFactor.getModelFromParameters(parameters);
    assertEquals(-1.0, factor.getUnnormalizedLogProbability("B", "F"), TOLERANCE);
    assertEquals(1.0, factor.getUnnormalizedLogProbability("A", "T"), TOLERANCE);
    assertEquals(2.0, factor.getUnnormalizedLogProbability("C", "T"), TOLERANCE);

    TableFactorBuilder incrementBuilder =
        new TableFactorBuilder(vars, SparseTensorBuilder.getFactory());
    incrementBuilder.setWeight(4.0, "A", "F");
    incrementBuilder.setWeight(6.0, "C", "F");
    Factor increment = incrementBuilder.build();
    parametricFactor.incrementSufficientStatisticsFromMarginal(
        parameters, parameters, increment, Assignment.EMPTY, 3.0, 2.0);

    factor = parametricFactor.getModelFromParameters(parameters);
    assertEquals(-1.0, factor.getUnnormalizedLogProbability("B", "F"), TOLERANCE);
    assertEquals(1.0, factor.getUnnormalizedLogProbability("A", "T"), TOLERANCE);
    assertEquals(2.0, factor.getUnnormalizedLogProbability("C", "T"), TOLERANCE);
    assertEquals(6.0, factor.getUnnormalizedLogProbability("A", "F"), TOLERANCE);
    assertEquals(9.0, factor.getUnnormalizedLogProbability("C", "F"), TOLERANCE);

    TableFactor pointDist =
        TableFactor.logPointDistribution(truthVar, truthVar.outcomeArrayToAssignment("T"));
    parametricFactor.incrementSufficientStatisticsFromMarginal(
        parameters, parameters, pointDist, alphabetVar.outcomeArrayToAssignment("B"), 3, 2.0);
    factor = parametricFactor.getModelFromParameters(parameters);
    assertEquals(1.5, factor.getUnnormalizedLogProbability("B", "T"), TOLERANCE);
  }