private void calculateArrowsForward(Node x, Node y, Graph graph) { clearArrow(x, y); if (!knowledgeEmpty()) { if (getKnowledge().isForbidden(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); List<Node> t = getTNeighbors(x, y, graph); DepthChoiceGenerator gen = new DepthChoiceGenerator(t.size(), t.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> s = GraphUtils.asList(choice, t); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, s)) { continue; } } double bump = insertEval(x, y, s, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, s, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } }
public Graph orient() { Graph skeleton = GraphUtils.undirectedGraph(getPattern()); Graph graph = new EdgeListGraph(skeleton.getNodes()); List<Node> nodes = skeleton.getNodes(); // Collections.shuffle(nodes); if (isR1Done()) { ruleR1(skeleton, graph, nodes); } for (Edge edge : skeleton.getEdges()) { if (!graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) { graph.addUndirectedEdge(edge.getNode1(), edge.getNode2()); } } if (isR2Done()) { ruleR2(skeleton, graph); } if (isMeekDone()) { new MeekRules().orientImplied(graph); } return graph; }
// Cannot be done if the graph changes. public void setInitialGraph(Graph initialGraph) { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); out.println("Initial graph variables: " + initialGraph.getNodes()); out.println("Data set variables: " + variables); if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) { throw new IllegalArgumentException("Variables aren't the same."); } this.initialGraph = initialGraph; }
private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) { for (Node node : nodes) { SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> adj = skeleton.getAdjacentNodes(node); DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size()); int[] choice; double maxScore = Double.NEGATIVE_INFINITY; List<Node> parents = null; while ((choice = gen.next()) != null) { List<Node> _parents = GraphUtils.asList(choice, adj); double score = score(node, _parents); scoreReports.put(-score, _parents.toString()); if (score > maxScore) { maxScore = score; parents = _parents; } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance() .log( "score", "For " + node + " parents = " + scoreReports.get(score) + " score = " + -score); } TetradLogger.getInstance().log("score", ""); if (parents == null) { continue; } if (normal(node, parents)) continue; for (Node _node : adj) { if (parents.contains(_node)) { Edge parentEdge = Edges.directedEdge(_node, node); if (!graph.containsEdge(parentEdge)) { graph.addEdge(parentEdge); } } } } }
// Invalid if then nodes or graph changes. private void calculateArrowsBackward(Node x, Node y, Graph graph) { if (x == y) { return; } if (!graph.isAdjacentTo(x, y)) { return; } if (!knowledgeEmpty()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); clearArrow(x, y); List<Node> _naYX = new ArrayList<Node>(naYX); DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _naYX.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> H = GraphUtils.asList(choice, _naYX); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, H)) { continue; } } double bump = deleteEval(x, y, H, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, H, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } }
public Graph search(List<Node> nodes) { long startTime = System.currentTimeMillis(); localScoreCache.clear(); if (!dataSet().getVariables().containsAll(nodes)) { throw new IllegalArgumentException("All of the nodes must be in " + "the supplied data set."); } Graph graph; if (initialGraph == null) { graph = new EdgeListGraphSingleConnections(nodes); } else { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); graph = new EdgeListGraphSingleConnections(initialGraph); } topGraphs.clear(); buildIndexing(graph); addRequiredEdges(graph); score = 0.0; // Do forward search. fes(graph, nodes); // Do backward search. bes(graph); long endTime = System.currentTimeMillis(); this.elapsedTime = endTime - startTime; this.logger.log("graph", "\nReturning this graph: " + graph); this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.flush(); return graph; }
private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) { if (RandomUtil.getInstance().nextDouble() > 0.5) { Node temp = x; x = y; y = temp; } TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y); SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> neighborsx = graph.getAdjacentNodes(x); neighborsx.remove(y); double max = Double.NEGATIVE_INFINITY; boolean left = false; boolean right = false; DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size()); int[] choicex; while ((choicex = genx.next()) != null) { List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx); List<Node> condxPlus = new ArrayList<Node>(condxMinus); condxPlus.add(y); double xPlus = score(x, condxPlus); double xMinus = score(x, condxMinus); List<Node> neighborsy = graph.getAdjacentNodes(y); neighborsy.remove(x); DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size()); int[] choicey; while ((choicey = geny.next()) != null) { List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy); // List<Node> parentsY = oldGraph.getParents(y); // parentsY.remove(x); // if (!condyMinus.containsAll(parentsY)) { // continue; // } List<Node> condyPlus = new ArrayList<Node>(condyMinus); condyPlus.add(x); double yPlus = score(y, condyPlus); double yMinus = score(y, condyMinus); // Checking them all at once is expensive but avoids lexical ordering problems in the // algorithm. if (normal(y, condyPlus) || normal(x, condxMinus) || normal(x, condxPlus) || normal(y, condyMinus)) { continue; } double delta = 0.0; if (strong) { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance().log("info", scoreReports.get(score)); } graph.removeEdges(x, y); if (left) { graph.addDirectedEdge(y, x); } if (right) { graph.addDirectedEdge(x, y); } if (!graph.isAdjacentTo(x, y)) { graph.addUndirectedEdge(x, y); } }
/** * Calculates the error variance for the given error node, given all of the coefficient values in * the model. * * @param error An error term in the model--i.e. a variable with NodeType.ERROR. * @return The value of the error variance, or Double.NaN is the value is undefined. */ private double calculateErrorVarianceFromParams(Node error) { error = semGraph.getNode(error.getName()); Node child = semGraph.getChildren(error).get(0); List<Node> parents = semGraph.getParents(child); double otherVariance = 0; for (Node parent : parents) { if (parent == error) continue; double coef = getEdgeCoefficient(parent, child); otherVariance += coef * coef; } if (parents.size() >= 2) { ChoiceGenerator gen = new ChoiceGenerator(parents.size(), 2); int[] indices; while ((indices = gen.next()) != null) { Node node1 = parents.get(indices[0]); Node node2 = parents.get(indices[1]); double coef1, coef2; if (node1.getNodeType() != NodeType.ERROR) { coef1 = getEdgeCoefficient(node1, child); } else { coef1 = 1; } if (node2.getNodeType() != NodeType.ERROR) { coef2 = getEdgeCoefficient(node2, child); } else { coef2 = 1; } List<List<Node>> treks = GraphUtils.treksIncludingBidirected(semGraph, node1, node2); double cov = 0.0; for (List<Node> trek : treks) { double product = 1.0; for (int i = 1; i < trek.size(); i++) { Node _node1 = trek.get(i - 1); Node _node2 = trek.get(i); Edge edge = semGraph.getEdge(_node1, _node2); double factor; if (Edges.isBidirectedEdge(edge)) { factor = edgeParameters.get(edge); } else if (!edgeParameters.containsKey(edge)) { factor = 1; } else if (semGraph.isParentOf(_node1, _node2)) { factor = getEdgeCoefficient(_node1, _node2); } else { factor = getEdgeCoefficient(_node2, _node1); } product *= factor; } cov += product; } otherVariance += 2 * coef1 * coef2 * cov; } } return 1.0 - otherVariance <= 0 ? Double.NaN : 1.0 - otherVariance; }