private double[] getDefaultPriors(DiscreteDomain domain) { final int length = domain.size(); double[] retval = new double[length]; double val = 1.0 / length; for (int i = 0; i < retval.length; i++) retval[i] = val; return retval; }
public Discrete(Object... domain) { this(DiscreteDomain.create(domain), "Discrete"); if (domain.length < 1) throw new DimpleException( String.format("ERROR Variable domain length %d must be at least 2", domain.length)); }
@Override public @Nullable IDatum setPrior(@Nullable Object prior) { if (prior instanceof double[]) { return setPrior((double[]) prior); } if (prior instanceof Value) { Value value = (Value) prior; final DiscreteDomain domain = getDomain(); if (!domain.equals(value.getDomain())) { // If domain does not match, create a new value with the correct domain. This ensures // that indexing operations can be assumed to be correct for this variable. prior = Value.create(domain, requireNonNull(value.getObject())); } } return super.setPrior(prior); }
/** * Sets values for domain from datum. * * <p> * * @param domain discrete domain with size matching {@link #size()}. * @param datum either an exact {@link Value}, another {@link DiscreteMessage} or other {@link * IUnaryFactorFunction} used to evaluate energies for all possible discrete values. * @since 0.08 */ public void setFrom(DiscreteDomain domain, IDatum datum) { if (datum instanceof DiscreteMessage) { setFrom((DiscreteMessage) datum); } else if (datum instanceof Value) { Value value = (Value) datum; if (domain.equals(value.getDomain())) { setDeterministicIndex(value.getIndex()); } else { setDeterministicIndex(domain.getIndex(value.getObject())); } } else { IUnaryFactorFunction function = (IUnaryFactorFunction) datum; assertSameSize(domain.size()); DiscreteValue value = Value.create(domain); for (int i = domain.size(); --i >= 0; ) { value.setIndex(i); setEnergy(i, function.evalEnergy(value)); } } }
/** * {@inheritDoc} * * <p>All {@code variables} must be of type {@link Discrete}. The domain of the returned variable * will be a {@link JointDiscreteDomain} with the subdomains in the same order as {@code * variables}. */ @Internal @Override public Variable createJointNoFactors(Variable... variables) { final boolean thisIsFirst = (variables[0] == this); final int dimensions = thisIsFirst ? variables.length : variables.length + 1; final DiscreteDomain[] domains = new DiscreteDomain[dimensions]; final IDatum[] subdomainPriors = new IDatum[dimensions]; domains[0] = getDomain(); subdomainPriors[0] = getPrior(); for (int i = thisIsFirst ? 1 : 0; i < dimensions; ++i) { final Discrete var = variables[i].asDiscreteVariable(); domains[i] = var.getDomain(); subdomainPriors[i] = var.getPrior(); } final JointDiscreteDomain<?> jointDomain = DiscreteDomain.joint(domains); final Discrete jointVar = new Discrete(jointDomain); jointVar.setPrior(joinPriors(jointDomain, subdomainPriors)); return jointVar; }
public Object[] getElements() { DiscreteDomain domain = getModelerObject(); // For some reason MATLAB cares about the component type of the array and for example will // implicitly convert strings from an Object[] but not a String[]! return domain.getElements(new Object[domain.size()]); }
private @Nullable IDatum joinPriors( JointDiscreteDomain<?> jointDomain, IDatum[] subdomainPriors) { final JointDomainIndexer domains = jointDomain.getDomainIndexer(); final int dimensions = jointDomain.getDimensions(); boolean hasPrior = false; int[] fixedIndices = new int[dimensions]; Arrays.fill(fixedIndices, -1); for (int i = 0; i < dimensions; ++i) { DiscreteDomain domain = domains.get(i); IDatum prior = subdomainPriors[i]; if (prior != null) { hasPrior = true; if (prior instanceof Value) { Value value = (Value) prior; fixedIndices[i] = domain.equals(value.getDomain()) ? value.getIndex() : domain.getIndex(value.getObject()); subdomainPriors[i] = new DiscreteEnergyMessage(domain, value); } else { DiscreteMessage msg = prior instanceof DiscreteMessage ? (DiscreteMessage) prior : new DiscreteWeightMessage(domain, prior); subdomainPriors[i] = msg; fixedIndices[i] = msg.toDeterministicValueIndex(); } } } if (!hasPrior) { // If none of the component variables has a prior, then neither will the joint variable. return null; } boolean hasAllFixedPriors = true; for (int i : fixedIndices) { if (i < 0) { hasAllFixedPriors = false; break; } } if (hasAllFixedPriors) { // Return fixed value with appropriate joint index. return Value.createWithIndex(jointDomain, domains.jointIndexFromIndices(fixedIndices)); } int cardinality = jointDomain.size(); double[] energies = new double[cardinality]; int inner = 1, outer = cardinality; for (int dim = 0; dim < dimensions; ++dim) { final DiscreteDomain domain = domains.get(dim); final DiscreteMessage prior = (DiscreteMessage) subdomainPriors[dim]; final int size = domain.size(); int i = 0; outer /= size; if (prior != null) { for (int o = 0; o < outer; ++o) { for (double energy : prior.getEnergies()) { for (int r = 0; r < inner; ++r) { energies[i++] += energy; } } } } inner *= size; } return new DiscreteEnergyMessage(energies); }
private MultiplexerCPD create(Object[][] zDomains, boolean oneBased, boolean aAsDouble) { DiscreteDomain[] domains = new DiscreteDomain[zDomains.length]; for (int i = 0; i < domains.length; i++) domains[i] = DiscreteDomain.create(zDomains[i]); return create(domains, oneBased, aAsDouble); }
public static DiscreteDomain[] buildDomains(Object[] domain, int numZs) { return buildDomains(DiscreteDomain.create(domain), numZs); }
private MultiplexerCPD create( Discrete Y, Discrete[] Zs, int zasize, boolean oneBased, boolean aAsDouble) { Y.setLabel("Y"); java.util.Hashtable<Object, Integer> yDomainObj2index = new java.util.Hashtable<Object, Integer>(); final DiscreteDomain yDomain = Y.getDiscreteDomain(); for (int i = 0, end = yDomain.size(); i < end; i++) yDomainObj2index.put(yDomain.getElement(i), i); // Create a variable Object[] adomain = new Object[Zs.length]; for (int i = 0; i < adomain.length; i++) { int val = oneBased ? i + 1 : i; if (aAsDouble) adomain[i] = (double) val; else adomain[i] = (int) val; } Discrete A = new Discrete(adomain); A.setLabel("A"); addBoundaryVariables(Y); addBoundaryVariables(A); addBoundaryVariables(Zs); // Make all of those boundary variables Variable[] vars = new Variable[Zs.length + 2]; vars[0] = Y; vars[1] = A; for (int i = 0; i < Zs.length; i++) vars[i + 2] = Zs[i]; // Create ZA variable Object[] zaDomain = new Object[zasize]; for (int i = 0; i < zaDomain.length; i++) zaDomain[i] = i; Discrete ZA = new Discrete(zaDomain); ZA.setLabel("ZA"); // Create Z* variables Discrete[] Zstars = new Discrete[Zs.length]; for (int i = 0; i < Zstars.length; i++) { Object[] domain = new Object[Zs[i].getDiscreteDomain().size() + 1]; for (int j = 0; j < domain.length; j++) domain[j] = j; Zstars[i] = new Discrete(domain); } // Create ZA Y factor int[][] indices = new int[zasize][2]; double[] weights = new double[zasize]; int index = 0; for (int i = 0; i < Zs.length; i++) { for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) { indices[index][0] = index; indices[index][1] = yDomainObj2index.get(Zs[i].getDiscreteDomain().getElement(j)); weights[index] = 1; index++; } } Factor f = this.addFactor(indices, weights, ZA, Y); f.setLabel("Y2ZA"); // Create ZA A factor indices = new int[zasize][2]; weights = new double[zasize]; index = 0; for (int i = 0; i < Zs.length; i++) { for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) { indices[index][0] = index; indices[index][1] = i; weights[index] = 1; index++; } } f = this.addFactor(indices, weights, ZA, A); f.setLabel("ZA2A"); // Create ZA Z* factors // Create Z* Z factors for (int a = 0; a < Zs.length; a++) { Zs[a].setLabel("Z" + a); Zstars[a].setLabel("Z*" + a); indices = new int[zasize][2]; weights = new double[zasize]; index = 0; // Factor from ZA to Z* for (int i = 0; i < Zs.length; i++) { for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) { indices[index][0] = index; if (a == i) { indices[index][1] = j; } else { int sz = Zs[a].getDiscreteDomain().size(); indices[index][1] = sz; } weights[index] = 1; index++; } } f = this.addFactor(indices, weights, ZA, Zstars[a]); f.setLabel("ZA2Z*"); // From Z* to Z indices = new int[Zs[a].getDiscreteDomain().size() * 2][2]; weights = new double[indices.length]; int ds = Zs[a].getDiscreteDomain().size(); for (int i = 0; i < ds; i++) { indices[i][0] = i; indices[ds + i][0] = ds; indices[i][1] = i; indices[ds + i][1] = i; weights[i] = 1; weights[ds + i] = 1; } f = this.addFactor(indices, weights, Zstars[a], Zs[a]); f.setLabel("Z*2Z"); } this._y = Y; this._a = A; this._za = ZA; this._zs = Zs; return this; }
@Test public void test() { // // Test creation // Domain domain = ObjectDomain.instance(); Value value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); domain = DiscreteDomain.bit(); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(0, value.getInt()); assertEquals(0, value.getIndex()); assertEquals(0, value.getObject()); assertTrue(value instanceof SimpleIntRangeValue); domain = DiscreteDomain.range(2, 5); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(2, value.getInt()); assertEquals(0, value.getIndex()); assertTrue(value instanceof IntRangeValue); domain = DiscreteDomain.range(0, 4, 2); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(0, value.getInt()); assertEquals(0, value.getIndex()); assertTrue(value instanceof IntRangeValue); domain = DiscreteDomain.create(1, 2, 3, 5, 8, 13); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(1, value.getInt()); assertEquals(0, value.getIndex()); assertTrue(value instanceof GenericIntDiscreteValue); domain = DiscreteDomain.create(1, 2, 3.5); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(1, value.getObject()); assertEquals(0, value.getIndex()); assertTrue(value instanceof GenericDiscreteValue); domain = RealDomain.unbounded(); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(0.0, value.getDouble(), 0.0); assertTrue(value instanceof RealValue); domain = RealJointDomain.create(2); value = Value.create(domain); assertInvariants(value); assertEquals(domain, value.getDomain()); assertTrue(value instanceof RealJointValue); domain = IntDomain.unbounded(); value = Value.create(domain); assertInvariants(value); assertSame(domain, value.getDomain()); assertTrue(value instanceof IntValue); domain = DiscreteDomain.range(0, 9); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(3, value.getIndex()); domain = DiscreteDomain.range(-4, 4); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(7, value.getIndex()); domain = DiscreteDomain.range(0, 6, 3); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(1, value.getIndex()); domain = DiscreteDomain.create(1, 2, 3, 5); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(2, value.getIndex()); domain = DiscreteDomain.create("rabbit", 3, 4.2); value = Value.create(domain, 3); assertInvariants(value); assertSame(domain, value.getDomain()); assertEquals(3, value.getInt()); assertEquals(3, value.getObject()); assertEquals(1, value.getIndex()); value = Value.create(42); assertInvariants(value); assertSame(IntDomain.unbounded(), value.getDomain()); assertEquals(42, value.getObject()); assertEquals(42, value.getInt()); value = Value.create((short) 42); assertInvariants(value); assertEquals(42, value.getObject()); assertSame(IntDomain.unbounded(), value.getDomain()); value = Value.create((byte) 42); assertInvariants(value); assertEquals(42, value.getObject()); assertSame(IntDomain.unbounded(), value.getDomain()); value = Value.create(42.0); assertInvariants(value); assertSame(RealDomain.unbounded(), value.getDomain()); assertEquals(42.0, value.getObject()); double[] array = new double[] {1.0, 3.0}; value = Value.create(array); assertInvariants(value); assertEquals(RealJointDomain.create(2), value.getDomain()); assertArrayEquals(array, (double[]) value.getObject(), 0.0); value = Value.create("foo"); assertInvariants(value); assertSame(ObjectDomain.instance(), value.getDomain()); assertEquals("foo", value.getObject()); // // Test integral values // value = Value.create(42); assertEquals(42, value.getInt()); assertEquals(-1, value.getIndex()); value.setInt(23); expectThrow(DimpleException.class, value, "setIndex", 0); assertEquals(23, value.getInt()); value.setDouble(1.6); assertEquals(2, value.getInt()); assertEquals(2.0, value.getDouble(), 0.0); value.setObject(-40); assertEquals(-40, value.getInt()); assertTrue(value.valueEquals(Value.create(-40.0))); assertFalse(value.valueEquals(Value.create(39))); assertInvariants(value); Domain digit = DiscreteDomain.range(0, 9); value = Value.create(digit, 3); assertEquals(3, value.getInt()); assertEquals(3, value.getIndex()); value.setInt(4); assertEquals(4, value.getInt()); value.setDouble(5.2); assertEquals(5, value.getInt()); assertTrue(value.valueEquals(Value.create(5))); value.setIndex(2); assertEquals(2, value.getIndex()); assertEquals(2, value.getObject()); assertInvariants(value); Domain oddDigits = DiscreteDomain.range(1, 9, 2); value = Value.create(oddDigits, 3); assertEquals(3, value.getInt()); assertEquals(1, value.getIndex()); value.setInt(5); assertEquals(5, value.getObject()); assertEquals(2, value.getIndex()); value.setIndex(0); assertEquals(0, value.getIndex()); assertEquals(1, value.getInt()); value.setObject(9); assertEquals(9.0, value.getDouble(), 0.0); assertEquals(4, value.getIndex()); assertFalse(value.valueEquals(Value.create(10))); assertTrue(value.valueEquals(Value.create(9.1))); Domain primeDigits = DiscreteDomain.create(2, 3, 5, 7); value = Value.create(primeDigits, 5); assertEquals(5, value.getInt()); value.setInt(2); assertEquals(2, value.getInt()); assertEquals(0, value.getIndex()); value.setIndex(3); assertEquals(3, value.getIndex()); assertEquals(7, value.getInt()); /* * Test generic discrete values */ DiscreteDomain stooges = DiscreteDomain.create("moe", "joe", "curly"); value = Value.create(stooges); assertEquals("moe", value.getObject()); assertEquals(0, value.getIndex()); assertInvariants(value); value.setObject("joe"); assertEquals("joe", value.getObject()); assertEquals(1, value.getIndex()); value.setIndex(2); assertEquals("curly", value.getObject()); assertEquals(2, value.getIndex()); assertFalse(value.valueEquals(Value.create(stooges, "moe"))); assertTrue(value.valueEquals(Value.create("curly"))); value.setFrom(Value.create("joe")); assertEquals("joe", value.getObject()); /* * Test real values */ value = Value.create(3.14159); assertEquals(3.14159, value.getDouble(), 0.0); value.setInt(42); assertEquals(42.0, value.getDouble(), 0.0); value.setDouble(2.3); assertEquals(2.3, value.getDouble(), 0.0); value.setObject(-123.4); assertEquals(-123.4, value.getDouble(), 0.0); assertFalse(value.valueEquals(Value.create(-123.4002))); Domain halves = DiscreteDomain.range(0, 10, .5); value = Value.create(halves, 3); assertEquals(3, value.getInt()); assertEquals(3, value.getDouble(), 0.0); assertEquals(6, value.getIndex()); assertInvariants(value); value.setIndex(1); assertEquals(.5, value.getDouble(), 0.0); assertEquals(1, value.getIndex()); Domain realDigits = DiscreteDomain.range(0.0, 9.0); value = Value.create(realDigits, 3); assertEquals(3, value.getInt()); assertEquals(3, value.getIndex()); assertInvariants(value); value.setIndex(5); assertEquals(5, value.getInt()); Domain powersOfTwo = DiscreteDomain.create(.125, .25, .5, 1.0, 2.0, 4.0, 8.0); value = Value.create(powersOfTwo); assertEquals(.125, value.getDouble(), 0.0); assertInvariants(value); value.setIndex(2); assertEquals(.5, value.getDouble(), 0.0); value.setFrom(Value.create(.25)); assertEquals(1, value.getIndex()); value.setFrom(Value.create(DiscreteDomain.create(2.3, 4.0, 5.0), 4.0)); assertEquals(4.0, value.getDouble(), 0.0); /* * Test object values */ value = Value.create("foo"); assertEquals("foo", value.getObject()); expectThrow(DimpleException.class, value, "getInt"); value.setInt(42); assertEquals(42, value.getObject()); assertEquals(42, value.getInt()); assertEquals(42.0, value.getDouble(), 0.0); value.setDouble(23.1); assertEquals(23.1, value.getObject()); assertEquals(23, value.getInt()); assertEquals(23.1, value.getDouble(), 0.0); value = Value.create((Domain) null); assertNull(value.getObject()); assertSame(ObjectDomain.instance(), value.getDomain()); /* * */ }