/** * Completes a pattern that was modified by an insertion/deletion operator Based on the algorithm * described on Appendix C of (Chickering, 2002). */ private void rebuildPattern(Graph graph) { SearchGraphUtils.basicPattern(graph, false); addRequiredEdges(graph); meekOrient(graph, getKnowledge()); if (TetradLogger.getInstance().isEventActive("rebuiltPatterns")) { TetradLogger.getInstance().log("rebuiltPatterns", "Rebuilt pattern = " + graph); } }
public void configurationActived(TetradLoggerEvent evt) { TetradLoggerConfig config = evt.getTetradLoggerConfig(); // if logging is actually turned on, then open display. if (TetradLogger.getInstance().isLogging() && config.isActive() && TetradLogger.getInstance().isDisplayLogEnabled()) { // if the log display isn't already up, open it. if (!isDisplayLogging() && allowAutomaticLogPopup()) { setDisplayLogging(true); } } }
private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) { for (Node node : nodes) { SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> adj = skeleton.getAdjacentNodes(node); DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size()); int[] choice; double maxScore = Double.NEGATIVE_INFINITY; List<Node> parents = null; while ((choice = gen.next()) != null) { List<Node> _parents = GraphUtils.asList(choice, adj); double score = score(node, _parents); scoreReports.put(-score, _parents.toString()); if (score > maxScore) { maxScore = score; parents = _parents; } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance() .log( "score", "For " + node + " parents = " + scoreReports.get(score) + " score = " + -score); } TetradLogger.getInstance().log("score", ""); if (parents == null) { continue; } if (normal(node, parents)) continue; for (Node _node : adj) { if (parents.contains(_node)) { Edge parentEdge = Edges.directedEdge(_node, node); if (!graph.containsEdge(parentEdge)) { graph.addEdge(parentEdge); } } } } }
private void bes(Graph graph) { TetradLogger.getInstance().log("info", "** BACKWARD EQUIVALENCE SEARCH"); initializeArrowsBackward(graph); while (!sortedArrows.isEmpty()) { Arrow arrow = sortedArrows.first(); sortedArrows.remove(arrow); Node x = arrow.getX(); Node y = arrow.getY(); clearArrow(x, y); if (!validDelete(arrow.getHOrT(), arrow.getNaYX(), graph)) { continue; } List<Node> h = arrow.getHOrT(); double bump = arrow.getBump(); delete(x, y, h, graph, bump); score += bump; rebuildPattern(graph); storeGraph(graph); initializeArrowsBackward( graph); // Rebuilds Arrows from scratch each time. Fast enough for backwards. } }
/** * States whether the log display should be automatically displayed, if there is no value in the * user's prefs then it will display a prompt asking the user whether they would like to disable * automatic popups. */ private boolean allowAutomaticLogPopup() { Boolean allowed = TetradLogger.getInstance().isAutomaticLogDisplayEnabled(); // ask the user whether they way the feature etc. if (allowed == null) { String message = "<html>Whenever Tetrad's logging features are active any generated log <br>" + "output will be automatically display in Tetrad's log display. Would you like Tetrad<br>" + "to continue to automatically open the log display window whenever there is logging output?</html>"; int option = JOptionPane.showConfirmDialog( this, message, "Automatic Logging", JOptionPane.YES_NO_OPTION); if (option == JOptionPane.NO_OPTION) { JOptionPane.showMessageDialog( this, "This feature can be enabled later by going to Logging>Setup Logging."); } TetradLogger.getInstance().setAutomaticLogDisplayEnabled(option == JOptionPane.YES_OPTION); // return true, so that opens this time, in the future the user's pref will be used. return true; } return allowed; }
/** * Sets whether the display log output should be displayed or not. If true then a text area * roughly 20% of the screen size will appear on the bottom and will display any log output, * otherwise just the standard tetrad workbend is shown. * * @param displayLogging */ public void setDisplayLogging(boolean displayLogging) { if (displayLogging) { this.logArea = new TetradLogArea(this); } else { if (this.logArea != null) { TetradLogger.getInstance().removeOutputStream(this.logArea.getOutputStream()); } this.logArea = null; } setupDesktop(); revalidate(); repaint(); }
/** Constructs a new desktop. */ public TetradDesktop() { setBackground(new Color(204, 204, 204)); sessionNodeKeys = new ArrayList(); // Create the desktop pane. this.desktopPane = new JDesktopPane(); // Do layout. setLayout(new BorderLayout()); desktopPane.setDesktopManager(new DefaultDesktopManager()); desktopPane.setBorder(new BevelBorder(BevelBorder.LOWERED)); desktopPane.addPropertyChangeListener(this); this.setupDesktop(); TetradLogger.getInstance().addTetradLoggerListener(new LoggerListener()); }
/** * Forward equivalence search. * * @param graph The graph in the state prior to the forward equivalence search. */ private void fes(Graph graph, List<Node> nodes) { TetradLogger.getInstance().log("info", "** FORWARD EQUIVALENCE SEARCH"); lookupArrows = new HashMap<OrderedPair, Set<Arrow>>(); initializeArrowsForward(nodes); while (!sortedArrows.isEmpty()) { Arrow arrow = sortedArrows.first(); sortedArrows.remove(arrow); Node x = arrow.getX(); Node y = arrow.getY(); clearArrow(x, y); if (graph.isAdjacentTo(x, y)) { continue; } if (!validInsert(x, y, arrow.getHOrT(), arrow.getNaYX(), graph)) { continue; } List<Node> t = arrow.getHOrT(); double bump = arrow.getBump(); Set<Edge> edges = graph.getEdges(); insert(x, y, t, graph, bump); score += bump; rebuildPattern(graph); // Try to avoid duplicating scoring calls. First clear out all of the edges that need to be // changed, // then change them, checking to see if they're already been changed. I know, roundabout, but // there's // a performance boost. for (Edge edge : graph.getEdges()) { if (!edges.contains(edge)) { reevaluateForward(graph, nodes, edge.getNode1(), edge.getNode2()); } } storeGraph(graph); } }
/** * Determines whether variable x is independent of variable y given a list of conditioning * variables z. * * @param x the one variable being compared. * @param y the second variable being compared. * @param z the list of conditioning variables. * @return true iff x _||_ y | z. * @throws RuntimeException if a matrix singularity is encountered. */ public boolean isIndependent(Node x, Node y, List<Node> z) { TetradMatrix submatrix = subMatrix(x, y, z); double r = 0; try { r = StatUtils.partialCorrelation(submatrix); if (Double.isNaN((r)) || r < -1. || r > 1.) throw new RuntimeException(); } catch (Exception e) { DepthChoiceGenerator gen = new DepthChoiceGenerator(z.size(), z.size()); int[] choice; while ((choice = gen.next()) != null) { try { List<Node> z2 = new ArrayList<Node>(z); z2.removeAll(GraphUtils.asList(choice, z)); submatrix = subMatrix(x, y, z2); r = StatUtils.partialCorrelation(submatrix); } catch (Exception e2) { continue; } // if (Double.isNaN(r)) continue; // // if (r > 1.) r = 1.; // if (r < -1.) r = -1.; if (Double.isNaN(r) || r < -1. || r > 1.) continue; break; } } // Either dividing by a zero standard deviation (in which case it's dependent) or doing a // regression // (effectively) with a multicolliarity if (Double.isNaN(r)) { int[] _z = new int[z.size()]; // for (int i = 0; i < _z.length; i++) _z[i] = i + 2; // //// double varx = StatUtils.partialVariance(submatrix, 0, _z); // submatrix.get(0, // 0); //// double vary = StatUtils.partialVariance(submatrix, 1, _z); //submatrix.get(1, // 1); // // double varx = submatrix.get(0, 0); // double vary = submatrix.get(1, 1); // // if (varx * vary == 0) { return true; // } } if (r > 1.) r = 1.; if (r < -1.) r = -1.; this.fisherZ = Math.sqrt(sampleSize() - z.size() - 3.0) * 0.5 * (Math.log(1.0 + r) - Math.log(1.0 - r)); if (Double.isNaN(this.fisherZ)) { throw new IllegalArgumentException( "The Fisher's Z " + "score for independence fact " + x + " _||_ " + y + " | " + z + " is undefined. r = " + r); } boolean independent = getPValue() > alpha; if (independent) { TetradLogger.getInstance() .log("independencies", SearchLogUtils.independenceFactMsg(x, y, z, getPValue())); } else { TetradLogger.getInstance() .log("dependencies", SearchLogUtils.dependenceFactMsg(x, y, z, getPValue())); } return independent; }
private void addRequiredEdges(Graph graph) { if (true) return; if (knowledgeEmpty()) return; for (Iterator<KnowledgeEdge> it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) { KnowledgeEdge next = it.next(); Node nodeA = graph.getNode(next.getFrom()); Node nodeB = graph.getNode(next.getTo()); if (!graph.isAncestorOf(nodeB, nodeA)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeA, nodeB); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB)); } } for (Edge edge : graph.getEdges()) { final String A = edge.getNode1().getName(); final String B = edge.getNode2().getName(); if (knowledge.isForbidden(A, B)) { Node nodeA = edge.getNode1(); Node nodeB = edge.getNode2(); if (nodeA != null && nodeB != null && graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } } else if (knowledge.isForbidden(B, A)) { Node nodeA = edge.getNode2(); Node nodeB = edge.getNode1(); if (nodeA != null && nodeB != null && graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } } } }
/** Do an actual deletion (Definition 13 from Chickering, 2002). */ private void delete(Node x, Node y, List<Node> subset, Graph graph, double bump) { Edge trueEdge = null; if (trueGraph != null) { Node _x = trueGraph.getNode(x.getName()); Node _y = trueGraph.getNode(y.getName()); trueEdge = trueGraph.getEdge(_x, _y); } if (log && verbose) { Edge oldEdge = graph.getEdge(x, y); String label = trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance() .log( "deletedEdges", (graph.getNumEdges() - 1) + ". DELETE " + oldEdge + " " + subset + " (" + bump + ") " + label); out.println( (graph.getNumEdges() - 1) + ". DELETE " + oldEdge + " " + subset + " (" + bump + ") " + label); } else { int numEdges = graph.getNumEdges() - 1; if (numEdges % 50 == 0) out.println(numEdges); } graph.removeEdge(x, y); for (Node h : subset) { Edge oldEdge = graph.getEdge(y, h); graph.removeEdge(y, h); graph.addDirectedEdge(y, h); if (log) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h)); } if (verbose) { out.println("--- Directing " + oldEdge + " to " + graph.getEdge(y, h)); } if (Edges.isUndirectedEdge(graph.getEdge(x, h))) { if (!graph.isAdjacentTo(x, h)) throw new IllegalArgumentException("Not adjacent: " + x + ", " + h); oldEdge = graph.getEdge(x, h); graph.removeEdge(x, h); graph.addDirectedEdge(x, h); if (log) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h)); } if (verbose) { out.println("--- Directing " + oldEdge + " to " + graph.getEdge(x, h)); } } } }
// serial. private void insert(Node x, Node y, List<Node> t, Graph graph, double bump) { if (graph.isAdjacentTo(x, y)) { return; // The initial graph may already have put this edge in the graph. // throw new IllegalArgumentException(x + " and " + y + " are already adjacent in // the graph."); } Edge trueEdge = null; if (trueGraph != null) { Node _x = trueGraph.getNode(x.getName()); Node _y = trueGraph.getNode(y.getName()); trueEdge = trueGraph.getEdge(_x, _y); } graph.addDirectedEdge(x, y); if (log) { String label = trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance() .log( "insertedEdges", graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + t + " " + bump + " " + label); } else { int numEdges = graph.getNumEdges() - 1; if (verbose) { if (numEdges % 50 == 0) out.println(numEdges); } } if (verbose) { String label = trueGraph != null && trueEdge != null ? "*" : ""; out.println( graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + t + " " + bump + " " + label); } else { int numEdges = graph.getNumEdges() - 1; if (verbose) { if (numEdges % 50 == 0) out.println(numEdges); } } for (Node _t : t) { Edge oldEdge = graph.getEdge(_t, y); if (oldEdge == null) throw new IllegalArgumentException("Not adjacent: " + _t + ", " + y); graph.removeEdge(_t, y); graph.addDirectedEdge(_t, y); if (log && verbose) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(_t, y)); out.println("--- Directing " + oldEdge + " to " + graph.getEdge(_t, y)); } } }
/** * GesSearch is an implentation of the GES algorithm, as specified in Chickering (2002) "Optimal * structure identification with greedy search" Journal of Machine Learning Research. It works for * both BayesNets and SEMs. * * <p>Some code optimization could be done for the scoring part of the graph for discrete models * (method scoreGraphChange). Some of Andrew Moore's approaches for caching sufficient statistics, * for instance. * * @author Ricardo Silva, Summer 2003 * @author Joseph Ramsey, Revisions 10/2005 */ public final class GesConcurrent implements GraphSearch, GraphScorer { /** The data set, various variable subsets of which are to be scored. */ private DataSet dataSet; /** The covariance matrix for the data set. */ private ICovarianceMatrix covariances; /** Sample size, either from the data set or from the variances. */ private int sampleSize; /** Specification of forbidden and required edges. */ private IKnowledge knowledge = new Knowledge2(); /** Map from variables to their column indices in the data set. */ private HashMap<Node, Integer> hashIndices; /** List of variables in the data set, in order. */ private List<Node> variables; /** True iff the data set is discrete. */ private boolean discrete; /** * The true graph, if known. If this is provided, asterisks will be printed out next to false * positive added edges (that is, edges added that aren't adjacencies in the true graph). */ private Graph trueGraph; /** An initial graph to start from. */ private Graph initialGraph; /** Caches scores for discrete search. */ private final LocalScoreCache localScoreCache = new LocalScoreCache(); /** Elapsed time of the most recent search. */ private long elapsedTime; /** * True if cycles are to be aggressively prevented. May be expensive for large graphs (but also * useful for large graphs). */ private boolean aggressivelyPreventCycles = false; /** Listeners for graph change events. */ private transient List<PropertyChangeListener> listeners; /** Penalty discount--the BIC penalty is multiplied by this (for continuous variables). */ private double penaltyDiscount = 1.0; /** The score for discrete searches. */ private LocalDiscreteScore discreteScore; /** The logger for this class. The config needs to be set. */ private TetradLogger logger = TetradLogger.getInstance(); /** The top n graphs found by the algorithm, where n is numPatternsToStore. */ private SortedSet<ScoredGraph> topGraphs = new TreeSet<ScoredGraph>(); /** The number of top patterns to store. */ private int numPatternsToStore = 0; // Potential arrows sorted by bump high to low. The first one is a candidate for adding to the // graph. private SortedSet<Arrow> sortedArrows = new ConcurrentSkipListSet<Arrow>(); // Arrows added to sortedArrows for each <i, j>. private Map<OrderedPair, Set<Arrow>> lookupArrows; /** True if graphs should be stored. */ private boolean log = true; private boolean verbose = false; private int NTHREADS = Runtime.getRuntime().availableProcessors() * 5; // private boolean checkedKnowledgeEmpty = false; // private boolean knowledgeEmpty = false; private ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool(); private double score; private PrintStream out; // ===========================CONSTRUCTORS=============================// public GesConcurrent(DataSet dataSet) { setDataSet(dataSet); if (dataSet.isDiscrete()) { BDeuScore score = new BDeuScore(dataSet); score.setSamplePrior(10); score.setStructurePrior(0.001); } setStructurePrior(0.001); setSamplePrior(10.); } public GesConcurrent(ICovarianceMatrix covMatrix) { setCovMatrix(covMatrix); setStructurePrior(0.001); setSamplePrior(10.); } // ==========================PUBLIC METHODS==========================// public boolean isAggressivelyPreventCycles() { return this.aggressivelyPreventCycles; } public void setAggressivelyPreventCycles(boolean aggressivelyPreventCycles) { this.aggressivelyPreventCycles = aggressivelyPreventCycles; } /** * Greedy equivalence search: Start from the empty graph, add edges till model is significant. * Then start deleting edges till a minimum is achieved. * * @return the resulting Pattern. */ public Graph search() { Graph graph; if (initialGraph == null) { graph = new EdgeListGraphSingleConnections(getVariables()); } else { graph = new EdgeListGraphSingleConnections(initialGraph); } fireGraphChange(graph); buildIndexing(graph); addRequiredEdges(graph); topGraphs.clear(); storeGraph(graph); List<Node> nodes = graph.getNodes(); long start = System.currentTimeMillis(); score = 0.0; // Do forward search. fes(graph, nodes); // Do backward search. bes(graph); long endTime = System.currentTimeMillis(); this.elapsedTime = endTime - start; this.logger.log("graph", "\nReturning this graph: " + graph); this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.flush(); return graph; } public Graph search(List<Node> nodes) { long startTime = System.currentTimeMillis(); localScoreCache.clear(); if (!dataSet().getVariables().containsAll(nodes)) { throw new IllegalArgumentException("All of the nodes must be in " + "the supplied data set."); } Graph graph; if (initialGraph == null) { graph = new EdgeListGraphSingleConnections(nodes); } else { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); graph = new EdgeListGraphSingleConnections(initialGraph); } topGraphs.clear(); buildIndexing(graph); addRequiredEdges(graph); score = 0.0; // Do forward search. fes(graph, nodes); // Do backward search. bes(graph); long endTime = System.currentTimeMillis(); this.elapsedTime = endTime - startTime; this.logger.log("graph", "\nReturning this graph: " + graph); this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.flush(); return graph; } public IKnowledge getKnowledge() { return knowledge; } /** * Sets the background knowledge. * * @param knowledge the knowledge object, specifying forbidden and required edges. */ public void setKnowledge(IKnowledge knowledge) { if (knowledge == null) throw new NullPointerException(); this.knowledge = knowledge; } public void setStructurePrior(double structurePrior) { if (getDiscreteScore() != null) { getDiscreteScore().setStructurePrior(structurePrior); } } public void setSamplePrior(double samplePrior) { if (getDiscreteScore() != null) { getDiscreteScore().setSamplePrior(samplePrior); } } public long getElapsedTime() { return elapsedTime; } public void addPropertyChangeListener(PropertyChangeListener l) { getListeners().add(l); } public double getPenaltyDiscount() { return penaltyDiscount; } public void setPenaltyDiscount(double penaltyDiscount) { if (penaltyDiscount < 0) { throw new IllegalArgumentException("Penalty discount must be >= 0: " + penaltyDiscount); } this.penaltyDiscount = penaltyDiscount; } public void setTrueGraph(Graph trueGraph) { this.trueGraph = trueGraph; } public double getScore(Graph dag) { return scoreDag(dag); } public SortedSet<ScoredGraph> getTopGraphs() { return topGraphs; } public int getNumPatternsToStore() { return numPatternsToStore; } public void setNumPatternsToStore(int numPatternsToStore) { if (numPatternsToStore < 0) { throw new IllegalArgumentException( "# graphs to store must at least 0: " + numPatternsToStore); } this.numPatternsToStore = numPatternsToStore; } public LocalDiscreteScore getDiscreteScore() { return discreteScore; } public void setDiscreteScore(LocalDiscreteScore discreteScore) { if (discreteScore.getDataSet() != dataSet) { throw new IllegalArgumentException("Must use the same data set."); } this.discreteScore = discreteScore; } // ===========================PRIVATE METHODS========================// /** * Forward equivalence search. * * @param graph The graph in the state prior to the forward equivalence search. */ private void fes(Graph graph, List<Node> nodes) { TetradLogger.getInstance().log("info", "** FORWARD EQUIVALENCE SEARCH"); lookupArrows = new HashMap<OrderedPair, Set<Arrow>>(); initializeArrowsForward(nodes); while (!sortedArrows.isEmpty()) { Arrow arrow = sortedArrows.first(); sortedArrows.remove(arrow); Node x = arrow.getX(); Node y = arrow.getY(); clearArrow(x, y); if (graph.isAdjacentTo(x, y)) { continue; } if (!validInsert(x, y, arrow.getHOrT(), arrow.getNaYX(), graph)) { continue; } List<Node> t = arrow.getHOrT(); double bump = arrow.getBump(); Set<Edge> edges = graph.getEdges(); insert(x, y, t, graph, bump); score += bump; rebuildPattern(graph); // Try to avoid duplicating scoring calls. First clear out all of the edges that need to be // changed, // then change them, checking to see if they're already been changed. I know, roundabout, but // there's // a performance boost. for (Edge edge : graph.getEdges()) { if (!edges.contains(edge)) { reevaluateForward(graph, nodes, edge.getNode1(), edge.getNode2()); } } storeGraph(graph); } } private void bes(Graph graph) { TetradLogger.getInstance().log("info", "** BACKWARD EQUIVALENCE SEARCH"); initializeArrowsBackward(graph); while (!sortedArrows.isEmpty()) { Arrow arrow = sortedArrows.first(); sortedArrows.remove(arrow); Node x = arrow.getX(); Node y = arrow.getY(); clearArrow(x, y); if (!validDelete(arrow.getHOrT(), arrow.getNaYX(), graph)) { continue; } List<Node> h = arrow.getHOrT(); double bump = arrow.getBump(); delete(x, y, h, graph, bump); score += bump; rebuildPattern(graph); storeGraph(graph); initializeArrowsBackward( graph); // Rebuilds Arrows from scratch each time. Fast enough for backwards. } } // Expensive // Concurrent. private void initializeArrowsForward(final List<Node> nodes) { final List<Node> emptyList = new ArrayList<Node>(0); final Set<Node> emptySet = new HashSet<Node>(0); List<Callable<Boolean>> callables = new ArrayList<Callable<Boolean>>(); for (int t = 0; t < NTHREADS; t++) { final int _t = t; Callable<Boolean> worker = new Callable<Boolean>() { @Override public Boolean call() throws Exception { int chunk = nodes.size() / NTHREADS + 1; for (int j = _t * chunk; j < Math.min((_t + 1) * chunk, nodes.size()); j++) { if (log && verbose) { if ((j + 1) % 10 == 0) out.println("Initializing arrows forward: " + (j + 1)); } for (int i = j + 1; i < nodes.size(); i++) { Node x = nodes.get(i); Node y = nodes.get(j); if (!knowledgeEmpty()) { if (getKnowledge().isForbidden(x.getName(), y.getName())) { continue; } if (!validSetByKnowledge(y, emptyList)) { continue; } } double bump = scoreGraphChange(y, Collections.singleton(x), emptySet); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, emptyList, emptyList); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); Arrow arrow2 = new Arrow(bump, y, x, emptyList, emptyList); sortedArrows.add(arrow2); addLookupArrow(y, x, arrow); } } } return true; } }; callables.add(worker); } pool.invokeAll(callables); } private boolean knowledgeEmpty() { // if (!checkedKnowledgeEmpty) { // knowledgeEmpty = knowledge.isEmpty(); // checkedKnowledgeEmpty = true; // } // // return knowledgeEmpty; // return knowledge.isEmpty(); return false; } private void initializeArrowsBackward(Graph graph) { sortedArrows.clear(); lookupArrows.clear(); for (Edge edge : graph.getEdges()) { Node x = edge.getNode1(); Node y = edge.getNode2(); if (!knowledgeEmpty()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { continue; } } if (Edges.isDirectedEdge(edge)) { calculateArrowsBackward(x, y, graph); } else { calculateArrowsBackward(x, y, graph); calculateArrowsBackward(y, x, graph); } } } private void reevaluateForward( final Graph graph, final List<Node> nodes, final Node x, final Node y) { List<Callable<Boolean>> callables = new ArrayList<Callable<Boolean>>(); for (int t = 0; t < NTHREADS; t++) { final int _t = t; Callable<Boolean> worker = new Callable<Boolean>() { @Override public Boolean call() { int chunk = nodes.size() / NTHREADS + 1; for (int _w = _t * chunk; _w < Math.min((_t + 1) * chunk, nodes.size()); _w++) { final Node w = nodes.get(_w); if (w == x) continue; if (w == y) continue; if (!graph.isAdjacentTo(w, x)) { calculateArrowsForward(w, x, graph); if (graph.isAdjacentTo(w, y)) { calculateArrowsForward(x, w, graph); } } if (!graph.isAdjacentTo(w, y)) { calculateArrowsForward(w, y, graph); if (graph.isAdjacentTo(w, x)) { calculateArrowsForward(y, w, graph); } } } return true; } }; callables.add(worker); } pool.invokeAll(callables); } private void calculateArrowsForward(Node x, Node y, Graph graph) { clearArrow(x, y); if (!knowledgeEmpty()) { if (getKnowledge().isForbidden(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); List<Node> t = getTNeighbors(x, y, graph); DepthChoiceGenerator gen = new DepthChoiceGenerator(t.size(), t.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> s = GraphUtils.asList(choice, t); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, s)) { continue; } } double bump = insertEval(x, y, s, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, s, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } } // Invalid if then nodes or graph changes. private void calculateArrowsBackward(Node x, Node y, Graph graph) { if (x == y) { return; } if (!graph.isAdjacentTo(x, y)) { return; } if (!knowledgeEmpty()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); clearArrow(x, y); List<Node> _naYX = new ArrayList<Node>(naYX); DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _naYX.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> H = GraphUtils.asList(choice, _naYX); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, H)) { continue; } } double bump = deleteEval(x, y, H, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, H, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } } /** True iff log output should be produced. */ public boolean isLog() { return log; } public void setLog(boolean log) { this.log = log; } public Graph getInitialGraph() { return initialGraph; } // Cannot be done if the graph changes. public void setInitialGraph(Graph initialGraph) { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); out.println("Initial graph variables: " + initialGraph.getNodes()); out.println("Data set variables: " + variables); if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) { throw new IllegalArgumentException("Variables aren't the same."); } this.initialGraph = initialGraph; } public void setVerbose(boolean verbose) { this.verbose = verbose; } public void setOut(PrintStream out) { this.out = out; } public PrintStream getOut() { return out; } // Concurrent OK. private static class Arrow implements Comparable { private double bump; private Node x; private Node y; private List<Node> hOrT; private List<Node> naYX; public Arrow(double bump, Node x, Node y, List<Node> hOrT, List<Node> naYX) { this.bump = bump; this.x = x; this.y = y; this.hOrT = hOrT; this.naYX = naYX; } public double getBump() { return bump; } public Node getX() { return x; } public Node getY() { return y; } public List<Node> getHOrT() { return hOrT; } public List<Node> getNaYX() { return naYX; } // Sorting is by bump, high to low. public int compareTo(Object o) { Arrow arrow = (Arrow) o; return Double.compare(arrow.getBump(), getBump()); } public String toString() { return "Arrow<" + x + "->" + y + " bump = " + bump + " t/h = " + hOrT + " naYX = " + naYX + ">"; } } /** Get all nodes that are connected to Y by an undirected edge and not adjacent to X. */ private static List<Node> getTNeighbors(Node x, Node y, Graph graph) { List<Edge> yEdges = graph.getEdges(y); List<Node> tNeighbors = new ArrayList<Node>(); for (Edge edge : yEdges) { if (!Edges.isUndirectedEdge(edge)) { continue; } Node z = edge.getDistalNode(y); if (graph.isAdjacentTo(z, x)) { continue; } tNeighbors.add(z); } return tNeighbors; } /** Evaluate the Insert(X, Y, T) operator (Definition 12 from Chickering, 2002). */ private double insertEval(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) { Set<Node> set1 = new HashSet<Node>(naYX); set1.addAll(t); List<Node> paY = graph.getParents(y); set1.addAll(paY); Set<Node> set2 = new HashSet<Node>(set1); set1.add(x); return scoreGraphChange(y, set1, set2); } /** Evaluate the Delete(X, Y, T) operator (Definition 12 from Chickering, 2002). */ // Can be done concurrently. private double deleteEval(Node x, Node y, List<Node> h, List<Node> naYX, Graph graph) { List<Node> paY = graph.getParents(y); Set<Node> paYMinuxX = new HashSet<Node>(paY); paYMinuxX.remove(x); Set<Node> set1 = new HashSet<Node>(naYX); set1.removeAll(h); set1.addAll(paYMinuxX); Set<Node> set2 = new HashSet<Node>(naYX); set2.removeAll(h); set2.addAll(paY); return scoreGraphChange(y, set1, set2); } /* * Do an actual insertion * (Definition 12 from Chickering, 2002). **/ // serial. private void insert(Node x, Node y, List<Node> t, Graph graph, double bump) { if (graph.isAdjacentTo(x, y)) { return; // The initial graph may already have put this edge in the graph. // throw new IllegalArgumentException(x + " and " + y + " are already adjacent in // the graph."); } Edge trueEdge = null; if (trueGraph != null) { Node _x = trueGraph.getNode(x.getName()); Node _y = trueGraph.getNode(y.getName()); trueEdge = trueGraph.getEdge(_x, _y); } graph.addDirectedEdge(x, y); if (log) { String label = trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance() .log( "insertedEdges", graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + t + " " + bump + " " + label); } else { int numEdges = graph.getNumEdges() - 1; if (verbose) { if (numEdges % 50 == 0) out.println(numEdges); } } if (verbose) { String label = trueGraph != null && trueEdge != null ? "*" : ""; out.println( graph.getNumEdges() + ". INSERT " + graph.getEdge(x, y) + " " + t + " " + bump + " " + label); } else { int numEdges = graph.getNumEdges() - 1; if (verbose) { if (numEdges % 50 == 0) out.println(numEdges); } } for (Node _t : t) { Edge oldEdge = graph.getEdge(_t, y); if (oldEdge == null) throw new IllegalArgumentException("Not adjacent: " + _t + ", " + y); graph.removeEdge(_t, y); graph.addDirectedEdge(_t, y); if (log && verbose) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(_t, y)); out.println("--- Directing " + oldEdge + " to " + graph.getEdge(_t, y)); } } } /** Do an actual deletion (Definition 13 from Chickering, 2002). */ private void delete(Node x, Node y, List<Node> subset, Graph graph, double bump) { Edge trueEdge = null; if (trueGraph != null) { Node _x = trueGraph.getNode(x.getName()); Node _y = trueGraph.getNode(y.getName()); trueEdge = trueGraph.getEdge(_x, _y); } if (log && verbose) { Edge oldEdge = graph.getEdge(x, y); String label = trueGraph != null && trueEdge != null ? "*" : ""; TetradLogger.getInstance() .log( "deletedEdges", (graph.getNumEdges() - 1) + ". DELETE " + oldEdge + " " + subset + " (" + bump + ") " + label); out.println( (graph.getNumEdges() - 1) + ". DELETE " + oldEdge + " " + subset + " (" + bump + ") " + label); } else { int numEdges = graph.getNumEdges() - 1; if (numEdges % 50 == 0) out.println(numEdges); } graph.removeEdge(x, y); for (Node h : subset) { Edge oldEdge = graph.getEdge(y, h); graph.removeEdge(y, h); graph.addDirectedEdge(y, h); if (log) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h)); } if (verbose) { out.println("--- Directing " + oldEdge + " to " + graph.getEdge(y, h)); } if (Edges.isUndirectedEdge(graph.getEdge(x, h))) { if (!graph.isAdjacentTo(x, h)) throw new IllegalArgumentException("Not adjacent: " + x + ", " + h); oldEdge = graph.getEdge(x, h); graph.removeEdge(x, h); graph.addDirectedEdge(x, h); if (log) { TetradLogger.getInstance() .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h)); } if (verbose) { out.println("--- Directing " + oldEdge + " to " + graph.getEdge(x, h)); } } } } /* * Test if the candidate insertion is a valid operation * (Theorem 15 from Chickering, 2002). **/ private boolean validInsert(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) { List<Node> union = new ArrayList<Node>(t); // t and nayx are disjoint union.addAll(naYX); return isClique(union, graph) && !existsUnblockedSemiDirectedPath(y, x, union, graph); } /** Test if the candidate deletion is a valid operation (Theorem 17 from Chickering, 2002). */ private static boolean validDelete(List<Node> h, List<Node> naXY, Graph graph) { List<Node> list = new ArrayList<Node>(naXY); list.removeAll(h); return isClique(list, graph); } // ---Background knowledge methods. private void addRequiredEdges(Graph graph) { if (true) return; if (knowledgeEmpty()) return; for (Iterator<KnowledgeEdge> it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) { KnowledgeEdge next = it.next(); Node nodeA = graph.getNode(next.getFrom()); Node nodeB = graph.getNode(next.getTo()); if (!graph.isAncestorOf(nodeB, nodeA)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeA, nodeB); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB)); } } for (Edge edge : graph.getEdges()) { final String A = edge.getNode1().getName(); final String B = edge.getNode2().getName(); if (knowledge.isForbidden(A, B)) { Node nodeA = edge.getNode1(); Node nodeB = edge.getNode2(); if (nodeA != null && nodeB != null && graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } } else if (knowledge.isForbidden(B, A)) { Node nodeA = edge.getNode2(); Node nodeB = edge.getNode1(); if (nodeA != null && nodeB != null && graph.isAdjacentTo(nodeA, nodeB) && !graph.isChildOf(nodeA, nodeB)) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } if (!graph.isChildOf(nodeA, nodeB) && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) { if (!graph.isAncestorOf(nodeA, nodeB)) { graph.removeEdges(nodeA, nodeB); graph.addDirectedEdge(nodeB, nodeA); TetradLogger.getInstance() .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA)); } } } } } private String getString(KnowledgeEdge next) { return next.getTo(); } /** * Use background knowledge to decide if an insert or delete operation does not orient edges in a * forbidden direction according to prior knowledge. If some orientation is forbidden in the * subset, the whole subset is forbidden. */ private boolean validSetByKnowledge(Node y, List<Node> subset) { for (Node node : subset) { if (getKnowledge().isForbidden(node.getName(), y.getName())) { return false; } } return true; } // --Auxiliary methods. /** * Find all nodes that are connected to Y by an undirected edge that are adjacent to X (that is, * by undirected or directed edge). */ private static List<Node> getNaYX(Node x, Node y, Graph graph) { List<Edge> yEdges = graph.getEdges(y); List<Node> nayx = new ArrayList<Node>(); for (Edge edge : yEdges) { if (!Edges.isUndirectedEdge(edge)) { continue; } Node z = edge.getDistalNode(y); if (!graph.isAdjacentTo(z, x)) { continue; } nayx.add(z); } return nayx; } /** Returns true iif the given set forms a clique in the given graph. */ private static boolean isClique(List<Node> nodes, Graph graph) { for (int i = 0; i < nodes.size() - 1; i++) { for (int j = i + 1; j < nodes.size(); j++) { if (!graph.isAdjacentTo(nodes.get(i), nodes.get(j))) { return false; } } } return true; } private boolean existsUnblockedSemiDirectedPath(Node from, Node to, List<Node> cond, Graph G) { Queue<Node> Q = new LinkedList<Node>(); Set<Node> V = new HashSet<Node>(); Q.offer(from); V.add(from); while (!Q.isEmpty()) { Node t = Q.remove(); if (t == to) return true; for (Node u : G.getAdjacentNodes(t)) { Edge edge = G.getEdge(t, u); Node c = Edges.traverseSemiDirected(t, edge); if (c == null) continue; if (cond.contains(c)) continue; if (c == to) return true; if (!V.contains(c)) { V.add(c); Q.offer(c); } } } return false; } /** * Completes a pattern that was modified by an insertion/deletion operator Based on the algorithm * described on Appendix C of (Chickering, 2002). */ private void rebuildPattern(Graph graph) { SearchGraphUtils.basicPattern(graph, false); addRequiredEdges(graph); meekOrient(graph, getKnowledge()); if (TetradLogger.getInstance().isEventActive("rebuiltPatterns")) { TetradLogger.getInstance().log("rebuiltPatterns", "Rebuilt pattern = " + graph); } } private Graph pickDag(Graph graph) { SearchGraphUtils.basicPattern(graph, false); addRequiredEdges(graph); boolean containsUndirected; do { containsUndirected = false; for (Edge edge : graph.getEdges()) { if (Edges.isUndirectedEdge(edge)) { containsUndirected = true; graph.removeEdge(edge); Edge _edge = Edges.directedEdge(edge.getNode1(), edge.getNode2()); graph.addEdge(_edge); } } meekOrient(graph, getKnowledge()); } while (containsUndirected); return graph; } /** * Fully direct a graph with background knowledge. I am not sure how to adapt Chickering's * suggested algorithm above (dagToPdag) to incorporate background knowledge, so I am also * implementing this algorithm based on Meek's 1995 UAI paper. Notice it is the same implemented * in PcSearch. *IMPORTANT!* *It assumes all colliders are oriented, as well as arrows dictated by * time order.* */ private void meekOrient(Graph graph, IKnowledge knowledge) { MeekRules rules = new MeekRules(); rules.setOrientInPlace(false); rules.setKnowledge(knowledge); rules.orientImplied(graph); } private void setDataSet(DataSet dataSet) { List<String> _varNames = dataSet.getVariableNames(); this.variables = dataSet.getVariables(); this.dataSet = dataSet; this.discrete = dataSet.isDiscrete(); if (!isDiscrete()) { this.covariances = new CovarianceMatrix(dataSet); } this.sampleSize = dataSet.getNumRows(); } private void setCovMatrix(ICovarianceMatrix covarianceMatrix) { this.covariances = covarianceMatrix; this.variables = covarianceMatrix.getVariables(); this.sampleSize = covarianceMatrix.getSampleSize(); } private void buildIndexing(Graph graph) { this.hashIndices = new HashMap<Node, Integer>(); for (Node node : graph.getNodes()) { this.hashIndices.put(node, variables.indexOf(node)); } } private void clearArrow(Node x, Node y) { final OrderedPair<Node> pair = new OrderedPair<Node>(x, y); final Set<Arrow> lookupArrows = this.lookupArrows.get(pair); if (lookupArrows != null) { sortedArrows.removeAll(lookupArrows); } this.lookupArrows.remove(pair); } private void addLookupArrow(Node i, Node j, Arrow arrow) { OrderedPair<Node> pair = new OrderedPair<Node>(i, j); Set<Arrow> arrows = lookupArrows.get(pair); if (arrows == null) { arrows = new HashSet<Arrow>(); lookupArrows.put(pair, arrows); } arrows.add(arrow); } // ===========================SCORING METHODS===================// public double scoreDag(Graph graph) { Graph dag = new EdgeListGraphSingleConnections(graph); buildIndexing(graph); double score = 0.0; for (Node y : dag.getNodes()) { Set<Node> parents = new HashSet<Node>(dag.getParents(y)); int nextIndex = -1; for (int i = 0; i < getVariables().size(); i++) { nextIndex = hashIndices.get(variables.get(i)); } int parentIndices[] = new int[parents.size()]; Iterator<Node> pi = parents.iterator(); int count = 0; while (pi.hasNext()) { Node nextParent = pi.next(); parentIndices[count++] = hashIndices.get(nextParent); } if (this.isDiscrete()) { score += localDiscreteScore(nextIndex, parentIndices); } else { score += localSemScore(nextIndex, parentIndices); } } return score; } private double scoreGraphChange(Node y, Set<Node> parents1, Set<Node> parents2) { int yIndex = hashIndices.get(y); double score1, score2; int[] parentIndices1 = new int[parents1.size()]; int count = -1; for (Node parent : parents1) { parentIndices1[++count] = hashIndices.get(parent); } if (isDiscrete()) { score1 = localDiscreteScore(yIndex, parentIndices1); } else { score1 = localSemScore(yIndex, parentIndices1); } int[] parentIndices2 = new int[parents2.size()]; int count2 = -1; for (Node parent : parents2) { parentIndices2[++count2] = hashIndices.get(parent); } if (isDiscrete()) { score2 = localDiscreteScore(yIndex, parentIndices2); } else { score2 = localSemScore(yIndex, parentIndices2); } return score1 - score2; } /** Compute the local BDeu score of (i, parents(i)). See (Chickering, 2002). */ private double localDiscreteScore(int i, int parents[]) { return getDiscreteScore().localScore(i, parents); } /** * Calculates the sample likelihood and BIC score for i given its parents in a simple SEM model. */ private double localSemScore(int i, int[] parents) { try { ICovarianceMatrix cov = getCovMatrix(); double varianceY = cov.getValue(i, i); double residualVariance = varianceY; int n = sampleSize(); int p = parents.length; int k = (p * (p + 1)) / 2 + p; // int k = (p + 1) * (p + 1); // int k = p + 1; TetradMatrix covxx = cov.getSelection(parents, parents); TetradMatrix covxxInv = covxx.inverse(); TetradVector covxy = cov.getSelection(parents, new int[] {i}).getColumn(0); TetradVector b = covxxInv.times(covxy); residualVariance -= covxy.dotProduct(b); if (residualVariance <= 0 && verbose) { out.println( "Nonpositive residual varianceY: resVar / varianceY = " + (residualVariance / varianceY)); return Double.NaN; } double c = getPenaltyDiscount(); // return -n * log(residualVariance) - 2 * k; //AIC return -n * Math.log(residualVariance) - c * k * Math.log(n); // return -n * log(residualVariance) - c * k * (log(n) - log(2 * PI)); } catch (Exception e) { e.printStackTrace(); throw new RuntimeException(e); // throwMinimalLinearDependentSet(parents, cov); } } // private void throwMinimalLinearDependentSet(int[] parents, TetradMatrix cov) { // List<Node> _parents = new ArrayList<Node>(); // for (int p : parents) _parents.add(variables.get(p)); // // DepthChoiceGenerator gen = new DepthChoiceGenerator(_parents.size(), _parents.size()); // int[] choice; // // while ((choice = gen.next()) != null) { // int[] sel = new int[choice.length]; // List<Node> _sel = new ArrayList<Node>(); // for (int m = 0; m < choice.length; m++) { // sel[m] = parents[m]; // _sel.add(variables.get(sel[m])); // } // // TetradMatrix m = cov.getSelection(sel, sel); // // try { // m.inverse(); // } catch (Exception e2) { // throw new RuntimeException("Linear dependence among variables: " + _sel); // } // } // } private int sampleSize() { return this.sampleSize; } private List<Node> getVariables() { return variables; } private ICovarianceMatrix getCovMatrix() { return covariances; } private DataSet dataSet() { return dataSet; } private boolean isDiscrete() { return discrete; } private void fireGraphChange(Graph graph) { for (PropertyChangeListener l : getListeners()) { l.propertyChange(new PropertyChangeEvent(this, "graph", null, graph)); } } private List<PropertyChangeListener> getListeners() { if (listeners == null) { listeners = new ArrayList<PropertyChangeListener>(); } return listeners; } private void storeGraph(Graph graph) { if (numPatternsToStore < 1) return; if (topGraphs.isEmpty() || score > topGraphs.first().getScore()) { Graph graphCopy = new EdgeListGraphSingleConnections(graph); topGraphs.add(new ScoredGraph(graphCopy, score)); if (topGraphs.size() > getNumPatternsToStore()) { topGraphs.remove(topGraphs.first()); } } } }
private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) { if (RandomUtil.getInstance().nextDouble() > 0.5) { Node temp = x; x = y; y = temp; } TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y); SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> neighborsx = graph.getAdjacentNodes(x); neighborsx.remove(y); double max = Double.NEGATIVE_INFINITY; boolean left = false; boolean right = false; DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size()); int[] choicex; while ((choicex = genx.next()) != null) { List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx); List<Node> condxPlus = new ArrayList<Node>(condxMinus); condxPlus.add(y); double xPlus = score(x, condxPlus); double xMinus = score(x, condxMinus); List<Node> neighborsy = graph.getAdjacentNodes(y); neighborsy.remove(x); DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size()); int[] choicey; while ((choicey = geny.next()) != null) { List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy); // List<Node> parentsY = oldGraph.getParents(y); // parentsY.remove(x); // if (!condyMinus.containsAll(parentsY)) { // continue; // } List<Node> condyPlus = new ArrayList<Node>(condyMinus); condyPlus.add(x); double yPlus = score(y, condyPlus); double yMinus = score(y, condyMinus); // Checking them all at once is expensive but avoids lexical ordering problems in the // algorithm. if (normal(y, condyPlus) || normal(x, condxMinus) || normal(x, condxPlus) || normal(y, condyMinus)) { continue; } double delta = 0.0; if (strong) { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance().log("info", scoreReports.get(score)); } graph.removeEdges(x, y); if (left) { graph.addDirectedEdge(y, x); } if (right) { graph.addDirectedEdge(x, y); } if (!graph.isAdjacentTo(x, y)) { graph.addUndirectedEdge(x, y); } }