private Graph condense(Graph mimStructure, Graph mimbuildStructure) { // System.out.println("Uncondensed: " + mimbuildStructure); Map<Node, Node> substitutions = new HashMap<Node, Node>(); for (Node node : mimbuildStructure.getNodes()) { for (Node _node : mimStructure.getNodes()) { if (node.getName().startsWith(_node.getName())) { substitutions.put(node, _node); break; } substitutions.put(node, node); } } HashSet<Node> nodes = new HashSet<Node>(substitutions.values()); Graph graph = new EdgeListGraph(new ArrayList<Node>(nodes)); for (Edge edge : mimbuildStructure.getEdges()) { Node node1 = substitutions.get(edge.getNode1()); Node node2 = substitutions.get(edge.getNode2()); if (node1 == node2) continue; if (graph.isAdjacentTo(node1, node2)) continue; graph.addEdge(new Edge(node1, node2, edge.getEndpoint1(), edge.getEndpoint2())); } // System.out.println("Condensed: " + graph); return graph; }
public Graph orient() { Graph skeleton = GraphUtils.undirectedGraph(getPattern()); Graph graph = new EdgeListGraph(skeleton.getNodes()); List<Node> nodes = skeleton.getNodes(); // Collections.shuffle(nodes); if (isR1Done()) { ruleR1(skeleton, graph, nodes); } for (Edge edge : skeleton.getEdges()) { if (!graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) { graph.addUndirectedEdge(edge.getNode1(), edge.getNode2()); } } if (isR2Done()) { ruleR2(skeleton, graph); } if (isMeekDone()) { new MeekRules().orientImplied(graph); } return graph; }
// Cannot be done if the graph changes. public void setInitialGraph(Graph initialGraph) { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); out.println("Initial graph variables: " + initialGraph.getNodes()); out.println("Data set variables: " + variables); if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) { throw new IllegalArgumentException("Variables aren't the same."); } this.initialGraph = initialGraph; }
//////////////////////////////////////////////// // collect in rTupleList all unshielded tuples //////////////////////////////////////////////// private List<Node[]> getRTuples() { List<Node[]> rTuples = new ArrayList<Node[]>(); List<Node> nodes = graph.getNodes(); for (Node j : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(j); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node i = adjacentNodes.get(combination[0]); Node k = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (!graph.isAdjacentTo(i, k)) { Node[] newTuple = {i, j, k}; rTuples.add(newTuple); } } } return (rTuples); }
public List<Triple> getUnshieldedCollidersFromGraph(Graph graph) { List<Triple> colliders = new ArrayList<>(); List<Node> nodes = graph.getNodes(); for (Node b : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (graph.isAdjacentTo(a, c)) { continue; } if (graph.isDefCollider(a, b, c)) { colliders.add(new Triple(a, b, c)); } } } return colliders; }
/** * Double checks a sepset map against a pattern to make sure that X is adjacent to Y in the * pattern iff {X, Y} is not in the domain of the sepset map. * * @param sepset a sepset map, over variables V. * @param pattern a pattern over variables W, V subset of W. * @return true if the sepset map is consistent with the pattern. */ public static boolean verifySepsetIntegrity(SepsetMap sepset, Graph pattern) { for (Node x : pattern.getNodes()) { for (Node y : pattern.getNodes()) { if (x == y) { continue; } if ((pattern.isAdjacentTo(y, x)) != (sepset.get(x, y) == null)) { System.out.println("Sepset not consistent with graph for {" + x + ", " + y + "}"); return false; } } } return true; }
public TimeLagGraphWrapper(GraphWrapper graphWrapper) { if (graphWrapper == null) { throw new NullPointerException("No graph wrapper."); } TimeLagGraph graph = new TimeLagGraph(); Graph _graph = graphWrapper.getGraph(); for (Node node : _graph.getNodes()) { Node _node = node.like(node.getName() + ":0"); _node.setNodeType(node.getNodeType()); graph.addNode(_node); } for (Edge edge : _graph.getEdges()) { if (!Edges.isDirectedEdge(edge)) { throw new IllegalArgumentException(); } Node from = edge.getNode1(); Node to = edge.getNode2(); Node _from = graph.getNode(from.getName(), 1); Node _to = graph.getNode(to.getName(), 0); graph.addDirectedEdge(_from, _to); } this.graph = graph; }
// ===========================SCORING METHODS===================// public double scoreDag(Graph graph) { Graph dag = new EdgeListGraphSingleConnections(graph); buildIndexing(graph); double score = 0.0; for (Node y : dag.getNodes()) { Set<Node> parents = new HashSet<Node>(dag.getParents(y)); int nextIndex = -1; for (int i = 0; i < getVariables().size(); i++) { nextIndex = hashIndices.get(variables.get(i)); } int parentIndices[] = new int[parents.size()]; Iterator<Node> pi = parents.iterator(); int count = 0; while (pi.hasNext()) { Node nextParent = pi.next(); parentIndices[count++] = hashIndices.get(nextParent); } if (this.isDiscrete()) { score += localDiscreteScore(nextIndex, parentIndices); } else { score += localSemScore(nextIndex, parentIndices); } } return score; }
private Graph changeLatentNames(Graph full, Clusters measurements, List<String> latentVarList) { Graph g2 = null; try { g2 = (Graph) new MarshalledObject(full).get(); } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } for (int i = 0; i < measurements.getNumClusters(); i++) { List<String> d = measurements.getCluster(i); String latentName = latentVarList.get(i); for (Node node : full.getNodes()) { if (!(node.getNodeType() == NodeType.LATENT)) { continue; } List<Node> _children = full.getChildren(node); _children.removeAll(ReidentifyVariables.getLatents(full)); List<String> childNames = getNames(_children); if (new HashSet<String>(childNames).equals(new HashSet<String>(d))) { g2.getNode(node.getName()).setName(latentName); } } } return g2; }
/** * Constructs a new on-the-fly BayesIM that will calculate conditional probabilities on the fly * from the given discrete data set, for the given Bayes PM. * * @param bayesPm the given Bayes PM, which specifies a directed acyclic graph for a Bayes net and * parametrization for the Bayes net, but not actual values for the parameters. * @param dataSet the discrete data set from which conditional probabilities should be estimated * on the fly. */ public OnTheFlyMarginalCalculator(BayesPm bayesPm, DataSet dataSet) throws IllegalArgumentException { if (bayesPm == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } // Make sure all of the variables in the PM are in the data set; // otherwise, estimation is impossible. BayesUtils.ensureVarsInData(bayesPm.getVariables(), dataSet); // DataUtils.ensureVariablesExist(bayesPm, dataSet); this.bayesPm = new BayesPm(bayesPm); // Get the nodes from the BayesPm. This fixes the order of the nodes // in the BayesIm, independently of any change to the BayesPm. // (This order must be maintained.) Graph graph = bayesPm.getDag(); this.nodes = graph.getNodes().toArray(new Node[0]); // Initialize. initialize(); // Create a subset of the data set with the variables of the IM, in // the order of the IM. List<Node> variables = getVariables(); this.dataSet = dataSet.subsetColumns(variables); // Create a tautologous proposition for evidence. this.evidence = new Evidence(Proposition.tautology(this)); }
/** * Step C of PC; orients colliders using specified sepset. That is, orients x *-* y *-* z as x *-> * y <-* z just in case y is in Sepset({x, z}). */ public Map<Triple, Double> findCollidersUsingSepsets( SepsetProducer sepsetProducer, Graph graph, boolean verbose, IKnowledge knowledge) { TetradLogger.getInstance().log("details", "Starting Collider Orientation:"); Map<Triple, Double> colliders = new HashMap<>(); System.out.println("Looking for colliders"); List<Node> nodes = graph.getNodes(); for (Node b : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(b); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node a = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (graph.isAdjacentTo(a, c)) { continue; } List<Node> sepset = sepsetProducer.getSepset(a, c); if (sepset == null) continue; // if (sepsetProducer.getPValue() < 0.5) continue; if (!sepset.contains(b)) { if (verbose) { // boolean dsep = this.dsep.isIndependent(a, c); // System.out.println("QQQ p = " + independenceTest.getPValue() + // " " + dsep); System.out.println( "\nCollider orientation <" + a + ", " + b + ", " + c + "> sepset = " + sepset); } colliders.put(new Triple(a, b, c), sepsetProducer.getPValue()); TetradLogger.getInstance() .log("colliderOrientations", SearchLogUtils.colliderOrientedMsg(a, b, c, sepset)); } } } TetradLogger.getInstance().log("details", "Finishing Collider Orientation."); System.out.println("Done finding colliders"); return colliders; }
/** * Transforms a maximally directed pattern (PDAG) represented in graph <code>g</code> into an * arbitrary DAG by modifying <code>g</code> itself. Based on the algorithm described in * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine * Learning Research. R. Silva, June 2004 */ public static void pdagToDag(Graph g) { Graph p = new EdgeListGraph(g); List<Edge> undirectedEdges = new ArrayList<Edge>(); for (Edge edge : g.getEdges()) { if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL && !undirectedEdges.contains(edge)) { undirectedEdges.add(edge); } } g.removeEdges(undirectedEdges); List<Node> pNodes = p.getNodes(); do { Node x = null; for (Node pNode : pNodes) { x = pNode; if (p.getChildren(x).size() > 0) { continue; } Set<Node> neighbors = new HashSet<Node>(); for (Edge edge : p.getEdges()) { if (edge.getNode1() == x || edge.getNode2() == x) { if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL) { if (edge.getNode1() == x) { neighbors.add(edge.getNode2()); } else { neighbors.add(edge.getNode1()); } } } } if (neighbors.size() > 0) { Collection<Node> parents = p.getParents(x); Set<Node> all = new HashSet<Node>(neighbors); all.addAll(parents); if (!GraphUtils.isClique(all, p)) { continue; } } for (Node neighbor : neighbors) { Node node1 = g.getNode(neighbor.getName()); Node node2 = g.getNode(x.getName()); g.addDirectedEdge(node1, node2); } p.removeNode(x); break; } pNodes.remove(x); } while (pNodes.size() > 0); }
private List<String> measuredNames(Graph graph) { List<String> names = new ArrayList<>(); for (Node node : graph.getNodes()) { if (node.getNodeType() == NodeType.MEASURED) { names.add(node.getName()); } } return names; }
public static boolean meekR1Locally2( Graph graph, Knowledge knowledge, IndependenceTest test, int depth) { List<Node> nodes = graph.getNodes(); boolean changed = true; while (changed) { changed = false; for (Node a : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(a); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node b = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (graph.isAdjacentTo(b, c)) { continue; } if (graph.getEndpoint(b, a) == Endpoint.ARROW && graph.isUndirectedFromTo(a, c)) { if (existsLocalSepsetWithoutDet(b, a, c, test, graph, depth)) { continue; } if (isArrowpointAllowed(a, c, knowledge)) { graph.setEndpoint(a, c, Endpoint.ARROW); TetradLogger.getInstance() .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R1", graph.getEdge(a, c))); changed = true; } } else if (graph.getEndpoint(c, a) == Endpoint.ARROW && graph.isUndirectedFromTo(a, b)) { if (existsLocalSepsetWithoutDet(b, a, c, test, graph, depth)) { continue; } if (isArrowpointAllowed(a, b, knowledge)) { graph.setEndpoint(a, b, Endpoint.ARROW); TetradLogger.getInstance() .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R1", graph.getEdge(a, b))); changed = true; } } } } } return changed; }
/** * Performs step C of the algorithm, as indicated on page xxx of CPS, with the modification that * X--W--Y is oriented as X-->W<--Y if W is *determined by* the sepset of (X, Y), rather than W * just being *in* the sepset of (X, Y). */ public static void pcdOrientC( SepsetMap set, IndependenceTest test, Knowledge knowledge, Graph graph) { TetradLogger.getInstance().log("info", "Staring Collider Orientation:"); List<Node> nodes = graph.getNodes(); for (Node y : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(y); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node x = adjacentNodes.get(combination[0]); Node z = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (graph.isAdjacentTo(x, z)) { continue; } List<Node> sepset = set.get(x, z); if (sepset == null) { continue; } List<Node> augmentedSet = new LinkedList<Node>(sepset); augmentedSet.add(y); if (test.determines(sepset, y)) { continue; } // if (!test.splitDetermines(sepset, x, z) && test.splitDetermines(augmentedSet, x, z)) { continue; } if (!isArrowpointAllowed(x, y, knowledge) || !isArrowpointAllowed(z, y, knowledge)) { continue; } graph.setEndpoint(x, y, Endpoint.ARROW); graph.setEndpoint(z, y, Endpoint.ARROW); TetradLogger.getInstance() .log("colliderOriented", SearchLogUtils.colliderOrientedMsg(x, y, z)); } } TetradLogger.getInstance().log("info", "Finishing Collider Orientation."); }
/** * Provides methods for estimating a Bayes IM from an existing BayesIM and a discrete dataset * using EM (Expectation Maximization). The data columns in the given data must be equal to a * variable in the given Bayes IM but the latter may contain variables which don't occur in the * dataset (latent variables). The first argument of the constructoris the BayesPm whose graph * contains latent and observed variables. The second is the dataset of observed variables; * missing value codes may be present. */ public EmBayesEstimator(BayesPm bayesPm, DataSet dataSet) { if (bayesPm == null) { throw new NullPointerException(); } if (dataSet == null) { throw new NullPointerException(); } List<Node> observedVars = new ArrayList<Node>(); this.bayesPm = bayesPm; this.dataSet = dataSet; // this.variables = Collections.unmodifiableList(vars); graph = bayesPm.getDag(); this.nodes = new Node[graph.getNumNodes()]; Iterator<Node> it = graph.getNodes().iterator(); for (int i = 0; i < this.nodes.length; i++) { this.nodes[i] = it.next(); // System.out.println("node " + i + " " + nodes[i]); } for (int i = 0; i < this.nodes.length; i++) { if (nodes[i].getNodeType() == NodeType.MEASURED) { observedVars.add(bayesPm.getVariable(nodes[i])); } } // Make sure variables in dataset are measured variables in the BayesPM // for (Node dataSetVariable : this.dataSet.getVariables()) { // if (!observedVars.contains(dataSetVariable)) { // throw new IllegalArgumentException( // "Some var in the dataset is not a " + // "measured variable in the Bayes net"); // } // } // Make sure all measured variables in the BayesPm are in the discrete dataset for (Node observedVar : observedVars) { try { this.dataSet.getVariable(observedVar.getName()); } catch (Exception e) { throw new IllegalArgumentException( "Some observed var in the Bayes net " + "is not in the dataset: " + observedVar); } } findBayesNetObserved(); // Sets bayesPmObs initialize(); }
private Graph structure(Graph mim) { List<Node> latents = new ArrayList<Node>(); for (Node node : mim.getNodes()) { if (node.getNodeType() == NodeType.LATENT) { latents.add(node); } } return mim.subgraph(latents); }
private Graph restrictToEmpiricalLatents(Graph mimStructure, Graph mimbuildStructure) { Graph _mim = new EdgeListGraph(mimStructure); for (Node node : mimbuildStructure.getNodes()) { if (!mimbuildStructure.containsNode(node)) { _mim.removeNode(node); } } return _mim; }
/** Meek's rule R3. If a--b, a--c, a--d, c-->b, c-->b, then orient a-->b. */ public static boolean meekR3(Graph graph, Knowledge knowledge) { List<Node> nodes = graph.getNodes(); boolean changed = false; for (Node a : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(a); if (adjacentNodes.size() < 3) { continue; } for (Node b : adjacentNodes) { List<Node> otherAdjacents = new LinkedList<Node>(adjacentNodes); otherAdjacents.remove(b); if (!graph.isUndirectedFromTo(a, b)) { continue; } ChoiceGenerator cg = new ChoiceGenerator(otherAdjacents.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node c = otherAdjacents.get(combination[0]); Node d = otherAdjacents.get(combination[1]); if (graph.isAdjacentTo(c, d)) { continue; } if (!graph.isUndirectedFromTo(a, c)) { continue; } if (!graph.isUndirectedFromTo(a, d)) { continue; } if (graph.isDirectedFromTo(c, b) && graph.isDirectedFromTo(d, b)) { if (isArrowpointAllowed(a, b, knowledge)) { graph.setEndpoint(a, b, Endpoint.ARROW); TetradLogger.getInstance() .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R3", graph.getEdge(a, b))); changed = true; break; } } } } } return changed; }
public List<Node> findMb(String targetName) { Node target = getVariableForName(targetName); Pc search = new Pc(test); search.setDepth(depth); Graph graph = search.search(); MbUtils.trimToMbNodes(graph, target, false); List<Node> mbVariables = graph.getNodes(); mbVariables.remove(target); return mbVariables; }
public static void arrangeByKnowledgeTiers(Graph graph, Knowledge knowledge) { if (knowledge.getNumTiers() == 0) { throw new IllegalArgumentException("There are no Tiers to arrange."); } List<Node> nodes = graph.getNodes(); List<String> varNames = new ArrayList<String>(); int ySpace = 500 / knowledge.getNumTiers(); ySpace = ySpace < 50 ? 50 : ySpace; for (Node node1 : nodes) { varNames.add(node1.getName()); } List<String> notInTier = knowledge.getVarsNotInTier(varNames); int x = 0; int y = 50 - ySpace; if (notInTier.size() > 0) { y += ySpace; for (String name : notInTier) { x += 90; Node node = graph.getNode(name); if (node != null) { node.setCenterX(x); node.setCenterY(y); } } } for (int i = 0; i < knowledge.getNumTiers(); i++) { List<String> tier = knowledge.getTier(i); y += ySpace; x = -25; for (String name : tier) { x += 90; Node node = graph.getNode(name); if (node != null) { node.setCenterX(x); node.setCenterY(y); } } } }
private static double characteristicPathLength(Graph g) { List<Node> nodes = g.getNodes(); int total = 0; int count = 0; for (int i = 0; i < nodes.size(); i++) { for (int j = i; j < nodes.size(); j++) { int shortest = shortestPath(nodes.get(i), nodes.get(j), g); total += shortest; count++; } } return total / (double) count; }
public void testAlternativeGraphs() { // UniformGraphGenerator gen = new UniformGraphGenerator(UniformGraphGenerator.ANY_DAG); // gen.setNumNodes(100); // gen.setMaxEdges(200); // gen.setMaxDegree(30); // gen.setMaxInDegree(30); // gen.setMaxOutDegree(30); //// gen.setNumIterations(3000000); // gen.setResamplingDegree(10); // // gen.generate(); // // Graph graph = gen.getDag(); Graph graph = weightedRandomGraph(250, 400); List<Integer> degreeCounts = new ArrayList<Integer>(); Map<Integer, Integer> degreeCount = new HashMap<Integer, Integer>(); for (Node node : graph.getNodes()) { int degree = graph.getNumEdges(node); degreeCounts.add(degree); if (degreeCount.get(degree) == null) { degreeCount.put(degree, 0); } degreeCount.put(degree, degreeCount.get(degree) + 1); } Collections.sort(degreeCounts); System.out.println(degreeCounts); List<Integer> _degrees = new ArrayList<Integer>(degreeCount.keySet()); Collections.sort(_degrees); for (int i : _degrees) { int j = degreeCount.get(i); // System.out.println(i + " " + j); System.out.println(log(i + 1) + " " + log(j)); } System.out.println("\nCPL = " + characteristicPathLength(graph)); Graph erGraph = erdosRenyiGraph(200, 200); System.out.println("\n ER CPL = " + characteristicPathLength(erGraph)); }
/** * Step C of PC; orients colliders using specified sepset. That is, orients x *-* y *-* z as x *-> * y <-* z just in case y is in Sepset({x, z}). */ public static void orientCollidersUsingSepsets(SepsetMap set, Knowledge knowledge, Graph graph) { TetradLogger.getInstance().log("info", "Starting Collider Orientation:"); // verifySepsetIntegrity(set, graph); List<Node> nodes = graph.getNodes(); for (Node a : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(a); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node b = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (graph.isAdjacentTo(b, c)) { continue; } List<Node> sepset = set.get(b, c); if (sepset != null && !sepset.contains(a) && isArrowpointAllowed(b, a, knowledge) && isArrowpointAllowed(c, a, knowledge)) { graph.setEndpoint(b, a, Endpoint.ARROW); graph.setEndpoint(c, a, Endpoint.ARROW); TetradLogger.getInstance() .log("colliderOriented", SearchLogUtils.colliderOrientedMsg(b, a, c, sepset)); } } } TetradLogger.getInstance().log("info", "Finishing Collider Orientation."); }
/** If */ public static boolean meekR2(Graph graph, Knowledge knowledge) { List<Node> nodes = graph.getNodes(); boolean changed = false; for (Node a : nodes) { List<Node> adjacentNodes = graph.getAdjacentNodes(a); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node b = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); if (graph.isDirectedFromTo(b, a) && graph.isDirectedFromTo(a, c) && graph.isUndirectedFromTo(b, c)) { if (isArrowpointAllowed(b, c, knowledge)) { graph.setEndpoint(b, c, Endpoint.ARROW); TetradLogger.getInstance() .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R2", graph.getEdge(b, c))); } } else if (graph.isDirectedFromTo(c, a) && graph.isDirectedFromTo(a, b) && graph.isUndirectedFromTo(c, b)) { if (isArrowpointAllowed(c, b, knowledge)) { graph.setEndpoint(c, b, Endpoint.ARROW); TetradLogger.getInstance() .edgeOriented(SearchLogUtils.edgeOrientedMsg("Meek R2", graph.getEdge(c, b))); } } } } return changed; }
public static void orientCollidersLocally( Knowledge knowledge, Graph graph, IndependenceTest test, int depth, Set<Node> nodesToVisit) { TetradLogger.getInstance().log("info", "Starting Collider Orientation:"); if (nodesToVisit == null) { nodesToVisit = new HashSet<Node>(graph.getNodes()); } for (Node a : nodesToVisit) { List<Node> adjacentNodes = graph.getAdjacentNodes(a); if (adjacentNodes.size() < 2) { continue; } ChoiceGenerator cg = new ChoiceGenerator(adjacentNodes.size(), 2); int[] combination; while ((combination = cg.next()) != null) { Node b = adjacentNodes.get(combination[0]); Node c = adjacentNodes.get(combination[1]); // Skip triples that are shielded. if (graph.isAdjacentTo(b, c)) { continue; } if (isArrowpointAllowed1(b, a, knowledge) && isArrowpointAllowed1(c, a, knowledge)) { if (!existsLocalSepsetWith(b, a, c, test, graph, depth)) { graph.setEndpoint(b, a, Endpoint.ARROW); graph.setEndpoint(c, a, Endpoint.ARROW); TetradLogger.getInstance() .log("colliderOriented", SearchLogUtils.colliderOrientedMsg(b, a, c)); } } } } TetradLogger.getInstance().log("info", "Finishing Collider Orientation."); }
private void correlateExogenousVariables() { Graph graph = getWorkbench().getGraph(); if (graph instanceof Dag) { JOptionPane.showMessageDialog( JOptionUtils.centeringComp(), "Cannot add bidirected edges to DAG's."); return; } List<Node> nodes = graph.getNodes(); List<Node> exoNodes = new LinkedList<Node>(); for (int i = 0; i < nodes.size(); i++) { Node node = nodes.get(i); if (graph.isExogenous(node)) { exoNodes.add(node); } } for (int i = 0; i < exoNodes.size(); i++) { loop: for (int j = i + 1; j < exoNodes.size(); j++) { Node node1 = exoNodes.get(i); Node node2 = exoNodes.get(j); List<Edge> edges = graph.getEdges(node1, node2); for (int k = 0; k < edges.size(); k++) { Edge edge = edges.get(k); if (Edges.isBidirectedEdge(edge)) { continue loop; } } graph.addBidirectedEdge(node1, node2); } } }
/** * Greedy equivalence search: Start from the empty graph, add edges till model is significant. * Then start deleting edges till a minimum is achieved. * * @return the resulting Pattern. */ public Graph search() { Graph graph; if (initialGraph == null) { graph = new EdgeListGraphSingleConnections(getVariables()); } else { graph = new EdgeListGraphSingleConnections(initialGraph); } fireGraphChange(graph); buildIndexing(graph); addRequiredEdges(graph); topGraphs.clear(); storeGraph(graph); List<Node> nodes = graph.getNodes(); long start = System.currentTimeMillis(); score = 0.0; // Do forward search. fes(graph, nodes); // Do backward search. bes(graph); long endTime = System.currentTimeMillis(); this.elapsedTime = endTime - start; this.logger.log("graph", "\nReturning this graph: " + graph); this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.flush(); return graph; }
public List<Node> getNodes() { return graph.getNodes(); }
/** * Transforms a DAG represented in graph <code>graph</code> into a maximally directed pattern * (PDAG) by modifying <code>g</code> itself. Based on the algorithm described in Chickering * (2002) "Optimal structure identification with greedy search" Journal of Machine Learning * Research. It works for both BayesNets and SEMs. R. Silva, June 2004 */ public static void dagToPdag(Graph graph) { // do topological sort on the nodes Graph graphCopy = new EdgeListGraph(graph); Node orderedNodes[] = new Node[graphCopy.getNodes().size()]; int count = 0; while (graphCopy.getNodes().size() > 0) { Set<Node> exogenousNodes = new HashSet<Node>(); for (Node next : graphCopy.getNodes()) { if (graphCopy.isExogenous(next)) { exogenousNodes.add(next); orderedNodes[count++] = graph.getNode(next.getName()); } } graphCopy.removeNodes(new ArrayList<Node>(exogenousNodes)); } // ordered edges - improvised, inefficient implementation count = 0; Edge edges[] = new Edge[graph.getNumEdges()]; boolean edgeOrdered[] = new boolean[graph.getNumEdges()]; Edge orderedEdges[] = new Edge[graph.getNumEdges()]; for (Edge edge : graph.getEdges()) { edges[count++] = edge; } for (int i = 0; i < edges.length; i++) { edgeOrdered[i] = false; } while (count > 0) { for (Node orderedNode : orderedNodes) { for (int k = orderedNodes.length - 1; k >= 0; k--) { for (int q = 0; q < edges.length; q++) { if (!edgeOrdered[q] && edges[q].getNode1() == orderedNodes[k] && edges[q].getNode2() == orderedNode) { edgeOrdered[q] = true; orderedEdges[orderedEdges.length - count] = edges[q]; count--; } } } } } // label edges boolean compelledEdges[] = new boolean[graph.getNumEdges()]; boolean reversibleEdges[] = new boolean[graph.getNumEdges()]; for (int i = 0; i < graph.getNumEdges(); i++) { compelledEdges[i] = false; reversibleEdges[i] = false; } for (int i = 0; i < graph.getNumEdges(); i++) { if (compelledEdges[i] || reversibleEdges[i]) { continue; } Node x = orderedEdges[i].getNode1(); Node y = orderedEdges[i].getNode2(); for (int j = 0; j < orderedEdges.length; j++) { if (orderedEdges[j].getNode2() == x && compelledEdges[j]) { Node w = orderedEdges[j].getNode1(); if (!graph.isParentOf(w, y)) { for (int k = 0; k < orderedEdges.length; k++) { if (orderedEdges[k].getNode2() == y) { compelledEdges[k] = true; break; } } } else { for (int k = 0; k < orderedEdges.length; k++) { if (orderedEdges[k].getNode1() == w && orderedEdges[k].getNode2() == y) { compelledEdges[k] = true; break; } } } } if (compelledEdges[i]) { break; } } if (compelledEdges[i]) { continue; } boolean foundZ = false; for (Edge orderedEdge : orderedEdges) { Node z = orderedEdge.getNode1(); if (z != x && orderedEdge.getNode2() == y && !graph.isParentOf(z, x)) { compelledEdges[i] = true; for (int k = i + 1; k < graph.getNumEdges(); k++) { if (orderedEdges[k].getNode2() == y && !reversibleEdges[k]) { compelledEdges[k] = true; } } foundZ = true; break; } } if (!foundZ) { reversibleEdges[i] = true; for (int j = i + 1; j < orderedEdges.length; j++) { if (!compelledEdges[j] && orderedEdges[j].getNode2() == y) { reversibleEdges[j] = true; } } } } // undirect edges that are reversible for (int i = 0; i < reversibleEdges.length; i++) { if (reversibleEdges[i]) { graph.setEndpoint(orderedEdges[i].getNode1(), orderedEdges[i].getNode2(), Endpoint.TAIL); graph.setEndpoint(orderedEdges[i].getNode2(), orderedEdges[i].getNode1(), Endpoint.TAIL); } } }