@Override
  public double getUnnormalizedLogProbability(Assignment assignment) {
    Preconditions.checkArgument(assignment.containsAll(getVars().getVariableNumsArray()));
    Tensor inputFeatureVector =
        (Tensor) assignment.getValue(getInputVariable().getOnlyVariableNum());

    if (conditionalVars.size() == 0) {
      // No normalization for any conditioned-on variables. This case
      // allows a more efficient implementation than the default
      // in ClassifierFactor.
      VariableNumMap outputVars = getOutputVariables();
      Tensor outputTensor =
          SparseTensor.singleElement(
              outputVars.getVariableNumsArray(),
              outputVars.getVariableSizes(),
              outputVars.assignmentToIntArray(assignment),
              1.0);

      Tensor featureIndicator = outputTensor.outerProduct(inputFeatureVector);
      return logWeights.innerProduct(featureIndicator).getByDimKey();
    } else {
      // Default to looking up the answer in the output log probabilities
      int[] outputIndexes = getOutputVariables().assignmentToIntArray(assignment);
      Tensor logProbs = getOutputLogProbTensor(inputFeatureVector);
      return logProbs.getByDimKey(outputIndexes);
    }
  }
  @Override
  public void incrementSufficientStatisticsFromAssignment(
      SufficientStatistics gradient,
      SufficientStatistics currentParameters,
      Assignment assignment,
      double count) {
    Preconditions.checkArgument(assignment.containsAll(getVars().getVariableNumsArray()));
    Assignment subAssignment = assignment.intersection(getVars().getVariableNumsArray());

    long keyNum =
        initialWeights
            .getWeights()
            .dimKeyToKeyNum(initialWeights.getVars().assignmentToIntArray(subAssignment));
    int index = initialWeights.getWeights().keyNumToIndex(keyNum);

    ((TensorSufficientStatistics) gradient).incrementFeatureByIndex(count, index);
  }
  /**
   * If this tree contains max-marginals, recover the best parse subtree for a given symbol with the
   * specified span.
   */
  public CfgParseTree getBestParseTreeWithSpan(Object root, int spanStart, int spanEnd) {
    Preconditions.checkState(!sumProduct);

    Assignment rootAssignment = parentVar.outcomeArrayToAssignment(root);
    int rootNonterminalNum = parentVar.assignmentToIntArray(rootAssignment)[0];
    double prob =
        insideChart[spanStart][spanEnd][rootNonterminalNum]
            * outsideChart[spanStart][spanEnd][rootNonterminalNum];

    if (prob == 0.0) {
      return null;
    }

    int splitInd = splitBackpointers[spanStart][spanEnd][rootNonterminalNum];
    if (splitInd < 0) {
      long terminalKey = backpointers[spanStart][spanEnd][rootNonterminalNum];

      int positiveSplitInd = (-1 * splitInd) - 1;
      int terminalSpanStart = positiveSplitInd / numTerminals;
      int terminalSpanEnd = positiveSplitInd % numTerminals;

      // This is a really sucky way to transform the keys back to objects.
      VariableNumMap vars = parentVar.union(ruleTypeVar);
      int[] dimKey = TableFactor.zero(vars).getWeights().keyNumToDimKey(terminalKey);
      Assignment a = vars.intArrayToAssignment(dimKey);
      Object ruleType = a.getValue(ruleTypeVar.getOnlyVariableNum());

      List<Object> terminalList = Lists.newArrayList();
      terminalList.addAll(terminals.subList(terminalSpanStart, terminalSpanEnd + 1));
      return new CfgParseTree(root, ruleType, terminalList, prob, spanStart, spanEnd);
    } else {
      long binaryRuleKey = backpointers[spanStart][spanEnd][rootNonterminalNum];
      int[] binaryRuleComponents =
          binaryRuleDistribution.coerceToDiscrete().getWeights().keyNumToDimKey(binaryRuleKey);

      Assignment best = binaryRuleDistribution.getVars().intArrayToAssignment(binaryRuleComponents);
      Object leftRoot = best.getValue(leftVar.getOnlyVariableNum());
      Object rightRoot = best.getValue(rightVar.getOnlyVariableNum());
      Object ruleType = best.getValue(ruleTypeVar.getOnlyVariableNum());

      Preconditions.checkArgument(
          spanStart + splitInd != spanEnd,
          "CFG parse decoding error: %s %s %s",
          spanStart,
          spanEnd,
          splitInd);
      CfgParseTree leftTree = getBestParseTreeWithSpan(leftRoot, spanStart, spanStart + splitInd);
      CfgParseTree rightTree =
          getBestParseTreeWithSpan(rightRoot, spanStart + splitInd + 1, spanEnd);

      Preconditions.checkState(leftTree != null);
      Preconditions.checkState(rightTree != null);

      return new CfgParseTree(root, ruleType, leftTree, rightTree, prob);
    }
  }
  @Override
  public void incrementSufficientStatisticsFromMarginal(
      SufficientStatistics gradient,
      SufficientStatistics currentParameters,
      Factor marginal,
      Assignment conditionalAssignment,
      double count,
      double partitionFunction) {
    if (conditionalAssignment.containsAll(getVars().getVariableNumsArray())) {
      // Short-circuit the slow computation below if possible.
      double multiplier = marginal.getTotalUnnormalizedProbability() * count / partitionFunction;
      incrementSufficientStatisticsFromAssignment(
          gradient, currentParameters, conditionalAssignment, multiplier);
    } else {
      VariableNumMap conditionedVars =
          initialWeights.getVars().intersection(conditionalAssignment.getVariableNumsArray());

      TableFactor productFactor =
          (TableFactor)
              initialWeights
                  .product(
                      TableFactor.pointDistribution(
                          conditionedVars, conditionalAssignment.intersection(conditionedVars)))
                  .product(marginal);

      Tensor productFactorWeights = productFactor.getWeights();
      double[] productFactorValues = productFactorWeights.getValues();
      int tensorSize = productFactorWeights.size();
      double multiplier = count / partitionFunction;
      TensorSufficientStatistics tensorGradient = (TensorSufficientStatistics) gradient;
      for (int i = 0; i < tensorSize; i++) {
        int builderIndex = (int) productFactorWeights.indexToKeyNum(i);
        tensorGradient.incrementFeatureByIndex(productFactorValues[i] * multiplier, builderIndex);
      }
    }
  }
 @Override
 public void log(long iteration, int exampleNum, Assignment example, FactorGraph graph) {
   if (showExamples) {
     if (iteration % logInterval == 0) {
       String prob = "";
       if (example.containsAll(graph.getVariables().getVariableNumsArray())) {
         prob = Double.toString(graph.getUnnormalizedLogProbability(example));
       }
       print(
           iteration
               + "."
               + exampleNum
               + " "
               + prob
               + ": example: "
               + graph.assignmentToObject(example));
     }
   }
 }
 /**
  * Gets the best parse tree spanning the entire sentence.
  *
  * @return
  */
 public CfgParseTree getBestParseTree() {
   Factor rootMarginal = getMarginalEntries(0, chartSize() - 1);
   Assignment bestAssignment = rootMarginal.getMostLikelyAssignments(1).get(0);
   return getBestParseTree(bestAssignment.getOnlyValue());
 }