Beispiel #1
0
  /**
   * Product two factors, taking the multiplication at the intersections.
   *
   * @param other the other factor to be multiplied
   * @return a factor containing the union of both variable sets
   */
  public TableFactor multiply(TableFactor other) {

    // Calculate the result domain

    List<Integer> domain = new ArrayList<>();
    List<Integer> otherDomain = new ArrayList<>();
    List<Integer> resultDomain = new ArrayList<>();

    for (int n : neighborIndices) {
      domain.add(n);
      resultDomain.add(n);
    }
    for (int n : other.neighborIndices) {
      otherDomain.add(n);
      if (!resultDomain.contains(n)) resultDomain.add(n);
    }

    // Create result TableFactor

    int[] resultNeighborIndices = new int[resultDomain.size()];
    int[] resultDimensions = new int[resultNeighborIndices.length];
    for (int i = 0; i < resultDomain.size(); i++) {
      int var = resultDomain.get(i);
      resultNeighborIndices[i] = var;
      // assert consistency about variable size, we can't have the same variable with two different
      // sizes
      assert ((getVariableSize(var) == 0 && other.getVariableSize(var) > 0)
          || (getVariableSize(var) > 0 && other.getVariableSize(var) == 0)
          || (getVariableSize(var) == other.getVariableSize(var)));
      resultDimensions[i] =
          Math.max(
              getVariableSize(resultDomain.get(i)), other.getVariableSize(resultDomain.get(i)));
    }
    TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions);

    // OPTIMIZATION:
    // If we're a factor of size 2 receiving a message of size 1, then we can optimize that pretty
    // heavily
    // We could just use the general algorithm at the end of this set of special cases, but this is
    // the fastest way
    if (otherDomain.size() == 1 && (resultDomain.size() == domain.size()) && domain.size() == 2) {
      int msgVar = otherDomain.get(0);
      int msgIndex = resultDomain.indexOf(msgVar);

      if (msgIndex == 0) {
        for (int i = 0; i < resultDimensions[0]; i++) {
          double d = other.values[i];
          int k = i * resultDimensions[1];
          for (int j = 0; j < resultDimensions[1]; j++) {
            int index = k + j;
            result.values[index] = values[index] + d;
          }
        }
      } else if (msgIndex == 1) {
        for (int i = 0; i < resultDimensions[0]; i++) {
          int k = i * resultDimensions[1];
          for (int j = 0; j < resultDimensions[1]; j++) {
            int index = k + j;
            result.values[index] = values[index] + other.values[j];
          }
        }
      }
    }
    // OPTIMIZATION:
    // The special case where we're a message of size 1, and the other factor is receiving the
    // message, and of size 2
    else if (domain.size() == 1
        && (resultDomain.size() == otherDomain.size())
        && resultDomain.size() == 2) {
      return other.multiply(this);
    }
    // Otherwise we follow the big comprehensive, slow general purpose algorithm
    else {

      // Calculate back-pointers from the result domain indices to original indices

      int[] mapping = new int[result.neighborIndices.length];
      int[] otherMapping = new int[result.neighborIndices.length];
      for (int i = 0; i < result.neighborIndices.length; i++) {
        mapping[i] = domain.indexOf(result.neighborIndices[i]);
        otherMapping[i] = otherDomain.indexOf(result.neighborIndices[i]);
      }

      // Do the actual joining operation between the two tables, applying 'join' for each result
      // element.

      int[] assignment = new int[neighborIndices.length];
      int[] otherAssignment = new int[other.neighborIndices.length];

      // OPTIMIZATION:
      // Rather than use the standard iterator, which creates lots of int[] arrays on the heap,
      // which need to be GC'd,
      // we use the fast version that just mutates one array. Since this is read once for us here,
      // this is ideal.
      Iterator<int[]> fastPassByReferenceIterator = result.fastPassByReferenceIterator();
      int[] resultAssignment = fastPassByReferenceIterator.next();
      while (true) {
        // Set the assignment arrays correctly
        for (int i = 0; i < resultAssignment.length; i++) {
          if (mapping[i] != -1) assignment[mapping[i]] = resultAssignment[i];
          if (otherMapping[i] != -1) otherAssignment[otherMapping[i]] = resultAssignment[i];
        }
        result.setAssignmentLogValue(
            resultAssignment,
            getAssignmentLogValue(assignment) + other.getAssignmentLogValue(otherAssignment));
        // This mutates the resultAssignment[] array, rather than creating a new one
        if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next();
        else break;
      }
    }

    return result;
  }
Beispiel #2
0
  /**
   * Marginalizes out a variable by applying an associative join operation for each possible
   * assignment to the marginalized variable.
   *
   * @param variable the variable (by 'name', not offset into neighborIndices)
   * @param startingValue associativeJoin is basically a foldr over a table, and this is the
   *     initialization
   * @param curriedFoldr the associative function to use when applying the join operation, taking
   *     first the assignment to the value being marginalized, and then a foldr operation
   * @return a new TableFactor that doesn't contain 'variable', where values were gotten through
   *     associative marginalization.
   */
  private TableFactor marginalize(
      int variable,
      double startingValue,
      BiFunction<Integer, int[], BiFunction<Double, Double, Double>> curriedFoldr) {
    // Can't marginalize the last variable
    assert (getDimensions().length > 1);

    // Calculate the result domain

    List<Integer> resultDomain = new ArrayList<>();
    for (int n : neighborIndices) {
      if (n != variable) {
        resultDomain.add(n);
      }
    }

    // Create result TableFactor

    int[] resultNeighborIndices = new int[resultDomain.size()];
    int[] resultDimensions = new int[resultNeighborIndices.length];
    for (int i = 0; i < resultDomain.size(); i++) {
      int var = resultDomain.get(i);
      resultNeighborIndices[i] = var;
      resultDimensions[i] = getVariableSize(var);
    }
    TableFactor result = new TableFactor(resultNeighborIndices, resultDimensions);

    // Calculate forward-pointers from the old domain to new domain

    int[] mapping = new int[neighborIndices.length];
    for (int i = 0; i < neighborIndices.length; i++) {
      mapping[i] = resultDomain.indexOf(neighborIndices[i]);
    }

    // Initialize

    for (int[] assignment : result) {
      result.setAssignmentLogValue(assignment, startingValue);
    }

    // Do the actual fold into the result

    int[] resultAssignment = new int[result.neighborIndices.length];
    int marginalizedVariableValue = 0;

    // OPTIMIZATION:
    // Rather than use the standard iterator, which creates lots of int[] arrays on the heap, which
    // need to be GC'd,
    // we use the fast version that just mutates one array. Since this is read once for us here,
    // this is ideal.
    Iterator<int[]> fastPassByReferenceIterator = fastPassByReferenceIterator();
    int[] assignment = fastPassByReferenceIterator.next();
    while (true) {
      // Set the assignment arrays correctly
      for (int i = 0; i < assignment.length; i++) {
        if (mapping[i] != -1) resultAssignment[mapping[i]] = assignment[i];
        else marginalizedVariableValue = assignment[i];
      }
      result.setAssignmentLogValue(
          resultAssignment,
          curriedFoldr
              .apply(marginalizedVariableValue, resultAssignment)
              .apply(
                  result.getAssignmentLogValue(resultAssignment),
                  getAssignmentLogValue(assignment)));
      if (fastPassByReferenceIterator.hasNext()) fastPassByReferenceIterator.next();
      else break;
    }

    return result;
  }
Beispiel #3
0
  /**
   * Marginalize out a variable by taking a sum.
   *
   * @param variable the variable to be summed out
   * @return a factor with variable removed
   */
  public TableFactor sumOut(int variable) {

    // OPTIMIZATION: This is by far the most common case, for linear chain inference, and is worth
    // making fast
    // We can use closed loops, and not bother with using the basic iterator to loop through
    // indices.
    // If this special case doesn't trip, we fall back to the standard (but slower) algorithm for
    // the general case

    if (getDimensions().length == 2) {
      if (neighborIndices[0] == variable) {
        TableFactor marginalized =
            new TableFactor(new int[] {neighborIndices[1]}, new int[] {getDimensions()[1]});

        for (int i = 0; i < marginalized.values.length; i++) marginalized.values[i] = 0;

        // We use the stable log-sum-exp trick here, so first we calculate the max

        double[] max = new double[getDimensions()[1]];
        for (int j = 0; j < getDimensions()[1]; j++) {
          max[j] = Double.NEGATIVE_INFINITY;
        }

        for (int i = 0; i < getDimensions()[0]; i++) {
          int k = i * getDimensions()[1];
          for (int j = 0; j < getDimensions()[1]; j++) {
            int index = k + j;
            if (values[index] > max[j]) {
              max[j] = values[index];
            }
          }
        }

        // Then we take the sum, minus the max

        for (int i = 0; i < getDimensions()[0]; i++) {
          int k = i * getDimensions()[1];
          for (int j = 0; j < getDimensions()[1]; j++) {
            int index = k + j;
            if (Double.isFinite(max[j])) {
              marginalized.values[j] += Math.exp(values[index] - max[j]);
            }
          }
        }

        // And now we exponentiate, and add back in the values

        for (int j = 0; j < getDimensions()[1]; j++) {
          if (Double.isFinite(max[j])) {
            marginalized.values[j] = max[j] + Math.log(marginalized.values[j]);
          } else {
            marginalized.values[j] = max[j];
          }
        }

        return marginalized;
      } else {
        assert (neighborIndices[1] == variable);
        TableFactor marginalized =
            new TableFactor(new int[] {neighborIndices[0]}, new int[] {getDimensions()[0]});

        for (int i = 0; i < marginalized.values.length; i++) marginalized.values[i] = 0;

        // We use the stable log-sum-exp trick here, so first we calculate the max

        double[] max = new double[getDimensions()[0]];
        for (int i = 0; i < getDimensions()[0]; i++) {
          max[i] = Double.NEGATIVE_INFINITY;
        }

        for (int i = 0; i < getDimensions()[0]; i++) {
          int k = i * getDimensions()[1];
          for (int j = 0; j < getDimensions()[1]; j++) {
            int index = k + j;
            if (values[index] > max[i]) {
              max[i] = values[index];
            }
          }
        }

        // Then we take the sum, minus the max

        for (int i = 0; i < getDimensions()[0]; i++) {
          int k = i * getDimensions()[1];
          for (int j = 0; j < getDimensions()[1]; j++) {
            int index = k + j;
            if (Double.isFinite(max[i])) {
              marginalized.values[i] += Math.exp(values[index] - max[i]);
            }
          }
        }

        // And now we exponentiate, and add back in the values

        for (int i = 0; i < getDimensions()[0]; i++) {
          if (Double.isFinite(max[i])) {
            marginalized.values[i] = max[i] + Math.log(marginalized.values[i]);
          } else {
            marginalized.values[i] = max[i];
          }
        }

        return marginalized;
      }
    } else {
      // This is a little tricky because we need to use the stable log-sum-exp trick on top of our
      // marginalize
      // dataflow operation.

      // First we calculate all the max values to use as pivots to prevent overflow
      TableFactor maxValues = maxOut(variable);

      // Then we do the sum against an offset from the pivots
      TableFactor marginalized =
          marginalize(
              variable,
              0,
              (marginalizedVariableValue, assignment) ->
                  (a, b) -> a + Math.exp(b - maxValues.getAssignmentLogValue(assignment)));

      // Then we factor the max values back in, and
      for (int[] assignment : marginalized) {
        marginalized.setAssignmentLogValue(
            assignment,
            maxValues.getAssignmentLogValue(assignment)
                + Math.log(marginalized.getAssignmentLogValue(assignment)));
      }

      return marginalized;
    }
  }