示例#1
0
  /**
   * Tests that models can be created that have multiple factors over the same variable, and that
   * potentialOfVertex returns the product in that case.
   */
  public void testMultipleNodePotentials() {
    Variable var = new Variable(2);
    FactorGraph mdl = new FactorGraph(new Variable[] {var});

    Factor ptl1 = new TableFactor(var, new double[] {0.5, 0.5});
    mdl.addFactor(ptl1);

    Factor ptl2 = new TableFactor(var, new double[] {0.25, 0.25});
    mdl.addFactor(ptl2);

    // verify that factorOf(var) doesn't work
    try {
      mdl.factorOf(var);
      fail();
    } catch (RuntimeException e) {
    } // expected

    List factors = mdl.allFactorsOf(var);
    Factor total = TableFactor.multiplyAll(factors);
    double[] expected = {0.125, 0.125};
    assertTrue(
        "Arrays not equal\n  Expected "
            + ArrayUtils.toString(expected)
            + "\n  Actual "
            + ArrayUtils.toString(((TableFactor) total).toValueArray()),
        Arrays.equals(expected, ((TableFactor) total).toValueArray()));
  }
示例#2
0
  /**
   * Tests that models can be created that have multiple factors over the same edge, and that
   * potentialOfEdge returns the product in that case.
   */
  public void testMultipleEdgePotentials() {
    Variable v1 = new Variable(2);
    Variable v2 = new Variable(2);
    Variable[] vars = new Variable[] {v1, v2};

    FactorGraph mdl = new FactorGraph(vars);

    Factor ptl1 = new TableFactor(vars, new double[] {0.5, 0.5, 0.5, 0.5});
    mdl.addFactor(ptl1);

    Factor ptl2 = new TableFactor(vars, new double[] {0.25, 0.25, 0.5, 0.5});
    mdl.addFactor(ptl2);

    try {
      mdl.factorOf(v1, v2);
      fail();
    } catch (RuntimeException e) {
    }

    Collection factors = mdl.allFactorsContaining(new HashVarSet(vars));
    assertEquals(2, factors.size());
    assertTrue(factors.contains(ptl1));
    assertTrue(factors.contains(ptl2));

    double[] vals = {0.125, 0.125, 0.25, 0.25};
    Factor total = TableFactor.multiplyAll(factors);
    Factor expected = new TableFactor(vars, vals);

    assertTrue(
        "Arrays not equal\n  Expected "
            + ArrayUtils.toString(vals)
            + "\n  Actual "
            + ArrayUtils.toString(((TableFactor) total).toValueArray()),
        expected.almostEquals(total, 1e-10));
  }
示例#3
0
  private void verifyCachesConsistent(FactorGraph mdl) {
    Factor pot, pot2, pot3;
    for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) {
      pot = (Factor) it.next();
      //				System.out.println("Testing model "+i+" potential "+pot);

      Object[] vars = pot.varSet().toArray();
      switch (vars.length) {
        case 1:
          pot2 = mdl.factorOf((Variable) vars[0]);
          assertTrue(pot == pot2);
          break;

        case 2:
          Variable var1 = (Variable) vars[0];
          Variable var2 = (Variable) vars[1];
          pot2 = mdl.factorOf(var1, var2);
          pot3 = mdl.factorOf(var2, var1);
          assertTrue(pot == pot2);
          assertTrue(pot2 == pot3);
          break;

          // Factors of size > 2 aren't now cached.
        default:
          break;
      }
    }
  }
示例#4
0
 public void testOutputToDot() throws IOException {
   FactorGraph mdl = TestInference.createRandomGrid(3, 4, 2, new Random(4234));
   PrintWriter out = new PrintWriter(new FileWriter(new File("grmm-model.dot")));
   mdl.printAsDot(out);
   out.close();
   System.out.println("Now you can open up grmm-model.dot in Graphviz.");
 }
示例#5
0
  private Assignment initialAssignment(FactorGraph mdl) {
    Assignment assn = new Assignment(mdl, new int[mdl.numVariables()]);
    if (mdl.logValue(assn) > Double.NEGATIVE_INFINITY) return assn;

    assn = new Assignment();
    return initialAssignmentRec(mdl, assn, 0);
  }
示例#6
0
  // backtracking search for a feasible assignment
  private Assignment initialAssignmentRec(FactorGraph mdl, Assignment assn, int fi) {
    if (fi >= mdl.factors().size()) return assn;
    Factor f = mdl.getFactor(fi);

    Factor sliced = f.slice(assn);
    if (sliced.varSet().isEmpty()) {
      double val = f.value(assn);
      if (val > 1e-50) {
        return initialAssignmentRec(mdl, assn, fi + 1);
      } else {
        return null;
      }
    }

    for (AssignmentIterator it = sliced.assignmentIterator(); it.hasNext(); ) {
      double val = sliced.value(it);
      if (val > 1e-50) {
        Assignment new_assn = Assignment.union(assn, it.assignment());
        Assignment assn_ret = initialAssignmentRec(mdl, new_assn, fi + 1);
        if (assn_ret != null) return assn_ret;
      }
      it.advance();
    }

    return null;
  }
示例#7
0
  public void testPotentialConnections() {
    Variable v1 = new Variable(2);
    Variable v2 = new Variable(2);
    Variable v3 = new Variable(2);
    Variable[] vars = new Variable[] {v1, v2, v3};
    FactorGraph mdl = new FactorGraph();

    TableFactor ptl = new TableFactor(vars, new double[8]);
    mdl.addFactor(ptl);

    assertTrue(mdl.isAdjacent(v1, v2));
    assertTrue(mdl.isAdjacent(v2, v3));
    assertTrue(mdl.isAdjacent(v1, v3));
  }
示例#8
0
  public void testFactorOfSet() {
    Variable[] vars = new Variable[3];
    for (int i = 0; i < vars.length; i++) {
      vars[i] = new Variable(2);
    }
    Factor factor = new TableFactor(vars, new double[] {0, 1, 2, 3, 4, 5, 6, 7});

    FactorGraph fg = new FactorGraph(vars);
    fg.addFactor(factor);

    assertTrue(factor == fg.factorOf(factor.varSet()));

    HashSet set = new HashSet(factor.varSet());
    assertTrue(factor == fg.factorOf(set));
    set.remove(vars[0]);
    assertTrue(null == fg.factorOf(set));
  }
示例#9
0
  private Assignment doOnePass(FactorGraph mdl, Assignment initial) {
    Assignment ret = (Assignment) initial.duplicate();
    for (int vidx = 0; vidx < ret.size(); vidx++) {
      Variable var = mdl.get(vidx);
      DiscreteFactor subcpt = constructConditionalCpt(mdl, var, ret);
      int value = subcpt.sampleLocation(r);
      ret.setValue(var, value);
    }

    return ret;
  }
  /**
   * Constructs a junction tree from a given factor graph. Does not perform BP in the resulting
   * graph. So this gives you the structure of a jnuction tree, but the factors don't correspond to
   * the true marginals unless you call BP yourself.
   *
   * @param mdl Factor graph to compute JT for.
   */
  public JunctionTree buildJunctionTree(FactorGraph mdl) {
    jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class);
    if (jtCurrent != null) {
      jtCurrent.clearCPFs();
    } else {
      /* The graph g is the topology of the MRF that corresponds to the factor graph mdl.
       * Essentially, this means that we triangulate factor graphs by converting to an MRF first.
       * I could have chosen to trianglualte the FactorGraph directly, but I didn't for historical reasons
       *  (I already had a version of triangulate() for MRFs, not bipartite factor graphs.)
       * Note that the call to mdlToGraph() is perfectly valid for FactorGraphs that are also DirectedModels,
       *  and has the effect of moralizing in that case.  */
      UndirectedGraph g = Graphs.mdlToGraph(mdl);
      triangulate(g);
      jtCurrent = buildJtStructure();
      mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent);
    }

    initJtCpts(mdl, jtCurrent);
    return jtCurrent;
  }
示例#11
0
  public void testThreeNodeModel() {
    Random r = new Random(23534709);

    FactorGraph mdl = new FactorGraph();
    Variable root = new Variable(2);
    Variable childL = new Variable(2);
    Variable childR = new Variable(2);

    mdl.addFactor(root, childL, RandomGraphs.generateMixedPotentialValues(r, 1.5));
    mdl.addFactor(root, childR, RandomGraphs.generateMixedPotentialValues(r, 1.5));

    //    assertTrue (mdl.isConnected (root, childL));
    //    assertTrue (mdl.isConnected (root, childR));
    //    assertTrue (mdl.isConnected (childL, childR));
    assertTrue(mdl.isAdjacent(root, childR));
    assertTrue(mdl.isAdjacent(root, childL));
    assertTrue(!mdl.isAdjacent(childL, childR));

    assertTrue(mdl.factorOf(root, childL) != null);
    assertTrue(mdl.factorOf(root, childR) != null);
  }
示例#12
0
 // Warning: destructively modifies ret's assignment to fullAssn (I could save and restore, but I
 // don't care
 private DiscreteFactor constructConditionalCpt(
     FactorGraph mdl, Variable var, Assignment fullAssn) {
   List ptlList = mdl.allFactorsContaining(var);
   LogTableFactor ptl = new LogTableFactor(var);
   for (AssignmentIterator it = ptl.assignmentIterator(); it.hasNext(); it.advance()) {
     Assignment varAssn = it.assignment();
     fullAssn.setValue(var, varAssn.get(var));
     ptl.setRawValue(varAssn, sumValues(ptlList, fullAssn));
   }
   ptl.normalize();
   return ptl;
 }
示例#13
0
  // Verify that potentialOfVertex and potentialOfEdge (which use
  // caches) are consistent with the potentials set even if a vertex is removed.
  public void testUndirectedCachesAfterRemove() {
    List models = TestInference.createTestModels();
    for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) {
      FactorGraph mdl = (FactorGraph) mdlIt.next();
      mdl = (FactorGraph) mdl.duplicate();
      mdl.remove(mdl.get(0));

      // Verify that indexing correct
      for (Iterator it = mdl.variablesIterator(); it.hasNext(); ) {
        Variable var = (Variable) it.next();
        int idx = mdl.getIndex(var);
        assertTrue(idx >= 0);
        assertTrue(idx < mdl.numVariables());
      }

      // Verify that caches consistent
      verifyCachesConsistent(mdl);
    }
  }
  private void initJtCpts(FactorGraph mdl, JunctionTree jt) {
    for (Iterator it = jt.getVerticesIterator(); it.hasNext(); ) {
      VarSet c = (VarSet) it.next();
      //      DiscreteFactor ptl = createBlankFactor (c);
      //      jt.setCPF(c, ptl);
      jt.setCPF(c, new ConstantFactor(1.0));
    }

    for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) {
      Factor ptl = (Factor) it.next();
      VarSet parent = jt.findParentCluster(ptl.varSet());
      assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;

      Factor cpf = jt.getCPF(parent);
      Factor newCpf = cpf.multiply(ptl);
      jt.setCPF(parent, newCpf);

      /* debug
      if (jt.isNaN()) {
        throw new RuntimeException ("Got a NaN");
      }
      */
    }
  }
 public void computeMarginals(FactorGraph mdl) {
   inLogSpace = mdl.getFactor(0) instanceof LogTableFactor;
   buildJunctionTree(mdl);
   propagator.computeMarginals(jtCurrent);
   totalMessagesSent += propagator.getTotalMessagesSent();
 }