/**
   * Transforms a maximally directed pattern (PDAG) represented in graph <code>g</code> into an
   * arbitrary DAG by modifying <code>g</code> itself. Based on the algorithm described in
   * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine
   * Learning Research. R. Silva, June 2004
   */
  public static void pdagToDag(Graph g) {
    Graph p = new EdgeListGraph(g);
    List<Edge> undirectedEdges = new ArrayList<Edge>();

    for (Edge edge : g.getEdges()) {
      if (edge.getEndpoint1() == Endpoint.TAIL
          && edge.getEndpoint2() == Endpoint.TAIL
          && !undirectedEdges.contains(edge)) {
        undirectedEdges.add(edge);
      }
    }
    g.removeEdges(undirectedEdges);
    List<Node> pNodes = p.getNodes();

    do {
      Node x = null;

      for (Node pNode : pNodes) {
        x = pNode;

        if (p.getChildren(x).size() > 0) {
          continue;
        }

        Set<Node> neighbors = new HashSet<Node>();

        for (Edge edge : p.getEdges()) {
          if (edge.getNode1() == x || edge.getNode2() == x) {
            if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL) {
              if (edge.getNode1() == x) {
                neighbors.add(edge.getNode2());
              } else {
                neighbors.add(edge.getNode1());
              }
            }
          }
        }
        if (neighbors.size() > 0) {
          Collection<Node> parents = p.getParents(x);
          Set<Node> all = new HashSet<Node>(neighbors);
          all.addAll(parents);
          if (!GraphUtils.isClique(all, p)) {
            continue;
          }
        }

        for (Node neighbor : neighbors) {
          Node node1 = g.getNode(neighbor.getName());
          Node node2 = g.getNode(x.getName());

          g.addDirectedEdge(node1, node2);
        }
        p.removeNode(x);
        break;
      }
      pNodes.remove(x);
    } while (pNodes.size() > 0);
  }
Exemplo n.º 2
0
 private boolean localMarkovIndep(Node x, Node y, Graph pattern, IndependenceTest test) {
   List<Node> future = pattern.getDescendants(Collections.singletonList(x));
   List<Node> boundary = pattern.getAdjacentNodes(x);
   boundary.removeAll(future);
   List<Node> closure = new ArrayList<>(boundary);
   closure.add(x);
   closure.remove(y);
   if (future.contains(y) || boundary.contains(y)) return false;
   return test.isIndependent(x, y, boundary);
 }
  /** Meek's rule R3. If a--b, a--c, a--d, c-->b, c-->b, then orient a-->b. */
  public static boolean meekR3(Graph graph, Knowledge knowledge) {

    List<Node> nodes = graph.getNodes();
    boolean changed = false;

    for (Node a : nodes) {
      List<Node> adjacentNodes = graph.getAdjacentNodes(a);

      if (adjacentNodes.size() < 3) {
        continue;
      }

      for (Node b : adjacentNodes) {
        List<Node> otherAdjacents = new LinkedList<Node>(adjacentNodes);
        otherAdjacents.remove(b);

        if (!graph.isUndirectedFromTo(a, b)) {
          continue;
        }

        ChoiceGenerator cg = new ChoiceGenerator(otherAdjacents.size(), 2);
        int[] combination;

        while ((combination = cg.next()) != null) {
          Node c = otherAdjacents.get(combination[0]);
          Node d = otherAdjacents.get(combination[1]);

          if (graph.isAdjacentTo(c, d)) {
            continue;
          }

          if (!graph.isUndirectedFromTo(a, c)) {
            continue;
          }

          if (!graph.isUndirectedFromTo(a, d)) {
            continue;
          }

          if (graph.isDirectedFromTo(c, b) && graph.isDirectedFromTo(d, b)) {
            if (isArrowpointAllowed(a, b, knowledge)) {
              graph.setEndpoint(a, b, Endpoint.ARROW);
              TetradLogger.getInstance()
                  .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R3", graph.getEdge(a, b)));
              changed = true;
              break;
            }
          }
        }
      }
    }

    return changed;
  }
Exemplo n.º 4
0
  /** Tests to see if d separation facts are symmetric. */
  public void testDSeparation() {
    EdgeListGraphSingleConnections graph =
        new EdgeListGraphSingleConnections(
            new Dag(GraphUtils.randomGraph(7, 0, 7, 30, 15, 15, true)));
    System.out.println(graph);

    List<Node> nodes = graph.getNodes();

    int depth = -1;

    for (int i = 0; i < nodes.size(); i++) {
      for (int j = i + 1; j < nodes.size(); j++) {
        Node x = nodes.get(i);
        Node y = nodes.get(j);

        List<Node> theRest = new ArrayList<Node>(nodes);
        theRest.remove(x);
        theRest.remove(y);

        DepthChoiceGenerator gen = new DepthChoiceGenerator(theRest.size(), depth);
        int[] choice;

        while ((choice = gen.next()) != null) {
          List<Node> z = new LinkedList<Node>();

          for (int k = 0; k < choice.length; k++) {
            z.add(theRest.get(choice[k]));
          }

          if (graph.isDSeparatedFrom(x, y, z) != graph.isDSeparatedFrom(y, x, z)) {
            fail(
                SearchLogUtils.independenceFact(x, y, z)
                    + " should have same d-sep result as "
                    + SearchLogUtils.independenceFact(y, x, z));
          }
        }
      }
    }
  }
Exemplo n.º 5
0
  private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) {
    if (RandomUtil.getInstance().nextDouble() > 0.5) {
      Node temp = x;
      x = y;
      y = temp;
    }

    TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y);

    SortedMap<Double, String> scoreReports = new TreeMap<Double, String>();

    List<Node> neighborsx = graph.getAdjacentNodes(x);
    neighborsx.remove(y);

    double max = Double.NEGATIVE_INFINITY;
    boolean left = false;
    boolean right = false;

    DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size());
    int[] choicex;

    while ((choicex = genx.next()) != null) {
      List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx);

      List<Node> condxPlus = new ArrayList<Node>(condxMinus);
      condxPlus.add(y);

      double xPlus = score(x, condxPlus);
      double xMinus = score(x, condxMinus);

      List<Node> neighborsy = graph.getAdjacentNodes(y);
      neighborsy.remove(x);

      DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size());
      int[] choicey;

      while ((choicey = geny.next()) != null) {
        List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy);

        //                List<Node> parentsY = oldGraph.getParents(y);
        //                parentsY.remove(x);
        //                if (!condyMinus.containsAll(parentsY)) {
        //                    continue;
        //                }

        List<Node> condyPlus = new ArrayList<Node>(condyMinus);
        condyPlus.add(x);

        double yPlus = score(y, condyPlus);
        double yMinus = score(y, condyMinus);

        // Checking them all at once is expensive but avoids lexical ordering problems in the
        // algorithm.
        if (normal(y, condyPlus)
            || normal(x, condxMinus)
            || normal(x, condxPlus)
            || normal(y, condyMinus)) {
          continue;
        }

        double delta = 0.0;

        if (strong) {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + y + "->" + x + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = true;
                right = false;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + x + "->" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = false;
                right = true;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        } else {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + y + "->" + x + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = true;
              right = false;
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + x + "->" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = false;
              right = true;
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        }
      }
    }

    for (double score : scoreReports.keySet()) {
      TetradLogger.getInstance().log("info", scoreReports.get(score));
    }

    graph.removeEdges(x, y);

    if (left) {
      graph.addDirectedEdge(y, x);
    }

    if (right) {
      graph.addDirectedEdge(x, y);
    }

    if (!graph.isAdjacentTo(x, y)) {
      graph.addUndirectedEdge(x, y);
    }
  }
Exemplo n.º 6
0
  ////////////////////////////////////////////
  // RFCI Algorithm 4.4 (Colombo et al, 2012)
  // Orient colliders
  ////////////////////////////////////////////
  private void ruleR0_RFCI(List<Node[]> rTuples) {
    List<Node[]> lTuples = new ArrayList<Node[]>();

    List<Node> nodes = graph.getNodes();

    ///////////////////////////////
    // process tuples in rTuples
    while (!rTuples.isEmpty()) {
      Node[] thisTuple = rTuples.remove(0);

      Node i = thisTuple[0];
      Node j = thisTuple[1];
      Node k = thisTuple[2];

      final List<Node> nodes1 = getSepset(i, k);

      if (nodes1 == null) continue;

      List<Node> sepSet = new ArrayList<Node>(nodes1);
      sepSet.remove(j);

      boolean independent1 = false;
      if (knowledge.noEdgeRequired(i.getName(), j.getName())) // if BK allows
      {
        try {
          independent1 = independenceTest.isIndependent(i, j, sepSet);
        } catch (Exception e) {
          independent1 = true;
        }
      }

      boolean independent2 = false;
      if (knowledge.noEdgeRequired(j.getName(), k.getName())) // if BK allows
      {
        try {
          independent2 = independenceTest.isIndependent(j, k, sepSet);
        } catch (Exception e) {
          independent2 = true;
        }
      }

      if (!independent1 && !independent2) {
        lTuples.add(thisTuple);
      } else {
        // set sepSets to minimal separating sets
        if (independent1) {
          setMinSepSet(sepSet, i, j);
          graph.removeEdge(i, j);
        }
        if (independent2) {
          setMinSepSet(sepSet, j, k);
          graph.removeEdge(j, k);
        }

        // add new unshielded tuples to rTuples
        for (Node thisNode : nodes) {
          List<Node> adjacentNodes = graph.getAdjacentNodes(thisNode);
          if (independent1) // <i, ., j>
          {
            if (adjacentNodes.contains(i) && adjacentNodes.contains(j)) {
              Node[] newTuple = {i, thisNode, j};
              rTuples.add(newTuple);
            }
          }
          if (independent2) // <j, ., k>
          {
            if (adjacentNodes.contains(j) && adjacentNodes.contains(k)) {
              Node[] newTuple = {j, thisNode, k};
              rTuples.add(newTuple);
            }
          }
        }

        // remove tuples involving either (if independent1) <i, j>
        // or (if independent2) <j, k> from rTuples
        Iterator<Node[]> iter = rTuples.iterator();
        while (iter.hasNext()) {
          Node[] curTuple = iter.next();
          if ((independent1 && (curTuple[1] == i) && ((curTuple[0] == j) || (curTuple[2] == j)))
              || (independent2 && (curTuple[1] == k) && ((curTuple[0] == j) || (curTuple[2] == j)))
              || (independent1 && (curTuple[1] == j) && ((curTuple[0] == i) || (curTuple[2] == i)))
              || (independent2
                  && (curTuple[1] == j)
                  && ((curTuple[0] == k) || (curTuple[2] == k)))) {
            iter.remove();
          }
        }

        // remove tuples involving either (if independent1) <i, j>
        // or (if independent2) <j, k> from lTuples
        iter = lTuples.iterator();
        while (iter.hasNext()) {
          Node[] curTuple = iter.next();
          if ((independent1 && (curTuple[1] == i) && ((curTuple[0] == j) || (curTuple[2] == j)))
              || (independent2 && (curTuple[1] == k) && ((curTuple[0] == j) || (curTuple[2] == j)))
              || (independent1 && (curTuple[1] == j) && ((curTuple[0] == i) || (curTuple[2] == i)))
              || (independent2
                  && (curTuple[1] == j)
                  && ((curTuple[0] == k) || (curTuple[2] == k)))) {
            iter.remove();
          }
        }
      }
    }

    ///////////////////////////////////////////////////////
    // orient colliders (similar to original FCI ruleR0)
    for (Node[] thisTuple : lTuples) {
      Node i = thisTuple[0];
      Node j = thisTuple[1];
      Node k = thisTuple[2];

      List<Node> sepset = getSepset(i, k);

      if (sepset == null) {
        continue;
      }

      if (!sepset.contains(j) && graph.isAdjacentTo(i, j) && graph.isAdjacentTo(j, k)) {

        if (!isArrowpointAllowed(i, j)) {
          continue;
        }

        if (!isArrowpointAllowed(k, j)) {
          continue;
        }

        graph.setEndpoint(i, j, Endpoint.ARROW);
        graph.setEndpoint(k, j, Endpoint.ARROW);

        printWrongColliderMessage(i, j, k, "R0_RFCI");
      }
    }
  }
Exemplo n.º 7
0
  public void setNodeExpression(Node node, String expressionString) throws ParseException {
    if (node == null) {
      throw new NullPointerException("Node was null.");
    }

    if (expressionString == null) {
      //            return;
      throw new NullPointerException("Expression string was null.");
    }

    // Parse the expression. This could throw an ParseException, but that exception needs to handed
    // up the
    // chain, because the interface will need it.
    ExpressionParser parser = new ExpressionParser();
    Expression expression = parser.parseExpression(expressionString);
    List<String> parameterNames = parser.getParameters();

    // Make a list of parent names.
    List<Node> parents = this.graph.getParents(node);
    List<String> parentNames = new LinkedList<>();

    for (Node parent : parents) {
      parentNames.add(parent.getName());
    }

    //        List<String> _params = new ArrayList<String>(parameterNames);
    //        _params.retainAll(variableNames);
    //        _params.removeAll(parentNames);
    //
    //        if (!_params.isEmpty()) {
    //            throw new IllegalArgumentException("Conditioning on a variable other than the
    // parents: " + node);
    //        }

    // Make a list of parameter names, by removing from the parser's list of freeParameters any that
    // correspond
    // to parent variables. If there are any variable names (including error terms) that are not
    // among the list of
    // parents, that's a time to throw an exception. We must respect the graph! (We will not
    // complain if any parents
    // are missing.)
    parameterNames.removeAll(variableNames);

    for (Node variable : nodes) {
      if (parameterNames.contains(variable.getName())) {
        parameterNames.remove(variable.getName());
        //                throw new IllegalArgumentException("The list of parameter names may not
        // include variables: " + variable.getName());
      }
    }

    // Remove old parameter references.
    List<String> parametersToRemove = new LinkedList<>();

    for (String parameter : this.referencedParameters.keySet()) {
      Set<Node> nodes = this.referencedParameters.get(parameter);

      if (nodes.contains(node)) {
        nodes.remove(node);
      }

      if (nodes.isEmpty()) {
        parametersToRemove.add(parameter);
      }
    }

    for (String parameter : parametersToRemove) {
      this.referencedParameters.remove(parameter);
      this.parameterExpressions.remove(parameter);
      this.parameterExpressionStrings.remove(parameter);
      this.parameterEstimationInitializationExpressions.remove(parameter);
      this.parameterEstimationInitializationExpressionStrings.remove(parameter);
    }

    // Add new parameter references.
    for (String parameter : parameterNames) {
      if (this.referencedParameters.get(parameter) == null) {
        this.referencedParameters.put(parameter, new HashSet<Node>());
      }

      Set<Node> nodes = this.referencedParameters.get(parameter);
      nodes.add(node);

      setSuitableParameterDistribution(parameter);
    }

    // Remove old node references.
    List<Node> nodesToRemove = new LinkedList<>();

    for (Node _node : this.referencedNodes.keySet()) {
      Set<Node> nodes = this.referencedNodes.get(_node);

      if (nodes.contains(node)) {
        nodes.remove(node);
      }

      if (nodes.isEmpty()) {
        nodesToRemove.add(_node);
      }
    }

    for (Node _node : nodesToRemove) {
      this.referencedNodes.remove(_node);
    }

    // Add new freeParameters.
    for (String variableString : variableNames) {
      Node _node = getNode(variableString);

      if (this.referencedNodes.get(_node) == null) {
        this.referencedNodes.put(_node, new HashSet<Node>());
      }

      for (String s : parentNames) {
        if (s.equals(variableString)) {
          Set<Node> nodes = this.referencedNodes.get(_node);
          nodes.add(node);
        }
      }
    }

    // Finally, save the parsed expression and the original string that the user entered. No need to
    // annoy
    // the user by changing spacing.
    nodeExpressions.put(node, expression);
    nodeExpressionStrings.put(node, expressionString);
  }
  // Finds clusters of size 4 or higher.
  private Set<Set<Integer>> findPureClusters(
      List<Integer> _variables, Map<Node, Set<Node>> adjacencies) {
    //        System.out.println("Original variables = " + variables);

    Set<Set<Integer>> clusters = new HashSet<Set<Integer>>();
    List<Integer> allVariables = new ArrayList<Integer>();
    for (int i = 0; i < this.variables.size(); i++) allVariables.add(i);

    VARIABLES:
    while (!_variables.isEmpty()) {
      if (_variables.size() < 4) break;

      for (int x : _variables) {
        Node nodeX = variables.get(x);
        List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX));
        adjX.retainAll(variablesForIndices(new ArrayList<Integer>(_variables)));

        for (Node node : new ArrayList<Node>(adjX)) {
          if (adjacencies.get(node).size() < 3) {
            adjX.remove(node);
          }
        }

        if (adjX.size() < 3) {
          continue;
        }

        ChoiceGenerator gen = new ChoiceGenerator(adjX.size(), 3);
        int[] choice;

        while ((choice = gen.next()) != null) {
          Node nodeY = adjX.get(choice[0]);
          Node nodeZ = adjX.get(choice[1]);
          Node nodeW = adjX.get(choice[2]);

          int y = variables.indexOf(nodeY);
          int w = variables.indexOf(nodeW);
          int z = variables.indexOf(nodeZ);

          Set<Integer> cluster = quartet(x, y, z, w);

          if (!clique(cluster, adjacencies)) {
            continue;
          }

          // Note that purity needs to be assessed with respect to all of the variables in order to
          // remove all latent-measure impurities between pairs of latents.
          if (pure(cluster, allVariables)) {

            //                        Collections.shuffle(_variables);

            O:
            for (int o : _variables) {
              if (cluster.contains(o)) continue;
              cluster.add(o);
              List<Integer> _cluster = new ArrayList<Integer>(cluster);

              if (!clique(cluster, adjacencies)) {
                cluster.remove(o);
                continue O;
              }

              //                            if (!allVariablesDependent(cluster)) {
              //                                cluster.remove(o);
              //                                continue O;
              //                            }

              ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4);
              int[] choice2;
              int count = 0;

              while ((choice2 = gen2.next()) != null) {
                int x2 = _cluster.get(choice2[0]);
                int y2 = _cluster.get(choice2[1]);
                int z2 = _cluster.get(choice2[2]);
                int w2 = _cluster.get(choice2[3]);

                Set<Integer> quartet = quartet(x2, y2, z2, w2);

                // Optimizes for large clusters.
                if (quartet.contains(o)) {
                  if (++count > 50) continue O;
                }

                if (quartet.contains(o) && !pure(quartet, allVariables)) {
                  cluster.remove(o);
                  continue O;
                }
              }
            }

            System.out.println(
                "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster)));
            clusters.add(cluster);
            _variables.removeAll(cluster);

            continue VARIABLES;
          }
        }
      }

      break;
    }

    return clusters;
  }