@Override public void updateEdgeMessage(int portNum) { if (portNum < _numParameterEdges) { // Port is a parameter input // Determine sample alpha and beta parameters // NOTE: This case works for either CategoricalUnnormalizedParameters or // CategoricalEnergyParameters factor functions // since the actual parameter value doesn't come into play in determining the message in this // direction GammaParameters outputMsg = (GammaParameters) _outputMsgs[portNum]; // The parameter being updated corresponds to this value int parameterIndex = _factorFunction.getIndexByEdge(portNum); // Start with the ports to variable outputs int count = 0; for (int i = 0; i < _numOutputEdges; i++) { int outputIndex = _outputVariables[i].getCurrentSampleIndex(); if (outputIndex == parameterIndex) count++; } // Include any constant outputs also if (_hasConstantOutputs) count += _constantOutputCounts[parameterIndex]; outputMsg.setAlphaMinusOne(count); // Sample alpha outputMsg.setBeta(0); // Sample beta } else super.updateEdgeMessage(portNum); }
@Override public void initialize() { super.initialize(); _numPorts = _factor.getSiblingCount(); // Pre-compute parity associated with any constant edges _constantParity = 1; FactorFunction factorFunction = _factor.getFactorFunction(); if (factorFunction.hasConstants()) { Object[] constantValues = factorFunction.getConstants(); int constantSum = 0; for (int i = 0; i < constantValues.length; i++) constantSum += FactorFunctionUtilities.toInteger(constantValues[i]); _constantParity = ((constantSum & 1) == 0) ? 1 : -1; } }
private void determineParameterConstantsAndEdges() { // Get the factor function and related state FactorFunction factorFunction = _factor.getFactorFunction(); FactorFunction containedFactorFunction = factorFunction.getContainedFactorFunction(); // In case the factor function is wrapped _factorFunction = factorFunction; boolean hasFactorFunctionConstants = factorFunction.hasConstants(); if (containedFactorFunction instanceof CategoricalUnnormalizedParameters) { CategoricalUnnormalizedParameters specificFactorFunction = (CategoricalUnnormalizedParameters) containedFactorFunction; _hasFactorFunctionConstructorConstants = specificFactorFunction.hasConstantParameters(); _numParameters = specificFactorFunction.getDimension(); _useEnergyParameters = false; } else if (containedFactorFunction instanceof CategoricalEnergyParameters) { CategoricalEnergyParameters specificFactorFunction = (CategoricalEnergyParameters) containedFactorFunction; _hasFactorFunctionConstructorConstants = specificFactorFunction.hasConstantParameters(); _numParameters = specificFactorFunction.getDimension(); _useEnergyParameters = true; } else throw new DimpleException("Invalid factor function"); // Pre-determine whether or not the parameters are constant; if so save the value; if not save // reference to the variable _numParameterEdges = _numParameters; _hasConstantOutputs = false; if (_hasFactorFunctionConstructorConstants) { // The factor function has fixed parameters provided in the factor-function constructor _numParameterEdges = 0; _hasConstantOutputs = hasFactorFunctionConstants; } else if (hasFactorFunctionConstants) { _hasConstantOutputs = factorFunction.hasConstantAtOrAboveIndex(_numParameters); int numConstantParameters = factorFunction.numConstantsInIndexRange(0, _numParameters - 1); _numParameterEdges = _numParameters - numConstantParameters; } _numOutputEdges = _numPorts - _numParameterEdges; // Save output variables List<? extends VariableBase> siblings = _factor.getSiblings(); _outputVariables = new SDiscreteVariable[_numOutputEdges]; for (int i = 0; i < _numOutputEdges; i++) _outputVariables[i] = (SDiscreteVariable) ((siblings.get(i + _numParameterEdges)).getSolver()); }
@Override public void initialize() { super.initialize(); // Determine what parameters are constants or edges, and save the state determineParameterConstantsAndEdges(); // Pre-compute statistics associated with any constant output values _constantOutputCounts = null; if (_hasConstantOutputs) { FactorFunction factorFunction = _factor.getFactorFunction(); Object[] constantValues = factorFunction.getConstants(); int[] constantIndices = factorFunction.getConstantIndices(); _constantOutputCounts = new int[_numParameters]; for (int i = 0; i < constantIndices.length; i++) { if (_hasFactorFunctionConstructorConstants || constantIndices[i] >= _numParameters) { int outputValue = FactorFunctionUtilities.toInteger(constantValues[i]); _constantOutputCounts[outputValue]++; // Histogram among constant outputs } } } }