// Verify that potentialOfVertex and potentialOfEdge (which use // caches) are consistent with the potentials set. public void testUndirectedCaches() { List models = TestInference.createTestModels(); for (Iterator it = models.iterator(); it.hasNext(); ) { FactorGraph mdl = (FactorGraph) it.next(); verifyCachesConsistent(mdl); } }
private void verifyCachesConsistent(FactorGraph mdl) { Factor pot, pot2, pot3; for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) { pot = (Factor) it.next(); // System.out.println("Testing model "+i+" potential "+pot); Object[] vars = pot.varSet().toArray(); switch (vars.length) { case 1: pot2 = mdl.factorOf((Variable) vars[0]); assertTrue(pot == pot2); break; case 2: Variable var1 = (Variable) vars[0]; Variable var2 = (Variable) vars[1]; pot2 = mdl.factorOf(var1, var2); pot3 = mdl.factorOf(var2, var1); assertTrue(pot == pot2); assertTrue(pot2 == pot3); break; // Factors of size > 2 aren't now cached. default: break; } } }
public Variable pickVertexToRemove(UndirectedGraph mdl, ArrayList lst) { Iterator it = lst.iterator(); Variable best = (Variable) it.next(); int bestVal1 = newEdgesRequired(mdl, best); int bestVal2 = weightRequired(mdl, best); while (it.hasNext()) { Variable v = (Variable) it.next(); int val = newEdgesRequired(mdl, v); if (val < bestVal1) { best = v; bestVal1 = val; bestVal2 = weightRequired(mdl, v); } else if (val == bestVal1) { int val2 = weightRequired(mdl, v); if (val2 < bestVal2) { best = v; bestVal1 = val; bestVal2 = val2; } } } return best; }
/* Return true iff a clique in L strictly contains c. */ private boolean findSuperClique(List l, VarSet c) { for (Iterator it = l.iterator(); it.hasNext(); ) { VarSet c2 = (VarSet) it.next(); if (c2.containsAll(c)) { return true; } } return false; }
/** * Returns the weight of the clique that would be added to a graph if a given vertex would be * removed in the triangulation procedure. The return value is the number of edges in the * elimination clique of V that are not already present. */ private int weightRequired(UndirectedGraph mdl, Variable v) { int rating = 1; for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) { Variable neighbor = (Variable) it1.next(); rating *= neighbor.getNumOutcomes(); } // System.out.println(v+" = "+rating); return rating; }
private JunctionTree buildJtStructure() { TreeSet pq = new TreeSet(sepsetChooser); // Initialize pq with all possible edges... for (Iterator it = cliques.iterator(); it.hasNext(); ) { BitVarSet c1 = (BitVarSet) it.next(); for (Iterator it2 = cliques.iterator(); it2.hasNext(); ) { BitVarSet c2 = (BitVarSet) it2.next(); if (c1 == c2) break; pq.add(new BitVarSet[] {c1, c2}); } } // ...and add the edges to jt that come to the top of the queue // and don't cause a cycle. // xxx OK, this sucks. openjgraph doesn't allow adding // disconnected edges to a tree, so what we'll do is create a // Graph frist, then convert it to a Tree. ListenableUndirectedGraph g = new ListenableUndirectedGraph(new SimpleGraph()); // first add every clique to the graph for (Iterator it = cliques.iterator(); it.hasNext(); ) { VarSet c = (VarSet) it.next(); g.addVertex(c); } ConnectivityInspector inspector = new ConnectivityInspector(g); g.addGraphListener(inspector); // then add n - 1 edges int numCliques = cliques.size(); int edgesAdded = 0; while (edgesAdded < numCliques - 1) { VarSet[] pair = (VarSet[]) pq.first(); pq.remove(pair); if (!inspector.pathExists(pair[0], pair[1])) { g.addEdge(pair[0], pair[1]); edgesAdded++; } } JunctionTree jt = graphToJt(g); if (logger.isLoggable(Level.FINER)) { logger.finer(" jt structure was " + jt); } return jt; }
private JunctionTree graphToJt(UndirectedGraph g) { JunctionTree jt = new JunctionTree(g.vertexSet().size()); Object root = g.vertexSet().iterator().next(); jt.add(root); for (Iterator it1 = new BreadthFirstIterator(g, root); it1.hasNext(); ) { Object v1 = it1.next(); for (Iterator it2 = GraphHelper.neighborListOf(g, v1).iterator(); it2.hasNext(); ) { Object v2 = it2.next(); if (jt.getParent(v1) != v2) { jt.addNode(v1, v2); } } } return jt; }
private void connectNeighbors(UndirectedGraph mdl, Variable v) { for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) { Variable neighbor1 = (Variable) it1.next(); Iterator it2 = neighborsIterator(mdl, v); while (it2.hasNext()) { Variable neighbor2 = (Variable) it2.next(); if (neighbor1 != neighbor2) { if (!isAdjacent(mdl, neighbor1, neighbor2)) { try { mdl.addEdge(neighbor1, neighbor2); } catch (Exception e) { throw new RuntimeException(e); } } } } } }
// Verify that potentialOfVertex and potentialOfEdge (which use // caches) are consistent with the potentials set even if a vertex is removed. public void testUndirectedCachesAfterRemove() { List models = TestInference.createTestModels(); for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) { FactorGraph mdl = (FactorGraph) mdlIt.next(); mdl = (FactorGraph) mdl.duplicate(); mdl.remove(mdl.get(0)); // Verify that indexing correct for (Iterator it = mdl.variablesIterator(); it.hasNext(); ) { Variable var = (Variable) it.next(); int idx = mdl.getIndex(var); assertTrue(idx >= 0); assertTrue(idx < mdl.numVariables()); } // Verify that caches consistent verifyCachesConsistent(mdl); } }
/** * Returns the number of edges that would be added to a graph if a given vertex would be removed * in the triangulation procedure. The return value is the number of edges in the elimination * clique of V that are not already present. */ private int newEdgesRequired(UndirectedGraph mdl, Variable v) { int rating = 0; for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) { Variable neighbor1 = (Variable) it1.next(); Iterator it2 = neighborsIterator(mdl, v); while (it2.hasNext()) { Variable neighbor2 = (Variable) it2.next(); if (neighbor1 != neighbor2) { if (!isAdjacent(mdl, neighbor1, neighbor2)) { rating++; } } } } // System.out.println(v+" = "+rating); return rating; }
private void initJtCpts(FactorGraph mdl, JunctionTree jt) { for (Iterator it = jt.getVerticesIterator(); it.hasNext(); ) { VarSet c = (VarSet) it.next(); // DiscreteFactor ptl = createBlankFactor (c); // jt.setCPF(c, ptl); jt.setCPF(c, new ConstantFactor(1.0)); } for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) { Factor ptl = (Factor) it.next(); VarSet parent = jt.findParentCluster(ptl.varSet()); assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt; Factor cpf = jt.getCPF(parent); Factor newCpf = cpf.multiply(ptl); jt.setCPF(parent, newCpf); /* debug if (jt.isNaN()) { throw new RuntimeException ("Got a NaN"); } */ } }
public void testMdlToGraph() { List models = TestInference.createTestModels(); for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) { UndirectedModel mdl = (UndirectedModel) mdlIt.next(); UndirectedGraph g = Graphs.mdlToGraph(mdl); Set vertices = g.vertexSet(); // check the number of vertices assertEquals(mdl.numVariables(), vertices.size()); // check the number of edges int numEdgePtls = 0; for (Iterator factorIt = mdl.factors().iterator(); factorIt.hasNext(); ) { Factor factor = (Factor) factorIt.next(); if (factor.varSet().size() == 2) numEdgePtls++; } assertEquals(numEdgePtls, g.edgeSet().size()); // check that the neighbors of each vertex contain at least some of what they're supposed to Iterator it = vertices.iterator(); while (it.hasNext()) { Variable var = (Variable) it.next(); assertTrue(vertices.contains(var)); Set neighborsInG = new HashSet(GraphHelper.neighborListOf(g, var)); neighborsInG.add(var); Iterator factorIt = mdl.allFactorsContaining(var).iterator(); while (factorIt.hasNext()) { Factor factor = (Factor) factorIt.next(); assertTrue(neighborsInG.containsAll(factor.varSet())); } } } }
/** Adds edges to graph until it is triangulated. */ private void triangulate(final UndirectedGraph mdl) { UndirectedGraph mdl2 = dupGraph(mdl); ArrayList<Variable> vars = new ArrayList<Variable>(mdl.vertexSet()); Alphabet<Variable> varMap = makeVertexMap(vars); cliques = new ArrayList(); // debug if (logger.isLoggable(Level.FINER)) { logger.finer("Triangulating model: " + mdl); String ret = ""; for (int i = 0; i < vars.size(); i++) { Variable next = (Variable) vars.get(i); ret += next.toString() + "\n"; // " (" + mdl.getIndex(next) + ")\n "; } logger.finer(ret); } while (!vars.isEmpty()) { Variable v = (Variable) pickVertexToRemove(mdl2, vars); logger.finer("Triangulating vertex " + v); VarSet varSet = new BitVarSet(v.getUniverse(), GraphHelper.neighborListOf(mdl2, v)); varSet.add(v); if (!findSuperClique(cliques, varSet)) { cliques.add(varSet); if (logger.isLoggable(Level.FINER)) { logger.finer( " Elim clique " + varSet + " size " + varSet.size() + " weight " + varSet.weight()); } } // must remove V from graph first, because adding the edges // will change the rating of other vertices connectNeighbors(mdl2, v); vars.remove(v); mdl2.removeVertex(v); } if (logger.isLoggable(Level.FINE)) { logger.fine("Triangulation done. Cliques are: "); int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0; for (Iterator it = cliques.iterator(); it.hasNext(); ) { VarSet c = (VarSet) it.next(); logger.finer(c.toString()); totSize += c.size(); maxSize = Math.max(c.size(), maxSize); totWeight += c.weight(); maxWeight = Math.max(c.weight(), maxWeight); } double sz = cliques.size(); logger.fine( "Jt created " + sz + " cliques. Size: avg " + (totSize / sz) + " max " + (maxSize) + " Weight: avg " + (totWeight / sz) + " max " + (maxWeight)); } }