Beispiel #1
0
  /**
   * Tests that models can be created that have multiple factors over the same edge, and that
   * potentialOfEdge returns the product in that case.
   */
  public void testMultipleEdgePotentials() {
    Variable v1 = new Variable(2);
    Variable v2 = new Variable(2);
    Variable[] vars = new Variable[] {v1, v2};

    FactorGraph mdl = new FactorGraph(vars);

    Factor ptl1 = new TableFactor(vars, new double[] {0.5, 0.5, 0.5, 0.5});
    mdl.addFactor(ptl1);

    Factor ptl2 = new TableFactor(vars, new double[] {0.25, 0.25, 0.5, 0.5});
    mdl.addFactor(ptl2);

    try {
      mdl.factorOf(v1, v2);
      fail();
    } catch (RuntimeException e) {
    }

    Collection factors = mdl.allFactorsContaining(new HashVarSet(vars));
    assertEquals(2, factors.size());
    assertTrue(factors.contains(ptl1));
    assertTrue(factors.contains(ptl2));

    double[] vals = {0.125, 0.125, 0.25, 0.25};
    Factor total = TableFactor.multiplyAll(factors);
    Factor expected = new TableFactor(vars, vals);

    assertTrue(
        "Arrays not equal\n  Expected "
            + ArrayUtils.toString(vals)
            + "\n  Actual "
            + ArrayUtils.toString(((TableFactor) total).toValueArray()),
        expected.almostEquals(total, 1e-10));
  }
Beispiel #2
0
 public void testOutputToDot() throws IOException {
   FactorGraph mdl = TestInference.createRandomGrid(3, 4, 2, new Random(4234));
   PrintWriter out = new PrintWriter(new FileWriter(new File("grmm-model.dot")));
   mdl.printAsDot(out);
   out.close();
   System.out.println("Now you can open up grmm-model.dot in Graphviz.");
 }
Beispiel #3
0
  /**
   * Tests that models can be created that have multiple factors over the same variable, and that
   * potentialOfVertex returns the product in that case.
   */
  public void testMultipleNodePotentials() {
    Variable var = new Variable(2);
    FactorGraph mdl = new FactorGraph(new Variable[] {var});

    Factor ptl1 = new TableFactor(var, new double[] {0.5, 0.5});
    mdl.addFactor(ptl1);

    Factor ptl2 = new TableFactor(var, new double[] {0.25, 0.25});
    mdl.addFactor(ptl2);

    // verify that factorOf(var) doesn't work
    try {
      mdl.factorOf(var);
      fail();
    } catch (RuntimeException e) {
    } // expected

    List factors = mdl.allFactorsOf(var);
    Factor total = TableFactor.multiplyAll(factors);
    double[] expected = {0.125, 0.125};
    assertTrue(
        "Arrays not equal\n  Expected "
            + ArrayUtils.toString(expected)
            + "\n  Actual "
            + ArrayUtils.toString(((TableFactor) total).toValueArray()),
        Arrays.equals(expected, ((TableFactor) total).toValueArray()));
  }
Beispiel #4
0
  private void verifyCachesConsistent(FactorGraph mdl) {
    Factor pot, pot2, pot3;
    for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) {
      pot = (Factor) it.next();
      //				System.out.println("Testing model "+i+" potential "+pot);

      Object[] vars = pot.varSet().toArray();
      switch (vars.length) {
        case 1:
          pot2 = mdl.factorOf((Variable) vars[0]);
          assertTrue(pot == pot2);
          break;

        case 2:
          Variable var1 = (Variable) vars[0];
          Variable var2 = (Variable) vars[1];
          pot2 = mdl.factorOf(var1, var2);
          pot3 = mdl.factorOf(var2, var1);
          assertTrue(pot == pot2);
          assertTrue(pot2 == pot3);
          break;

          // Factors of size > 2 aren't now cached.
        default:
          break;
      }
    }
  }
Beispiel #5
0
  public void testPotentialConnections() {
    Variable v1 = new Variable(2);
    Variable v2 = new Variable(2);
    Variable v3 = new Variable(2);
    Variable[] vars = new Variable[] {v1, v2, v3};
    FactorGraph mdl = new FactorGraph();

    TableFactor ptl = new TableFactor(vars, new double[8]);
    mdl.addFactor(ptl);

    assertTrue(mdl.isAdjacent(v1, v2));
    assertTrue(mdl.isAdjacent(v2, v3));
    assertTrue(mdl.isAdjacent(v1, v3));
  }
Beispiel #6
0
  public void testFactorOfSet() {
    Variable[] vars = new Variable[3];
    for (int i = 0; i < vars.length; i++) {
      vars[i] = new Variable(2);
    }
    Factor factor = new TableFactor(vars, new double[] {0, 1, 2, 3, 4, 5, 6, 7});

    FactorGraph fg = new FactorGraph(vars);
    fg.addFactor(factor);

    assertTrue(factor == fg.factorOf(factor.varSet()));

    HashSet set = new HashSet(factor.varSet());
    assertTrue(factor == fg.factorOf(set));
    set.remove(vars[0]);
    assertTrue(null == fg.factorOf(set));
  }
  /**
   * Constructs a junction tree from a given factor graph. Does not perform BP in the resulting
   * graph. So this gives you the structure of a jnuction tree, but the factors don't correspond to
   * the true marginals unless you call BP yourself.
   *
   * @param mdl Factor graph to compute JT for.
   */
  public JunctionTree buildJunctionTree(FactorGraph mdl) {
    jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class);
    if (jtCurrent != null) {
      jtCurrent.clearCPFs();
    } else {
      /* The graph g is the topology of the MRF that corresponds to the factor graph mdl.
       * Essentially, this means that we triangulate factor graphs by converting to an MRF first.
       * I could have chosen to trianglualte the FactorGraph directly, but I didn't for historical reasons
       *  (I already had a version of triangulate() for MRFs, not bipartite factor graphs.)
       * Note that the call to mdlToGraph() is perfectly valid for FactorGraphs that are also DirectedModels,
       *  and has the effect of moralizing in that case.  */
      UndirectedGraph g = Graphs.mdlToGraph(mdl);
      triangulate(g);
      jtCurrent = buildJtStructure();
      mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent);
    }

    initJtCpts(mdl, jtCurrent);
    return jtCurrent;
  }
Beispiel #8
0
  public void testThreeNodeModel() {
    Random r = new Random(23534709);

    FactorGraph mdl = new FactorGraph();
    Variable root = new Variable(2);
    Variable childL = new Variable(2);
    Variable childR = new Variable(2);

    mdl.addFactor(root, childL, RandomGraphs.generateMixedPotentialValues(r, 1.5));
    mdl.addFactor(root, childR, RandomGraphs.generateMixedPotentialValues(r, 1.5));

    //    assertTrue (mdl.isConnected (root, childL));
    //    assertTrue (mdl.isConnected (root, childR));
    //    assertTrue (mdl.isConnected (childL, childR));
    assertTrue(mdl.isAdjacent(root, childR));
    assertTrue(mdl.isAdjacent(root, childL));
    assertTrue(!mdl.isAdjacent(childL, childR));

    assertTrue(mdl.factorOf(root, childL) != null);
    assertTrue(mdl.factorOf(root, childR) != null);
  }
Beispiel #9
0
  // Verify that potentialOfVertex and potentialOfEdge (which use
  // caches) are consistent with the potentials set even if a vertex is removed.
  public void testUndirectedCachesAfterRemove() {
    List models = TestInference.createTestModels();
    for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) {
      FactorGraph mdl = (FactorGraph) mdlIt.next();
      mdl = (FactorGraph) mdl.duplicate();
      mdl.remove(mdl.get(0));

      // Verify that indexing correct
      for (Iterator it = mdl.variablesIterator(); it.hasNext(); ) {
        Variable var = (Variable) it.next();
        int idx = mdl.getIndex(var);
        assertTrue(idx >= 0);
        assertTrue(idx < mdl.numVariables());
      }

      // Verify that caches consistent
      verifyCachesConsistent(mdl);
    }
  }
  private void initJtCpts(FactorGraph mdl, JunctionTree jt) {
    for (Iterator it = jt.getVerticesIterator(); it.hasNext(); ) {
      VarSet c = (VarSet) it.next();
      //      DiscreteFactor ptl = createBlankFactor (c);
      //      jt.setCPF(c, ptl);
      jt.setCPF(c, new ConstantFactor(1.0));
    }

    for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) {
      Factor ptl = (Factor) it.next();
      VarSet parent = jt.findParentCluster(ptl.varSet());
      assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;

      Factor cpf = jt.getCPF(parent);
      Factor newCpf = cpf.multiply(ptl);
      jt.setCPF(parent, newCpf);

      /* debug
      if (jt.isNaN()) {
        throw new RuntimeException ("Got a NaN");
      }
      */
    }
  }
 public void computeMarginals(FactorGraph mdl) {
   inLogSpace = mdl.getFactor(0) instanceof LogTableFactor;
   buildJunctionTree(mdl);
   propagator.computeMarginals(jtCurrent);
   totalMessagesSent += propagator.getTotalMessagesSent();
 }
 public void addFactor(Factor factor) {
   super.addFactor(factor);
   if (factor.varSet().size() == 2) {
     edges.add(factor.varSet());
   }
 }