/** * If this tree contains max-marginals, recover the best parse subtree for a given symbol with the * specified span. */ public CfgParseTree getBestParseTreeWithSpan(Object root, int spanStart, int spanEnd) { Preconditions.checkState(!sumProduct); Assignment rootAssignment = parentVar.outcomeArrayToAssignment(root); int rootNonterminalNum = parentVar.assignmentToIntArray(rootAssignment)[0]; double prob = insideChart[spanStart][spanEnd][rootNonterminalNum] * outsideChart[spanStart][spanEnd][rootNonterminalNum]; if (prob == 0.0) { return null; } int splitInd = splitBackpointers[spanStart][spanEnd][rootNonterminalNum]; if (splitInd < 0) { long terminalKey = backpointers[spanStart][spanEnd][rootNonterminalNum]; int positiveSplitInd = (-1 * splitInd) - 1; int terminalSpanStart = positiveSplitInd / numTerminals; int terminalSpanEnd = positiveSplitInd % numTerminals; // This is a really sucky way to transform the keys back to objects. VariableNumMap vars = parentVar.union(ruleTypeVar); int[] dimKey = TableFactor.zero(vars).getWeights().keyNumToDimKey(terminalKey); Assignment a = vars.intArrayToAssignment(dimKey); Object ruleType = a.getValue(ruleTypeVar.getOnlyVariableNum()); List<Object> terminalList = Lists.newArrayList(); terminalList.addAll(terminals.subList(terminalSpanStart, terminalSpanEnd + 1)); return new CfgParseTree(root, ruleType, terminalList, prob, spanStart, spanEnd); } else { long binaryRuleKey = backpointers[spanStart][spanEnd][rootNonterminalNum]; int[] binaryRuleComponents = binaryRuleDistribution.coerceToDiscrete().getWeights().keyNumToDimKey(binaryRuleKey); Assignment best = binaryRuleDistribution.getVars().intArrayToAssignment(binaryRuleComponents); Object leftRoot = best.getValue(leftVar.getOnlyVariableNum()); Object rightRoot = best.getValue(rightVar.getOnlyVariableNum()); Object ruleType = best.getValue(ruleTypeVar.getOnlyVariableNum()); Preconditions.checkArgument( spanStart + splitInd != spanEnd, "CFG parse decoding error: %s %s %s", spanStart, spanEnd, splitInd); CfgParseTree leftTree = getBestParseTreeWithSpan(leftRoot, spanStart, spanStart + splitInd); CfgParseTree rightTree = getBestParseTreeWithSpan(rightRoot, spanStart + splitInd + 1, spanEnd); Preconditions.checkState(leftTree != null); Preconditions.checkState(rightTree != null); return new CfgParseTree(root, ruleType, leftTree, rightTree, prob); } }
@Override public double getUnnormalizedLogProbability(Assignment assignment) { Preconditions.checkArgument(assignment.containsAll(getVars().getVariableNumsArray())); Tensor inputFeatureVector = (Tensor) assignment.getValue(getInputVariable().getOnlyVariableNum()); if (conditionalVars.size() == 0) { // No normalization for any conditioned-on variables. This case // allows a more efficient implementation than the default // in ClassifierFactor. VariableNumMap outputVars = getOutputVariables(); Tensor outputTensor = SparseTensor.singleElement( outputVars.getVariableNumsArray(), outputVars.getVariableSizes(), outputVars.assignmentToIntArray(assignment), 1.0); Tensor featureIndicator = outputTensor.outerProduct(inputFeatureVector); return logWeights.innerProduct(featureIndicator).getByDimKey(); } else { // Default to looking up the answer in the output log probabilities int[] outputIndexes = getOutputVariables().assignmentToIntArray(assignment); Tensor logProbs = getOutputLogProbTensor(inputFeatureVector); return logProbs.getByDimKey(outputIndexes); } }