コード例 #1
0
  // Quartets first , then triples.
  private Set<Set<Integer>> estimateClustersSAG() {
    Map<Node, Set<Node>> adjacencies;

    if (depth == -2) {
      adjacencies = new HashMap<Node, Set<Node>>();

      for (Node node : variables) {
        HashSet<Node> _nodes = new HashSet<Node>(variables);
        _nodes.remove(node);
        adjacencies.put(node, _nodes);
      }
    } else {
      System.out.println("Running PC adjacency search...");
      Graph graph = new EdgeListGraph(variables);
      Fas fas = new Fas(graph, indTest);
      fas.setDepth(depth); // 1?
      adjacencies = fas.searchMapOnly();
      System.out.println("...done.");
    }

    List<Integer> _variables = new ArrayList<Integer>();
    for (int i = 0; i < variables.size(); i++) _variables.add(i);

    Set<Set<Integer>> pureClusters = findPureClusters(_variables, adjacencies);
    for (Set<Integer> cluster : pureClusters) _variables.removeAll(cluster);
    Set<Set<Integer>> mixedClusters =
        findMixedClusters(_variables, unionPure(pureClusters), adjacencies);
    Set<Set<Integer>> allClusters = new HashSet<Set<Integer>>(pureClusters);
    allClusters.addAll(mixedClusters);
    return allClusters;
  }
コード例 #2
0
  private Void findSeeds() {
    Tetrad tetrad = null;
    List<Node> empty = new ArrayList();
    if (variables.size() < 4) {
      Set<Set<Integer>> ESeeds = new HashSet<Set<Integer>>();
    }

    Map<Node, Set<Node>> adjacencies;

    if (depth == -2) {
      adjacencies = new HashMap<Node, Set<Node>>();

      for (Node node : variables) {
        HashSet<Node> _nodes = new HashSet<Node>(variables);
        _nodes.remove(node);
        adjacencies.put(node, _nodes);
      }
    } else {
      //            System.out.println("Running PC adjacency search...");
      Graph graph = new EdgeListGraph(variables);
      Fas fas = new Fas(graph, indTest);
      fas.setVerbose(false);
      fas.setDepth(depth); // 1?
      adjacencies = fas.searchMapOnly();
      //            System.out.println("...done.");
    }

    List<Integer> allVariables = new ArrayList<Integer>();
    for (int i = 0; i < variables.size(); i++) allVariables.add(i);

    log("Finding seeds.", true);

    ChoiceGenerator gen = new ChoiceGenerator(allVariables.size(), 3);
    int[] choice;
    CHOICE:
    while ((choice = gen.next()) != null) {
      int n1 = allVariables.get(choice[0]);
      int n2 = allVariables.get(choice[1]);
      int n3 = allVariables.get(choice[2]);
      Node v1 = variables.get(choice[0]);
      Node v2 = variables.get(choice[1]);
      Node v3 = variables.get(choice[2]);

      Set<Integer> triple = triple(n1, n2, n3);

      if (!clique(triple, adjacencies)) {
        continue;
      }

      boolean EPure = true;
      boolean CPure1 = true;
      boolean CPure2 = true;
      boolean CPure3 = true;

      for (int o : allVariables) {
        if (triple.contains(o)) {
          continue;
        }

        Node v4 = variables.get(o);
        tetrad = new Tetrad(v1, v2, v3, v4);

        if (deltaTest.getPValue(tetrad) > alpha) {
          EPure = false;
          if (indTest.isDependent(v1, v4, empty)) {
            CPure1 = false;
          }
          if (indTest.isDependent(v2, v4, empty)) {
            CPure2 = false;
          }
        }
        tetrad = new Tetrad(v1, v3, v2, v4);
        if (deltaTest.getPValue(tetrad) > alpha) {
          EPure = false;
          if (indTest.isDependent(v3, v4, empty)) {
            CPure3 = false;
          }
        }

        if (!(EPure || CPure1 || CPure2 || CPure3)) {
          continue CHOICE;
        }
      }

      HashSet<Integer> _cluster = new HashSet<Integer>(triple);

      if (verbose) {
        log("++" + variablesForIndices(new ArrayList<Integer>(triple)), false);
      }

      if (EPure) {
        ESeeds.add(_cluster);
      }
      if (!EPure) {
        if (CPure1) {
          Set<Integer> _cluster1 = new HashSet<Integer>(n2, n3);
          _cluster1.addAll(CSeeds.get(n1));
          CSeeds.set(n1, _cluster1);
        }
        if (CPure2) {
          Set<Integer> _cluster2 = new HashSet<Integer>(n1, n3);
          _cluster2.addAll(CSeeds.get(n2));
          CSeeds.set(n2, _cluster2);
        }
        if (CPure3) {
          Set<Integer> _cluster3 = new HashSet<Integer>(n1, n2);
          _cluster3.addAll(CSeeds.get(n3));
          CSeeds.set(n3, _cluster3);
        }
      }
    }
    return null;
  }