示例#1
0
文件: PcMax.java 项目: renjiey/tetrad
  public List<Triple> getUnshieldedCollidersFromGraph(Graph graph) {
    List<Triple> colliders = new ArrayList<>();

    List<Node> nodes = graph.getNodes();

    for (Node b : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(b);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node a = adjacentNodes.get(combination[0]);
        Node c = adjacentNodes.get(combination[1]);

        // Skip triples that are shielded.
        if (graph.isAdjacentTo(a, c)) {
          continue;
        }

        if (graph.isDefCollider(a, b, c)) {
          colliders.add(new Triple(a, b, c));
        }
      }
    }

    return colliders;
  }
示例#2
0
  ////////////////////////////////////////////////
  // collect in rTupleList all unshielded tuples
  ////////////////////////////////////////////////
  private List<Node[]> getRTuples() {
    List<Node[]> rTuples = new ArrayList<Node[]>();
    List<Node> nodes = graph.getNodes();

    for (Node j : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(j);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node i = adjacentNodes.get(combination[0]);
        Node k = adjacentNodes.get(combination[1]);

        // Skip triples that are shielded.
        if (!graph.isAdjacentTo(i, k)) {
          Node[] newTuple = {i, j, k};
          rTuples.add(newTuple);
        }
      }
    }

    return (rTuples);
  }
示例#3
0
文件: PcMax.java 项目: renjiey/tetrad
  /**
   * Step C of PC; orients colliders using specified sepset. That is, orients x *-* y *-* z as x *->
   * y <-* z just in case y is in Sepset({x, z}).
   */
  public Map<Triple, Double> findCollidersUsingSepsets(
      SepsetProducer sepsetProducer, Graph graph, boolean verbose, IKnowledge knowledge) {
    TetradLogger.getInstance().log("details", "Starting Collider Orientation:");
    Map<Triple, Double> colliders = new HashMap<>();

    System.out.println("Looking for colliders");

    List<Node> nodes = graph.getNodes();

    for (Node b : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(b);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node a = adjacentNodes.get(combination[0]);
        Node c = adjacentNodes.get(combination[1]);

        // Skip triples that are shielded.
        if (graph.isAdjacentTo(a, c)) {
          continue;
        }

        List<Node> sepset = sepsetProducer.getSepset(a, c);

        if (sepset == null) continue;

        //                if (sepsetProducer.getPValue() < 0.5) continue;

        if (!sepset.contains(b)) {
          if (verbose) {
            //                        boolean dsep = this.dsep.isIndependent(a, c);
            //                        System.out.println("QQQ p = " + independenceTest.getPValue() +
            // " " + dsep);

            System.out.println(
                "\nCollider orientation <" + a + ", " + b + ", " + c + "> sepset = " + sepset);
          }

          colliders.put(new Triple(a, b, c), sepsetProducer.getPValue());

          TetradLogger.getInstance()
              .log("colliderOrientations", SearchLogUtils.colliderOrientedMsg(a, b, c, sepset));
        }
      }
    }

    TetradLogger.getInstance().log("details", "Finishing Collider Orientation.");

    System.out.println("Done finding colliders");

    return colliders;
  }
  public static boolean meekR1Locally2(
      Graph graph, Knowledge knowledge, IndependenceTest test, int depth) {
    List<Node> nodes = graph.getNodes();
    boolean changed = true;

    while (changed) {
      changed = false;

      for (Node a : nodes) {
        List<Node> adjacentNodes = graph.getAdjacentNodes(a);

        if (adjacentNodes.size() < 2) {
          continue;
        }

        ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
        int[] combination;

        while ((combination = cg.next()) != null) {
          Node b = adjacentNodes.get(combination[0]);
          Node c = adjacentNodes.get(combination[1]);

          // Skip triples that are shielded.
          if (graph.isAdjacentTo(b, c)) {
            continue;
          }

          if (graph.getEndpoint(b, a) == Endpoint.ARROW && graph.isUndirectedFromTo(a, c)) {
            if (existsLocalSepsetWithoutDet(b, a, c, test, graph, depth)) {
              continue;
            }

            if (isArrowpointAllowed(a, c, knowledge)) {
              graph.setEndpoint(a, c, Endpoint.ARROW);
              TetradLogger.getInstance()
                  .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R1", graph.getEdge(a, c)));
              changed = true;
            }
          } else if (graph.getEndpoint(c, a) == Endpoint.ARROW && graph.isUndirectedFromTo(a, b)) {
            if (existsLocalSepsetWithoutDet(b, a, c, test, graph, depth)) {
              continue;
            }

            if (isArrowpointAllowed(a, b, knowledge)) {
              graph.setEndpoint(a, b, Endpoint.ARROW);
              TetradLogger.getInstance()
                  .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R1", graph.getEdge(a, b)));
              changed = true;
            }
          }
        }
      }
    }

    return changed;
  }
  /**
   * Performs step C of the algorithm, as indicated on page xxx of CPS, with the modification that
   * X--W--Y is oriented as X-->W<--Y if W is *determined by* the sepset of (X, Y), rather than W
   * just being *in* the sepset of (X, Y).
   */
  public static void pcdOrientC(
      SepsetMap set, IndependenceTest test, Knowledge knowledge, Graph graph) {
    TetradLogger.getInstance().log("info", "Staring Collider Orientation:");

    List<Node> nodes = graph.getNodes();

    for (Node y : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(y);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node x = adjacentNodes.get(combination[0]);
        Node z = adjacentNodes.get(combination[1]);

        // Skip triples that are shielded.
        if (graph.isAdjacentTo(x, z)) {
          continue;
        }

        List<Node> sepset = set.get(x, z);

        if (sepset == null) {
          continue;
        }

        List<Node> augmentedSet = new LinkedList<Node>(sepset);
        augmentedSet.add(y);

        if (test.determines(sepset, y)) {
          continue;
        }
        //
        if (!test.splitDetermines(sepset, x, z) && test.splitDetermines(augmentedSet, x, z)) {
          continue;
        }

        if (!isArrowpointAllowed(x, y, knowledge) || !isArrowpointAllowed(z, y, knowledge)) {
          continue;
        }

        graph.setEndpoint(x, y, Endpoint.ARROW);
        graph.setEndpoint(z, y, Endpoint.ARROW);

        TetradLogger.getInstance()
            .log("colliderOriented", SearchLogUtils.colliderOrientedMsg(x, y, z));
      }
    }

    TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
  }
  /** Meek's rule R3. If a--b, a--c, a--d, c-->b, c-->b, then orient a-->b. */
  public static boolean meekR3(Graph graph, Knowledge knowledge) {

    List<Node> nodes = graph.getNodes();
    boolean changed = false;

    for (Node a : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(a);

      if (adjacentNodes.size() < 3) {
        continue;
      }

      for (Node b : adjacentNodes) {
        List<Node> otherAdjacents = new LinkedList<Node>(adjacentNodes);
        otherAdjacents.remove(b);

        if (!graph.isUndirectedFromTo(a, b)) {
          continue;
        }

        ChoiceGenerator cg = new ChoiceGenerator(otherAdjacents.size(), 2);
        int[] combination;

        while ((combination = cg.next()) != null) {
          Node c = otherAdjacents.get(combination[0]);
          Node d = otherAdjacents.get(combination[1]);

          if (graph.isAdjacentTo(c, d)) {
            continue;
          }

          if (!graph.isUndirectedFromTo(a, c)) {
            continue;
          }

          if (!graph.isUndirectedFromTo(a, d)) {
            continue;
          }

          if (graph.isDirectedFromTo(c, b) && graph.isDirectedFromTo(d, b)) {
            if (isArrowpointAllowed(a, b, knowledge)) {
              graph.setEndpoint(a, b, Endpoint.ARROW);
              TetradLogger.getInstance()
                  .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R3", graph.getEdge(a, b)));
              changed = true;
              break;
            }
          }
        }
      }
    }

    return changed;
  }
示例#7
0
  /////////////////////////////////////////////////////////////////////////////
  // set the sepSet of x and y to the minimal such subset of the given sepSet
  // and remove the edge <x, y> if background knowledge allows
  /////////////////////////////////////////////////////////////////////////////
  private void setMinSepSet(List<Node> sepSet, Node x, Node y) {
    // It is assumed that BK has been considered before calling this method
    // (for example, setting independent1 and independent2 in ruleR0_RFCI)
    /*
          // background knowledge requires this edge
    if (knowledge.noEdgeRequired(x.getNode(), y.getNode()))
    {
    	return;
    }
     */

    List<Node> empty = Collections.emptyList();
    boolean indep;

    try {
      indep = independenceTest.isIndependent(x, y, empty);
    } catch (Exception e) {
      indep = false;
    }

    if (indep) {
      getSepsets().set(x, y, empty);
      return;
    }

    int sepSetSize = sepSet.size();
    for (int i = 1; i <= sepSetSize; i++) {
      ChoiceGenerator cg = new ChoiceGenerator(sepSetSize, i);
      int[] combination;

      while ((combination = cg.next()) != null) {
        List<Node> condSet = GraphUtils.asList(combination, sepSet);

        try {
          indep = independenceTest.isIndependent(x, y, condSet);
        } catch (Exception e) {
          indep = false;
        }

        if (indep) {
          getSepsets().set(x, y, condSet);
          return;
        }
      }
    }
  }
  public static boolean existsLocalSepsetWithoutDet(
      Node x, Node y, Node z, IndependenceTest test, Graph graph, int depth) {
    Set<Node> __nodes = new HashSet<Node>(graph.getAdjacentNodes(x));
    __nodes.addAll(graph.getAdjacentNodes(z));
    __nodes.remove(x);
    __nodes.remove(z);
    List<Node> _nodes = new LinkedList<Node>(__nodes);
    TetradLogger.getInstance()
        .log("adjacencies", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes);

    int _depth = depth;
    if (_depth == -1) {
      _depth = 1000;
    }
    _depth = Math.min(_depth, _nodes.size());

    for (int d = 0; d <= _depth; d++) {
      if (_nodes.size() >= d) {
        ChoiceGenerator cg2 = new ChoiceGenerator(_nodes.size(), d);
        int[] choice;

        while ((choice = cg2.next()) != null) {
          List<Node> condSet = asList(choice, _nodes);

          if (condSet.contains(y)) {
            continue;
          }

          if (test.determines(condSet, y)) {
            continue;
          }

          //        LogUtils.getInstance().finest("Trying " + condSet);

          if (test.isIndependent(x, z, condSet)) {
            return true;
          }
        }
      }
    }

    return false;
  }
  /**
   * Step C of PC; orients colliders using specified sepset. That is, orients x *-* y *-* z as x *->
   * y <-* z just in case y is in Sepset({x, z}).
   */
  public static void orientCollidersUsingSepsets(SepsetMap set, Knowledge knowledge, Graph graph) {
    TetradLogger.getInstance().log("info", "Starting Collider Orientation:");

    //        verifySepsetIntegrity(set, graph);

    List<Node> nodes = graph.getNodes();

    for (Node a : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(a);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node b = adjacentNodes.get(combination[0]);
        Node c = adjacentNodes.get(combination[1]);

        // Skip triples that are shielded.
        if (graph.isAdjacentTo(b, c)) {
          continue;
        }

        List<Node> sepset = set.get(b, c);
        if (sepset != null
            && !sepset.contains(a)
            && isArrowpointAllowed(b, a, knowledge)
            && isArrowpointAllowed(c, a, knowledge)) {
          graph.setEndpoint(b, a, Endpoint.ARROW);
          graph.setEndpoint(c, a, Endpoint.ARROW);
          TetradLogger.getInstance()
              .log("colliderOriented", SearchLogUtils.colliderOrientedMsg(b, a, c, sepset));
        }
      }
    }

    TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
  }
  /** If */
  public static boolean meekR2(Graph graph, Knowledge knowledge) {
    List<Node> nodes = graph.getNodes();
    boolean changed = false;

    for (Node a : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(a);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node b = adjacentNodes.get(combination[0]);
        Node c = adjacentNodes.get(combination[1]);

        if (graph.isDirectedFromTo(b, a)
            && graph.isDirectedFromTo(a, c)
            && graph.isUndirectedFromTo(b, c)) {
          if (isArrowpointAllowed(b, c, knowledge)) {
            graph.setEndpoint(b, c, Endpoint.ARROW);
            TetradLogger.getInstance()
                .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R2", graph.getEdge(b, c)));
          }
        } else if (graph.isDirectedFromTo(c, a)
            && graph.isDirectedFromTo(a, b)
            && graph.isUndirectedFromTo(c, b)) {
          if (isArrowpointAllowed(c, b, knowledge)) {
            graph.setEndpoint(c, b, Endpoint.ARROW);
            TetradLogger.getInstance()
                .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R2", graph.getEdge(c, b)));
          }
        }
      }
    }

    return changed;
  }
  public static void orientCollidersLocally(
      Knowledge knowledge, Graph graph, IndependenceTest test, int depth, Set<Node> nodesToVisit) {
    TetradLogger.getInstance().log("info", "Starting Collider Orientation:");

    if (nodesToVisit == null) {
      nodesToVisit = new HashSet<Node>(graph.getNodes());
    }

    for (Node a : nodesToVisit) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(a);

      if (adjacentNodes.size() < 2) {
        continue;
      }

      ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2);
      int[] combination;

      while ((combination = cg.next()) != null) {
        Node b = adjacentNodes.get(combination[0]);
        Node c = adjacentNodes.get(combination[1]);

        // Skip triples that are shielded.
        if (graph.isAdjacentTo(b, c)) {
          continue;
        }

        if (isArrowpointAllowed1(b, a, knowledge) && isArrowpointAllowed1(c, a, knowledge)) {
          if (!existsLocalSepsetWith(b, a, c, test, graph, depth)) {
            graph.setEndpoint(b, a, Endpoint.ARROW);
            graph.setEndpoint(c, a, Endpoint.ARROW);
            TetradLogger.getInstance()
                .log("colliderOriented", SearchLogUtils.colliderOrientedMsg(b, a, c));
          }
        }
      }
    }

    TetradLogger.getInstance().log("info", "Finishing Collider Orientation.");
  }
  // Finds clusters of size 4 or higher.
  private Set<Set<Integer>> findPureClusters(
      List<Integer> _variables, Map<Node, Set<Node>> adjacencies) {
    //        System.out.println("Original variables = " + variables);

    Set<Set<Integer>> clusters = new HashSet<Set<Integer>>();
    List<Integer> allVariables = new ArrayList<Integer>();
    for (int i = 0; i < this.variables.size(); i++) allVariables.add(i);

    VARIABLES:
    while (!_variables.isEmpty()) {
      if (_variables.size() < 4) break;

      for (int x : _variables) {
        Node nodeX = variables.get(x);
        List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX));
        adjX.retainAll(variablesForIndices(new ArrayList<Integer>(_variables)));

        for (Node node : new ArrayList<Node>(adjX)) {
          if (adjacencies.get(node).size() < 3) {
            adjX.remove(node);
          }
        }

        if (adjX.size() < 3) {
          continue;
        }

        ChoiceGenerator gen = new ChoiceGenerator(adjX.size(), 3);
        int[] choice;

        while ((choice = gen.next()) != null) {
          Node nodeY = adjX.get(choice[0]);
          Node nodeZ = adjX.get(choice[1]);
          Node nodeW = adjX.get(choice[2]);

          int y = variables.indexOf(nodeY);
          int w = variables.indexOf(nodeW);
          int z = variables.indexOf(nodeZ);

          Set<Integer> cluster = quartet(x, y, z, w);

          if (!clique(cluster, adjacencies)) {
            continue;
          }

          // Note that purity needs to be assessed with respect to all of the variables in order to
          // remove all latent-measure impurities between pairs of latents.
          if (pure(cluster, allVariables)) {

            //                        Collections.shuffle(_variables);

            O:
            for (int o : _variables) {
              if (cluster.contains(o)) continue;
              cluster.add(o);
              List<Integer> _cluster = new ArrayList<Integer>(cluster);

              if (!clique(cluster, adjacencies)) {
                cluster.remove(o);
                continue O;
              }

              //                            if (!allVariablesDependent(cluster)) {
              //                                cluster.remove(o);
              //                                continue O;
              //                            }

              ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4);
              int[] choice2;
              int count = 0;

              while ((choice2 = gen2.next()) != null) {
                int x2 = _cluster.get(choice2[0]);
                int y2 = _cluster.get(choice2[1]);
                int z2 = _cluster.get(choice2[2]);
                int w2 = _cluster.get(choice2[3]);

                Set<Integer> quartet = quartet(x2, y2, z2, w2);

                // Optimizes for large clusters.
                if (quartet.contains(o)) {
                  if (++count > 50) continue O;
                }

                if (quartet.contains(o) && !pure(quartet, allVariables)) {
                  cluster.remove(o);
                  continue O;
                }
              }
            }

            System.out.println(
                "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));
            clusters.add(cluster);
            _variables.removeAll(cluster);

            continue VARIABLES;
          }
        }
      }

      break;
    }

    return clusters;
  }
  //  Finds clusters of size 3.
  private Set<Set<Integer>> findMixedClusters(
      List<Integer> remaining, Set<Integer> unionPure, Map<Node, Set<Node>> adjacencies) {
    Set<Set<Integer>> threeClusters = new HashSet<Set<Integer>>();

    if (unionPure.isEmpty()) {
      return new HashSet<Set<Integer>>();
    }

    REMAINING:
    while (true) {
      if (remaining.size() < 3) break;

      ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 3);
      int[] choice;

      while ((choice = gen.next()) != null) {
        int y = remaining.get(choice[0]);
        int z = remaining.get(choice[1]);
        int w = remaining.get(choice[2]);

        Set<Integer> cluster = new HashSet<Integer>();
        cluster.add(y);
        cluster.add(z);
        cluster.add(w);

        //                if (!allVariablesDependent(cluster)) {
        //                    continue;
        //                }

        if (!clique(cluster, adjacencies)) {
          continue;
        }

        // Check all x as a cross check; really only one should be necessary.
        boolean allX = true;

        for (int x : unionPure) {
          Set<Integer> _cluster = new HashSet<Integer>(cluster);
          _cluster.add(x);

          if (!quartetVanishes(_cluster) || !significant(new ArrayList<Integer>(_cluster))) {
            allX = false;
            break;
          }
        }

        if (allX) {
          threeClusters.add(cluster);
          unionPure.addAll(cluster);
          remaining.removeAll(cluster);

          System.out.println(
              "3-cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));

          continue REMAINING;
        }
      }

      break;
    }

    return threeClusters;
  }
  private Void findSeeds() {
    Tetrad tetrad = null;
    List<Node> empty = new ArrayList();
    if (variables.size() < 4) {
      Set<Set<Integer>> ESeeds = new HashSet<Set<Integer>>();
    }

    Map<Node, Set<Node>> adjacencies;

    if (depth == -2) {
      adjacencies = new HashMap<Node, Set<Node>>();

      for (Node node : variables) {
        HashSet<Node> _nodes = new HashSet<Node>(variables);
        _nodes.remove(node);
        adjacencies.put(node, _nodes);
      }
    } else {
      //            System.out.println("Running PC adjacency search...");
      Graph graph = new EdgeListGraph(variables);
      Fas fas = new Fas(graph, indTest);
      fas.setVerbose(false);
      fas.setDepth(depth); // 1?
      adjacencies = fas.searchMapOnly();
      //            System.out.println("...done.");
    }

    List<Integer> allVariables = new ArrayList<Integer>();
    for (int i = 0; i < variables.size(); i++) allVariables.add(i);

    log("Finding seeds.", true);

    ChoiceGenerator gen = new ChoiceGenerator(allVariables.size(), 3);
    int[] choice;
    CHOICE:
    while ((choice = gen.next()) != null) {
      int n1 = allVariables.get(choice[0]);
      int n2 = allVariables.get(choice[1]);
      int n3 = allVariables.get(choice[2]);
      Node v1 = variables.get(choice[0]);
      Node v2 = variables.get(choice[1]);
      Node v3 = variables.get(choice[2]);

      Set<Integer> triple = triple(n1, n2, n3);

      if (!clique(triple, adjacencies)) {
        continue;
      }

      boolean EPure = true;
      boolean CPure1 = true;
      boolean CPure2 = true;
      boolean CPure3 = true;

      for (int o : allVariables) {
        if (triple.contains(o)) {
          continue;
        }

        Node v4 = variables.get(o);
        tetrad = new Tetrad(v1, v2, v3, v4);

        if (deltaTest.getPValue(tetrad) > alpha) {
          EPure = false;
          if (indTest.isDependent(v1, v4, empty)) {
            CPure1 = false;
          }
          if (indTest.isDependent(v2, v4, empty)) {
            CPure2 = false;
          }
        }
        tetrad = new Tetrad(v1, v3, v2, v4);
        if (deltaTest.getPValue(tetrad) > alpha) {
          EPure = false;
          if (indTest.isDependent(v3, v4, empty)) {
            CPure3 = false;
          }
        }

        if (!(EPure || CPure1 || CPure2 || CPure3)) {
          continue CHOICE;
        }
      }

      HashSet<Integer> _cluster = new HashSet<Integer>(triple);

      if (verbose) {
        log("++" + variablesForIndices(new ArrayList<Integer>(triple)), false);
      }

      if (EPure) {
        ESeeds.add(_cluster);
      }
      if (!EPure) {
        if (CPure1) {
          Set<Integer> _cluster1 = new HashSet<Integer>(n2, n3);
          _cluster1.addAll(CSeeds.get(n1));
          CSeeds.set(n1, _cluster1);
        }
        if (CPure2) {
          Set<Integer> _cluster2 = new HashSet<Integer>(n1, n3);
          _cluster2.addAll(CSeeds.get(n2));
          CSeeds.set(n2, _cluster2);
        }
        if (CPure3) {
          Set<Integer> _cluster3 = new HashSet<Integer>(n1, n2);
          _cluster3.addAll(CSeeds.get(n3));
          CSeeds.set(n3, _cluster3);
        }
      }
    }
    return null;
  }
  public static CpcTripleType getCpcTripleType(
      Node x, Node y, Node z, IndependenceTest test, int depth, Graph graph) {
    //    	System.out.println("getCpcTripleType 1");

    boolean existsSepsetContainingY = false;
    boolean existsSepsetNotContainingY = false;

    Set<Node> __nodes = new HashSet<Node>(graph.getAdjacentNodes(x));
    __nodes.remove(z);

    //    	System.out.println("getCpcTripleType 2");

    List<Node> _nodes = new LinkedList<Node>(__nodes);
    TetradLogger.getInstance()
        .log("adjacencies", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes);

    //        System.out.println("getCpcTripleType 3");

    int _depth = depth;
    if (_depth == -1) {
      _depth = 1000;
    }
    _depth = Math.min(_depth, _nodes.size());

    //    	System.out.println("getCpcTripleType 4");

    for (int d = 0; d <= _depth; d++) {
      //        	System.out.println("getCpcTripleType 5");

      ChoiceGenerator cg = new ChoiceGenerator(_nodes.size(), d);
      int[] choice;

      while ((choice = cg.next()) != null) {
        //            	System.out.println("getCpcTripleType 6");

        List<Node> condSet = GraphUtils.asList(choice, _nodes);

        //            	System.out.println("getCpcTripleType 7");

        if (test.isIndependent(x, z, condSet)) {
          if (condSet.contains(y)) {
            existsSepsetContainingY = true;
          } else {
            existsSepsetNotContainingY = true;
          }
        }
      }
    }

    //    	System.out.println("getCpcTripleType 8");

    __nodes = new HashSet<Node>(graph.getAdjacentNodes(z));
    __nodes.remove(x);

    _nodes = new LinkedList<Node>(__nodes);
    TetradLogger.getInstance()
        .log("adjacencies", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes);

    //    	System.out.println("getCpcTripleType 9");

    _depth = depth;
    if (_depth == -1) {
      _depth = 1000;
    }
    _depth = Math.min(_depth, _nodes.size());

    //    	System.out.println("getCpcTripleType 10");

    for (int d = 0; d <= _depth; d++) {
      //        	System.out.println("getCpcTripleType 11");

      ChoiceGenerator cg = new ChoiceGenerator(_nodes.size(), d);
      int[] choice;

      while ((choice = cg.next()) != null) {
        List<Node> condSet = GraphUtils.asList(choice, _nodes);

        if (test.isIndependent(x, z, condSet)) {
          if (condSet.contains(y)) {
            existsSepsetContainingY = true;
          } else {
            existsSepsetNotContainingY = true;
          }
        }
      }
    }

    //    	System.out.println("getCpcTripleType 12");

    if (existsSepsetContainingY == existsSepsetNotContainingY) {
      return CpcTripleType.AMBIGUOUS;
    } else if (!existsSepsetNotContainingY) {
      return CpcTripleType.NONCOLLIDER;
    } else {
      return CpcTripleType.COLLIDER;
    }
  }
  // Trying to optimize the search for 4-cliques a bit.
  private Set<Set<Integer>> findPureClusters2(
      List<Integer> _variables, Map<Node, Set<Node>> adjacencies) {
    System.out.println("Original variables = " + variables);

    Set<Set<Integer>> clusters = new HashSet<Set<Integer>>();
    List<Integer> allVariables = new ArrayList<Integer>();
    Set<Node> foundVariables = new HashSet<Node>();
    for (int i = 0; i < this.variables.size(); i++) allVariables.add(i);

    for (int x : _variables) {
      Node nodeX = variables.get(x);
      if (foundVariables.contains(nodeX)) continue;

      List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX));
      adjX.removeAll(foundVariables);

      if (adjX.size() < 3) continue;

      for (Node nodeY : adjX) {
        if (foundVariables.contains(nodeY)) continue;

        List<Node> commonXY = new ArrayList<Node>(adjacencies.get(nodeY));
        commonXY.retainAll(adjX);
        commonXY.removeAll(foundVariables);

        for (Node nodeZ : commonXY) {
          if (foundVariables.contains(nodeZ)) continue;

          List<Node> commonXZ = new ArrayList<Node>(commonXY);
          commonXZ.retainAll(adjacencies.get(nodeZ));
          commonXZ.removeAll(foundVariables);

          for (Node nodeW : commonXZ) {
            if (foundVariables.contains(nodeW)) continue;

            if (!adjacencies.get(nodeY).contains(nodeW)) {
              continue;
            }

            int y = variables.indexOf(nodeY);
            int w = variables.indexOf(nodeW);
            int z = variables.indexOf(nodeZ);

            Set<Integer> cluster = quartet(x, y, z, w);

            // Note that purity needs to be assessed with respect to all of the variables in order
            // to
            // remove all latent-measure impurities between pairs of latents.
            if (pure(cluster, allVariables)) {

              O:
              for (int o : _variables) {
                if (cluster.contains(o)) continue;
                cluster.add(o);

                if (!clique(cluster, adjacencies)) {
                  cluster.remove(o);
                  continue O;
                }

                //                                if (!allVariablesDependent(cluster)) {
                //                                    cluster.remove(o);
                //                                    continue O;
                //                                }

                List<Integer> _cluster = new ArrayList<Integer>(cluster);

                ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4);
                int[] choice2;
                int count = 0;

                while ((choice2 = gen2.next()) != null) {
                  int x2 = _cluster.get(choice2[0]);
                  int y2 = _cluster.get(choice2[1]);
                  int z2 = _cluster.get(choice2[2]);
                  int w2 = _cluster.get(choice2[3]);

                  Set<Integer> quartet = quartet(x2, y2, z2, w2);

                  // Optimizes for large clusters.
                  if (quartet.contains(o)) {
                    if (++count > 2) continue O;
                  }

                  if (quartet.contains(o) && !pure(quartet, allVariables)) {
                    cluster.remove(o);
                    continue O;
                  }
                }
              }

              System.out.println(
                  "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));
              clusters.add(cluster);
              foundVariables.addAll(variablesForIndices(new ArrayList<Integer>(cluster)));
            }
          }
        }
      }
    }

    return clusters;
  }
  private Set<Set<Integer>> finishESeeds(Set<Set<Integer>> ESeeds) {
    log("Growing Effect Seeds.", true);
    Set<Set<Integer>> grown = new HashSet<Set<Integer>>();

    List<Integer> _variables = new ArrayList<Integer>();
    for (int i = 0; i < variables.size(); i++) _variables.add(i);

    // Lax grow phase with speedup.
    if (algType == AlgType.lax) {
      Set<Integer> t = new HashSet<Integer>();
      int count = 0;
      int total = ESeeds.size();

      do {
        if (!ESeeds.iterator().hasNext()) {
          break;
        }

        Set<Integer> cluster = ESeeds.iterator().next();
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
          int rejected = 0;
          int accepted = 0;

          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            t.clear();
            t.add(n1);
            t.add(n2);
            t.add(o);

            if (!ESeeds.contains(t)) {
              rejected++;
            } else {
              accepted++;
            }
          }

          if (rejected > accepted) {
            continue;
          }

          _cluster.add(o);

          //                    if (!(avgSumLnP(new ArrayList<Integer>(_cluster)) > -10)) {
          //                        _cluster.remove(o);
          //                    }
        }

        // This takes out all pure clusters that are subsets of _cluster.
        ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
        int[] choice2;
        List<Integer> _cluster3 = new ArrayList<Integer>(_cluster);

        while ((choice2 = gen2.next()) != null) {
          int n1 = _cluster3.get(choice2[0]);
          int n2 = _cluster3.get(choice2[1]);
          int n3 = _cluster3.get(choice2[2]);

          t.clear();
          t.add(n1);
          t.add(n2);
          t.add(n3);

          ESeeds.remove(t);
        }

        if (verbose) {
          System.out.println(
              "Grown "
                  + (++count)
                  + " of "
                  + total
                  + ": "
                  + variablesForIndices(new ArrayList<Integer>(_cluster)));
        }
        grown.add(_cluster);
      } while (!ESeeds.isEmpty());
    }

    // Lax grow phase without speedup.
    if (algType == AlgType.laxWithSpeedup) {
      int count = 0;
      int total = ESeeds.size();

      // Optimized lax version of grow phase.
      for (Set<Integer> cluster : new HashSet<Set<Integer>>(ESeeds)) {
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
          int rejected = 0;
          int accepted = 0;
          //
          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            Set<Integer> triple = triple(n1, n2, o);

            if (!ESeeds.contains(triple)) {
              rejected++;
            } else {
              accepted++;
            }
          }
          //
          if (rejected > accepted) {
            continue;
          }

          //                    System.out.println("Adding " + o  + " to " + cluster);
          _cluster.add(o);
        }

        for (Set<Integer> c : new HashSet<Set<Integer>>(ESeeds)) {
          if (_cluster.containsAll(c)) {
            ESeeds.remove(c);
          }
        }

        if (verbose) {
          System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
        }

        grown.add(_cluster);
      }
    }

    // Strict grow phase.
    if (algType == AlgType.strict) {
      Set<Integer> t = new HashSet<Integer>();
      int count = 0;
      int total = ESeeds.size();

      do {
        if (!ESeeds.iterator().hasNext()) {
          break;
        }

        Set<Integer> cluster = ESeeds.iterator().next();
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        VARIABLES:
        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);

          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            t.clear();
            t.add(n1);
            t.add(n2);
            t.add(o);

            if (!ESeeds.contains(t)) {
              continue VARIABLES;
            }

            //                        if (avgSumLnP(new ArrayList<Integer>(t)) < -10) continue
            // CLUSTER;
          }

          _cluster.add(o);
        }

        // This takes out all pure clusters that are subsets of _cluster.
        ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
        int[] choice2;
        List<Integer> _cluster3 = new ArrayList<Integer>(_cluster);

        while ((choice2 = gen2.next()) != null) {
          int n1 = _cluster3.get(choice2[0]);
          int n2 = _cluster3.get(choice2[1]);
          int n3 = _cluster3.get(choice2[2]);

          t.clear();
          t.add(n1);
          t.add(n2);
          t.add(n3);

          ESeeds.remove(t);
        }

        if (verbose) {
          System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
        }
        grown.add(_cluster);
      } while (!ESeeds.isEmpty());
    }

    // Optimized pick phase.
    log("Choosing among grown Effect Clusters.", true);

    for (Set<Integer> l : grown) {
      ArrayList<Integer> _l = new ArrayList<Integer>(l);
      Collections.sort(_l);
      if (verbose) {
        log("Grown: " + variablesForIndices(_l), false);
      }
    }

    Set<Set<Integer>> out = new HashSet<Set<Integer>>();

    List<Set<Integer>> list = new ArrayList<Set<Integer>>(grown);

    //        final Map<Set<Integer>, Double> pValues = new HashMap<Set<Integer>, Double>();
    //
    //        for (Set<Integer> o : grown) {
    //            pValues.put(o, getP(new ArrayList<Integer>(o)));
    //        }

    Collections.sort(
        list,
        new Comparator<Set<Integer>>() {
          @Override
          public int compare(Set<Integer> o1, Set<Integer> o2) {
            //                if (o1.size() == o2.size()) {
            //                    double chisq1 = pValues.get(o1);
            //                    double chisq2 = pValues.get(o2);
            //                    return Double.compare(chisq2, chisq1);
            //                }

            return o2.size() - o1.size();
          }
        });

    //        for (Set<Integer> o : list) {
    //            if (pValues.get(o) < alpha) continue;
    //            System.out.println(variablesForIndices(new ArrayList<Integer>(o)) + "  p = " +
    // pValues.get(o));
    //        }

    Set<Integer> all = new HashSet<Integer>();

    CLUSTER:
    for (Set<Integer> cluster : list) {
      //            if (pValues.get(cluster) < alpha) continue;

      for (Integer i : cluster) {
        if (all.contains(i)) continue CLUSTER;
      }

      out.add(cluster);

      //            if (getPMulticluster(out) < alpha) {
      //                out.remove(cluster);
      //                continue;
      //            }

      all.addAll(cluster);
    }

    return out;
  }
示例#18
0
  /**
   * Calculates the error variance for the given error node, given all of the coefficient values in
   * the model.
   *
   * @param error An error term in the model--i.e. a variable with NodeType.ERROR.
   * @return The value of the error variance, or Double.NaN is the value is undefined.
   */
  private double calculateErrorVarianceFromParams(Node error) {
    error = semGraph.getNode(error.getName());

    Node child = semGraph.getChildren(error).get(0);
    List<Node> parents = semGraph.getParents(child);

    double otherVariance = 0;

    for (Node parent : parents) {
      if (parent == error) continue;
      double coef = getEdgeCoefficient(parent, child);
      otherVariance += coef * coef;
    }

    if (parents.size() >= 2) {
      ChoiceGenerator gen = new ChoiceGenerator(parents.size(), 2);
      int[] indices;

      while ((indices = gen.next()) != null) {
        Node node1 = parents.get(indices[0]);
        Node node2 = parents.get(indices[1]);

        double coef1, coef2;

        if (node1.getNodeType() != NodeType.ERROR) {
          coef1 = getEdgeCoefficient(node1, child);
        } else {
          coef1 = 1;
        }

        if (node2.getNodeType() != NodeType.ERROR) {
          coef2 = getEdgeCoefficient(node2, child);
        } else {
          coef2 = 1;
        }

        List<List<Node>> treks = GraphUtils.treksIncludingBidirected(semGraph, node1, node2);

        double cov = 0.0;

        for (List<Node> trek : treks) {
          double product = 1.0;

          for (int i = 1; i < trek.size(); i++) {
            Node _node1 = trek.get(i - 1);
            Node _node2 = trek.get(i);

            Edge edge = semGraph.getEdge(_node1, _node2);
            double factor;

            if (Edges.isBidirectedEdge(edge)) {
              factor = edgeParameters.get(edge);
            } else if (!edgeParameters.containsKey(edge)) {
              factor = 1;
            } else if (semGraph.isParentOf(_node1, _node2)) {
              factor = getEdgeCoefficient(_node1, _node2);
            } else {
              factor = getEdgeCoefficient(_node2, _node1);
            }

            product *= factor;
          }

          cov += product;
        }

        otherVariance += 2 * coef1 * coef2 * cov;
      }
    }

    return 1.0 - otherVariance <= 0 ? Double.NaN : 1.0 - otherVariance;
  }