/** * If this tree contains max-marginals, recover the best parse subtree for a given symbol with the * specified span. */ public CfgParseTree getBestParseTreeWithSpan(Object root, int spanStart, int spanEnd) { Preconditions.checkState(!sumProduct); Assignment rootAssignment = parentVar.outcomeArrayToAssignment(root); int rootNonterminalNum = parentVar.assignmentToIntArray(rootAssignment)[0]; double prob = insideChart[spanStart][spanEnd][rootNonterminalNum] * outsideChart[spanStart][spanEnd][rootNonterminalNum]; if (prob == 0.0) { return null; } int splitInd = splitBackpointers[spanStart][spanEnd][rootNonterminalNum]; if (splitInd < 0) { long terminalKey = backpointers[spanStart][spanEnd][rootNonterminalNum]; int positiveSplitInd = (-1 * splitInd) - 1; int terminalSpanStart = positiveSplitInd / numTerminals; int terminalSpanEnd = positiveSplitInd % numTerminals; // This is a really sucky way to transform the keys back to objects. VariableNumMap vars = parentVar.union(ruleTypeVar); int[] dimKey = TableFactor.zero(vars).getWeights().keyNumToDimKey(terminalKey); Assignment a = vars.intArrayToAssignment(dimKey); Object ruleType = a.getValue(ruleTypeVar.getOnlyVariableNum()); List<Object> terminalList = Lists.newArrayList(); terminalList.addAll(terminals.subList(terminalSpanStart, terminalSpanEnd + 1)); return new CfgParseTree(root, ruleType, terminalList, prob, spanStart, spanEnd); } else { long binaryRuleKey = backpointers[spanStart][spanEnd][rootNonterminalNum]; int[] binaryRuleComponents = binaryRuleDistribution.coerceToDiscrete().getWeights().keyNumToDimKey(binaryRuleKey); Assignment best = binaryRuleDistribution.getVars().intArrayToAssignment(binaryRuleComponents); Object leftRoot = best.getValue(leftVar.getOnlyVariableNum()); Object rightRoot = best.getValue(rightVar.getOnlyVariableNum()); Object ruleType = best.getValue(ruleTypeVar.getOnlyVariableNum()); Preconditions.checkArgument( spanStart + splitInd != spanEnd, "CFG parse decoding error: %s %s %s", spanStart, spanEnd, splitInd); CfgParseTree leftTree = getBestParseTreeWithSpan(leftRoot, spanStart, spanStart + splitInd); CfgParseTree rightTree = getBestParseTreeWithSpan(rightRoot, spanStart + splitInd + 1, spanEnd); Preconditions.checkState(leftTree != null); Preconditions.checkState(rightTree != null); return new CfgParseTree(root, ruleType, leftTree, rightTree, prob); } }
/** * Update an entry of the outside chart with a new production. Depending on the type of the chart, * this performs either a sum or max over productions of the same type in the same entry. */ public void updateOutsideEntry( int spanStart, int spanEnd, double[] values, Factor factor, VariableNumMap var) { if (sumProduct) { updateEntrySumProduct( outsideChart[spanStart][spanEnd], values, factor.coerceToDiscrete().getWeights(), var.getOnlyVariableNum()); } else { updateEntryMaxProduct( outsideChart[spanStart][spanEnd], values, factor.coerceToDiscrete().getWeights(), var.getOnlyVariableNum()); } }
/** * Update an entry of the inside chart with a new production. Depending on the type of the chart, * this performs either a sum or max over productions of the same type in the same entry. */ public void updateInsideEntry( int spanStart, int spanEnd, int splitInd, double[] values, Factor binaryRuleProbabilities) { Preconditions.checkArgument(binaryRuleProbabilities.getVars().size() == 4); if (sumProduct) { updateEntrySumProduct( insideChart[spanStart][spanEnd], values, binaryRuleProbabilities.coerceToDiscrete().getWeights(), parentVar.getOnlyVariableNum()); } else { Tensor weights = binaryRuleProbabilities.coerceToDiscrete().getWeights(); updateInsideEntryMaxProduct(spanStart, spanEnd, values, weights, splitInd); } }
/** * Update an entry of the inside chart with a new production. Depending on the type of the chart, * this performs either a sum or max over productions of the same type in the same entry. */ public void updateInsideEntryTerminal( int spanStart, int spanEnd, int terminalSpanStart, int terminalSpanEnd, Factor factor) { Preconditions.checkArgument(factor.getVars().size() == 2); // The first entry initializes the chart at this span. if (sumProduct) { updateEntrySumProduct( insideChart[spanStart][spanEnd], factor.coerceToDiscrete().getWeights().getValues(), factor.coerceToDiscrete().getWeights(), parentVar.getOnlyVariableNum()); } else { Tensor weights = factor.coerceToDiscrete().getWeights(); // Negative split indexes are used to represent terminal rules. int splitInd = -1 * (terminalSpanStart * numTerminals + terminalSpanEnd + 1); updateInsideEntryMaxProduct(spanStart, spanEnd, weights.getValues(), weights, splitInd); } }
private void runTestAllFeatures(DenseIndicatorLogLinearFactor parametricFactor) { SufficientStatistics parameters = parametricFactor.getNewSufficientStatistics(); Factor initial = parametricFactor.getModelFromParameters(parameters); assertEquals(0.0, initial.getUnnormalizedLogProbability("A", "T"), TOLERANCE); assertEquals(0.0, initial.getUnnormalizedLogProbability("B", "T"), TOLERANCE); assertEquals(0.0, initial.getUnnormalizedLogProbability("C", "F"), TOLERANCE); parametricFactor.incrementSufficientStatisticsFromAssignment( parameters, parameters, vars.outcomeArrayToAssignment("B", "F"), 2.0); parametricFactor.incrementSufficientStatisticsFromAssignment( parameters, parameters, vars.outcomeArrayToAssignment("B", "F"), -3.0); parametricFactor.incrementSufficientStatisticsFromAssignment( parameters, parameters, vars.outcomeArrayToAssignment("A", "T"), 1.0); parametricFactor.incrementSufficientStatisticsFromAssignment( parameters, parameters, vars.outcomeArrayToAssignment("C", "T"), 2.0); Factor factor = parametricFactor.getModelFromParameters(parameters); assertEquals(-1.0, factor.getUnnormalizedLogProbability("B", "F"), TOLERANCE); assertEquals(1.0, factor.getUnnormalizedLogProbability("A", "T"), TOLERANCE); assertEquals(2.0, factor.getUnnormalizedLogProbability("C", "T"), TOLERANCE); TableFactorBuilder incrementBuilder = new TableFactorBuilder(vars, SparseTensorBuilder.getFactory()); incrementBuilder.setWeight(4.0, "A", "F"); incrementBuilder.setWeight(6.0, "C", "F"); Factor increment = incrementBuilder.build(); parametricFactor.incrementSufficientStatisticsFromMarginal( parameters, parameters, increment, Assignment.EMPTY, 3.0, 2.0); factor = parametricFactor.getModelFromParameters(parameters); assertEquals(-1.0, factor.getUnnormalizedLogProbability("B", "F"), TOLERANCE); assertEquals(1.0, factor.getUnnormalizedLogProbability("A", "T"), TOLERANCE); assertEquals(2.0, factor.getUnnormalizedLogProbability("C", "T"), TOLERANCE); assertEquals(6.0, factor.getUnnormalizedLogProbability("A", "F"), TOLERANCE); assertEquals(9.0, factor.getUnnormalizedLogProbability("C", "F"), TOLERANCE); TableFactor pointDist = TableFactor.logPointDistribution(truthVar, truthVar.outcomeArrayToAssignment("T")); parametricFactor.incrementSufficientStatisticsFromMarginal( parameters, parameters, pointDist, alphabetVar.outcomeArrayToAssignment("B"), 3, 2.0); factor = parametricFactor.getModelFromParameters(parameters); assertEquals(1.5, factor.getUnnormalizedLogProbability("B", "T"), TOLERANCE); }
/** * Create a parse chart with the specified number of terminal symbols. * * <p>If "sumProduct" is true, then updates for the same symbol add (i.e., the ParseChart computes * marginals). Otherwise, updates use the maximum probability, meaning ParseChart computes * max-marginals. */ public CfgParseChart( List<?> terminals, VariableNumMap parent, VariableNumMap left, VariableNumMap right, VariableNumMap terminal, VariableNumMap ruleTypeVar, Factor binaryRuleDistribution, boolean sumProduct) { this.terminals = terminals; this.parentVar = parent; this.leftVar = left; this.rightVar = right; this.ruleTypeVar = ruleTypeVar; this.terminalVar = terminal; this.sumProduct = sumProduct; this.numTerminals = terminals.size(); this.numNonterminals = parentVar.getDiscreteVariables().get(0).numValues(); insideChart = new double[numTerminals][numTerminals][numNonterminals]; outsideChart = new double[numTerminals][numTerminals][numNonterminals]; this.binaryRuleDistribution = binaryRuleDistribution; binaryRuleExpectations = new double[binaryRuleDistribution.coerceToDiscrete().getWeights().getValues().length]; terminalRuleExpectations = TableFactor.zero(VariableNumMap.unionAll(parentVar, terminalVar, ruleTypeVar)); insideCalculated = false; outsideCalculated = false; partitionFunction = 0.0; if (!sumProduct) { backpointers = new long[numTerminals][numTerminals][numNonterminals]; splitBackpointers = new int[numTerminals][numTerminals][numNonterminals]; } else { backpointers = null; splitBackpointers = null; } }
@Override public void incrementSufficientStatisticsFromMarginal( SufficientStatistics gradient, SufficientStatistics currentParameters, Factor marginal, Assignment conditionalAssignment, double count, double partitionFunction) { if (conditionalAssignment.containsAll(getVars().getVariableNumsArray())) { // Short-circuit the slow computation below if possible. double multiplier = marginal.getTotalUnnormalizedProbability() * count / partitionFunction; incrementSufficientStatisticsFromAssignment( gradient, currentParameters, conditionalAssignment, multiplier); } else { VariableNumMap conditionedVars = initialWeights.getVars().intersection(conditionalAssignment.getVariableNumsArray()); TableFactor productFactor = (TableFactor) initialWeights .product( TableFactor.pointDistribution( conditionedVars, conditionalAssignment.intersection(conditionedVars))) .product(marginal); Tensor productFactorWeights = productFactor.getWeights(); double[] productFactorValues = productFactorWeights.getValues(); int tensorSize = productFactorWeights.size(); double multiplier = count / partitionFunction; TensorSufficientStatistics tensorGradient = (TensorSufficientStatistics) gradient; for (int i = 0; i < tensorSize; i++) { int builderIndex = (int) productFactorWeights.indexToKeyNum(i); tensorGradient.incrementFeatureByIndex(productFactorValues[i] * multiplier, builderIndex); } } }
/** Update the expected number of times that a terminal production rule is used in the parse. */ public void updateTerminalRuleExpectations(Factor terminalRuleMarginal) { terminalRuleExpectations = terminalRuleExpectations.add(terminalRuleMarginal); }
/** Compute the expected *unnormalized* probability of every rule. */ public Factor getBinaryRuleExpectations() { Tensor binaryRuleWeights = binaryRuleDistribution.coerceToDiscrete().getWeights(); SparseTensor tensor = SparseTensor.copyRemovingZeros(binaryRuleWeights, binaryRuleExpectations); return new TableFactor(binaryRuleDistribution.getVars(), tensor); }
/** * Gets the best parse tree spanning the entire sentence. * * @return */ public CfgParseTree getBestParseTree() { Factor rootMarginal = getMarginalEntries(0, chartSize() - 1); Assignment bestAssignment = rootMarginal.getMostLikelyAssignments(1).get(0); return getBestParseTree(bestAssignment.getOnlyValue()); }