private Tensor getFeatureWeights(SufficientStatistics parameters) { TensorSufficientStatistics featureParameters = (TensorSufficientStatistics) parameters; // Check that the parameters are a vector of the appropriate size. Preconditions.checkArgument(featureParameters.get().getDimensionNumbers().length == 1); Preconditions.checkArgument( featureParameters.get().getDimensionSizes()[0] == initialWeights.getWeights().getValues().length); return featureParameters.get(); }
@Override public void incrementSufficientStatisticsFromMarginal( SufficientStatistics gradient, SufficientStatistics currentParameters, Factor marginal, Assignment conditionalAssignment, double count, double partitionFunction) { if (conditionalAssignment.containsAll(getVars().getVariableNumsArray())) { // Short-circuit the slow computation below if possible. double multiplier = marginal.getTotalUnnormalizedProbability() * count / partitionFunction; incrementSufficientStatisticsFromAssignment( gradient, currentParameters, conditionalAssignment, multiplier); } else { VariableNumMap conditionedVars = initialWeights.getVars().intersection(conditionalAssignment.getVariableNumsArray()); TableFactor productFactor = (TableFactor) initialWeights .product( TableFactor.pointDistribution( conditionedVars, conditionalAssignment.intersection(conditionedVars))) .product(marginal); Tensor productFactorWeights = productFactor.getWeights(); double[] productFactorValues = productFactorWeights.getValues(); int tensorSize = productFactorWeights.size(); double multiplier = count / partitionFunction; TensorSufficientStatistics tensorGradient = (TensorSufficientStatistics) gradient; for (int i = 0; i < tensorSize; i++) { int builderIndex = (int) productFactorWeights.indexToKeyNum(i); tensorGradient.incrementFeatureByIndex(productFactorValues[i] * multiplier, builderIndex); } } }
@Override public void incrementSufficientStatisticsFromAssignment( SufficientStatistics gradient, SufficientStatistics currentParameters, Assignment assignment, double count) { Preconditions.checkArgument(assignment.containsAll(getVars().getVariableNumsArray())); Assignment subAssignment = assignment.intersection(getVars().getVariableNumsArray()); long keyNum = initialWeights .getWeights() .dimKeyToKeyNum(initialWeights.getVars().assignmentToIntArray(subAssignment)); int index = initialWeights.getWeights().keyNumToIndex(keyNum); ((TensorSufficientStatistics) gradient).incrementFeatureByIndex(count, index); }