/** * Tests that models can be created that have multiple factors over the same edge, and that * potentialOfEdge returns the product in that case. */ public void testMultipleEdgePotentials() { Variable v1 = new Variable(2); Variable v2 = new Variable(2); Variable[] vars = new Variable[] {v1, v2}; FactorGraph mdl = new FactorGraph(vars); Factor ptl1 = new TableFactor(vars, new double[] {0.5, 0.5, 0.5, 0.5}); mdl.addFactor(ptl1); Factor ptl2 = new TableFactor(vars, new double[] {0.25, 0.25, 0.5, 0.5}); mdl.addFactor(ptl2); try { mdl.factorOf(v1, v2); fail(); } catch (RuntimeException e) { } Collection factors = mdl.allFactorsContaining(new HashVarSet(vars)); assertEquals(2, factors.size()); assertTrue(factors.contains(ptl1)); assertTrue(factors.contains(ptl2)); double[] vals = {0.125, 0.125, 0.25, 0.25}; Factor total = TableFactor.multiplyAll(factors); Factor expected = new TableFactor(vars, vals); assertTrue( "Arrays not equal\n Expected " + ArrayUtils.toString(vals) + "\n Actual " + ArrayUtils.toString(((TableFactor) total).toValueArray()), expected.almostEquals(total, 1e-10)); }
public void testMdlToGraph() { List models = TestInference.createTestModels(); for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) { UndirectedModel mdl = (UndirectedModel) mdlIt.next(); UndirectedGraph g = Graphs.mdlToGraph(mdl); Set vertices = g.vertexSet(); // check the number of vertices assertEquals(mdl.numVariables(), vertices.size()); // check the number of edges int numEdgePtls = 0; for (Iterator factorIt = mdl.factors().iterator(); factorIt.hasNext(); ) { Factor factor = (Factor) factorIt.next(); if (factor.varSet().size() == 2) numEdgePtls++; } assertEquals(numEdgePtls, g.edgeSet().size()); // check that the neighbors of each vertex contain at least some of what they're supposed to Iterator it = vertices.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); assertTrue(vertices.contains(var)); Set neighborsInG = new HashSet(GraphHelper.neighborListOf(g, var)); neighborsInG.add(var); Iterator factorIt = mdl.allFactorsContaining(var).iterator(); while (factorIt.hasNext()) { Factor factor = (Factor) factorIt.next(); assertTrue(neighborsInG.containsAll(factor.varSet())); } } } }
private void verifyCachesConsistent(FactorGraph mdl) { Factor pot, pot2, pot3; for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) { pot = (Factor) it.next(); // System.out.println("Testing model "+i+" potential "+pot); Object[] vars = pot.varSet().toArray(); switch (vars.length) { case 1: pot2 = mdl.factorOf((Variable) vars[0]); assertTrue(pot == pot2); break; case 2: Variable var1 = (Variable) vars[0]; Variable var2 = (Variable) vars[1]; pot2 = mdl.factorOf(var1, var2); pot3 = mdl.factorOf(var2, var1); assertTrue(pot == pot2); assertTrue(pot2 == pot3); break; // Factors of size > 2 aren't now cached. default: break; } } }
public void testFactorOfSet() { Variable[] vars = new Variable[3]; for (int i = 0; i < vars.length; i++) { vars[i] = new Variable(2); } Factor factor = new TableFactor(vars, new double[] {0, 1, 2, 3, 4, 5, 6, 7}); FactorGraph fg = new FactorGraph(vars); fg.addFactor(factor); assertTrue(factor == fg.factorOf(factor.varSet())); HashSet set = new HashSet(factor.varSet()); assertTrue(factor == fg.factorOf(set)); set.remove(vars[0]); assertTrue(null == fg.factorOf(set)); }
private void initJtCpts(FactorGraph mdl, JunctionTree jt) { for (Iterator it = jt.getVerticesIterator(); it.hasNext(); ) { VarSet c = (VarSet) it.next(); // DiscreteFactor ptl = createBlankFactor (c); // jt.setCPF(c, ptl); jt.setCPF(c, new ConstantFactor(1.0)); } for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) { Factor ptl = (Factor) it.next(); VarSet parent = jt.findParentCluster(ptl.varSet()); assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt; Factor cpf = jt.getCPF(parent); Factor newCpf = cpf.multiply(ptl); jt.setCPF(parent, newCpf); /* debug if (jt.isNaN()) { throw new RuntimeException ("Got a NaN"); } */ } }
public void addFactor(Factor factor) { super.addFactor(factor); if (factor.varSet().size() == 2) { edges.add(factor.varSet()); } }