@Override
  public void updateEdgeMessage(int portNum) {
    if (portNum < _numParameterEdges) {
      // Port is a parameter input
      // Determine sample alpha and beta parameters
      // NOTE: This case works for either CategoricalUnnormalizedParameters or
      // CategoricalEnergyParameters factor functions
      // since the actual parameter value doesn't come into play in determining the message in this
      // direction

      GammaParameters outputMsg = (GammaParameters) _outputMsgs[portNum];

      // The parameter being updated corresponds to this value
      int parameterIndex = _factorFunction.getIndexByEdge(portNum);

      // Start with the ports to variable outputs
      int count = 0;
      for (int i = 0; i < _numOutputEdges; i++) {
        int outputIndex = _outputVariables[i].getCurrentSampleIndex();
        if (outputIndex == parameterIndex) count++;
      }

      // Include any constant outputs also
      if (_hasConstantOutputs) count += _constantOutputCounts[parameterIndex];

      outputMsg.setAlphaMinusOne(count); // Sample alpha
      outputMsg.setBeta(0); // Sample beta
    } else super.updateEdgeMessage(portNum);
  }
Ejemplo n.º 2
0
  @Override
  public void initialize() {
    super.initialize();

    _numPorts = _factor.getSiblingCount();

    // Pre-compute parity associated with any constant edges
    _constantParity = 1;
    FactorFunction factorFunction = _factor.getFactorFunction();
    if (factorFunction.hasConstants()) {
      Object[] constantValues = factorFunction.getConstants();
      int constantSum = 0;
      for (int i = 0; i < constantValues.length; i++)
        constantSum += FactorFunctionUtilities.toInteger(constantValues[i]);
      _constantParity = ((constantSum & 1) == 0) ? 1 : -1;
    }
  }
  private void determineParameterConstantsAndEdges() {
    // Get the factor function and related state
    FactorFunction factorFunction = _factor.getFactorFunction();
    FactorFunction containedFactorFunction =
        factorFunction.getContainedFactorFunction(); // In case the factor function is wrapped
    _factorFunction = factorFunction;
    boolean hasFactorFunctionConstants = factorFunction.hasConstants();
    if (containedFactorFunction instanceof CategoricalUnnormalizedParameters) {
      CategoricalUnnormalizedParameters specificFactorFunction =
          (CategoricalUnnormalizedParameters) containedFactorFunction;
      _hasFactorFunctionConstructorConstants = specificFactorFunction.hasConstantParameters();
      _numParameters = specificFactorFunction.getDimension();
      _useEnergyParameters = false;
    } else if (containedFactorFunction instanceof CategoricalEnergyParameters) {
      CategoricalEnergyParameters specificFactorFunction =
          (CategoricalEnergyParameters) containedFactorFunction;
      _hasFactorFunctionConstructorConstants = specificFactorFunction.hasConstantParameters();
      _numParameters = specificFactorFunction.getDimension();
      _useEnergyParameters = true;
    } else throw new DimpleException("Invalid factor function");

    // Pre-determine whether or not the parameters are constant; if so save the value; if not save
    // reference to the variable
    _numParameterEdges = _numParameters;
    _hasConstantOutputs = false;
    if (_hasFactorFunctionConstructorConstants) {
      // The factor function has fixed parameters provided in the factor-function constructor
      _numParameterEdges = 0;
      _hasConstantOutputs = hasFactorFunctionConstants;
    } else if (hasFactorFunctionConstants) {
      _hasConstantOutputs = factorFunction.hasConstantAtOrAboveIndex(_numParameters);
      int numConstantParameters = factorFunction.numConstantsInIndexRange(0, _numParameters - 1);
      _numParameterEdges = _numParameters - numConstantParameters;
    }
    _numOutputEdges = _numPorts - _numParameterEdges;

    // Save output variables
    List<? extends VariableBase> siblings = _factor.getSiblings();
    _outputVariables = new SDiscreteVariable[_numOutputEdges];
    for (int i = 0; i < _numOutputEdges; i++)
      _outputVariables[i] =
          (SDiscreteVariable) ((siblings.get(i + _numParameterEdges)).getSolver());
  }
  @Override
  public void initialize() {
    super.initialize();

    // Determine what parameters are constants or edges, and save the state
    determineParameterConstantsAndEdges();

    // Pre-compute statistics associated with any constant output values
    _constantOutputCounts = null;
    if (_hasConstantOutputs) {
      FactorFunction factorFunction = _factor.getFactorFunction();
      Object[] constantValues = factorFunction.getConstants();
      int[] constantIndices = factorFunction.getConstantIndices();
      _constantOutputCounts = new int[_numParameters];
      for (int i = 0; i < constantIndices.length; i++) {
        if (_hasFactorFunctionConstructorConstants || constantIndices[i] >= _numParameters) {
          int outputValue = FactorFunctionUtilities.toInteger(constantValues[i]);
          _constantOutputCounts[outputValue]++; // Histogram among constant outputs
        }
      }
    }
  }