Exemplo n.º 1
0
  /**
   * {@inheritDoc}
   *
   * <p>All {@code variables} must be of type {@link Discrete}. The domain of the returned variable
   * will be a {@link JointDiscreteDomain} with the subdomains in the same order as {@code
   * variables}.
   */
  @Internal
  @Override
  public Variable createJointNoFactors(Variable... variables) {
    final boolean thisIsFirst = (variables[0] == this);
    final int dimensions = thisIsFirst ? variables.length : variables.length + 1;
    final DiscreteDomain[] domains = new DiscreteDomain[dimensions];
    final IDatum[] subdomainPriors = new IDatum[dimensions];
    domains[0] = getDomain();
    subdomainPriors[0] = getPrior();

    for (int i = thisIsFirst ? 1 : 0; i < dimensions; ++i) {
      final Discrete var = variables[i].asDiscreteVariable();
      domains[i] = var.getDomain();
      subdomainPriors[i] = var.getPrior();
    }

    final JointDiscreteDomain<?> jointDomain = DiscreteDomain.joint(domains);
    final Discrete jointVar = new Discrete(jointDomain);
    jointVar.setPrior(joinPriors(jointDomain, subdomainPriors));
    return jointVar;
  }
  private MultiplexerCPD create(
      Discrete Y, Discrete[] Zs, int zasize, boolean oneBased, boolean aAsDouble) {
    Y.setLabel("Y");

    java.util.Hashtable<Object, Integer> yDomainObj2index =
        new java.util.Hashtable<Object, Integer>();
    final DiscreteDomain yDomain = Y.getDiscreteDomain();
    for (int i = 0, end = yDomain.size(); i < end; i++)
      yDomainObj2index.put(yDomain.getElement(i), i);

    // Create a variable
    Object[] adomain = new Object[Zs.length];
    for (int i = 0; i < adomain.length; i++) {
      int val = oneBased ? i + 1 : i;
      if (aAsDouble) adomain[i] = (double) val;
      else adomain[i] = (int) val;
    }
    Discrete A = new Discrete(adomain);
    A.setLabel("A");

    addBoundaryVariables(Y);
    addBoundaryVariables(A);
    addBoundaryVariables(Zs);

    // Make all of those boundary variables
    Variable[] vars = new Variable[Zs.length + 2];
    vars[0] = Y;
    vars[1] = A;
    for (int i = 0; i < Zs.length; i++) vars[i + 2] = Zs[i];

    // Create ZA variable
    Object[] zaDomain = new Object[zasize];
    for (int i = 0; i < zaDomain.length; i++) zaDomain[i] = i;
    Discrete ZA = new Discrete(zaDomain);
    ZA.setLabel("ZA");

    // Create Z* variables
    Discrete[] Zstars = new Discrete[Zs.length];
    for (int i = 0; i < Zstars.length; i++) {
      Object[] domain = new Object[Zs[i].getDiscreteDomain().size() + 1];
      for (int j = 0; j < domain.length; j++) domain[j] = j;
      Zstars[i] = new Discrete(domain);
    }

    // Create ZA Y factor
    int[][] indices = new int[zasize][2];
    double[] weights = new double[zasize];

    int index = 0;
    for (int i = 0; i < Zs.length; i++) {
      for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) {
        indices[index][0] = index;
        indices[index][1] = yDomainObj2index.get(Zs[i].getDiscreteDomain().getElement(j));

        weights[index] = 1;

        index++;
      }
    }

    Factor f = this.addFactor(indices, weights, ZA, Y);
    f.setLabel("Y2ZA");

    // Create ZA A factor
    indices = new int[zasize][2];
    weights = new double[zasize];

    index = 0;

    for (int i = 0; i < Zs.length; i++) {
      for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) {
        indices[index][0] = index;
        indices[index][1] = i;

        weights[index] = 1;

        index++;
      }
    }

    f = this.addFactor(indices, weights, ZA, A);
    f.setLabel("ZA2A");

    // Create ZA Z* factors
    // Create Z* Z factors

    for (int a = 0; a < Zs.length; a++) {
      Zs[a].setLabel("Z" + a);
      Zstars[a].setLabel("Z*" + a);

      indices = new int[zasize][2];
      weights = new double[zasize];

      index = 0;

      // Factor from ZA to Z*
      for (int i = 0; i < Zs.length; i++) {
        for (int j = 0; j < Zs[i].getDiscreteDomain().size(); j++) {
          indices[index][0] = index;

          if (a == i) {
            indices[index][1] = j;
          } else {
            int sz = Zs[a].getDiscreteDomain().size();
            indices[index][1] = sz;
          }

          weights[index] = 1;
          index++;
        }
      }

      f = this.addFactor(indices, weights, ZA, Zstars[a]);
      f.setLabel("ZA2Z*");

      // From Z* to Z
      indices = new int[Zs[a].getDiscreteDomain().size() * 2][2];
      weights = new double[indices.length];

      int ds = Zs[a].getDiscreteDomain().size();

      for (int i = 0; i < ds; i++) {
        indices[i][0] = i;
        indices[ds + i][0] = ds;

        indices[i][1] = i;
        indices[ds + i][1] = i;

        weights[i] = 1;
        weights[ds + i] = 1;
      }

      f = this.addFactor(indices, weights, Zstars[a], Zs[a]);
      f.setLabel("Z*2Z");
    }

    this._y = Y;
    this._a = A;
    this._za = ZA;
    this._zs = Zs;

    return this;
  }