private double[] getDefaultPriors(DiscreteDomain domain) {
   final int length = domain.size();
   double[] retval = new double[length];
   double val = 1.0 / length;
   for (int i = 0; i < retval.length; i++) retval[i] = val;
   return retval;
 }
  public Discrete(Object... domain) {
    this(DiscreteDomain.create(domain), "Discrete");

    if (domain.length < 1)
      throw new DimpleException(
          String.format("ERROR Variable domain length %d must be at least 2", domain.length));
  }
  @Override
  public @Nullable IDatum setPrior(@Nullable Object prior) {
    if (prior instanceof double[]) {
      return setPrior((double[]) prior);
    }

    if (prior instanceof Value) {
      Value value = (Value) prior;
      final DiscreteDomain domain = getDomain();

      if (!domain.equals(value.getDomain())) {
        // If domain does not match, create a new value with the correct domain. This ensures
        // that indexing operations can be assumed to be correct for this variable.
        prior = Value.create(domain, requireNonNull(value.getObject()));
      }
    }

    return super.setPrior(prior);
  }
 /**
  * Sets values for domain from datum.
  *
  * <p>
  *
  * @param domain discrete domain with size matching {@link #size()}.
  * @param datum either an exact {@link Value}, another {@link DiscreteMessage} or other {@link
  *     IUnaryFactorFunction} used to evaluate energies for all possible discrete values.
  * @since 0.08
  */
 public void setFrom(DiscreteDomain domain, IDatum datum) {
   if (datum instanceof DiscreteMessage) {
     setFrom((DiscreteMessage) datum);
   } else if (datum instanceof Value) {
     Value value = (Value) datum;
     if (domain.equals(value.getDomain())) {
       setDeterministicIndex(value.getIndex());
     } else {
       setDeterministicIndex(domain.getIndex(value.getObject()));
     }
   } else {
     IUnaryFactorFunction function = (IUnaryFactorFunction) datum;
     assertSameSize(domain.size());
     DiscreteValue value = Value.create(domain);
     for (int i = domain.size(); --i >= 0; ) {
       value.setIndex(i);
       setEnergy(i, function.evalEnergy(value));
     }
   }
 }
  /**
   * {@inheritDoc}
   *
   * <p>All {@code variables} must be of type {@link Discrete}. The domain of the returned variable
   * will be a {@link JointDiscreteDomain} with the subdomains in the same order as {@code
   * variables}.
   */
  @Internal
  @Override
  public Variable createJointNoFactors(Variable... variables) {
    final boolean thisIsFirst = (variables[0] == this);
    final int dimensions = thisIsFirst ? variables.length : variables.length + 1;
    final DiscreteDomain[] domains = new DiscreteDomain[dimensions];
    final IDatum[] subdomainPriors = new IDatum[dimensions];
    domains[0] = getDomain();
    subdomainPriors[0] = getPrior();

    for (int i = thisIsFirst ? 1 : 0; i < dimensions; ++i) {
      final Discrete var = variables[i].asDiscreteVariable();
      domains[i] = var.getDomain();
      subdomainPriors[i] = var.getPrior();
    }

    final JointDiscreteDomain<?> jointDomain = DiscreteDomain.joint(domains);
    final Discrete jointVar = new Discrete(jointDomain);
    jointVar.setPrior(joinPriors(jointDomain, subdomainPriors));
    return jointVar;
  }
 public Object[] getElements() {
   DiscreteDomain domain = getModelerObject();
   // For some reason MATLAB cares about the component type of the array and for example will
   // implicitly convert strings from an Object[] but not a String[]!
   return domain.getElements(new Object[domain.size()]);
 }
  private @Nullable IDatum joinPriors(
      JointDiscreteDomain<?> jointDomain, IDatum[] subdomainPriors) {
    final JointDomainIndexer domains = jointDomain.getDomainIndexer();
    final int dimensions = jointDomain.getDimensions();
    boolean hasPrior = false;
    int[] fixedIndices = new int[dimensions];
    Arrays.fill(fixedIndices, -1);

    for (int i = 0; i < dimensions; ++i) {
      DiscreteDomain domain = domains.get(i);
      IDatum prior = subdomainPriors[i];
      if (prior != null) {
        hasPrior = true;

        if (prior instanceof Value) {
          Value value = (Value) prior;
          fixedIndices[i] =
              domain.equals(value.getDomain())
                  ? value.getIndex()
                  : domain.getIndex(value.getObject());
          subdomainPriors[i] = new DiscreteEnergyMessage(domain, value);
        } else {
          DiscreteMessage msg =
              prior instanceof DiscreteMessage
                  ? (DiscreteMessage) prior
                  : new DiscreteWeightMessage(domain, prior);
          subdomainPriors[i] = msg;
          fixedIndices[i] = msg.toDeterministicValueIndex();
        }
      }
    }

    if (!hasPrior) {
      // If none of the component variables has a prior, then neither will the joint variable.
      return null;
    }

    boolean hasAllFixedPriors = true;
    for (int i : fixedIndices) {
      if (i < 0) {
        hasAllFixedPriors = false;
        break;
      }
    }

    if (hasAllFixedPriors) {
      // Return fixed value with appropriate joint index.
      return Value.createWithIndex(jointDomain, domains.jointIndexFromIndices(fixedIndices));
    }

    int cardinality = jointDomain.size();

    double[] energies = new double[cardinality];

    int inner = 1, outer = cardinality;
    for (int dim = 0; dim < dimensions; ++dim) {
      final DiscreteDomain domain = domains.get(dim);
      final DiscreteMessage prior = (DiscreteMessage) subdomainPriors[dim];
      final int size = domain.size();
      int i = 0;

      outer /= size;

      if (prior != null) {
        for (int o = 0; o < outer; ++o) {
          for (double energy : prior.getEnergies()) {
            for (int r = 0; r < inner; ++r) {
              energies[i++] += energy;
            }
          }
        }
      }

      inner *= size;
    }

    return new DiscreteEnergyMessage(energies);
  }
  private MultiplexerCPD create(Object[][] zDomains, boolean oneBased, boolean aAsDouble) {
    DiscreteDomain[] domains = new DiscreteDomain[zDomains.length];
    for (int i = 0; i < domains.length; i++) domains[i] = DiscreteDomain.create(zDomains[i]);

    return create(domains, oneBased, aAsDouble);
  }
 public static DiscreteDomain[] buildDomains(Object[] domain, int numZs) {
   return buildDomains(DiscreteDomain.create(domain), numZs);
 }
  private MultiplexerCPD create(
      Discrete Y, Discrete[] Zs, int zasize, boolean oneBased, boolean aAsDouble) {
    Y.setLabel("Y");

    java.util.Hashtable<Object, Integer> yDomainObj2index =
        new java.util.Hashtable<Object, Integer>();
    final DiscreteDomain yDomain = Y.getDiscreteDomain();
    for (int i = 0, end = yDomain.size(); i < end; i++)
      yDomainObj2index.put(yDomain.getElement(i), i);

    // Create a variable
    Object[] adomain = new Object[Zs.length];
    for (int i = 0; i < adomain.length; i++) {
      int val = oneBased ? i + 1 : i;
      if (aAsDouble) adomain[i] = (double) val;
      else adomain[i] = (int) val;
    }
    Discrete A = new Discrete(adomain);
    A.setLabel("A");

    addBoundaryVariables(Y);
    addBoundaryVariables(A);
    addBoundaryVariables(Zs);

    // Make all of those boundary variables
    Variable[] vars = new Variable[Zs.length + 2];
    vars[0] = Y;
    vars[1] = A;
    for (int i = 0; i < Zs.length; i++) vars[i + 2] = Zs[i];

    // Create ZA variable
    Object[] zaDomain = new Object[zasize];
    for (int i = 0; i < zaDomain.length; i++) zaDomain[i] = i;
    Discrete ZA = new Discrete(zaDomain);
    ZA.setLabel("ZA");

    // Create Z* variables
    Discrete[] Zstars = new Discrete[Zs.length];
    for (int i = 0; i < Zstars.length; i++) {
      Object[] domain = new Object[Zs[i].getDiscreteDomain().size() + 1];
      for (int j = 0; j < domain.length; j++) domain[j] = j;
      Zstars[i] = new Discrete(domain);
    }

    // Create ZA Y factor
    int[][] indices = new int[zasize][2];
    double[] weights = new double[zasize];

    int index = 0;
    for (int i = 0; i < Zs.length; i++) {
      for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) {
        indices[index][0] = index;
        indices[index][1] = yDomainObj2index.get(Zs[i].getDiscreteDomain().getElement(j));

        weights[index] = 1;

        index++;
      }
    }

    Factor f = this.addFactor(indices, weights, ZA, Y);
    f.setLabel("Y2ZA");

    // Create ZA A factor
    indices = new int[zasize][2];
    weights = new double[zasize];

    index = 0;

    for (int i = 0; i < Zs.length; i++) {
      for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) {
        indices[index][0] = index;
        indices[index][1] = i;

        weights[index] = 1;

        index++;
      }
    }

    f = this.addFactor(indices, weights, ZA, A);
    f.setLabel("ZA2A");

    // Create ZA Z* factors
    // Create Z* Z factors

    for (int a = 0; a < Zs.length; a++) {
      Zs[a].setLabel("Z" + a);
      Zstars[a].setLabel("Z*" + a);

      indices = new int[zasize][2];
      weights = new double[zasize];

      index = 0;

      // Factor from ZA to Z*
      for (int i = 0; i < Zs.length; i++) {
        for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) {
          indices[index][0] = index;

          if (a == i) {
            indices[index][1] = j;
          } else {
            int sz = Zs[a].getDiscreteDomain().size();
            indices[index][1] = sz;
          }

          weights[index] = 1;
          index++;
        }
      }

      f = this.addFactor(indices, weights, ZA, Zstars[a]);
      f.setLabel("ZA2Z*");

      // From Z* to Z
      indices = new int[Zs[a].getDiscreteDomain().size() * 2][2];
      weights = new double[indices.length];

      int ds = Zs[a].getDiscreteDomain().size();

      for (int i = 0; i < ds; i++) {
        indices[i][0] = i;
        indices[ds + i][0] = ds;

        indices[i][1] = i;
        indices[ds + i][1] = i;

        weights[i] = 1;
        weights[ds + i] = 1;
      }

      f = this.addFactor(indices, weights, Zstars[a], Zs[a]);
      f.setLabel("Z*2Z");
    }

    this._y = Y;
    this._a = A;
    this._za = ZA;
    this._zs = Zs;

    return this;
  }
Beispiel #11
0
  @Test
  public void test() {
    //
    // Test creation
    //

    Domain domain = ObjectDomain.instance();
    Value value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());

    domain = DiscreteDomain.bit();
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(0, value.getInt());
    assertEquals(0, value.getIndex());
    assertEquals(0, value.getObject());
    assertTrue(value instanceof SimpleIntRangeValue);

    domain = DiscreteDomain.range(2, 5);
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(2, value.getInt());
    assertEquals(0, value.getIndex());
    assertTrue(value instanceof IntRangeValue);

    domain = DiscreteDomain.range(0, 4, 2);
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(0, value.getInt());
    assertEquals(0, value.getIndex());
    assertTrue(value instanceof IntRangeValue);

    domain = DiscreteDomain.create(1, 2, 3, 5, 8, 13);
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(1, value.getInt());
    assertEquals(0, value.getIndex());
    assertTrue(value instanceof GenericIntDiscreteValue);

    domain = DiscreteDomain.create(1, 2, 3.5);
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(1, value.getObject());
    assertEquals(0, value.getIndex());
    assertTrue(value instanceof GenericDiscreteValue);

    domain = RealDomain.unbounded();
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(0.0, value.getDouble(), 0.0);
    assertTrue(value instanceof RealValue);

    domain = RealJointDomain.create(2);
    value = Value.create(domain);
    assertInvariants(value);
    assertEquals(domain, value.getDomain());
    assertTrue(value instanceof RealJointValue);

    domain = IntDomain.unbounded();
    value = Value.create(domain);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertTrue(value instanceof IntValue);

    domain = DiscreteDomain.range(0, 9);
    value = Value.create(domain, 3);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(3, value.getInt());
    assertEquals(3, value.getIndex());

    domain = DiscreteDomain.range(-4, 4);
    value = Value.create(domain, 3);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(3, value.getInt());
    assertEquals(7, value.getIndex());

    domain = DiscreteDomain.range(0, 6, 3);
    value = Value.create(domain, 3);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(3, value.getInt());
    assertEquals(1, value.getIndex());

    domain = DiscreteDomain.create(1, 2, 3, 5);
    value = Value.create(domain, 3);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(3, value.getInt());
    assertEquals(2, value.getIndex());

    domain = DiscreteDomain.create("rabbit", 3, 4.2);
    value = Value.create(domain, 3);
    assertInvariants(value);
    assertSame(domain, value.getDomain());
    assertEquals(3, value.getInt());
    assertEquals(3, value.getObject());
    assertEquals(1, value.getIndex());

    value = Value.create(42);
    assertInvariants(value);
    assertSame(IntDomain.unbounded(), value.getDomain());
    assertEquals(42, value.getObject());
    assertEquals(42, value.getInt());

    value = Value.create((short) 42);
    assertInvariants(value);
    assertEquals(42, value.getObject());
    assertSame(IntDomain.unbounded(), value.getDomain());

    value = Value.create((byte) 42);
    assertInvariants(value);
    assertEquals(42, value.getObject());
    assertSame(IntDomain.unbounded(), value.getDomain());

    value = Value.create(42.0);
    assertInvariants(value);
    assertSame(RealDomain.unbounded(), value.getDomain());
    assertEquals(42.0, value.getObject());

    double[] array = new double[] {1.0, 3.0};
    value = Value.create(array);
    assertInvariants(value);
    assertEquals(RealJointDomain.create(2), value.getDomain());
    assertArrayEquals(array, (double[]) value.getObject(), 0.0);

    value = Value.create("foo");
    assertInvariants(value);
    assertSame(ObjectDomain.instance(), value.getDomain());
    assertEquals("foo", value.getObject());

    //
    // Test integral values
    //

    value = Value.create(42);
    assertEquals(42, value.getInt());
    assertEquals(-1, value.getIndex());
    value.setInt(23);
    expectThrow(DimpleException.class, value, "setIndex", 0);
    assertEquals(23, value.getInt());
    value.setDouble(1.6);
    assertEquals(2, value.getInt());
    assertEquals(2.0, value.getDouble(), 0.0);
    value.setObject(-40);
    assertEquals(-40, value.getInt());
    assertTrue(value.valueEquals(Value.create(-40.0)));
    assertFalse(value.valueEquals(Value.create(39)));
    assertInvariants(value);

    Domain digit = DiscreteDomain.range(0, 9);
    value = Value.create(digit, 3);
    assertEquals(3, value.getInt());
    assertEquals(3, value.getIndex());
    value.setInt(4);
    assertEquals(4, value.getInt());
    value.setDouble(5.2);
    assertEquals(5, value.getInt());
    assertTrue(value.valueEquals(Value.create(5)));
    value.setIndex(2);
    assertEquals(2, value.getIndex());
    assertEquals(2, value.getObject());
    assertInvariants(value);

    Domain oddDigits = DiscreteDomain.range(1, 9, 2);
    value = Value.create(oddDigits, 3);
    assertEquals(3, value.getInt());
    assertEquals(1, value.getIndex());
    value.setInt(5);
    assertEquals(5, value.getObject());
    assertEquals(2, value.getIndex());
    value.setIndex(0);
    assertEquals(0, value.getIndex());
    assertEquals(1, value.getInt());
    value.setObject(9);
    assertEquals(9.0, value.getDouble(), 0.0);
    assertEquals(4, value.getIndex());
    assertFalse(value.valueEquals(Value.create(10)));
    assertTrue(value.valueEquals(Value.create(9.1)));

    Domain primeDigits = DiscreteDomain.create(2, 3, 5, 7);
    value = Value.create(primeDigits, 5);
    assertEquals(5, value.getInt());
    value.setInt(2);
    assertEquals(2, value.getInt());
    assertEquals(0, value.getIndex());
    value.setIndex(3);
    assertEquals(3, value.getIndex());
    assertEquals(7, value.getInt());

    /*
     * Test generic discrete values
     */

    DiscreteDomain stooges = DiscreteDomain.create("moe", "joe", "curly");
    value = Value.create(stooges);
    assertEquals("moe", value.getObject());
    assertEquals(0, value.getIndex());
    assertInvariants(value);
    value.setObject("joe");
    assertEquals("joe", value.getObject());
    assertEquals(1, value.getIndex());
    value.setIndex(2);
    assertEquals("curly", value.getObject());
    assertEquals(2, value.getIndex());
    assertFalse(value.valueEquals(Value.create(stooges, "moe")));
    assertTrue(value.valueEquals(Value.create("curly")));
    value.setFrom(Value.create("joe"));
    assertEquals("joe", value.getObject());

    /*
     * Test real values
     */

    value = Value.create(3.14159);
    assertEquals(3.14159, value.getDouble(), 0.0);
    value.setInt(42);
    assertEquals(42.0, value.getDouble(), 0.0);
    value.setDouble(2.3);
    assertEquals(2.3, value.getDouble(), 0.0);
    value.setObject(-123.4);
    assertEquals(-123.4, value.getDouble(), 0.0);
    assertFalse(value.valueEquals(Value.create(-123.4002)));

    Domain halves = DiscreteDomain.range(0, 10, .5);
    value = Value.create(halves, 3);
    assertEquals(3, value.getInt());
    assertEquals(3, value.getDouble(), 0.0);
    assertEquals(6, value.getIndex());
    assertInvariants(value);
    value.setIndex(1);
    assertEquals(.5, value.getDouble(), 0.0);
    assertEquals(1, value.getIndex());

    Domain realDigits = DiscreteDomain.range(0.0, 9.0);
    value = Value.create(realDigits, 3);
    assertEquals(3, value.getInt());
    assertEquals(3, value.getIndex());
    assertInvariants(value);
    value.setIndex(5);
    assertEquals(5, value.getInt());

    Domain powersOfTwo = DiscreteDomain.create(.125, .25, .5, 1.0, 2.0, 4.0, 8.0);
    value = Value.create(powersOfTwo);
    assertEquals(.125, value.getDouble(), 0.0);
    assertInvariants(value);
    value.setIndex(2);
    assertEquals(.5, value.getDouble(), 0.0);
    value.setFrom(Value.create(.25));
    assertEquals(1, value.getIndex());
    value.setFrom(Value.create(DiscreteDomain.create(2.3, 4.0, 5.0), 4.0));
    assertEquals(4.0, value.getDouble(), 0.0);

    /*
     * Test object values
     */

    value = Value.create("foo");
    assertEquals("foo", value.getObject());
    expectThrow(DimpleException.class, value, "getInt");
    value.setInt(42);
    assertEquals(42, value.getObject());
    assertEquals(42, value.getInt());
    assertEquals(42.0, value.getDouble(), 0.0);
    value.setDouble(23.1);
    assertEquals(23.1, value.getObject());
    assertEquals(23, value.getInt());
    assertEquals(23.1, value.getDouble(), 0.0);

    value = Value.create((Domain) null);
    assertNull(value.getObject());
    assertSame(ObjectDomain.instance(), value.getDomain());

    /*
     *
     */
  }