/** * Product two factors, taking the multiplication at the intersections. * * @param other the other factor to be multiplied * @return a factor containing the union of both variable sets */ public TableFactor multiply(TableFactor other) { // Calculate the result domain List<Integer> domain = new ArrayList<>(); List<Integer> otherDomain = new ArrayList<>(); List<Integer> resultDomain = new ArrayList<>(); for (int n : neighborIndices) { domain.add(n); resultDomain.add(n); } for (int n : other.neighborIndices) { otherDomain.add(n); if (!resultDomain.contains(n)) resultDomain.add(n); } // Create result TableFactor int[] resultNeighborIndices = new int[resultDomain.size()]; int[] resultDimensions = new int[resultNeighborIndices.length]; for (int i = 0; i < resultDomain.size(); i++) { int var = resultDomain.get(i); resultNeighborIndices[i] = var; // assert consistency about variable size, we can't have the same variable with two different // sizes assert ((getVariableSize(var) == 0 && other.getVariableSize(var) > 0) || (getVariableSize(var) > 0 && other.getVariableSize(var) == 0) || (getVariableSize(var) == other.getVariableSize(var))); resultDimensions[i] = Math.max( getVariableSize(resultDomain.get(i)), other.getVariableSize(resultDomain.get(i))); } TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions); // OPTIMIZATION: // If we're a factor of size 2 receiving a message of size 1, then we can optimize that pretty // heavily // We could just use the general algorithm at the end of this set of special cases, but this is // the fastest way if (otherDomain.size() == 1 && (resultDomain.size() == domain.size()) && domain.size() == 2) { int msgVar = otherDomain.get(0); int msgIndex = resultDomain.indexOf(msgVar); if (msgIndex == 0) { for (int i = 0; i < resultDimensions[0]; i++) { double d = other.values[i]; int k = i * resultDimensions[1]; for (int j = 0; j < resultDimensions[1]; j++) { int index = k + j; result.values[index] = values[index] + d; } } } else if (msgIndex == 1) { for (int i = 0; i < resultDimensions[0]; i++) { int k = i * resultDimensions[1]; for (int j = 0; j < resultDimensions[1]; j++) { int index = k + j; result.values[index] = values[index] + other.values[j]; } } } } // OPTIMIZATION: // The special case where we're a message of size 1, and the other factor is receiving the // message, and of size 2 else if (domain.size() == 1 && (resultDomain.size() == otherDomain.size()) && resultDomain.size() == 2) { return other.multiply(this); } // Otherwise we follow the big comprehensive, slow general purpose algorithm else { // Calculate back-pointers from the result domain indices to original indices int[] mapping = new int[result.neighborIndices.length]; int[] otherMapping = new int[result.neighborIndices.length]; for (int i = 0; i < result.neighborIndices.length; i++) { mapping[i] = domain.indexOf(result.neighborIndices[i]); otherMapping[i] = otherDomain.indexOf(result.neighborIndices[i]); } // Do the actual joining operation between the two tables, applying 'join' for each result // element. int[] assignment = new int[neighborIndices.length]; int[] otherAssignment = new int[other.neighborIndices.length]; // OPTIMIZATION: // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, // which need to be GC'd, // we use the fast version that just mutates one array. Since this is read once for us here, // this is ideal. Iterator<int[]> fastPassByReferenceIterator = result.fastPassByReferenceIterator(); int[] resultAssignment = fastPassByReferenceIterator.next(); while (true) { // Set the assignment arrays correctly for (int i = 0; i < resultAssignment.length; i++) { if (mapping[i] != -1) assignment[mapping[i]] = resultAssignment[i]; if (otherMapping[i] != -1) otherAssignment[otherMapping[i]] = resultAssignment[i]; } result.setAssignmentLogValue( resultAssignment, getAssignmentLogValue(assignment) + other.getAssignmentLogValue(otherAssignment)); // This mutates the resultAssignment[] array, rather than creating a new one if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next(); else break; } } return result; }
/** * Marginalizes out a variable by applying an associative join operation for each possible * assignment to the marginalized variable. * * @param variable the variable (by 'name', not offset into neighborIndices) * @param startingValue associativeJoin is basically a foldr over a table, and this is the * initialization * @param curriedFoldr the associative function to use when applying the join operation, taking * first the assignment to the value being marginalized, and then a foldr operation * @return a new TableFactor that doesn't contain 'variable', where values were gotten through * associative marginalization. */ private TableFactor marginalize( int variable, double startingValue, BiFunction<Integer, int[], BiFunction<Double, Double, Double>> curriedFoldr) { // Can't marginalize the last variable assert (getDimensions().length > 1); // Calculate the result domain List<Integer> resultDomain = new ArrayList<>(); for (int n : neighborIndices) { if (n != variable) { resultDomain.add(n); } } // Create result TableFactor int[] resultNeighborIndices = new int[resultDomain.size()]; int[] resultDimensions = new int[resultNeighborIndices.length]; for (int i = 0; i < resultDomain.size(); i++) { int var = resultDomain.get(i); resultNeighborIndices[i] = var; resultDimensions[i] = getVariableSize(var); } TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions); // Calculate forward-pointers from the old domain to new domain int[] mapping = new int[neighborIndices.length]; for (int i = 0; i < neighborIndices.length; i++) { mapping[i] = resultDomain.indexOf(neighborIndices[i]); } // Initialize for (int[] assignment : result) { result.setAssignmentLogValue(assignment, startingValue); } // Do the actual fold into the result int[] resultAssignment = new int[result.neighborIndices.length]; int marginalizedVariableValue = 0; // OPTIMIZATION: // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which // need to be GC'd, // we use the fast version that just mutates one array. Since this is read once for us here, // this is ideal. Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator(); int[] assignment = fastPassByReferenceIterator.next(); while (true) { // Set the assignment arrays correctly for (int i = 0; i < assignment.length; i++) { if (mapping[i] != -1) resultAssignment[mapping[i]] = assignment[i]; else marginalizedVariableValue = assignment[i]; } result.setAssignmentLogValue( resultAssignment, curriedFoldr .apply(marginalizedVariableValue, resultAssignment) .apply( result.getAssignmentLogValue(resultAssignment), getAssignmentLogValue(assignment))); if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next(); else break; } return result; }
/** * Marginalize out a variable by taking a sum. * * @param variable the variable to be summed out * @return a factor with variable removed */ public TableFactor sumOut(int variable) { // OPTIMIZATION: This is by far the most common case, for linear chain inference, and is worth // making fast // We can use closed loops, and not bother with using the basic iterator to loop through // indices. // If this special case doesn't trip, we fall back to the standard (but slower) algorithm for // the general case if (getDimensions().length == 2) { if (neighborIndices[0] == variable) { TableFactor marginalized = new TableFactor(new int[] {neighborIndices[1]}, new int[] {getDimensions()[1]}); for (int i = 0; i < marginalized.values.length; i++) marginalized.values[i] = 0; // We use the stable log-sum-exp trick here, so first we calculate the max double[] max = new double[getDimensions()[1]]; for (int j = 0; j < getDimensions()[1]; j++) { max[j] = Double.NEGATIVE_INFINITY; } for (int i = 0; i < getDimensions()[0]; i++) { int k = i * getDimensions()[1]; for (int j = 0; j < getDimensions()[1]; j++) { int index = k + j; if (values[index] > max[j]) { max[j] = values[index]; } } } // Then we take the sum, minus the max for (int i = 0; i < getDimensions()[0]; i++) { int k = i * getDimensions()[1]; for (int j = 0; j < getDimensions()[1]; j++) { int index = k + j; if (Double.isFinite(max[j])) { marginalized.values[j] += Math.exp(values[index] - max[j]); } } } // And now we exponentiate, and add back in the values for (int j = 0; j < getDimensions()[1]; j++) { if (Double.isFinite(max[j])) { marginalized.values[j] = max[j] + Math.log(marginalized.values[j]); } else { marginalized.values[j] = max[j]; } } return marginalized; } else { assert (neighborIndices[1] == variable); TableFactor marginalized = new TableFactor(new int[] {neighborIndices[0]}, new int[] {getDimensions()[0]}); for (int i = 0; i < marginalized.values.length; i++) marginalized.values[i] = 0; // We use the stable log-sum-exp trick here, so first we calculate the max double[] max = new double[getDimensions()[0]]; for (int i = 0; i < getDimensions()[0]; i++) { max[i] = Double.NEGATIVE_INFINITY; } for (int i = 0; i < getDimensions()[0]; i++) { int k = i * getDimensions()[1]; for (int j = 0; j < getDimensions()[1]; j++) { int index = k + j; if (values[index] > max[i]) { max[i] = values[index]; } } } // Then we take the sum, minus the max for (int i = 0; i < getDimensions()[0]; i++) { int k = i * getDimensions()[1]; for (int j = 0; j < getDimensions()[1]; j++) { int index = k + j; if (Double.isFinite(max[i])) { marginalized.values[i] += Math.exp(values[index] - max[i]); } } } // And now we exponentiate, and add back in the values for (int i = 0; i < getDimensions()[0]; i++) { if (Double.isFinite(max[i])) { marginalized.values[i] = max[i] + Math.log(marginalized.values[i]); } else { marginalized.values[i] = max[i]; } } return marginalized; } } else { // This is a little tricky because we need to use the stable log-sum-exp trick on top of our // marginalize // dataflow operation. // First we calculate all the max values to use as pivots to prevent overflow TableFactor maxValues = maxOut(variable); // Then we do the sum against an offset from the pivots TableFactor marginalized = marginalize( variable, 0, (marginalizedVariableValue, assignment) -> (a, b) -> a + Math.exp(b - maxValues.getAssignmentLogValue(assignment))); // Then we factor the max values back in, and for (int[] assignment : marginalized) { marginalized.setAssignmentLogValue( assignment, maxValues.getAssignmentLogValue(assignment) + Math.log(marginalized.getAssignmentLogValue(assignment))); } return marginalized; } }