Example #1
0
  public void testMdlToGraph() {
    List models = TestInference.createTestModels();
    for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) {
      UndirectedModel mdl = (UndirectedModel) mdlIt.next();
      UndirectedGraph g = Graphs.mdlToGraph(mdl);
      Set vertices = g.vertexSet();

      // check the number of vertices
      assertEquals(mdl.numVariables(), vertices.size());

      // check the number of edges
      int numEdgePtls = 0;
      for (Iterator factorIt = mdl.factors().iterator(); factorIt.hasNext(); ) {
        Factor factor = (Factor) factorIt.next();
        if (factor.varSet().size() == 2) numEdgePtls++;
      }
      assertEquals(numEdgePtls, g.edgeSet().size());

      // check that the neighbors of each vertex contain at least some of what they're supposed to
      Iterator it = vertices.iterator();
      while (it.hasNext()) {
        Variable var = (Variable) it.next();
        assertTrue(vertices.contains(var));
        Set neighborsInG = new HashSet(GraphHelper.neighborListOf(g, var));
        neighborsInG.add(var);

        Iterator factorIt = mdl.allFactorsContaining(var).iterator();
        while (factorIt.hasNext()) {
          Factor factor = (Factor) factorIt.next();
          assertTrue(neighborsInG.containsAll(factor.varSet()));
        }
      }
    }
  }
  private JunctionTree graphToJt(UndirectedGraph g) {
    JunctionTree jt = new JunctionTree(g.vertexSet().size());
    Object root = g.vertexSet().iterator().next();
    jt.add(root);

    for (Iterator it1 = new BreadthFirstIterator(g, root); it1.hasNext(); ) {
      Object v1 = it1.next();
      for (Iterator it2 = GraphHelper.neighborListOf(g, v1).iterator(); it2.hasNext(); ) {
        Object v2 = it2.next();
        if (jt.getParent(v1) != v2) {
          jt.addNode(v1, v2);
        }
      }
    }
    return jt;
  }
 private void connectNeighbors(UndirectedGraph mdl, Variable v) {
   for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) {
     Variable neighbor1 = (Variable) it1.next();
     Iterator it2 = neighborsIterator(mdl, v);
     while (it2.hasNext()) {
       Variable neighbor2 = (Variable) it2.next();
       if (neighbor1 != neighbor2) {
         if (!isAdjacent(mdl, neighbor1, neighbor2)) {
           try {
             mdl.addEdge(neighbor1, neighbor2);
           } catch (Exception e) {
             throw new RuntimeException(e);
           }
         }
       }
     }
   }
 }
 private boolean isAdjacent(UndirectedGraph g, Variable v1, Variable v2) {
   return g.getEdge(v1, v2) != null;
 }
  /** Adds edges to graph until it is triangulated. */
  private void triangulate(final UndirectedGraph mdl) {
    UndirectedGraph mdl2 = dupGraph(mdl);
    ArrayList<Variable> vars = new ArrayList<Variable>(mdl.vertexSet());
    Alphabet<Variable> varMap = makeVertexMap(vars);
    cliques = new ArrayList();

    // debug
    if (logger.isLoggable(Level.FINER)) {
      logger.finer("Triangulating model: " + mdl);
      String ret = "";
      for (int i = 0; i < vars.size(); i++) {
        Variable next = (Variable) vars.get(i);
        ret += next.toString() + "\n"; // " (" + mdl.getIndex(next) + ")\n  ";
      }
      logger.finer(ret);
    }

    while (!vars.isEmpty()) {
      Variable v = (Variable) pickVertexToRemove(mdl2, vars);
      logger.finer("Triangulating vertex " + v);

      VarSet varSet = new BitVarSet(v.getUniverse(), GraphHelper.neighborListOf(mdl2, v));
      varSet.add(v);
      if (!findSuperClique(cliques, varSet)) {
        cliques.add(varSet);
        if (logger.isLoggable(Level.FINER)) {
          logger.finer(
              "  Elim clique " + varSet + " size " + varSet.size() + " weight " + varSet.weight());
        }
      }

      // must remove V from graph first, because adding the edges
      //  will change the rating of other vertices

      connectNeighbors(mdl2, v);
      vars.remove(v);
      mdl2.removeVertex(v);
    }

    if (logger.isLoggable(Level.FINE)) {
      logger.fine("Triangulation done. Cliques are: ");
      int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0;
      for (Iterator it = cliques.iterator(); it.hasNext(); ) {
        VarSet c = (VarSet) it.next();
        logger.finer(c.toString());
        totSize += c.size();
        maxSize = Math.max(c.size(), maxSize);
        totWeight += c.weight();
        maxWeight = Math.max(c.weight(), maxWeight);
      }
      double sz = cliques.size();
      logger.fine(
          "Jt created "
              + sz
              + " cliques. Size: avg "
              + (totSize / sz)
              + " max "
              + (maxSize)
              + " Weight: avg "
              + (totWeight / sz)
              + " max "
              + (maxWeight));
    }
  }
 // xxx Insanely inefficient stub
 public boolean isConnected(Variable v1, Variable v2) {
   UndirectedGraph g = Graphs.mdlToGraph(this);
   ConnectivityInspector ins = new ConnectivityInspector(g);
   return g.containsVertex(v1) && g.containsVertex(v2) && ins.pathExists(v1, v2);
 }