/** * Tests that models can be created that have multiple factors over the same variable, and that * potentialOfVertex returns the product in that case. */ public void testMultipleNodePotentials() { Variable var = new Variable(2); FactorGraph mdl = new FactorGraph(new Variable[] {var}); Factor ptl1 = new TableFactor(var, new double[] {0.5, 0.5}); mdl.addFactor(ptl1); Factor ptl2 = new TableFactor(var, new double[] {0.25, 0.25}); mdl.addFactor(ptl2); // verify that factorOf(var) doesn't work try { mdl.factorOf(var); fail(); } catch (RuntimeException e) { } // expected List factors = mdl.allFactorsOf(var); Factor total = TableFactor.multiplyAll(factors); double[] expected = {0.125, 0.125}; assertTrue( "Arrays not equal\n Expected " + ArrayUtils.toString(expected) + "\n Actual " + ArrayUtils.toString(((TableFactor) total).toValueArray()), Arrays.equals(expected, ((TableFactor) total).toValueArray())); }
/** * Tests that models can be created that have multiple factors over the same edge, and that * potentialOfEdge returns the product in that case. */ public void testMultipleEdgePotentials() { Variable v1 = new Variable(2); Variable v2 = new Variable(2); Variable[] vars = new Variable[] {v1, v2}; FactorGraph mdl = new FactorGraph(vars); Factor ptl1 = new TableFactor(vars, new double[] {0.5, 0.5, 0.5, 0.5}); mdl.addFactor(ptl1); Factor ptl2 = new TableFactor(vars, new double[] {0.25, 0.25, 0.5, 0.5}); mdl.addFactor(ptl2); try { mdl.factorOf(v1, v2); fail(); } catch (RuntimeException e) { } Collection factors = mdl.allFactorsContaining(new HashVarSet(vars)); assertEquals(2, factors.size()); assertTrue(factors.contains(ptl1)); assertTrue(factors.contains(ptl2)); double[] vals = {0.125, 0.125, 0.25, 0.25}; Factor total = TableFactor.multiplyAll(factors); Factor expected = new TableFactor(vars, vals); assertTrue( "Arrays not equal\n Expected " + ArrayUtils.toString(vals) + "\n Actual " + ArrayUtils.toString(((TableFactor) total).toValueArray()), expected.almostEquals(total, 1e-10)); }
private void verifyCachesConsistent(FactorGraph mdl) { Factor pot, pot2, pot3; for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) { pot = (Factor) it.next(); // System.out.println("Testing model "+i+" potential "+pot); Object[] vars = pot.varSet().toArray(); switch (vars.length) { case 1: pot2 = mdl.factorOf((Variable) vars[0]); assertTrue(pot == pot2); break; case 2: Variable var1 = (Variable) vars[0]; Variable var2 = (Variable) vars[1]; pot2 = mdl.factorOf(var1, var2); pot3 = mdl.factorOf(var2, var1); assertTrue(pot == pot2); assertTrue(pot2 == pot3); break; // Factors of size > 2 aren't now cached. default: break; } } }
public void testOutputToDot() throws IOException { FactorGraph mdl = TestInference.createRandomGrid(3, 4, 2, new Random(4234)); PrintWriter out = new PrintWriter(new FileWriter(new File("grmm-model.dot"))); mdl.printAsDot(out); out.close(); System.out.println("Now you can open up grmm-model.dot in Graphviz."); }
private Assignment initialAssignment(FactorGraph mdl) { Assignment assn = new Assignment(mdl, new int[mdl.numVariables()]); if (mdl.logValue(assn) > Double.NEGATIVE_INFINITY) return assn; assn = new Assignment(); return initialAssignmentRec(mdl, assn, 0); }
// backtracking search for a feasible assignment private Assignment initialAssignmentRec(FactorGraph mdl, Assignment assn, int fi) { if (fi >= mdl.factors().size()) return assn; Factor f = mdl.getFactor(fi); Factor sliced = f.slice(assn); if (sliced.varSet().isEmpty()) { double val = f.value(assn); if (val > 1e-50) { return initialAssignmentRec(mdl, assn, fi + 1); } else { return null; } } for (AssignmentIterator it = sliced.assignmentIterator(); it.hasNext(); ) { double val = sliced.value(it); if (val > 1e-50) { Assignment new_assn = Assignment.union(assn, it.assignment()); Assignment assn_ret = initialAssignmentRec(mdl, new_assn, fi + 1); if (assn_ret != null) return assn_ret; } it.advance(); } return null; }
public void testPotentialConnections() { Variable v1 = new Variable(2); Variable v2 = new Variable(2); Variable v3 = new Variable(2); Variable[] vars = new Variable[] {v1, v2, v3}; FactorGraph mdl = new FactorGraph(); TableFactor ptl = new TableFactor(vars, new double[8]); mdl.addFactor(ptl); assertTrue(mdl.isAdjacent(v1, v2)); assertTrue(mdl.isAdjacent(v2, v3)); assertTrue(mdl.isAdjacent(v1, v3)); }
public void testFactorOfSet() { Variable[] vars = new Variable[3]; for (int i = 0; i < vars.length; i++) { vars[i] = new Variable(2); } Factor factor = new TableFactor(vars, new double[] {0, 1, 2, 3, 4, 5, 6, 7}); FactorGraph fg = new FactorGraph(vars); fg.addFactor(factor); assertTrue(factor == fg.factorOf(factor.varSet())); HashSet set = new HashSet(factor.varSet()); assertTrue(factor == fg.factorOf(set)); set.remove(vars[0]); assertTrue(null == fg.factorOf(set)); }
private Assignment doOnePass(FactorGraph mdl, Assignment initial) { Assignment ret = (Assignment) initial.duplicate(); for (int vidx = 0; vidx < ret.size(); vidx++) { Variable var = mdl.get(vidx); DiscreteFactor subcpt = constructConditionalCpt(mdl, var, ret); int value = subcpt.sampleLocation(r); ret.setValue(var, value); } return ret; }
/** * Constructs a junction tree from a given factor graph. Does not perform BP in the resulting * graph. So this gives you the structure of a jnuction tree, but the factors don't correspond to * the true marginals unless you call BP yourself. * * @param mdl Factor graph to compute JT for. */ public JunctionTree buildJunctionTree(FactorGraph mdl) { jtCurrent = (JunctionTree) mdl.getInferenceCache(JunctionTreeInferencer.class); if (jtCurrent != null) { jtCurrent.clearCPFs(); } else { /* The graph g is the topology of the MRF that corresponds to the factor graph mdl. * Essentially, this means that we triangulate factor graphs by converting to an MRF first. * I could have chosen to trianglualte the FactorGraph directly, but I didn't for historical reasons * (I already had a version of triangulate() for MRFs, not bipartite factor graphs.) * Note that the call to mdlToGraph() is perfectly valid for FactorGraphs that are also DirectedModels, * and has the effect of moralizing in that case. */ UndirectedGraph g = Graphs.mdlToGraph(mdl); triangulate(g); jtCurrent = buildJtStructure(); mdl.setInferenceCache(JunctionTreeInferencer.class, jtCurrent); } initJtCpts(mdl, jtCurrent); return jtCurrent; }
public void testThreeNodeModel() { Random r = new Random(23534709); FactorGraph mdl = new FactorGraph(); Variable root = new Variable(2); Variable childL = new Variable(2); Variable childR = new Variable(2); mdl.addFactor(root, childL, RandomGraphs.generateMixedPotentialValues(r, 1.5)); mdl.addFactor(root, childR, RandomGraphs.generateMixedPotentialValues(r, 1.5)); // assertTrue (mdl.isConnected (root, childL)); // assertTrue (mdl.isConnected (root, childR)); // assertTrue (mdl.isConnected (childL, childR)); assertTrue(mdl.isAdjacent(root, childR)); assertTrue(mdl.isAdjacent(root, childL)); assertTrue(!mdl.isAdjacent(childL, childR)); assertTrue(mdl.factorOf(root, childL) != null); assertTrue(mdl.factorOf(root, childR) != null); }
// Warning: destructively modifies ret's assignment to fullAssn (I could save and restore, but I // don't care private DiscreteFactor constructConditionalCpt( FactorGraph mdl, Variable var, Assignment fullAssn) { List ptlList = mdl.allFactorsContaining(var); LogTableFactor ptl = new LogTableFactor(var); for (AssignmentIterator it = ptl.assignmentIterator(); it.hasNext(); it.advance()) { Assignment varAssn = it.assignment(); fullAssn.setValue(var, varAssn.get(var)); ptl.setRawValue(varAssn, sumValues(ptlList, fullAssn)); } ptl.normalize(); return ptl; }
// Verify that potentialOfVertex and potentialOfEdge (which use // caches) are consistent with the potentials set even if a vertex is removed. public void testUndirectedCachesAfterRemove() { List models = TestInference.createTestModels(); for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) { FactorGraph mdl = (FactorGraph) mdlIt.next(); mdl = (FactorGraph) mdl.duplicate(); mdl.remove(mdl.get(0)); // Verify that indexing correct for (Iterator it = mdl.variablesIterator(); it.hasNext(); ) { Variable var = (Variable) it.next(); int idx = mdl.getIndex(var); assertTrue(idx >= 0); assertTrue(idx < mdl.numVariables()); } // Verify that caches consistent verifyCachesConsistent(mdl); } }
private void initJtCpts(FactorGraph mdl, JunctionTree jt) { for (Iterator it = jt.getVerticesIterator(); it.hasNext(); ) { VarSet c = (VarSet) it.next(); // DiscreteFactor ptl = createBlankFactor (c); // jt.setCPF(c, ptl); jt.setCPF(c, new ConstantFactor(1.0)); } for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) { Factor ptl = (Factor) it.next(); VarSet parent = jt.findParentCluster(ptl.varSet()); assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt; Factor cpf = jt.getCPF(parent); Factor newCpf = cpf.multiply(ptl); jt.setCPF(parent, newCpf); /* debug if (jt.isNaN()) { throw new RuntimeException ("Got a NaN"); } */ } }
public void computeMarginals(FactorGraph mdl) { inLogSpace = mdl.getFactor(0) instanceof LogTableFactor; buildJunctionTree(mdl); propagator.computeMarginals(jtCurrent); totalMessagesSent += propagator.getTotalMessagesSent(); }