Example #1
0
  /**
   * If this tree contains max-marginals, recover the best parse subtree for a given symbol with the
   * specified span.
   */
  public CfgParseTree getBestParseTreeWithSpan(Object root, int spanStart, int spanEnd) {
    Preconditions.checkState(!sumProduct);

    Assignment rootAssignment = parentVar.outcomeArrayToAssignment(root);
    int rootNonterminalNum = parentVar.assignmentToIntArray(rootAssignment)[0];
    double prob =
        insideChart[spanStart][spanEnd][rootNonterminalNum]
            * outsideChart[spanStart][spanEnd][rootNonterminalNum];

    if (prob == 0.0) {
      return null;
    }

    int splitInd = splitBackpointers[spanStart][spanEnd][rootNonterminalNum];
    if (splitInd < 0) {
      long terminalKey = backpointers[spanStart][spanEnd][rootNonterminalNum];

      int positiveSplitInd = (-1 * splitInd) - 1;
      int terminalSpanStart = positiveSplitInd / numTerminals;
      int terminalSpanEnd = positiveSplitInd % numTerminals;

      // This is a really sucky way to transform the keys back to objects.
      VariableNumMap vars = parentVar.union(ruleTypeVar);
      int[] dimKey = TableFactor.zero(vars).getWeights().keyNumToDimKey(terminalKey);
      Assignment a = vars.intArrayToAssignment(dimKey);
      Object ruleType = a.getValue(ruleTypeVar.getOnlyVariableNum());

      List<Object> terminalList = Lists.newArrayList();
      terminalList.addAll(terminals.subList(terminalSpanStart, terminalSpanEnd + 1));
      return new CfgParseTree(root, ruleType, terminalList, prob, spanStart, spanEnd);
    } else {
      long binaryRuleKey = backpointers[spanStart][spanEnd][rootNonterminalNum];
      int[] binaryRuleComponents =
          binaryRuleDistribution.coerceToDiscrete().getWeights().keyNumToDimKey(binaryRuleKey);

      Assignment best = binaryRuleDistribution.getVars().intArrayToAssignment(binaryRuleComponents);
      Object leftRoot = best.getValue(leftVar.getOnlyVariableNum());
      Object rightRoot = best.getValue(rightVar.getOnlyVariableNum());
      Object ruleType = best.getValue(ruleTypeVar.getOnlyVariableNum());

      Preconditions.checkArgument(
          spanStart + splitInd != spanEnd,
          "CFG parse decoding error: %s %s %s",
          spanStart,
          spanEnd,
          splitInd);
      CfgParseTree leftTree = getBestParseTreeWithSpan(leftRoot, spanStart, spanStart + splitInd);
      CfgParseTree rightTree =
          getBestParseTreeWithSpan(rightRoot, spanStart + splitInd + 1, spanEnd);

      Preconditions.checkState(leftTree != null);
      Preconditions.checkState(rightTree != null);

      return new CfgParseTree(root, ruleType, leftTree, rightTree, prob);
    }
  }
Example #2
0
 /**
  * Update an entry of the outside chart with a new production. Depending on the type of the chart,
  * this performs either a sum or max over productions of the same type in the same entry.
  */
 public void updateOutsideEntry(
     int spanStart, int spanEnd, double[] values, Factor factor, VariableNumMap var) {
   if (sumProduct) {
     updateEntrySumProduct(
         outsideChart[spanStart][spanEnd],
         values,
         factor.coerceToDiscrete().getWeights(),
         var.getOnlyVariableNum());
   } else {
     updateEntryMaxProduct(
         outsideChart[spanStart][spanEnd],
         values,
         factor.coerceToDiscrete().getWeights(),
         var.getOnlyVariableNum());
   }
 }
Example #3
0
  /**
   * Update an entry of the inside chart with a new production. Depending on the type of the chart,
   * this performs either a sum or max over productions of the same type in the same entry.
   */
  public void updateInsideEntry(
      int spanStart, int spanEnd, int splitInd, double[] values, Factor binaryRuleProbabilities) {
    Preconditions.checkArgument(binaryRuleProbabilities.getVars().size() == 4);

    if (sumProduct) {
      updateEntrySumProduct(
          insideChart[spanStart][spanEnd],
          values,
          binaryRuleProbabilities.coerceToDiscrete().getWeights(),
          parentVar.getOnlyVariableNum());
    } else {
      Tensor weights = binaryRuleProbabilities.coerceToDiscrete().getWeights();

      updateInsideEntryMaxProduct(spanStart, spanEnd, values, weights, splitInd);
    }
  }
Example #4
0
  /**
   * Update an entry of the inside chart with a new production. Depending on the type of the chart,
   * this performs either a sum or max over productions of the same type in the same entry.
   */
  public void updateInsideEntryTerminal(
      int spanStart, int spanEnd, int terminalSpanStart, int terminalSpanEnd, Factor factor) {
    Preconditions.checkArgument(factor.getVars().size() == 2);
    // The first entry initializes the chart at this span.
    if (sumProduct) {
      updateEntrySumProduct(
          insideChart[spanStart][spanEnd],
          factor.coerceToDiscrete().getWeights().getValues(),
          factor.coerceToDiscrete().getWeights(),
          parentVar.getOnlyVariableNum());
    } else {
      Tensor weights = factor.coerceToDiscrete().getWeights();

      // Negative split indexes are used to represent terminal rules.
      int splitInd = -1 * (terminalSpanStart * numTerminals + terminalSpanEnd + 1);
      updateInsideEntryMaxProduct(spanStart, spanEnd, weights.getValues(), weights, splitInd);
    }
  }
  private void runTestAllFeatures(DenseIndicatorLogLinearFactor parametricFactor) {
    SufficientStatistics parameters = parametricFactor.getNewSufficientStatistics();
    Factor initial = parametricFactor.getModelFromParameters(parameters);
    assertEquals(0.0, initial.getUnnormalizedLogProbability("A", "T"), TOLERANCE);
    assertEquals(0.0, initial.getUnnormalizedLogProbability("B", "T"), TOLERANCE);
    assertEquals(0.0, initial.getUnnormalizedLogProbability("C", "F"), TOLERANCE);

    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("B", "F"), 2.0);
    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("B", "F"), -3.0);
    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("A", "T"), 1.0);
    parametricFactor.incrementSufficientStatisticsFromAssignment(
        parameters, parameters, vars.outcomeArrayToAssignment("C", "T"), 2.0);

    Factor factor = parametricFactor.getModelFromParameters(parameters);
    assertEquals(-1.0, factor.getUnnormalizedLogProbability("B", "F"), TOLERANCE);
    assertEquals(1.0, factor.getUnnormalizedLogProbability("A", "T"), TOLERANCE);
    assertEquals(2.0, factor.getUnnormalizedLogProbability("C", "T"), TOLERANCE);

    TableFactorBuilder incrementBuilder =
        new TableFactorBuilder(vars, SparseTensorBuilder.getFactory());
    incrementBuilder.setWeight(4.0, "A", "F");
    incrementBuilder.setWeight(6.0, "C", "F");
    Factor increment = incrementBuilder.build();
    parametricFactor.incrementSufficientStatisticsFromMarginal(
        parameters, parameters, increment, Assignment.EMPTY, 3.0, 2.0);

    factor = parametricFactor.getModelFromParameters(parameters);
    assertEquals(-1.0, factor.getUnnormalizedLogProbability("B", "F"), TOLERANCE);
    assertEquals(1.0, factor.getUnnormalizedLogProbability("A", "T"), TOLERANCE);
    assertEquals(2.0, factor.getUnnormalizedLogProbability("C", "T"), TOLERANCE);
    assertEquals(6.0, factor.getUnnormalizedLogProbability("A", "F"), TOLERANCE);
    assertEquals(9.0, factor.getUnnormalizedLogProbability("C", "F"), TOLERANCE);

    TableFactor pointDist =
        TableFactor.logPointDistribution(truthVar, truthVar.outcomeArrayToAssignment("T"));
    parametricFactor.incrementSufficientStatisticsFromMarginal(
        parameters, parameters, pointDist, alphabetVar.outcomeArrayToAssignment("B"), 3, 2.0);
    factor = parametricFactor.getModelFromParameters(parameters);
    assertEquals(1.5, factor.getUnnormalizedLogProbability("B", "T"), TOLERANCE);
  }
Example #6
0
  /**
   * Create a parse chart with the specified number of terminal symbols.
   *
   * <p>If "sumProduct" is true, then updates for the same symbol add (i.e., the ParseChart computes
   * marginals). Otherwise, updates use the maximum probability, meaning ParseChart computes
   * max-marginals.
   */
  public CfgParseChart(
      List<?> terminals,
      VariableNumMap parent,
      VariableNumMap left,
      VariableNumMap right,
      VariableNumMap terminal,
      VariableNumMap ruleTypeVar,
      Factor binaryRuleDistribution,
      boolean sumProduct) {
    this.terminals = terminals;
    this.parentVar = parent;
    this.leftVar = left;
    this.rightVar = right;
    this.ruleTypeVar = ruleTypeVar;

    this.terminalVar = terminal;
    this.sumProduct = sumProduct;

    this.numTerminals = terminals.size();
    this.numNonterminals = parentVar.getDiscreteVariables().get(0).numValues();

    insideChart = new double[numTerminals][numTerminals][numNonterminals];
    outsideChart = new double[numTerminals][numTerminals][numNonterminals];
    this.binaryRuleDistribution = binaryRuleDistribution;
    binaryRuleExpectations =
        new double[binaryRuleDistribution.coerceToDiscrete().getWeights().getValues().length];
    terminalRuleExpectations =
        TableFactor.zero(VariableNumMap.unionAll(parentVar, terminalVar, ruleTypeVar));

    insideCalculated = false;
    outsideCalculated = false;
    partitionFunction = 0.0;

    if (!sumProduct) {
      backpointers = new long[numTerminals][numTerminals][numNonterminals];
      splitBackpointers = new int[numTerminals][numTerminals][numNonterminals];
    } else {
      backpointers = null;
      splitBackpointers = null;
    }
  }
  @Override
  public void incrementSufficientStatisticsFromMarginal(
      SufficientStatistics gradient,
      SufficientStatistics currentParameters,
      Factor marginal,
      Assignment conditionalAssignment,
      double count,
      double partitionFunction) {
    if (conditionalAssignment.containsAll(getVars().getVariableNumsArray())) {
      // Short-circuit the slow computation below if possible.
      double multiplier = marginal.getTotalUnnormalizedProbability() * count / partitionFunction;
      incrementSufficientStatisticsFromAssignment(
          gradient, currentParameters, conditionalAssignment, multiplier);
    } else {
      VariableNumMap conditionedVars =
          initialWeights.getVars().intersection(conditionalAssignment.getVariableNumsArray());

      TableFactor productFactor =
          (TableFactor)
              initialWeights
                  .product(
                      TableFactor.pointDistribution(
                          conditionedVars, conditionalAssignment.intersection(conditionedVars)))
                  .product(marginal);

      Tensor productFactorWeights = productFactor.getWeights();
      double[] productFactorValues = productFactorWeights.getValues();
      int tensorSize = productFactorWeights.size();
      double multiplier = count / partitionFunction;
      TensorSufficientStatistics tensorGradient = (TensorSufficientStatistics) gradient;
      for (int i = 0; i < tensorSize; i++) {
        int builderIndex = (int) productFactorWeights.indexToKeyNum(i);
        tensorGradient.incrementFeatureByIndex(productFactorValues[i] * multiplier, builderIndex);
      }
    }
  }
Example #8
0
 /** Update the expected number of times that a terminal production rule is used in the parse. */
 public void updateTerminalRuleExpectations(Factor terminalRuleMarginal) {
   terminalRuleExpectations = terminalRuleExpectations.add(terminalRuleMarginal);
 }
Example #9
0
  /** Compute the expected *unnormalized* probability of every rule. */
  public Factor getBinaryRuleExpectations() {
    Tensor binaryRuleWeights = binaryRuleDistribution.coerceToDiscrete().getWeights();
    SparseTensor tensor = SparseTensor.copyRemovingZeros(binaryRuleWeights, binaryRuleExpectations);

    return new TableFactor(binaryRuleDistribution.getVars(), tensor);
  }
Example #10
0
 /**
  * Gets the best parse tree spanning the entire sentence.
  *
  * @return
  */
 public CfgParseTree getBestParseTree() {
   Factor rootMarginal = getMarginalEntries(0, chartSize() - 1);
   Assignment bestAssignment = rootMarginal.getMostLikelyAssignments(1).get(0);
   return getBestParseTree(bestAssignment.getOnlyValue());
 }