Ejemplo n.º 1
0
  private void calculateArrowsForward(Node x, Node y, Graph graph) {
    clearArrow(x, y);

    if (!knowledgeEmpty()) {
      if (getKnowledge().isForbidden(x.getName(), y.getName())) {
        return;
      }
    }

    List<Node> naYX = getNaYX(x, y, graph);
    List<Node> t = getTNeighbors(x, y, graph);

    DepthChoiceGenerator gen = new DepthChoiceGenerator(t.size(), t.size());
    int[] choice;

    while ((choice = gen.next()) != null) {
      List<Node> s = GraphUtils.asList(choice, t);

      if (!knowledgeEmpty()) {
        if (!validSetByKnowledge(y, s)) {
          continue;
        }
      }

      double bump = insertEval(x, y, s, naYX, graph);

      if (bump > 0.0) {
        Arrow arrow = new Arrow(bump, x, y, s, naYX);
        sortedArrows.add(arrow);
        addLookupArrow(x, y, arrow);
      }
    }
  }
Ejemplo n.º 2
0
  public Graph orient() {
    Graph skeleton = GraphUtils.undirectedGraph(getPattern());
    Graph graph = new EdgeListGraph(skeleton.getNodes());

    List<Node> nodes = skeleton.getNodes();
    //        Collections.shuffle(nodes);

    if (isR1Done()) {
      ruleR1(skeleton, graph, nodes);
    }

    for (Edge edge : skeleton.getEdges()) {
      if (!graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) {
        graph.addUndirectedEdge(edge.getNode1(), edge.getNode2());
      }
    }

    if (isR2Done()) {
      ruleR2(skeleton, graph);
    }

    if (isMeekDone()) {
      new MeekRules().orientImplied(graph);
    }

    return graph;
  }
Ejemplo n.º 3
0
  // Cannot be done if the graph changes.
  public void setInitialGraph(Graph initialGraph) {
    initialGraph = GraphUtils.replaceNodes(initialGraph, variables);

    out.println("Initial graph variables: " + initialGraph.getNodes());
    out.println("Data set variables: " + variables);

    if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) {
      throw new IllegalArgumentException("Variables aren't the same.");
    }

    this.initialGraph = initialGraph;
  }
Ejemplo n.º 4
0
  private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) {
    for (Node node : nodes) {
      SortedMap<Double, String> scoreReports = new TreeMap<Double, String>();

      List<Node> adj = skeleton.getAdjacentNodes(node);

      DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size());
      int[] choice;
      double maxScore = Double.NEGATIVE_INFINITY;
      List<Node> parents = null;

      while ((choice = gen.next()) != null) {
        List<Node> _parents = GraphUtils.asList(choice, adj);

        double score = score(node, _parents);
        scoreReports.put(-score, _parents.toString());

        if (score > maxScore) {
          maxScore = score;
          parents = _parents;
        }
      }

      for (double score : scoreReports.keySet()) {
        TetradLogger.getInstance()
            .log(
                "score",
                "For " + node + " parents = " + scoreReports.get(score) + " score = " + -score);
      }

      TetradLogger.getInstance().log("score", "");

      if (parents == null) {
        continue;
      }

      if (normal(node, parents)) continue;

      for (Node _node : adj) {
        if (parents.contains(_node)) {
          Edge parentEdge = Edges.directedEdge(_node, node);

          if (!graph.containsEdge(parentEdge)) {
            graph.addEdge(parentEdge);
          }
        }
      }
    }
  }
Ejemplo n.º 5
0
  // Invalid if then nodes or graph changes.
  private void calculateArrowsBackward(Node x, Node y, Graph graph) {
    if (x == y) {
      return;
    }

    if (!graph.isAdjacentTo(x, y)) {
      return;
    }

    if (!knowledgeEmpty()) {
      if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) {
        return;
      }
    }

    List<Node> naYX = getNaYX(x, y, graph);

    clearArrow(x, y);

    List<Node> _naYX = new ArrayList<Node>(naYX);
    DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _naYX.size());
    int[] choice;

    while ((choice = gen.next()) != null) {
      List<Node> H = GraphUtils.asList(choice, _naYX);

      if (!knowledgeEmpty()) {
        if (!validSetByKnowledge(y, H)) {
          continue;
        }
      }

      double bump = deleteEval(x, y, H, naYX, graph);

      if (bump > 0.0) {
        Arrow arrow = new Arrow(bump, x, y, H, naYX);
        sortedArrows.add(arrow);
        addLookupArrow(x, y, arrow);
      }
    }
  }
Ejemplo n.º 6
0
  public Graph search(List<Node> nodes) {
    long startTime = System.currentTimeMillis();
    localScoreCache.clear();

    if (!dataSet().getVariables().containsAll(nodes)) {
      throw new IllegalArgumentException("All of the nodes must be in " + "the supplied data set.");
    }

    Graph graph;

    if (initialGraph == null) {
      graph = new EdgeListGraphSingleConnections(nodes);
    } else {
      initialGraph = GraphUtils.replaceNodes(initialGraph, variables);
      graph = new EdgeListGraphSingleConnections(initialGraph);
    }

    topGraphs.clear();

    buildIndexing(graph);
    addRequiredEdges(graph);
    score = 0.0;

    // Do forward search.
    fes(graph, nodes);

    // Do backward search.
    bes(graph);

    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - startTime;
    this.logger.log("graph", "\nReturning this graph: " + graph);

    this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    this.logger.flush();

    return graph;
  }
Ejemplo n.º 7
0
  private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) {
    if (RandomUtil.getInstance().nextDouble() > 0.5) {
      Node temp = x;
      x = y;
      y = temp;
    }

    TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y);

    SortedMap<Double, String> scoreReports = new TreeMap<Double, String>();

    List<Node> neighborsx = graph.getAdjacentNodes(x);
    neighborsx.remove(y);

    double max = Double.NEGATIVE_INFINITY;
    boolean left = false;
    boolean right = false;

    DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size());
    int[] choicex;

    while ((choicex = genx.next()) != null) {
      List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx);

      List<Node> condxPlus = new ArrayList<Node>(condxMinus);
      condxPlus.add(y);

      double xPlus = score(x, condxPlus);
      double xMinus = score(x, condxMinus);

      List<Node> neighborsy = graph.getAdjacentNodes(y);
      neighborsy.remove(x);

      DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size());
      int[] choicey;

      while ((choicey = geny.next()) != null) {
        List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy);

        //                List<Node> parentsY = oldGraph.getParents(y);
        //                parentsY.remove(x);
        //                if (!condyMinus.containsAll(parentsY)) {
        //                    continue;
        //                }

        List<Node> condyPlus = new ArrayList<Node>(condyMinus);
        condyPlus.add(x);

        double yPlus = score(y, condyPlus);
        double yMinus = score(y, condyMinus);

        // Checking them all at once is expensive but avoids lexical ordering problems in the
        // algorithm.
        if (normal(y, condyPlus)
            || normal(x, condxMinus)
            || normal(x, condxPlus)
            || normal(y, condyMinus)) {
          continue;
        }

        double delta = 0.0;

        if (strong) {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + y + "->" + x + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = true;
                right = false;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + x + "->" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = false;
                right = true;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        } else {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + y + "->" + x + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = true;
              right = false;
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + x + "->" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = false;
              right = true;
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        }
      }
    }

    for (double score : scoreReports.keySet()) {
      TetradLogger.getInstance().log("info", scoreReports.get(score));
    }

    graph.removeEdges(x, y);

    if (left) {
      graph.addDirectedEdge(y, x);
    }

    if (right) {
      graph.addDirectedEdge(x, y);
    }

    if (!graph.isAdjacentTo(x, y)) {
      graph.addUndirectedEdge(x, y);
    }
  }
Ejemplo n.º 8
0
  /**
   * Calculates the error variance for the given error node, given all of the coefficient values in
   * the model.
   *
   * @param error An error term in the model--i.e. a variable with NodeType.ERROR.
   * @return The value of the error variance, or Double.NaN is the value is undefined.
   */
  private double calculateErrorVarianceFromParams(Node error) {
    error = semGraph.getNode(error.getName());

    Node child = semGraph.getChildren(error).get(0);
    List<Node> parents = semGraph.getParents(child);

    double otherVariance = 0;

    for (Node parent : parents) {
      if (parent == error) continue;
      double coef = getEdgeCoefficient(parent, child);
      otherVariance += coef * coef;
    }

    if (parents.size() >= 2) {
      ChoiceGenerator gen = new ChoiceGenerator(parents.size(), 2);
      int[] indices;

      while ((indices = gen.next()) != null) {
        Node node1 = parents.get(indices[0]);
        Node node2 = parents.get(indices[1]);

        double coef1, coef2;

        if (node1.getNodeType() != NodeType.ERROR) {
          coef1 = getEdgeCoefficient(node1, child);
        } else {
          coef1 = 1;
        }

        if (node2.getNodeType() != NodeType.ERROR) {
          coef2 = getEdgeCoefficient(node2, child);
        } else {
          coef2 = 1;
        }

        List<List<Node>> treks = GraphUtils.treksIncludingBidirected(semGraph, node1, node2);

        double cov = 0.0;

        for (List<Node> trek : treks) {
          double product = 1.0;

          for (int i = 1; i < trek.size(); i++) {
            Node _node1 = trek.get(i - 1);
            Node _node2 = trek.get(i);

            Edge edge = semGraph.getEdge(_node1, _node2);
            double factor;

            if (Edges.isBidirectedEdge(edge)) {
              factor = edgeParameters.get(edge);
            } else if (!edgeParameters.containsKey(edge)) {
              factor = 1;
            } else if (semGraph.isParentOf(_node1, _node2)) {
              factor = getEdgeCoefficient(_node1, _node2);
            } else {
              factor = getEdgeCoefficient(_node2, _node1);
            }

            product *= factor;
          }

          cov += product;
        }

        otherVariance += 2 * coef1 * coef2 * cov;
      }
    }

    return 1.0 - otherVariance <= 0 ? Double.NaN : 1.0 - otherVariance;
  }