private void assertInvariants(Value value) { Domain domain = value.getDomain(); assertNotNull(domain); Object objValue = value.getObject(); domain.inDomain(objValue); if (domain.isIntegral()) { assertEquals(value.getInt(), value.getDouble(), 0.0); assertEquals(value.getInt(), objValue); } if (domain.isReal()) { assertEquals(value.getDouble(), objValue); } if (objValue instanceof Number) { Number number = (Number) objValue; assertEquals(number.intValue(), value.getInt()); assertEquals(number.doubleValue(), value.getDouble(), 0.0); assertEquals(number.doubleValue() != 0.0, value.getBoolean()); } else if (objValue instanceof Boolean) { Boolean bool = (Boolean) objValue; assertEquals(bool, value.getBoolean()); assertEquals(bool, value.getInt() == 1); assertEquals(bool, value.getDouble() == 1.0); } Value value2 = value.clone(); assertNotSame(value, value2); assertSame(value.getClass(), value2.getClass()); assertEquals(value.getObject(), value2.getObject()); assertEquals(value.getIndex(), value2.getIndex()); assertTrue(value.valueEquals(value2)); Value value3 = Value.create(domain); assertEquals(value.getClass(), value3.getClass()); value3.setFrom(value); assertTrue(value3.valueEquals(value)); assertEquals(objValue, value3.getObject()); assertEquals(value.getIndex(), value3.getIndex()); Value value4 = Value.create(domain, objValue); assertEquals(objValue, value4.getObject()); assertEquals(value.getIndex(), value4.getIndex()); }
@Override public @Nullable IDatum setPrior(@Nullable Object prior) { if (prior instanceof double[]) { return setPrior((double[]) prior); } if (prior instanceof Value) { Value value = (Value) prior; final DiscreteDomain domain = getDomain(); if (!domain.equals(value.getDomain())) { // If domain does not match, create a new value with the correct domain. This ensures // that indexing operations can be assumed to be correct for this variable. prior = Value.create(domain, requireNonNull(value.getObject())); } } return super.setPrior(prior); }
/** * Sets values for domain from datum. * * <p> * * @param domain discrete domain with size matching {@link #size()}. * @param datum either an exact {@link Value}, another {@link DiscreteMessage} or other {@link * IUnaryFactorFunction} used to evaluate energies for all possible discrete values. * @since 0.08 */ public void setFrom(DiscreteDomain domain, IDatum datum) { if (datum instanceof DiscreteMessage) { setFrom((DiscreteMessage) datum); } else if (datum instanceof Value) { Value value = (Value) datum; if (domain.equals(value.getDomain())) { setDeterministicIndex(value.getIndex()); } else { setDeterministicIndex(domain.getIndex(value.getObject())); } } else { IUnaryFactorFunction function = (IUnaryFactorFunction) datum; assertSameSize(domain.size()); DiscreteValue value = Value.create(domain); for (int i = domain.size(); --i >= 0; ) { value.setIndex(i); setEnergy(i, function.evalEnergy(value)); } } }
// Make proposal @Override public BlockProposal next(Value[] currentValue, Domain[] variableDomain) { final DimpleRandom rand = activeRandom(); double proposalForwardEnergy = 0; double proposalReverseEnergy = 0; int argumentIndex = 0; int argumentLength = currentValue.length; Value[] newValue = new Value[argumentLength]; for (int i = 0; i < argumentLength; i++) newValue[i] = Value.create(variableDomain[i]); // Get the current alpha values double[] alpha; double[] alphaEnergy; double alphaSum = 0; if (_customFactor.isAlphaEnergyRepresentation()) { alphaEnergy = _customFactor.getCurrentAlpha(); alpha = new double[alphaEnergy.length]; for (int i = 0; i < alphaEnergy.length; i++) { alpha[i] = Math.exp(-alphaEnergy[i]); alphaSum += alpha[i]; } } else { alpha = _customFactor.getCurrentAlpha(); alphaEnergy = new double[alpha.length]; for (int i = 0; i < alpha.length; i++) { alphaEnergy[i] = -Math.log(alpha[i]); alphaSum += alpha[i]; } } if (alphaSum == 0) // Shouldn't happen, but can during initialization { Arrays.fill(alpha, 1); Arrays.fill(alphaEnergy, 0); alphaSum = alpha.length; } int nextN = _constantN; if (!_hasConstantN) { // If N is variable, sample N uniformly int previousN = currentValue[argumentIndex].getIndex(); int NDomainSize = requireNonNull(variableDomain[0].asDiscrete()).size(); nextN = rand.nextInt(NDomainSize); newValue[argumentIndex].setIndex(nextN); argumentIndex++; // Add this portion of -log p(x_proposed -> x_previous) proposalReverseEnergy += -org.apache.commons.math3.special.Gamma.logGamma(previousN + 1) + previousN * Math.log(alphaSum); // Add this portion of -log p(x_previous -> x_proposed) proposalForwardEnergy += -org.apache.commons.math3.special.Gamma.logGamma(nextN + 1) + nextN * Math.log(alphaSum); } // Given N and alpha, resample the outputs // Multinomial formed by successively sampling from a binomial and subtracting each count from // the total // FIXME: Assumes all outputs are variable (no constant outputs) int remainingN = nextN; int alphaIndex = 0; for (; argumentIndex < argumentLength; argumentIndex++, alphaIndex++) { double alphai = alpha[alphaIndex]; double alphaEnergyi = alphaEnergy[alphaIndex]; int previousX = currentValue[argumentIndex].getIndex(); int nextX; if (argumentIndex < argumentLength - 1) nextX = rand.nextBinomial(remainingN, alphai / alphaSum); else // Last value nextX = remainingN; newValue[argumentIndex].setIndex(nextX); remainingN -= nextX; // Subtract the sample value from the remaining total count alphaSum -= alphai; // Subtract this alpha value from the sum used for normalization double previousXNegativeLogAlphai; double nextXNegativeLogAlphai; if (alphai == 0 && previousX == 0) previousXNegativeLogAlphai = 0; else previousXNegativeLogAlphai = previousX * alphaEnergyi; if (alphai == 0 && nextX == 0) nextXNegativeLogAlphai = 0; else nextXNegativeLogAlphai = nextX * alphaEnergyi; // Add this portion of -log p(x_proposed -> x_previous) proposalReverseEnergy += previousXNegativeLogAlphai + org.apache.commons.math3.special.Gamma.logGamma(previousX + 1); // Add this portion of -log p(x_previous -> x_proposed) proposalForwardEnergy += nextXNegativeLogAlphai + org.apache.commons.math3.special.Gamma.logGamma(nextX + 1); } return new BlockProposal(newValue, proposalForwardEnergy, proposalReverseEnergy); }
/** @deprecated use {@link #setPrior} instead. */ @Deprecated public void setFixedValue(Object fixedValue) { setPrior(Value.create(getDomain(), fixedValue)); }
@Test public void test() { // // Test creation // Domain domain = ObjectDomain.instance(); Value value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); domain = DiscreteDomain.bit(); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(0, value.getInt()); assertEquals(0, value.getIndex()); assertEquals(0, value.getObject()); assertTrue(value instanceof SimpleIntRangeValue); domain = DiscreteDomain.range(2, 5); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(2, value.getInt()); assertEquals(0, value.getIndex()); assertTrue(value instanceof IntRangeValue); domain = DiscreteDomain.range(0, 4, 2); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(0, value.getInt()); assertEquals(0, value.getIndex()); assertTrue(value instanceof IntRangeValue); domain = DiscreteDomain.create(1, 2, 3, 5, 8, 13); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(1, value.getInt()); assertEquals(0, value.getIndex()); assertTrue(value instanceof GenericIntDiscreteValue); domain = DiscreteDomain.create(1, 2, 3.5); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(1, value.getObject()); assertEquals(0, value.getIndex()); assertTrue(value instanceof GenericDiscreteValue); domain = RealDomain.unbounded(); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(0.0, value.getDouble(), 0.0); assertTrue(value instanceof RealValue); domain = RealJointDomain.create(2); value = Value.create(domain); assertInvariants(value); assertEquals(domain, value.getDomain()); assertTrue(value instanceof RealJointValue); domain = IntDomain.unbounded(); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertTrue(value instanceof IntValue); domain = DiscreteDomain.range(0, 9); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(3, value.getIndex()); domain = DiscreteDomain.range(-4, 4); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(7, value.getIndex()); domain = DiscreteDomain.range(0, 6, 3); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(1, value.getIndex()); domain = DiscreteDomain.create(1, 2, 3, 5); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(2, value.getIndex()); domain = DiscreteDomain.create("rabbit", 3, 4.2); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(3, value.getObject()); assertEquals(1, value.getIndex()); value = Value.create(42); assertInvariants(value); assertSame(IntDomain.unbounded(), value.getDomain()); assertEquals(42, value.getObject()); assertEquals(42, value.getInt()); value = Value.create((short) 42); assertInvariants(value); assertEquals(42, value.getObject()); assertSame(IntDomain.unbounded(), value.getDomain()); value = Value.create((byte) 42); assertInvariants(value); assertEquals(42, value.getObject()); assertSame(IntDomain.unbounded(), value.getDomain()); value = Value.create(42.0); assertInvariants(value); assertSame(RealDomain.unbounded(), value.getDomain()); assertEquals(42.0, value.getObject()); double[] array = new double[] {1.0, 3.0}; value = Value.create(array); assertInvariants(value); assertEquals(RealJointDomain.create(2), value.getDomain()); assertArrayEquals(array, (double[]) value.getObject(), 0.0); value = Value.create("foo"); assertInvariants(value); assertSame(ObjectDomain.instance(), value.getDomain()); assertEquals("foo", value.getObject()); // // Test integral values // value = Value.create(42); assertEquals(42, value.getInt()); assertEquals(-1, value.getIndex()); value.setInt(23); expectThrow(DimpleException.class, value, "setIndex", 0); assertEquals(23, value.getInt()); value.setDouble(1.6); assertEquals(2, value.getInt()); assertEquals(2.0, value.getDouble(), 0.0); value.setObject(-40); assertEquals(-40, value.getInt()); assertTrue(value.valueEquals(Value.create(-40.0))); assertFalse(value.valueEquals(Value.create(39))); assertInvariants(value); Domain digit = DiscreteDomain.range(0, 9); value = Value.create(digit, 3); assertEquals(3, value.getInt()); assertEquals(3, value.getIndex()); value.setInt(4); assertEquals(4, value.getInt()); value.setDouble(5.2); assertEquals(5, value.getInt()); assertTrue(value.valueEquals(Value.create(5))); value.setIndex(2); assertEquals(2, value.getIndex()); assertEquals(2, value.getObject()); assertInvariants(value); Domain oddDigits = DiscreteDomain.range(1, 9, 2); value = Value.create(oddDigits, 3); assertEquals(3, value.getInt()); assertEquals(1, value.getIndex()); value.setInt(5); assertEquals(5, value.getObject()); assertEquals(2, value.getIndex()); value.setIndex(0); assertEquals(0, value.getIndex()); assertEquals(1, value.getInt()); value.setObject(9); assertEquals(9.0, value.getDouble(), 0.0); assertEquals(4, value.getIndex()); assertFalse(value.valueEquals(Value.create(10))); assertTrue(value.valueEquals(Value.create(9.1))); Domain primeDigits = DiscreteDomain.create(2, 3, 5, 7); value = Value.create(primeDigits, 5); assertEquals(5, value.getInt()); value.setInt(2); assertEquals(2, value.getInt()); assertEquals(0, value.getIndex()); value.setIndex(3); assertEquals(3, value.getIndex()); assertEquals(7, value.getInt()); /* * Test generic discrete values */ DiscreteDomain stooges = DiscreteDomain.create("moe", "joe", "curly"); value = Value.create(stooges); assertEquals("moe", value.getObject()); assertEquals(0, value.getIndex()); assertInvariants(value); value.setObject("joe"); assertEquals("joe", value.getObject()); assertEquals(1, value.getIndex()); value.setIndex(2); assertEquals("curly", value.getObject()); assertEquals(2, value.getIndex()); assertFalse(value.valueEquals(Value.create(stooges, "moe"))); assertTrue(value.valueEquals(Value.create("curly"))); value.setFrom(Value.create("joe")); assertEquals("joe", value.getObject()); /* * Test real values */ value = Value.create(3.14159); assertEquals(3.14159, value.getDouble(), 0.0); value.setInt(42); assertEquals(42.0, value.getDouble(), 0.0); value.setDouble(2.3); assertEquals(2.3, value.getDouble(), 0.0); value.setObject(-123.4); assertEquals(-123.4, value.getDouble(), 0.0); assertFalse(value.valueEquals(Value.create(-123.4002))); Domain halves = DiscreteDomain.range(0, 10, .5); value = Value.create(halves, 3); assertEquals(3, value.getInt()); assertEquals(3, value.getDouble(), 0.0); assertEquals(6, value.getIndex()); assertInvariants(value); value.setIndex(1); assertEquals(.5, value.getDouble(), 0.0); assertEquals(1, value.getIndex()); Domain realDigits = DiscreteDomain.range(0.0, 9.0); value = Value.create(realDigits, 3); assertEquals(3, value.getInt()); assertEquals(3, value.getIndex()); assertInvariants(value); value.setIndex(5); assertEquals(5, value.getInt()); Domain powersOfTwo = DiscreteDomain.create(.125, .25, .5, 1.0, 2.0, 4.0, 8.0); value = Value.create(powersOfTwo); assertEquals(.125, value.getDouble(), 0.0); assertInvariants(value); value.setIndex(2); assertEquals(.5, value.getDouble(), 0.0); value.setFrom(Value.create(.25)); assertEquals(1, value.getIndex()); value.setFrom(Value.create(DiscreteDomain.create(2.3, 4.0, 5.0), 4.0)); assertEquals(4.0, value.getDouble(), 0.0); /* * Test object values */ value = Value.create("foo"); assertEquals("foo", value.getObject()); expectThrow(DimpleException.class, value, "getInt"); value.setInt(42); assertEquals(42, value.getObject()); assertEquals(42, value.getInt()); assertEquals(42.0, value.getDouble(), 0.0); value.setDouble(23.1); assertEquals(23.1, value.getObject()); assertEquals(23, value.getInt()); assertEquals(23.1, value.getDouble(), 0.0); value = Value.create((Domain) null); assertNull(value.getObject()); assertSame(ObjectDomain.instance(), value.getDomain()); /* * */ }