예제 #1
0
 // Verify that potentialOfVertex and potentialOfEdge (which use
 // caches) are consistent with the potentials set.
 public void testUndirectedCaches() {
   List models = TestInference.createTestModels();
   for (Iterator it = models.iterator(); it.hasNext(); ) {
     FactorGraph mdl = (FactorGraph) it.next();
     verifyCachesConsistent(mdl);
   }
 }
예제 #2
0
  private void verifyCachesConsistent(FactorGraph mdl) {
    Factor pot, pot2, pot3;
    for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) {
      pot = (Factor) it.next();
      //				System.out.println("Testing model "+i+" potential "+pot);

      Object[] vars = pot.varSet().toArray();
      switch (vars.length) {
        case 1:
          pot2 = mdl.factorOf((Variable) vars[0]);
          assertTrue(pot == pot2);
          break;

        case 2:
          Variable var1 = (Variable) vars[0];
          Variable var2 = (Variable) vars[1];
          pot2 = mdl.factorOf(var1, var2);
          pot3 = mdl.factorOf(var2, var1);
          assertTrue(pot == pot2);
          assertTrue(pot2 == pot3);
          break;

          // Factors of size > 2 aren't now cached.
        default:
          break;
      }
    }
  }
  public Variable pickVertexToRemove(UndirectedGraph mdl, ArrayList lst) {
    Iterator it = lst.iterator();
    Variable best = (Variable) it.next();
    int bestVal1 = newEdgesRequired(mdl, best);
    int bestVal2 = weightRequired(mdl, best);

    while (it.hasNext()) {
      Variable v = (Variable) it.next();
      int val = newEdgesRequired(mdl, v);
      if (val < bestVal1) {
        best = v;
        bestVal1 = val;
        bestVal2 = weightRequired(mdl, v);
      } else if (val == bestVal1) {
        int val2 = weightRequired(mdl, v);
        if (val2 < bestVal2) {
          best = v;
          bestVal1 = val;
          bestVal2 = val2;
        }
      }
    }

    return best;
  }
 /* Return true iff a clique in L strictly contains c. */
 private boolean findSuperClique(List l, VarSet c) {
   for (Iterator it = l.iterator(); it.hasNext(); ) {
     VarSet c2 = (VarSet) it.next();
     if (c2.containsAll(c)) {
       return true;
     }
   }
   return false;
 }
  /**
   * Returns the weight of the clique that would be added to a graph if a given vertex would be
   * removed in the triangulation procedure. The return value is the number of edges in the
   * elimination clique of V that are not already present.
   */
  private int weightRequired(UndirectedGraph mdl, Variable v) {
    int rating = 1;

    for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) {
      Variable neighbor = (Variable) it1.next();
      rating *= neighbor.getNumOutcomes();
    }

    //		System.out.println(v+" = "+rating);

    return rating;
  }
  private JunctionTree buildJtStructure() {
    TreeSet pq = new TreeSet(sepsetChooser);

    // Initialize pq with all possible edges...
    for (Iterator it = cliques.iterator(); it.hasNext(); ) {
      BitVarSet c1 = (BitVarSet) it.next();
      for (Iterator it2 = cliques.iterator(); it2.hasNext(); ) {
        BitVarSet c2 = (BitVarSet) it2.next();
        if (c1 == c2) break;
        pq.add(new BitVarSet[] {c1, c2});
      }
    }

    // ...and add the edges to jt that come to the top of the queue
    //  and don't cause a cycle.
    // xxx OK, this sucks.  openjgraph doesn't allow adding
    //  disconnected edges to a tree, so what we'll do is create a
    //  Graph frist, then convert it to a Tree.
    ListenableUndirectedGraph g = new ListenableUndirectedGraph(new SimpleGraph());

    // first add every clique to the graph
    for (Iterator it = cliques.iterator(); it.hasNext(); ) {
      VarSet c = (VarSet) it.next();
      g.addVertex(c);
    }

    ConnectivityInspector inspector = new ConnectivityInspector(g);
    g.addGraphListener(inspector);

    // then add n - 1 edges
    int numCliques = cliques.size();
    int edgesAdded = 0;
    while (edgesAdded < numCliques - 1) {
      VarSet[] pair = (VarSet[]) pq.first();
      pq.remove(pair);

      if (!inspector.pathExists(pair[0], pair[1])) {
        g.addEdge(pair[0], pair[1]);
        edgesAdded++;
      }
    }

    JunctionTree jt = graphToJt(g);
    if (logger.isLoggable(Level.FINER)) {
      logger.finer("  jt structure was " + jt);
    }
    return jt;
  }
  private JunctionTree graphToJt(UndirectedGraph g) {
    JunctionTree jt = new JunctionTree(g.vertexSet().size());
    Object root = g.vertexSet().iterator().next();
    jt.add(root);

    for (Iterator it1 = new BreadthFirstIterator(g, root); it1.hasNext(); ) {
      Object v1 = it1.next();
      for (Iterator it2 = GraphHelper.neighborListOf(g, v1).iterator(); it2.hasNext(); ) {
        Object v2 = it2.next();
        if (jt.getParent(v1) != v2) {
          jt.addNode(v1, v2);
        }
      }
    }
    return jt;
  }
 private void connectNeighbors(UndirectedGraph mdl, Variable v) {
   for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) {
     Variable neighbor1 = (Variable) it1.next();
     Iterator it2 = neighborsIterator(mdl, v);
     while (it2.hasNext()) {
       Variable neighbor2 = (Variable) it2.next();
       if (neighbor1 != neighbor2) {
         if (!isAdjacent(mdl, neighbor1, neighbor2)) {
           try {
             mdl.addEdge(neighbor1, neighbor2);
           } catch (Exception e) {
             throw new RuntimeException(e);
           }
         }
       }
     }
   }
 }
예제 #9
0
  // Verify that potentialOfVertex and potentialOfEdge (which use
  // caches) are consistent with the potentials set even if a vertex is removed.
  public void testUndirectedCachesAfterRemove() {
    List models = TestInference.createTestModels();
    for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) {
      FactorGraph mdl = (FactorGraph) mdlIt.next();
      mdl = (FactorGraph) mdl.duplicate();
      mdl.remove(mdl.get(0));

      // Verify that indexing correct
      for (Iterator it = mdl.variablesIterator(); it.hasNext(); ) {
        Variable var = (Variable) it.next();
        int idx = mdl.getIndex(var);
        assertTrue(idx >= 0);
        assertTrue(idx < mdl.numVariables());
      }

      // Verify that caches consistent
      verifyCachesConsistent(mdl);
    }
  }
예제 #10
0
  /**
   * Returns the number of edges that would be added to a graph if a given vertex would be removed
   * in the triangulation procedure. The return value is the number of edges in the elimination
   * clique of V that are not already present.
   */
  private int newEdgesRequired(UndirectedGraph mdl, Variable v) {
    int rating = 0;

    for (Iterator it1 = neighborsIterator(mdl, v); it1.hasNext(); ) {
      Variable neighbor1 = (Variable) it1.next();
      Iterator it2 = neighborsIterator(mdl, v);
      while (it2.hasNext()) {
        Variable neighbor2 = (Variable) it2.next();
        if (neighbor1 != neighbor2) {
          if (!isAdjacent(mdl, neighbor1, neighbor2)) {
            rating++;
          }
        }
      }
    }

    //		System.out.println(v+" = "+rating);

    return rating;
  }
예제 #11
0
  private void initJtCpts(FactorGraph mdl, JunctionTree jt) {
    for (Iterator it = jt.getVerticesIterator(); it.hasNext(); ) {
      VarSet c = (VarSet) it.next();
      //      DiscreteFactor ptl = createBlankFactor (c);
      //      jt.setCPF(c, ptl);
      jt.setCPF(c, new ConstantFactor(1.0));
    }

    for (Iterator it = mdl.factors().iterator(); it.hasNext(); ) {
      Factor ptl = (Factor) it.next();
      VarSet parent = jt.findParentCluster(ptl.varSet());
      assert parent != null : "Unable to find parent cluster for ptl " + ptl + "in jt " + jt;

      Factor cpf = jt.getCPF(parent);
      Factor newCpf = cpf.multiply(ptl);
      jt.setCPF(parent, newCpf);

      /* debug
      if (jt.isNaN()) {
        throw new RuntimeException ("Got a NaN");
      }
      */
    }
  }
예제 #12
0
  public void testMdlToGraph() {
    List models = TestInference.createTestModels();
    for (Iterator mdlIt = models.iterator(); mdlIt.hasNext(); ) {
      UndirectedModel mdl = (UndirectedModel) mdlIt.next();
      UndirectedGraph g = Graphs.mdlToGraph(mdl);
      Set vertices = g.vertexSet();

      // check the number of vertices
      assertEquals(mdl.numVariables(), vertices.size());

      // check the number of edges
      int numEdgePtls = 0;
      for (Iterator factorIt = mdl.factors().iterator(); factorIt.hasNext(); ) {
        Factor factor = (Factor) factorIt.next();
        if (factor.varSet().size() == 2) numEdgePtls++;
      }
      assertEquals(numEdgePtls, g.edgeSet().size());

      // check that the neighbors of each vertex contain at least some of what they're supposed to
      Iterator it = vertices.iterator();
      while (it.hasNext()) {
        Variable var = (Variable) it.next();
        assertTrue(vertices.contains(var));
        Set neighborsInG = new HashSet(GraphHelper.neighborListOf(g, var));
        neighborsInG.add(var);

        Iterator factorIt = mdl.allFactorsContaining(var).iterator();
        while (factorIt.hasNext()) {
          Factor factor = (Factor) factorIt.next();
          assertTrue(neighborsInG.containsAll(factor.varSet()));
        }
      }
    }
  }
예제 #13
0
  /** Adds edges to graph until it is triangulated. */
  private void triangulate(final UndirectedGraph mdl) {
    UndirectedGraph mdl2 = dupGraph(mdl);
    ArrayList<Variable> vars = new ArrayList<Variable>(mdl.vertexSet());
    Alphabet<Variable> varMap = makeVertexMap(vars);
    cliques = new ArrayList();

    // debug
    if (logger.isLoggable(Level.FINER)) {
      logger.finer("Triangulating model: " + mdl);
      String ret = "";
      for (int i = 0; i < vars.size(); i++) {
        Variable next = (Variable) vars.get(i);
        ret += next.toString() + "\n"; // " (" + mdl.getIndex(next) + ")\n  ";
      }
      logger.finer(ret);
    }

    while (!vars.isEmpty()) {
      Variable v = (Variable) pickVertexToRemove(mdl2, vars);
      logger.finer("Triangulating vertex " + v);

      VarSet varSet = new BitVarSet(v.getUniverse(), GraphHelper.neighborListOf(mdl2, v));
      varSet.add(v);
      if (!findSuperClique(cliques, varSet)) {
        cliques.add(varSet);
        if (logger.isLoggable(Level.FINER)) {
          logger.finer(
              "  Elim clique " + varSet + " size " + varSet.size() + " weight " + varSet.weight());
        }
      }

      // must remove V from graph first, because adding the edges
      //  will change the rating of other vertices

      connectNeighbors(mdl2, v);
      vars.remove(v);
      mdl2.removeVertex(v);
    }

    if (logger.isLoggable(Level.FINE)) {
      logger.fine("Triangulation done. Cliques are: ");
      int totSize = 0, totWeight = 0, maxSize = 0, maxWeight = 0;
      for (Iterator it = cliques.iterator(); it.hasNext(); ) {
        VarSet c = (VarSet) it.next();
        logger.finer(c.toString());
        totSize += c.size();
        maxSize = Math.max(c.size(), maxSize);
        totWeight += c.weight();
        maxWeight = Math.max(c.weight(), maxWeight);
      }
      double sz = cliques.size();
      logger.fine(
          "Jt created "
              + sz
              + " cliques. Size: avg "
              + (totSize / sz)
              + " max "
              + (maxSize)
              + " Weight: avg "
              + (totWeight / sz)
              + " max "
              + (maxWeight));
    }
  }