private boolean existsUnblockedSemiDirectedPath(Node from, Node to, List<Node> cond, Graph G) { Queue<Node> Q = new LinkedList<Node>(); Set<Node> V = new HashSet<Node>(); Q.offer(from); V.add(from); while (!Q.isEmpty()) { Node t = Q.remove(); if (t == to) return true; for (Node u : G.getAdjacentNodes(t)) { Edge edge = G.getEdge(t, u); Node c = Edges.traverseSemiDirected(t, edge); if (c == null) continue; if (cond.contains(c)) continue; if (c == to) return true; if (!V.contains(c)) { V.add(c); Q.offer(c); } } } return false; }
// ===========================SCORING METHODS===================// public double scoreDag(Graph graph) { Graph dag = new EdgeListGraphSingleConnections(graph); buildIndexing(graph); double score = 0.0; for (Node y : dag.getNodes()) { Set<Node> parents = new HashSet<Node>(dag.getParents(y)); int nextIndex = -1; for (int i = 0; i < getVariables().size(); i++) { nextIndex = hashIndices.get(variables.get(i)); } int parentIndices[] = new int[parents.size()]; Iterator<Node> pi = parents.iterator(); int count = 0; while (pi.hasNext()) { Node nextParent = pi.next(); parentIndices[count++] = hashIndices.get(nextParent); } if (this.isDiscrete()) { score += localDiscreteScore(nextIndex, parentIndices); } else { score += localSemScore(nextIndex, parentIndices); } } return score; }
private boolean isUndirected(Graph graph, Node x, Node y) { List<Edge> edges = graph.getEdges(x, y); if (edges.size() == 1) { Edge edge = graph.getEdge(x, y); return Edges.isUndirectedEdge(edge); } return false; }
// Cannot be done if the graph changes. public void setInitialGraph(Graph initialGraph) { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); out.println("Initial graph variables: " + initialGraph.getNodes()); out.println("Data set variables: " + variables); if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) { throw new IllegalArgumentException("Variables aren't the same."); } this.initialGraph = initialGraph; }
private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) { for (Node node : nodes) { SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> adj = skeleton.getAdjacentNodes(node); DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size()); int[] choice; double maxScore = Double.NEGATIVE_INFINITY; List<Node> parents = null; while ((choice = gen.next()) != null) { List<Node> _parents = GraphUtils.asList(choice, adj); double score = score(node, _parents); scoreReports.put(-score, _parents.toString()); if (score > maxScore) { maxScore = score; parents = _parents; } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance() .log( "score", "For " + node + " parents = " + scoreReports.get(score) + " score = " + -score); } TetradLogger.getInstance().log("score", ""); if (parents == null) { continue; } if (normal(node, parents)) continue; for (Node _node : adj) { if (parents.contains(_node)) { Edge parentEdge = Edges.directedEdge(_node, node); if (!graph.containsEdge(parentEdge)) { graph.addEdge(parentEdge); } } } } }
/** * Forward equivalence search. * * @param graph The graph in the state prior to the forward equivalence search. */ private void fes(Graph graph, List<Node> nodes) { TetradLogger.getInstance().log("info", "** FORWARD EQUIVALENCE SEARCH"); lookupArrows = new HashMap<OrderedPair, Set<Arrow>>(); initializeArrowsForward(nodes); while (!sortedArrows.isEmpty()) { Arrow arrow = sortedArrows.first(); sortedArrows.remove(arrow); Node x = arrow.getX(); Node y = arrow.getY(); clearArrow(x, y); if (graph.isAdjacentTo(x, y)) { continue; } if (!validInsert(x, y, arrow.getHOrT(), arrow.getNaYX(), graph)) { continue; } List<Node> t = arrow.getHOrT(); double bump = arrow.getBump(); Set<Edge> edges = graph.getEdges(); insert(x, y, t, graph, bump); score += bump; rebuildPattern(graph); // Try to avoid duplicating scoring calls. First clear out all of the edges that need to be // changed, // then change them, checking to see if they're already been changed. I know, roundabout, but // there's // a performance boost. for (Edge edge : graph.getEdges()) { if (!edges.contains(edge)) { reevaluateForward(graph, nodes, edge.getNode1(), edge.getNode2()); } } storeGraph(graph); } }
public Graph orient() { Graph skeleton = GraphUtils.undirectedGraph(getPattern()); Graph graph = new EdgeListGraph(skeleton.getNodes()); List<Node> nodes = skeleton.getNodes(); // Collections.shuffle(nodes); if (isR1Done()) { ruleR1(skeleton, graph, nodes); } for (Edge edge : skeleton.getEdges()) { if (!graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) { graph.addUndirectedEdge(edge.getNode1(), edge.getNode2()); } } if (isR2Done()) { ruleR2(skeleton, graph); } if (isMeekDone()) { new MeekRules().orientImplied(graph); } return graph; }
/** Evaluate the Insert(X, Y, T) operator (Definition 12 from Chickering, 2002). */ private double insertEval(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) { Set<Node> set1 = new HashSet<Node>(naYX); set1.addAll(t); List<Node> paY = graph.getParents(y); set1.addAll(paY); Set<Node> set2 = new HashSet<Node>(set1); set1.add(x); return scoreGraphChange(y, set1, set2); }
/** * Find all nodes that are connected to Y by an undirected edge that are adjacent to X (that is, * by undirected or directed edge). */ private static List<Node> getNaYX(Node x, Node y, Graph graph) { List<Edge> yEdges = graph.getEdges(y); List<Node> nayx = new ArrayList<Node>(); for (Edge edge : yEdges) { if (!Edges.isUndirectedEdge(edge)) { continue; } Node z = edge.getDistalNode(y); if (!graph.isAdjacentTo(z, x)) { continue; } nayx.add(z); } return nayx; }
/** Returns true iif the given set forms a clique in the given graph. */ private static boolean isClique(List<Node> nodes, Graph graph) { for (int i = 0; i < nodes.size() - 1; i++) { for (int j = i + 1; j < nodes.size(); j++) { if (!graph.isAdjacentTo(nodes.get(i), nodes.get(j))) { return false; } } } return true; }
/** Get all nodes that are connected to Y by an undirected edge and not adjacent to X. */ private static List<Node> getTNeighbors(Node x, Node y, Graph graph) { List<Edge> yEdges = graph.getEdges(y); List<Node> tNeighbors = new ArrayList<Node>(); for (Edge edge : yEdges) { if (!Edges.isUndirectedEdge(edge)) { continue; } Node z = edge.getDistalNode(y); if (graph.isAdjacentTo(z, x)) { continue; } tNeighbors.add(z); } return tNeighbors; }
private Graph pickDag(Graph graph) { SearchGraphUtils.basicPattern(graph, false); addRequiredEdges(graph); boolean containsUndirected; do { containsUndirected = false; for (Edge edge : graph.getEdges()) { if (Edges.isUndirectedEdge(edge)) { containsUndirected = true; graph.removeEdge(edge); Edge _edge = Edges.directedEdge(edge.getNode1(), edge.getNode2()); graph.addEdge(_edge); } } meekOrient(graph, getKnowledge()); } while (containsUndirected); return graph; }
/** * Greedy equivalence search: Start from the empty graph, add edges till model is significant. * Then start deleting edges till a minimum is achieved. * * @return the resulting Pattern. */ public Graph search() { Graph graph; if (initialGraph == null) { graph = new EdgeListGraphSingleConnections(getVariables()); } else { graph = new EdgeListGraphSingleConnections(initialGraph); } fireGraphChange(graph); buildIndexing(graph); addRequiredEdges(graph); topGraphs.clear(); storeGraph(graph); List<Node> nodes = graph.getNodes(); long start = System.currentTimeMillis(); score = 0.0; // Do forward search. fes(graph, nodes); // Do backward search. bes(graph); long endTime = System.currentTimeMillis(); this.elapsedTime = endTime - start; this.logger.log("graph", "\nReturning this graph: " + graph); this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.flush(); return graph; }
// Can be done concurrently. private double deleteEval(Node x, Node y, List<Node> h, List<Node> naYX, Graph graph) { List<Node> paY = graph.getParents(y); Set<Node> paYMinuxX = new HashSet<Node>(paY); paYMinuxX.remove(x); Set<Node> set1 = new HashSet<Node>(naYX); set1.removeAll(h); set1.addAll(paYMinuxX); Set<Node> set2 = new HashSet<Node>(naYX); set2.removeAll(h); set2.addAll(paY); return scoreGraphChange(y, set1, set2); }
private void ruleR2(Graph skeleton, Graph graph) { Set<Edge> edgeList1 = skeleton.getEdges(); // Collections.shuffle(edgeList1); for (Edge adj : edgeList1) { Node x = adj.getNode1(); Node y = adj.getNode2(); if (!isR2Orient2Cycles() && isTwoCycle(graph, x, y)) { continue; } if (!isTwoCycle(graph, x, y) && !isUndirected(graph, x, y)) { continue; } resolveOneEdgeMax(graph, x, y, isStrongR2(), new EdgeListGraph(graph)); } }
// Invalid if then nodes or graph changes. private void calculateArrowsBackward(Node x, Node y, Graph graph) { if (x == y) { return; } if (!graph.isAdjacentTo(x, y)) { return; } if (!knowledgeEmpty()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); clearArrow(x, y); List<Node> _naYX = new ArrayList<Node>(naYX); DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _naYX.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> H = GraphUtils.asList(choice, _naYX); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, H)) { continue; } } double bump = deleteEval(x, y, H, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, H, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } }
private void initializeArrowsBackward(Graph graph) { sortedArrows.clear(); lookupArrows.clear(); for (Edge edge : graph.getEdges()) { Node x = edge.getNode1(); Node y = edge.getNode2(); if (!knowledgeEmpty()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { continue; } } if (Edges.isDirectedEdge(edge)) { calculateArrowsBackward(x, y, graph); } else { calculateArrowsBackward(x, y, graph); calculateArrowsBackward(y, x, graph); } } }
private void addRequiredEdges(Graph graph) { if (true) return; if (knowledgeEmpty()) return; for (Iterator<KnowledgeEdge> it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) { KnowledgeEdge next = it.next(); Node nodeA = graph.getNode(next.getFrom()); Node nodeB = graph.getNode(next.getTo()); if (!graph.isAncestorOf(nodeB, nodeA)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeA, nodeB); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB)); } } for (Edge edge : graph.getEdges()) { final String A = edge.getNode1().getName(); final String B = edge.getNode2().getName(); if (knowledge.isForbidden(A, B)) { Node nodeA = edge.getNode1(); Node nodeB = edge.getNode2(); if (nodeA != null && nodeB != null && graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } } else if (knowledge.isForbidden(B, A)) { Node nodeA = edge.getNode2(); Node nodeB = edge.getNode1(); if (nodeA != null && nodeB != null && graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } } } }
/** Do an actual deletion (Definition 13 from Chickering, 2002). */ private void delete(Node x, Node y, List<Node> subset, Graph graph, double bump) { Edge trueEdge = null; if (trueGraph != null) { Node _x = trueGraph.getNode(x.getName()); Node _y = trueGraph.getNode(y.getName()); trueEdge = trueGraph.getEdge(_x, _y); } if (log && verbose) { Edge oldEdge = graph.getEdge(x, y); String label = trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance() .log( "deletedEdges", (graph.getNumEdges() - 1) + ". DELETE " + oldEdge + " " + subset + " (" + bump + ") " + label); out.println( (graph.getNumEdges() - 1) + ". DELETE " + oldEdge + " " + subset + " (" + bump + ") " + label); } else { int numEdges = graph.getNumEdges() - 1; if (numEdges % 50 == 0) out.println(numEdges); } graph.removeEdge(x, y); for (Node h : subset) { Edge oldEdge = graph.getEdge(y, h); graph.removeEdge(y, h); graph.addDirectedEdge(y, h); if (log) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h)); } if (verbose) { out.println("--- Directing " + oldEdge + " to " + graph.getEdge(y, h)); } if (Edges.isUndirectedEdge(graph.getEdge(x, h))) { if (!graph.isAdjacentTo(x, h)) throw new IllegalArgumentException("Not adjacent: " + x + ", " + h); oldEdge = graph.getEdge(x, h); graph.removeEdge(x, h); graph.addDirectedEdge(x, h); if (log) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h)); } if (verbose) { out.println("--- Directing " + oldEdge + " to " + graph.getEdge(x, h)); } } } }
// serial. private void insert(Node x, Node y, List<Node> t, Graph graph, double bump) { if (graph.isAdjacentTo(x, y)) { return; // The initial graph may already have put this edge in the graph. // throw new IllegalArgumentException(x + " and " + y + " are already adjacent in // the graph."); } Edge trueEdge = null; if (trueGraph != null) { Node _x = trueGraph.getNode(x.getName()); Node _y = trueGraph.getNode(y.getName()); trueEdge = trueGraph.getEdge(_x, _y); } graph.addDirectedEdge(x, y); if (log) { String label = trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance() .log( "insertedEdges", graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + t + " " + bump + " " + label); } else { int numEdges = graph.getNumEdges() - 1; if (verbose) { if (numEdges % 50 == 0) out.println(numEdges); } } if (verbose) { String label = trueGraph != null && trueEdge != null ? "*" : ""; out.println( graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + t + " " + bump + " " + label); } else { int numEdges = graph.getNumEdges() - 1; if (verbose) { if (numEdges % 50 == 0) out.println(numEdges); } } for (Node _t : t) { Edge oldEdge = graph.getEdge(_t, y); if (oldEdge == null) throw new IllegalArgumentException("Not adjacent: " + _t + ", " + y); graph.removeEdge(_t, y); graph.addDirectedEdge(_t, y); if (log && verbose) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(_t, y)); out.println("--- Directing " + oldEdge + " to " + graph.getEdge(_t, y)); } } }
private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) { if (RandomUtil.getInstance().nextDouble() > 0.5) { Node temp = x; x = y; y = temp; } TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y); SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> neighborsx = graph.getAdjacentNodes(x); neighborsx.remove(y); double max = Double.NEGATIVE_INFINITY; boolean left = false; boolean right = false; DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size()); int[] choicex; while ((choicex = genx.next()) != null) { List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx); List<Node> condxPlus = new ArrayList<Node>(condxMinus); condxPlus.add(y); double xPlus = score(x, condxPlus); double xMinus = score(x, condxMinus); List<Node> neighborsy = graph.getAdjacentNodes(y); neighborsy.remove(x); DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size()); int[] choicey; while ((choicey = geny.next()) != null) { List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy); // List<Node> parentsY = oldGraph.getParents(y); // parentsY.remove(x); // if (!condyMinus.containsAll(parentsY)) { // continue; // } List<Node> condyPlus = new ArrayList<Node>(condyMinus); condyPlus.add(x); double yPlus = score(y, condyPlus); double yMinus = score(y, condyMinus); // Checking them all at once is expensive but avoids lexical ordering problems in the // algorithm. if (normal(y, condyPlus) || normal(x, condxMinus) || normal(x, condxPlus) || normal(y, condyMinus)) { continue; } double delta = 0.0; if (strong) { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance().log("info", scoreReports.get(score)); } graph.removeEdges(x, y); if (left) { graph.addDirectedEdge(y, x); } if (right) { graph.addDirectedEdge(x, y); } if (!graph.isAdjacentTo(x, y)) { graph.addUndirectedEdge(x, y); } }
private void buildIndexing(Graph graph) { this.hashIndices = new HashMap<Node, Integer>(); for (Node node : graph.getNodes()) { this.hashIndices.put(node, variables.indexOf(node)); } }
private boolean isTwoCycle(Graph graph, Node x, Node y) { List<Edge> edges = graph.getEdges(x, y); return edges.size() == 2; }