コード例 #1
0
ファイル: GesConcurrent.java プロジェクト: renjiey/tetrad
  private boolean existsUnblockedSemiDirectedPath(Node from, Node to, List<Node> cond, Graph G) {
    Queue<Node> Q = new LinkedList<Node>();
    Set<Node> V = new HashSet<Node>();
    Q.offer(from);
    V.add(from);

    while (!Q.isEmpty()) {
      Node t = Q.remove();
      if (t == to) return true;

      for (Node u : G.getAdjacentNodes(t)) {
        Edge edge = G.getEdge(t, u);
        Node c = Edges.traverseSemiDirected(t, edge);
        if (c == null) continue;
        if (cond.contains(c)) continue;
        if (c == to) return true;

        if (!V.contains(c)) {
          V.add(c);
          Q.offer(c);
        }
      }
    }

    return false;
  }
コード例 #2
0
  /**
   * Transforms a maximally directed pattern (PDAG) represented in graph <code>g</code> into an
   * arbitrary DAG by modifying <code>g</code> itself. Based on the algorithm described in
   * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine
   * Learning Research. R. Silva, June 2004
   */
  public static void pdagToDag(Graph g) {
    Graph p = new EdgeListGraph(g);
    List<Edge> undirectedEdges = new ArrayList<Edge>();

    for (Edge edge : g.getEdges()) {
      if (edge.getEndpoint1() == Endpoint.TAIL
          && edge.getEndpoint2() == Endpoint.TAIL
          && !undirectedEdges.contains(edge)) {
        undirectedEdges.add(edge);
      }
    }
    g.removeEdges(undirectedEdges);
    List<Node> pNodes = p.getNodes();

    do {
      Node x = null;

      for (Node pNode : pNodes) {
        x = pNode;

        if (p.getChildren(x).size() > 0) {
          continue;
        }

        Set<Node> neighbors = new HashSet<Node>();

        for (Edge edge : p.getEdges()) {
          if (edge.getNode1() == x || edge.getNode2() == x) {
            if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL) {
              if (edge.getNode1() == x) {
                neighbors.add(edge.getNode2());
              } else {
                neighbors.add(edge.getNode1());
              }
            }
          }
        }
        if (neighbors.size() > 0) {
          Collection<Node> parents = p.getParents(x);
          Set<Node> all = new HashSet<Node>(neighbors);
          all.addAll(parents);
          if (!GraphUtils.isClique(all, p)) {
            continue;
          }
        }

        for (Node neighbor : neighbors) {
          Node node1 = g.getNode(neighbor.getName());
          Node node2 = g.getNode(x.getName());

          g.addDirectedEdge(node1, node2);
        }
        p.removeNode(x);
        break;
      }
      pNodes.remove(x);
    } while (pNodes.size() > 0);
  }
コード例 #3
0
  private Set<Integer> triple(int n1, int n2, int n3) {
    Set<Integer> triple = new HashSet<Integer>();
    triple.add(n1);
    triple.add(n2);
    triple.add(n3);

    if (triple.size() < 3)
      throw new IllegalArgumentException(
          "Triple elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ">");

    return triple;
  }
コード例 #4
0
  private Set<Integer> quartet(int x, int y, int z, int w) {
    Set<Integer> set = new HashSet<Integer>();
    set.add(x);
    set.add(y);
    set.add(z);
    set.add(w);

    if (set.size() < 4)
      throw new IllegalArgumentException(
          "Quartet elements must be unique: <" + x + ", " + y + ", " + z + ", " + w + ">");

    return set;
  }
コード例 #5
0
  /** Get a graph and direct only the unshielded colliders. */
  public static void basicPattern(Graph graph) {
    Set<Edge> undirectedEdges = new HashSet<Edge>();

    NEXT_EDGE:
    for (Edge edge : graph.getEdges()) {
      Node head = null, tail = null;

      if (edge.getEndpoint1() == Endpoint.ARROW && edge.getEndpoint2() == Endpoint.TAIL) {
        head = edge.getNode1();
        tail = edge.getNode2();
      } else if (edge.getEndpoint2() == Endpoint.ARROW && edge.getEndpoint1() == Endpoint.TAIL) {
        head = edge.getNode2();
        tail = edge.getNode1();
      }

      if (head != null) {
        for (Node node : graph.getParents(head)) {
          if (node != tail && !graph.isAdjacentTo(tail, node)) {
            continue NEXT_EDGE;
          }
        }

        undirectedEdges.add(edge);
      }
    }

    for (Edge nextUndirected : undirectedEdges) {
      Node node1 = nextUndirected.getNode1(), node2 = nextUndirected.getNode2();

      graph.removeEdge(nextUndirected);
      graph.addUndirectedEdge(node1, node2);
    }
  }
コード例 #6
0
ファイル: Knowledge2.java プロジェクト: ps7z/tetrad
 /** Adds the given variable name to knowledge. Duplicates are ignored. */
 public void addVariable(String varName) {
   if (!namesToVars.containsKey(varName) && checkVarName(varName)) {
     MyNode e = new MyNode(varName);
     myNodes.add(e);
     namesToVars.put(varName, e);
   }
 }
コード例 #7
0
ファイル: PcMax.java プロジェクト: renjiey/tetrad
 public Set<Edge> getAdjacencies() {
   Set<Edge> adjacencies = new HashSet<Edge>();
   for (Edge edge : graph.getEdges()) {
     adjacencies.add(edge);
   }
   return adjacencies;
 }
コード例 #8
0
 private Set<List<Set<Integer>>> combineClusters(
     Set<Set<Integer>> ESeeds, List<Set<Integer>> CSeeds) {
   Set<Set<Integer>> EClusters = finishESeeds(ESeeds);
   Set<Integer> Cs = new HashSet();
   for (int i = 0; i < variables.size(); i++) Cs.add(i);
   Set<Integer> Es = new HashSet();
   for (Set<Integer> ECluster : EClusters) Es.addAll(ECluster);
   Cs.removeAll(Es);
   List<List<Set<Integer>>> Clusters = new ArrayList();
   for (Set<Integer> ECluster : EClusters) {
     List<Set<Integer>> newCluster = new ArrayList<Set<Integer>>();
     newCluster.add(1, ECluster);
     Clusters.add(newCluster);
   }
   List<Set<Integer>> EClustersArray = new ArrayList<Set<Integer>>();
   for (Set<Integer> ECluster : EClusters) EClustersArray.add(ECluster);
   for (Integer c : Cs) {
     int match = -1;
     int overlap = 0;
     boolean pass = false;
     for (int i = 0; i < EClusters.size(); i++) {
       Set<Integer> ECluster = EClustersArray.get(i);
       Set<Integer> intersection = ECluster;
       intersection.retainAll(CSeeds.get(c));
       int _overlap = intersection.size();
       if (_overlap > overlap) {
         overlap = _overlap;
         match = i;
         if (overlap / ECluster.size() > CIparameter) {
           pass = true;
         }
       }
     }
     if (pass) {
       List<Set<Integer>> modCluster = new ArrayList<Set<Integer>>();
       Set<Integer> newCs = Clusters.get(match).get(0);
       newCs.add(c);
       modCluster.add(newCs);
       modCluster.add(EClustersArray.get(match));
       Clusters.set(match, modCluster);
     }
   }
   Set<List<Set<Integer>>> ClusterSet = new HashSet<List<Set<Integer>>>(Clusters);
   return ClusterSet;
 }
コード例 #9
0
ファイル: GesConcurrent.java プロジェクト: renjiey/tetrad
  /** Evaluate the Insert(X, Y, T) operator (Definition 12 from Chickering, 2002). */
  private double insertEval(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) {
    Set<Node> set1 = new HashSet<Node>(naYX);
    set1.addAll(t);
    List<Node> paY = graph.getParents(y);
    set1.addAll(paY);
    Set<Node> set2 = new HashSet<Node>(set1);
    set1.add(x);

    return scoreGraphChange(y, set1, set2);
  }
コード例 #10
0
ファイル: GeneralizedSemPm.java プロジェクト: ps7z/tetrad
  /**
   * @param node the node doing the referencing.
   * @return the freeParameters referenced by the given variable (variable node or error node).
   */
  public Set<String> getReferencedParameters(Node node) {
    Set<String> parameters = new HashSet<>();

    for (String parameter : this.referencedParameters.keySet()) {
      if (this.referencedParameters.get(parameter).contains(node)) {
        parameters.add(parameter);
      }
    }

    return parameters;
  }
コード例 #11
0
ファイル: GeneralizedSemPm.java プロジェクト: ps7z/tetrad
  /**
   * @param node the node doing the referencing.
   * @return the variables referenced by the expression for the given node (variable node or error
   *     node.
   */
  public Set<Node> getReferencedNodes(Node node) {
    Set<Node> nodes = new HashSet<>();

    for (Node _node : this.referencedNodes.keySet()) {
      if (this.referencedNodes.get(_node).contains(node)) {
        nodes.add(_node);
      }
    }

    return nodes;
  }
コード例 #12
0
ファイル: GesConcurrent.java プロジェクト: renjiey/tetrad
  private void addLookupArrow(Node i, Node j, Arrow arrow) {
    OrderedPair<Node> pair = new OrderedPair<Node>(i, j);
    Set<Arrow> arrows = lookupArrows.get(pair);

    if (arrows == null) {
      arrows = new HashSet<Arrow>();
      lookupArrows.put(pair, arrows);
    }

    arrows.add(arrow);
  }
コード例 #13
0
ファイル: Knowledge2.java プロジェクト: ps7z/tetrad
  private Set<String> split(String spec) {
    String[] tokens = spec.split(",");

    Set<String> _tokens = new HashSet<>();

    for (String _token : tokens) {
      if (!_token.trim().equals("")) {
        _tokens.add(_token);
      }
    }

    return _tokens;
  }
コード例 #14
0
 public static List<Set<Node>> powerSet(List<Node> nodes) {
   List<Set<Node>> subsets = new ArrayList<Set<Node>>();
   int total = (int) Math.pow(2, nodes.size());
   for (int i = 0; i < total; i++) {
     Set<Node> newSet = new HashSet<Node>();
     String selection = Integer.toBinaryString(i);
     for (int j = selection.length() - 1; j >= 0; j--) {
       if (selection.charAt(j) == '1') {
         newSet.add(nodes.get(selection.length() - j - 1));
       }
     }
     subsets.add(newSet);
   }
   return subsets;
 }
コード例 #15
0
  private Graph convertToGraph(Set<Set<Integer>> allClusters) {
    Set<Set<Node>> _clustering = new HashSet<Set<Node>>();

    for (Set<Integer> cluster : allClusters) {
      Set<Node> nodes = new HashSet<Node>();

      for (int i : cluster) {
        nodes.add(variables.get(i));
      }

      _clustering.add(nodes);
    }

    return convertSearchGraphNodes(_clustering);
  }
コード例 #16
0
ファイル: Knowledge2.java プロジェクト: ps7z/tetrad
  /** Iterator over the KnowledgeEdge's representing required edges. */
  public final Iterator<KnowledgeEdge> requiredEdgesIterator() {
    Set<KnowledgeEdge> edges = new HashSet<>();

    for (OrderedPair<Set<MyNode>> o : requiredRulesSpecs) {
      final Set<MyNode> first = o.getFirst();
      for (MyNode s1 : first) {
        final Set<MyNode> second = o.getSecond();
        for (MyNode s2 : second) {
          if (!s1.equals(s2)) {
            edges.add(new KnowledgeEdge(s1.getName(), s2.getName()));
          }
        }
      }
    }

    return edges.iterator();
  }
コード例 #17
0
ファイル: Knowledge2.java プロジェクト: ps7z/tetrad
  private Set<MyNode> getExtent(String spec) {
    Set<String> split = split(spec);
    Set<MyNode> matches = new HashSet<>();

    for (String _spec : split) {
      _spec = _spec.replace("*", ".*");

      java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(_spec);

      for (MyNode var : myNodes) {
        Matcher matcher = pattern.matcher(var.getName());
        if (matcher.matches()) {
          matches.add(var);
        }
      }
    }

    return matches;
  }
コード例 #18
0
  private boolean pure(Set<Integer> quartet, List<Integer> variables) {
    if (quartetVanishes(quartet)) {
      for (int o : variables) {
        if (quartet.contains(o)) continue;

        for (int p : quartet) {
          Set<Integer> _quartet = new HashSet<Integer>(quartet);
          _quartet.remove(p);
          _quartet.add(o);

          if (!quartetVanishes(_quartet)) {
            return false;
          }
        }
      }

      return significant(new ArrayList<Integer>(quartet));
    }

    return false;
  }
コード例 #19
0
ファイル: Knowledge2.java プロジェクト: ps7z/tetrad
  /** Iterator over the knowledge's explicitly forbidden edges. */
  public final Iterator<KnowledgeEdge> explicitlyForbiddenEdgesIterator() {
    Set<OrderedPair<Set<MyNode>>> copy = new HashSet<>(forbiddenRulesSpecs);
    copy.removeAll(forbiddenTierRules());

    for (KnowledgeGroup group : knowledgeGroups) {
      copy.remove(knowledgeGroupRules.get(group));
    }

    Set<KnowledgeEdge> edges = new HashSet<>();

    for (OrderedPair<Set<MyNode>> o : copy) {
      final Set<MyNode> first = o.getFirst();
      for (MyNode s1 : first) {
        final Set<MyNode> second = o.getSecond();
        for (MyNode s2 : second) {
          edges.add(new KnowledgeEdge(s1.getName(), s2.getName()));
        }
      }
    }

    return edges.iterator();
  }
コード例 #20
0
ファイル: Rfci.java プロジェクト: ajsedgewick/tetrad
  /**
   * Constructs a new FCI search for the given independence test and background knowledge and a list
   * of variables to search over.
   */
  public Rfci(IndependenceTest independenceTest, List<Node> searchVars) {
    if (independenceTest == null || knowledge == null) {
      throw new NullPointerException();
    }

    this.independenceTest = independenceTest;
    this.variables.addAll(independenceTest.getVariables());

    Set<Node> remVars = new HashSet<Node>();
    for (Node node1 : this.variables) {
      boolean search = false;
      for (Node node2 : searchVars) {
        if (node1.getName().equals(node2.getName())) {
          search = true;
        }
      }
      if (!search) {
        remVars.add(node1);
      }
    }
    this.variables.removeAll(remVars);
  }
コード例 #21
0
  //  Finds clusters of size 3.
  private Set<Set<Integer>> findMixedClusters(
      List<Integer> remaining, Set<Integer> unionPure, Map<Node, Set<Node>> adjacencies) {
    Set<Set<Integer>> threeClusters = new HashSet<Set<Integer>>();

    if (unionPure.isEmpty()) {
      return new HashSet<Set<Integer>>();
    }

    REMAINING:
    while (true) {
      if (remaining.size() < 3) break;

      ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 3);
      int[] choice;

      while ((choice = gen.next()) != null) {
        int y = remaining.get(choice[0]);
        int z = remaining.get(choice[1]);
        int w = remaining.get(choice[2]);

        Set<Integer> cluster = new HashSet<Integer>();
        cluster.add(y);
        cluster.add(z);
        cluster.add(w);

        //                if (!allVariablesDependent(cluster)) {
        //                    continue;
        //                }

        if (!clique(cluster, adjacencies)) {
          continue;
        }

        // Check all x as a cross check; really only one should be necessary.
        boolean allX = true;

        for (int x : unionPure) {
          Set<Integer> _cluster = new HashSet<Integer>(cluster);
          _cluster.add(x);

          if (!quartetVanishes(_cluster) || !significant(new ArrayList<Integer>(_cluster))) {
            allX = false;
            break;
          }
        }

        if (allX) {
          threeClusters.add(cluster);
          unionPure.addAll(cluster);
          remaining.removeAll(cluster);

          System.out.println(
              "3-cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));

          continue REMAINING;
        }
      }

      break;
    }

    return threeClusters;
  }
コード例 #22
0
  // Trying to optimize the search for 4-cliques a bit.
  private Set<Set<Integer>> findPureClusters2(
      List<Integer> _variables, Map<Node, Set<Node>> adjacencies) {
    System.out.println("Original variables = " + variables);

    Set<Set<Integer>> clusters = new HashSet<Set<Integer>>();
    List<Integer> allVariables = new ArrayList<Integer>();
    Set<Node> foundVariables = new HashSet<Node>();
    for (int i = 0; i < this.variables.size(); i++) allVariables.add(i);

    for (int x : _variables) {
      Node nodeX = variables.get(x);
      if (foundVariables.contains(nodeX)) continue;

      List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX));
      adjX.removeAll(foundVariables);

      if (adjX.size() < 3) continue;

      for (Node nodeY : adjX) {
        if (foundVariables.contains(nodeY)) continue;

        List<Node> commonXY = new ArrayList<Node>(adjacencies.get(nodeY));
        commonXY.retainAll(adjX);
        commonXY.removeAll(foundVariables);

        for (Node nodeZ : commonXY) {
          if (foundVariables.contains(nodeZ)) continue;

          List<Node> commonXZ = new ArrayList<Node>(commonXY);
          commonXZ.retainAll(adjacencies.get(nodeZ));
          commonXZ.removeAll(foundVariables);

          for (Node nodeW : commonXZ) {
            if (foundVariables.contains(nodeW)) continue;

            if (!adjacencies.get(nodeY).contains(nodeW)) {
              continue;
            }

            int y = variables.indexOf(nodeY);
            int w = variables.indexOf(nodeW);
            int z = variables.indexOf(nodeZ);

            Set<Integer> cluster = quartet(x, y, z, w);

            // Note that purity needs to be assessed with respect to all of the variables in order
            // to
            // remove all latent-measure impurities between pairs of latents.
            if (pure(cluster, allVariables)) {

              O:
              for (int o : _variables) {
                if (cluster.contains(o)) continue;
                cluster.add(o);

                if (!clique(cluster, adjacencies)) {
                  cluster.remove(o);
                  continue O;
                }

                //                                if (!allVariablesDependent(cluster)) {
                //                                    cluster.remove(o);
                //                                    continue O;
                //                                }

                List<Integer> _cluster = new ArrayList<Integer>(cluster);

                ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4);
                int[] choice2;
                int count = 0;

                while ((choice2 = gen2.next()) != null) {
                  int x2 = _cluster.get(choice2[0]);
                  int y2 = _cluster.get(choice2[1]);
                  int z2 = _cluster.get(choice2[2]);
                  int w2 = _cluster.get(choice2[3]);

                  Set<Integer> quartet = quartet(x2, y2, z2, w2);

                  // Optimizes for large clusters.
                  if (quartet.contains(o)) {
                    if (++count > 2) continue O;
                  }

                  if (quartet.contains(o) && !pure(quartet, allVariables)) {
                    cluster.remove(o);
                    continue O;
                  }
                }
              }

              System.out.println(
                  "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));
              clusters.add(cluster);
              foundVariables.addAll(variablesForIndices(new ArrayList<Integer>(cluster)));
            }
          }
        }
      }
    }

    return clusters;
  }
コード例 #23
0
  private Set<Set<Integer>> finishESeeds(Set<Set<Integer>> ESeeds) {
    log("Growing Effect Seeds.", true);
    Set<Set<Integer>> grown = new HashSet<Set<Integer>>();

    List<Integer> _variables = new ArrayList<Integer>();
    for (int i = 0; i < variables.size(); i++) _variables.add(i);

    // Lax grow phase with speedup.
    if (algType == AlgType.lax) {
      Set<Integer> t = new HashSet<Integer>();
      int count = 0;
      int total = ESeeds.size();

      do {
        if (!ESeeds.iterator().hasNext()) {
          break;
        }

        Set<Integer> cluster = ESeeds.iterator().next();
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
          int rejected = 0;
          int accepted = 0;

          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            t.clear();
            t.add(n1);
            t.add(n2);
            t.add(o);

            if (!ESeeds.contains(t)) {
              rejected++;
            } else {
              accepted++;
            }
          }

          if (rejected > accepted) {
            continue;
          }

          _cluster.add(o);

          //                    if (!(avgSumLnP(new ArrayList<Integer>(_cluster)) > -10)) {
          //                        _cluster.remove(o);
          //                    }
        }

        // This takes out all pure clusters that are subsets of _cluster.
        ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
        int[] choice2;
        List<Integer> _cluster3 = new ArrayList<Integer>(_cluster);

        while ((choice2 = gen2.next()) != null) {
          int n1 = _cluster3.get(choice2[0]);
          int n2 = _cluster3.get(choice2[1]);
          int n3 = _cluster3.get(choice2[2]);

          t.clear();
          t.add(n1);
          t.add(n2);
          t.add(n3);

          ESeeds.remove(t);
        }

        if (verbose) {
          System.out.println(
              "Grown "
                  + (++count)
                  + " of "
                  + total
                  + ": "
                  + variablesForIndices(new ArrayList<Integer>(_cluster)));
        }
        grown.add(_cluster);
      } while (!ESeeds.isEmpty());
    }

    // Lax grow phase without speedup.
    if (algType == AlgType.laxWithSpeedup) {
      int count = 0;
      int total = ESeeds.size();

      // Optimized lax version of grow phase.
      for (Set<Integer> cluster : new HashSet<Set<Integer>>(ESeeds)) {
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
          int rejected = 0;
          int accepted = 0;
          //
          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            Set<Integer> triple = triple(n1, n2, o);

            if (!ESeeds.contains(triple)) {
              rejected++;
            } else {
              accepted++;
            }
          }
          //
          if (rejected > accepted) {
            continue;
          }

          //                    System.out.println("Adding " + o  + " to " + cluster);
          _cluster.add(o);
        }

        for (Set<Integer> c : new HashSet<Set<Integer>>(ESeeds)) {
          if (_cluster.containsAll(c)) {
            ESeeds.remove(c);
          }
        }

        if (verbose) {
          System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
        }

        grown.add(_cluster);
      }
    }

    // Strict grow phase.
    if (algType == AlgType.strict) {
      Set<Integer> t = new HashSet<Integer>();
      int count = 0;
      int total = ESeeds.size();

      do {
        if (!ESeeds.iterator().hasNext()) {
          break;
        }

        Set<Integer> cluster = ESeeds.iterator().next();
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        VARIABLES:
        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);

          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            t.clear();
            t.add(n1);
            t.add(n2);
            t.add(o);

            if (!ESeeds.contains(t)) {
              continue VARIABLES;
            }

            //                        if (avgSumLnP(new ArrayList<Integer>(t)) < -10) continue
            // CLUSTER;
          }

          _cluster.add(o);
        }

        // This takes out all pure clusters that are subsets of _cluster.
        ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
        int[] choice2;
        List<Integer> _cluster3 = new ArrayList<Integer>(_cluster);

        while ((choice2 = gen2.next()) != null) {
          int n1 = _cluster3.get(choice2[0]);
          int n2 = _cluster3.get(choice2[1]);
          int n3 = _cluster3.get(choice2[2]);

          t.clear();
          t.add(n1);
          t.add(n2);
          t.add(n3);

          ESeeds.remove(t);
        }

        if (verbose) {
          System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
        }
        grown.add(_cluster);
      } while (!ESeeds.isEmpty());
    }

    // Optimized pick phase.
    log("Choosing among grown Effect Clusters.", true);

    for (Set<Integer> l : grown) {
      ArrayList<Integer> _l = new ArrayList<Integer>(l);
      Collections.sort(_l);
      if (verbose) {
        log("Grown: " + variablesForIndices(_l), false);
      }
    }

    Set<Set<Integer>> out = new HashSet<Set<Integer>>();

    List<Set<Integer>> list = new ArrayList<Set<Integer>>(grown);

    //        final Map<Set<Integer>, Double> pValues = new HashMap<Set<Integer>, Double>();
    //
    //        for (Set<Integer> o : grown) {
    //            pValues.put(o, getP(new ArrayList<Integer>(o)));
    //        }

    Collections.sort(
        list,
        new Comparator<Set<Integer>>() {
          @Override
          public int compare(Set<Integer> o1, Set<Integer> o2) {
            //                if (o1.size() == o2.size()) {
            //                    double chisq1 = pValues.get(o1);
            //                    double chisq2 = pValues.get(o2);
            //                    return Double.compare(chisq2, chisq1);
            //                }

            return o2.size() - o1.size();
          }
        });

    //        for (Set<Integer> o : list) {
    //            if (pValues.get(o) < alpha) continue;
    //            System.out.println(variablesForIndices(new ArrayList<Integer>(o)) + "  p = " +
    // pValues.get(o));
    //        }

    Set<Integer> all = new HashSet<Integer>();

    CLUSTER:
    for (Set<Integer> cluster : list) {
      //            if (pValues.get(cluster) < alpha) continue;

      for (Integer i : cluster) {
        if (all.contains(i)) continue CLUSTER;
      }

      out.add(cluster);

      //            if (getPMulticluster(out) < alpha) {
      //                out.remove(cluster);
      //                continue;
      //            }

      all.addAll(cluster);
    }

    return out;
  }
コード例 #24
0
ファイル: GeneralizedSemPm.java プロジェクト: ps7z/tetrad
  public void setNodeExpression(Node node, String expressionString) throws ParseException {
    if (node == null) {
      throw new NullPointerException("Node was null.");
    }

    if (expressionString == null) {
      //            return;
      throw new NullPointerException("Expression string was null.");
    }

    // Parse the expression. This could throw an ParseException, but that exception needs to handed
    // up the
    // chain, because the interface will need it.
    ExpressionParser parser = new ExpressionParser();
    Expression expression = parser.parseExpression(expressionString);
    List<String> parameterNames = parser.getParameters();

    // Make a list of parent names.
    List<Node> parents = this.graph.getParents(node);
    List<String> parentNames = new LinkedList<>();

    for (Node parent : parents) {
      parentNames.add(parent.getName());
    }

    //        List<String> _params = new ArrayList<String>(parameterNames);
    //        _params.retainAll(variableNames);
    //        _params.removeAll(parentNames);
    //
    //        if (!_params.isEmpty()) {
    //            throw new IllegalArgumentException("Conditioning on a variable other than the
    // parents: " + node);
    //        }

    // Make a list of parameter names, by removing from the parser's list of freeParameters any that
    // correspond
    // to parent variables. If there are any variable names (including error terms) that are not
    // among the list of
    // parents, that's a time to throw an exception. We must respect the graph! (We will not
    // complain if any parents
    // are missing.)
    parameterNames.removeAll(variableNames);

    for (Node variable : nodes) {
      if (parameterNames.contains(variable.getName())) {
        parameterNames.remove(variable.getName());
        //                throw new IllegalArgumentException("The list of parameter names may not
        // include variables: " + variable.getName());
      }
    }

    // Remove old parameter references.
    List<String> parametersToRemove = new LinkedList<>();

    for (String parameter : this.referencedParameters.keySet()) {
      Set<Node> nodes = this.referencedParameters.get(parameter);

      if (nodes.contains(node)) {
        nodes.remove(node);
      }

      if (nodes.isEmpty()) {
        parametersToRemove.add(parameter);
      }
    }

    for (String parameter : parametersToRemove) {
      this.referencedParameters.remove(parameter);
      this.parameterExpressions.remove(parameter);
      this.parameterExpressionStrings.remove(parameter);
      this.parameterEstimationInitializationExpressions.remove(parameter);
      this.parameterEstimationInitializationExpressionStrings.remove(parameter);
    }

    // Add new parameter references.
    for (String parameter : parameterNames) {
      if (this.referencedParameters.get(parameter) == null) {
        this.referencedParameters.put(parameter, new HashSet<Node>());
      }

      Set<Node> nodes = this.referencedParameters.get(parameter);
      nodes.add(node);

      setSuitableParameterDistribution(parameter);
    }

    // Remove old node references.
    List<Node> nodesToRemove = new LinkedList<>();

    for (Node _node : this.referencedNodes.keySet()) {
      Set<Node> nodes = this.referencedNodes.get(_node);

      if (nodes.contains(node)) {
        nodes.remove(node);
      }

      if (nodes.isEmpty()) {
        nodesToRemove.add(_node);
      }
    }

    for (Node _node : nodesToRemove) {
      this.referencedNodes.remove(_node);
    }

    // Add new freeParameters.
    for (String variableString : variableNames) {
      Node _node = getNode(variableString);

      if (this.referencedNodes.get(_node) == null) {
        this.referencedNodes.put(_node, new HashSet<Node>());
      }

      for (String s : parentNames) {
        if (s.equals(variableString)) {
          Set<Node> nodes = this.referencedNodes.get(_node);
          nodes.add(node);
        }
      }
    }

    // Finally, save the parsed expression and the original string that the user entered. No need to
    // annoy
    // the user by changing spacing.
    nodeExpressions.put(node, expression);
    nodeExpressionStrings.put(node, expressionString);
  }
コード例 #25
0
ファイル: GeneralizedSemPm.java プロジェクト: ps7z/tetrad
  public GeneralizedSemPm(SemPm semPm) {
    this(semPm.getGraph());

    // Write down equations.
    try {
      List<Node> variableNodes = getVariableNodes();

      for (int i = 0; i < variableNodes.size(); i++) {
        Node node = variableNodes.get(i);
        List<Node> parents = getVariableParents(node);

        StringBuilder buf = new StringBuilder();

        for (int j = 0; j < parents.size(); j++) {
          if (!(variableNodes.contains(parents.get(j)))) {
            continue;
          }

          Node parent = parents.get(j);

          Parameter _parameter = semPm.getParameter(parent, node);
          String parameter = _parameter.getName();
          Set<Node> nodes = new HashSet<>();
          nodes.add(node);

          referencedParameters.put(parameter, nodes);

          buf.append(parameter);
          buf.append("*");
          buf.append(parents.get(j).getName());

          setParameterExpression(parameter, "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersTemplate(parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersEstimationInitializaationTemplate(
              parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)");

          if (j < parents.size() - 1) {
            buf.append(" + ");
          }
        }

        if (buf.toString().trim().length() != 0) {
          buf.append(" + ");
        }

        buf.append(errorNodes.get(i));
        setNodeExpression(node, buf.toString());
      }

      for (Node node : variableNodes) {
        Parameter _parameter = semPm.getParameter(node, node);
        String parameter = _parameter.getName();

        String distributionFormula = "N(0," + parameter + ")";
        setNodeExpression(getErrorNode(node), distributionFormula);
        setParameterExpression(parameter, "U(0, 1)");
        setStartsWithParametersTemplate(parameter.substring(0, 1), "U(0, 1)");
        setStartsWithParametersEstimationInitializaationTemplate(
            parameter.substring(0, 1), "U(0, 1)");
      }

      variableNames = new ArrayList<>();
      for (Node _node : variableNodes) variableNames.add(_node.getName());
      for (Node _node : errorNodes) variableNames.add(_node.getName());

    } catch (ParseException e) {
      throw new IllegalStateException("Parse error in constructing initial model.", e);
    }
  }
コード例 #26
0
  @Test
  public void test4() {
    // For X3

    Map<String, String[]> templates = new HashMap<>();

    templates.put(
        "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {});
    templates.put("$", new String[] {});
    templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"});

    for (String template : templates.keySet()) {
      GeneralizedSemPm semPm = makeTypicalPm();
      print(semPm.getGraph().toString());

      Set<Node> shouldWork = new HashSet<>();

      for (String name : templates.get(template)) {
        shouldWork.add(semPm.getNode(name));
      }

      Set<Node> works = new HashSet<>();

      for (int i = 0; i < semPm.getNodes().size(); i++) {
        print("-----------");
        print(semPm.getNodes().get(i).toString());
        print("Trying template: " + template);
        String _template = template;

        Node node = semPm.getNodes().get(i);

        try {
          _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node);
        } catch (Exception e) {
          print("Couldn't expand template: " + template);
          continue;
        }

        try {
          semPm.setNodeExpression(node, _template);
          print("Set formula " + _template + " for " + node);

          if (semPm.getVariableNodes().contains(node)) {
            works.add(node);
          }

        } catch (Exception e) {
          print("Couldn't set formula " + _template + " for " + node);
        }
      }

      for (String parameter : semPm.getParameters()) {
        print("-----------");
        print(parameter);
        print("Trying template: " + template);
        String _template = template;

        try {
          _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null);
        } catch (Exception e) {
          print("Couldn't expand template: " + template);
          continue;
        }

        try {
          semPm.setParameterExpression(parameter, _template);
          print("Set formula " + _template + " for " + parameter);
        } catch (Exception e) {
          print("Couldn't set formula " + _template + " for " + parameter);
        }
      }

      assertEquals(shouldWork, works);
    }
  }
コード例 #27
0
  /**
   * Transforms a DAG represented in graph <code>graph</code> into a maximally directed pattern
   * (PDAG) by modifying <code>g</code> itself. Based on the algorithm described in Chickering
   * (2002) "Optimal structure identification with greedy search" Journal of Machine Learning
   * Research. It works for both BayesNets and SEMs. R. Silva, June 2004
   */
  public static void dagToPdag(Graph graph) {
    // do topological sort on the nodes
    Graph graphCopy = new EdgeListGraph(graph);
    Node orderedNodes[] = new Node[graphCopy.getNodes().size()];
    int count = 0;
    while (graphCopy.getNodes().size() > 0) {
      Set<Node> exogenousNodes = new HashSet<Node>();

      for (Node next : graphCopy.getNodes()) {
        if (graphCopy.isExogenous(next)) {
          exogenousNodes.add(next);
          orderedNodes[count++] = graph.getNode(next.getName());
        }
      }

      graphCopy.removeNodes(new ArrayList<Node>(exogenousNodes));
    }
    // ordered edges - improvised, inefficient implementation
    count = 0;
    Edge edges[] = new Edge[graph.getNumEdges()];
    boolean edgeOrdered[] = new boolean[graph.getNumEdges()];
    Edge orderedEdges[] = new Edge[graph.getNumEdges()];

    for (Edge edge : graph.getEdges()) {
      edges[count++] = edge;
    }

    for (int i = 0; i < edges.length; i++) {
      edgeOrdered[i] = false;
    }

    while (count > 0) {
      for (Node orderedNode : orderedNodes) {
        for (int k = orderedNodes.length - 1; k >= 0; k--) {
          for (int q = 0; q < edges.length; q++) {
            if (!edgeOrdered[q]
                && edges[q].getNode1() == orderedNodes[k]
                && edges[q].getNode2() == orderedNode) {
              edgeOrdered[q] = true;
              orderedEdges[orderedEdges.length - count] = edges[q];
              count--;
            }
          }
        }
      }
    }

    // label edges
    boolean compelledEdges[] = new boolean[graph.getNumEdges()];
    boolean reversibleEdges[] = new boolean[graph.getNumEdges()];
    for (int i = 0; i < graph.getNumEdges(); i++) {
      compelledEdges[i] = false;
      reversibleEdges[i] = false;
    }
    for (int i = 0; i < graph.getNumEdges(); i++) {
      if (compelledEdges[i] || reversibleEdges[i]) {
        continue;
      }
      Node x = orderedEdges[i].getNode1();
      Node y = orderedEdges[i].getNode2();
      for (int j = 0; j < orderedEdges.length; j++) {
        if (orderedEdges[j].getNode2() == x && compelledEdges[j]) {
          Node w = orderedEdges[j].getNode1();
          if (!graph.isParentOf(w, y)) {
            for (int k = 0; k < orderedEdges.length; k++) {
              if (orderedEdges[k].getNode2() == y) {
                compelledEdges[k] = true;
                break;
              }
            }
          } else {
            for (int k = 0; k < orderedEdges.length; k++) {
              if (orderedEdges[k].getNode1() == w && orderedEdges[k].getNode2() == y) {
                compelledEdges[k] = true;
                break;
              }
            }
          }
        }
        if (compelledEdges[i]) {
          break;
        }
      }
      if (compelledEdges[i]) {
        continue;
      }
      boolean foundZ = false;

      for (Edge orderedEdge : orderedEdges) {
        Node z = orderedEdge.getNode1();
        if (z != x && orderedEdge.getNode2() == y && !graph.isParentOf(z, x)) {
          compelledEdges[i] = true;
          for (int k = i + 1; k < graph.getNumEdges(); k++) {
            if (orderedEdges[k].getNode2() == y && !reversibleEdges[k]) {
              compelledEdges[k] = true;
            }
          }
          foundZ = true;
          break;
        }
      }

      if (!foundZ) {
        reversibleEdges[i] = true;

        for (int j = i + 1; j < orderedEdges.length; j++) {
          if (!compelledEdges[j] && orderedEdges[j].getNode2() == y) {
            reversibleEdges[j] = true;
          }
        }
      }
    }

    // undirect edges that are reversible
    for (int i = 0; i < reversibleEdges.length; i++) {
      if (reversibleEdges[i]) {
        graph.setEndpoint(orderedEdges[i].getNode1(), orderedEdges[i].getNode2(), Endpoint.TAIL);
        graph.setEndpoint(orderedEdges[i].getNode2(), orderedEdges[i].getNode1(), Endpoint.TAIL);
      }
    }
  }
コード例 #28
0
  // Finds clusters of size 4 or higher.
  private Set<Set<Integer>> findPureClusters(
      List<Integer> _variables, Map<Node, Set<Node>> adjacencies) {
    //        System.out.println("Original variables = " + variables);

    Set<Set<Integer>> clusters = new HashSet<Set<Integer>>();
    List<Integer> allVariables = new ArrayList<Integer>();
    for (int i = 0; i < this.variables.size(); i++) allVariables.add(i);

    VARIABLES:
    while (!_variables.isEmpty()) {
      if (_variables.size() < 4) break;

      for (int x : _variables) {
        Node nodeX = variables.get(x);
        List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX));
        adjX.retainAll(variablesForIndices(new ArrayList<Integer>(_variables)));

        for (Node node : new ArrayList<Node>(adjX)) {
          if (adjacencies.get(node).size() < 3) {
            adjX.remove(node);
          }
        }

        if (adjX.size() < 3) {
          continue;
        }

        ChoiceGenerator gen = new ChoiceGenerator(adjX.size(), 3);
        int[] choice;

        while ((choice = gen.next()) != null) {
          Node nodeY = adjX.get(choice[0]);
          Node nodeZ = adjX.get(choice[1]);
          Node nodeW = adjX.get(choice[2]);

          int y = variables.indexOf(nodeY);
          int w = variables.indexOf(nodeW);
          int z = variables.indexOf(nodeZ);

          Set<Integer> cluster = quartet(x, y, z, w);

          if (!clique(cluster, adjacencies)) {
            continue;
          }

          // Note that purity needs to be assessed with respect to all of the variables in order to
          // remove all latent-measure impurities between pairs of latents.
          if (pure(cluster, allVariables)) {

            //                        Collections.shuffle(_variables);

            O:
            for (int o : _variables) {
              if (cluster.contains(o)) continue;
              cluster.add(o);
              List<Integer> _cluster = new ArrayList<Integer>(cluster);

              if (!clique(cluster, adjacencies)) {
                cluster.remove(o);
                continue O;
              }

              //                            if (!allVariablesDependent(cluster)) {
              //                                cluster.remove(o);
              //                                continue O;
              //                            }

              ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4);
              int[] choice2;
              int count = 0;

              while ((choice2 = gen2.next()) != null) {
                int x2 = _cluster.get(choice2[0]);
                int y2 = _cluster.get(choice2[1]);
                int z2 = _cluster.get(choice2[2]);
                int w2 = _cluster.get(choice2[3]);

                Set<Integer> quartet = quartet(x2, y2, z2, w2);

                // Optimizes for large clusters.
                if (quartet.contains(o)) {
                  if (++count > 50) continue O;
                }

                if (quartet.contains(o) && !pure(quartet, allVariables)) {
                  cluster.remove(o);
                  continue O;
                }
              }
            }

            System.out.println(
                "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));
            clusters.add(cluster);
            _variables.removeAll(cluster);

            continue VARIABLES;
          }
        }
      }

      break;
    }

    return clusters;
  }