public Set<Edge> getNonadjacencies() { Graph complete = GraphUtils.completeGraph(graph); Set<Edge> nonAdjacencies = complete.getEdges(); Graph undirected = GraphUtils.undirectedGraph(graph); nonAdjacencies.removeAll(undirected.getEdges()); return new HashSet<Edge>(nonAdjacencies); }
// TODO Fix this. private List<ScoredGraph> arrangeGraphs() { IGesRunner runner = (IGesRunner) getAlgorithmRunner(); Graph resultGraph = runner.getResultGraph(); List<ScoredGraph> topGraphs = runner.getTopGraphs(); if (topGraphs == null) topGraphs = new ArrayList<ScoredGraph>(); Graph latestWorkbenchGraph = runner.getParams().getSourceGraph(); Graph sourceGraph = runner.getSourceGraph(); boolean arrangedAll = false; for (ScoredGraph topGraph1 : topGraphs) { arrangedAll = GraphUtils.arrangeBySourceGraph(topGraph1.getGraph(), latestWorkbenchGraph); } if (!arrangedAll) { arrangedAll = GraphUtils.arrangeBySourceGraph(resultGraph, sourceGraph); } if (!arrangedAll) { for (ScoredGraph topGraph : topGraphs) { GraphUtils.circleLayout(topGraph.getGraph(), 200, 200, 150); GraphUtils.circleLayout(resultGraph, 200, 200, 150); } } return topGraphs; }
/** * Return the longest suffix of bases shared among all provided vertices * * <p>For example, if the vertices have sequences AC, CC, and ATC, this would return a single C. * However, for ACC and TCC this would return CC. And for AC and TG this would return null; * * @param middleVertices a non-empty set of vertices * @return a single vertex that contains the common suffix of all middle vertices */ @Requires("!middleVertices.isEmpty()") protected static SeqVertex commonSuffix(final Collection<SeqVertex> middleVertices) { final List<byte[]> kmers = GraphUtils.getKmers(middleVertices); final int min = GraphUtils.minKmerLength(kmers); final int suffixLen = GraphUtils.compSuffixLen(kmers, min); final byte[] kmer = kmers.get(0); final byte[] suffix = Arrays.copyOfRange(kmer, kmer.length - suffixLen, kmer.length); return new SeqVertex(suffix); }
protected void doDefaultArrangement(Graph resultGraph) { if (getLatestWorkbenchGraph() != null) { // (alreadyLaidOut) { GraphUtils.arrangeBySourceGraph(resultGraph, getLatestWorkbenchGraph()); } else if (getKnowledge().isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(resultGraph, getKnowledge()); // alreadyLaidOut = true; } else { GraphUtils.circleLayout(resultGraph, 200, 200, 150); // alreadyLaidOut = true; } }
@Test public void test7() { RandomUtil.getInstance().setSeed(29999483L); List<Node> nodes = new ArrayList<>(); int numVars = 10; for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1))); Graph graph = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true); GeneralizedSemPm pm = new GeneralizedSemPm(graph); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.67, aSquaredStar, 0.01); }
private void calculateArrowsForward(Node x, Node y, Graph graph) { clearArrow(x, y); if (!knowledgeEmpty()) { if (getKnowledge().isForbidden(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); List<Node> t = getTNeighbors(x, y, graph); DepthChoiceGenerator gen = new DepthChoiceGenerator(t.size(), t.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> s = GraphUtils.asList(choice, t); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, s)) { continue; } } double bump = insertEval(x, y, s, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, s, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } }
@Test(dataProvider = "PrefixSuffixData") public void testPrefixSuffix( final List<String> strings, int expectedPrefixLen, int expectedSuffixLen) { final List<byte[]> bytes = new ArrayList<>(); int min = Integer.MAX_VALUE; for (final String s : strings) { bytes.add(s.getBytes()); min = Math.min(min, s.length()); } final int actualPrefixLen = GraphUtils.compPrefixLen(bytes, min); Assert.assertEquals(actualPrefixLen, expectedPrefixLen, "Failed prefix test"); final int actualSuffixLen = GraphUtils.compSuffixLen(bytes, min - actualPrefixLen); Assert.assertEquals(actualSuffixLen, expectedSuffixLen, "Failed suffix test"); }
public Graph orient() { Graph skeleton = GraphUtils.undirectedGraph(getPattern()); Graph graph = new EdgeListGraph(skeleton.getNodes()); List<Node> nodes = skeleton.getNodes(); // Collections.shuffle(nodes); if (isR1Done()) { ruleR1(skeleton, graph, nodes); } for (Edge edge : skeleton.getEdges()) { if (!graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) { graph.addUndirectedEdge(edge.getNode1(), edge.getNode2()); } } if (isR2Done()) { ruleR2(skeleton, graph); } if (isMeekDone()) { new MeekRules().orientImplied(graph); } return graph; }
/** * Transforms a maximally directed pattern (PDAG) represented in graph <code>g</code> into an * arbitrary DAG by modifying <code>g</code> itself. Based on the algorithm described in * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine * Learning Research. R. Silva, June 2004 */ public static void pdagToDag(Graph g) { Graph p = new EdgeListGraph(g); List<Edge> undirectedEdges = new ArrayList<Edge>(); for (Edge edge : g.getEdges()) { if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL && !undirectedEdges.contains(edge)) { undirectedEdges.add(edge); } } g.removeEdges(undirectedEdges); List<Node> pNodes = p.getNodes(); do { Node x = null; for (Node pNode : pNodes) { x = pNode; if (p.getChildren(x).size() > 0) { continue; } Set<Node> neighbors = new HashSet<Node>(); for (Edge edge : p.getEdges()) { if (edge.getNode1() == x || edge.getNode2() == x) { if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL) { if (edge.getNode1() == x) { neighbors.add(edge.getNode2()); } else { neighbors.add(edge.getNode1()); } } } } if (neighbors.size() > 0) { Collection<Node> parents = p.getParents(x); Set<Node> all = new HashSet<Node>(neighbors); all.addAll(parents); if (!GraphUtils.isClique(all, p)) { continue; } } for (Node neighbor : neighbors) { Node node1 = g.getNode(neighbor.getName()); Node node2 = g.getNode(x.getName()); g.addDirectedEdge(node1, node2); } p.removeNode(x); break; } pNodes.remove(x); } while (pNodes.size() > 0); }
// Cannot be done if the graph changes. public void setInitialGraph(Graph initialGraph) { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); out.println("Initial graph variables: " + initialGraph.getNodes()); out.println("Data set variables: " + variables); if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) { throw new IllegalArgumentException("Variables aren't the same."); } this.initialGraph = initialGraph; }
public static Graph bestGuessCycleOrientation(Graph graph, IndependenceTest test) { while (true) { List<Node> cycle = GraphUtils.directedCycle(graph); if (cycle == null) { break; } LinkedList<Node> _cycle = new LinkedList<Node>(cycle); Node first = _cycle.getFirst(); Node last = _cycle.getLast(); _cycle.addFirst(last); _cycle.addLast(first); int _j = -1; double minP = Double.POSITIVE_INFINITY; for (int j = 1; j < _cycle.size() - 1; j++) { int i = j - 1; int k = j + 1; Node x = test.getVariable(_cycle.get(i).getName()); Node y = test.getVariable(_cycle.get(j).getName()); Node z = test.getVariable(_cycle.get(k).getName()); test.isIndependent(x, z, Collections.singletonList(y)); System.out.println("Testing " + x + " _||_ " + z + " | " + y); double p = test.getPValue(); System.out.println("p = " + p); if (p < minP) { _j = j; minP = p; } } Node x = _cycle.get(_j - 1); Node y = _cycle.get(_j); Node z = _cycle.get(_j + 1); graph.removeEdge(x, y); graph.removeEdge(z, y); graph.addDirectedEdge(x, y); graph.addDirectedEdge(z, y); } return graph; }
private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) { for (Node node : nodes) { SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> adj = skeleton.getAdjacentNodes(node); DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size()); int[] choice; double maxScore = Double.NEGATIVE_INFINITY; List<Node> parents = null; while ((choice = gen.next()) != null) { List<Node> _parents = GraphUtils.asList(choice, adj); double score = score(node, _parents); scoreReports.put(-score, _parents.toString()); if (score > maxScore) { maxScore = score; parents = _parents; } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance() .log( "score", "For " + node + " parents = " + scoreReports.get(score) + " score = " + -score); } TetradLogger.getInstance().log("score", ""); if (parents == null) { continue; } if (normal(node, parents)) continue; for (Node _node : adj) { if (parents.contains(_node)) { Edge parentEdge = Edges.directedEdge(_node, node); if (!graph.containsEdge(parentEdge)) { graph.addEdge(parentEdge); } } } } }
///////////////////////////////////////////////////////////////////////////// // set the sepSet of x and y to the minimal such subset of the given sepSet // and remove the edge <x, y> if background knowledge allows ///////////////////////////////////////////////////////////////////////////// private void setMinSepSet(List<Node> sepSet, Node x, Node y) { // It is assumed that BK has been considered before calling this method // (for example, setting independent1 and independent2 in ruleR0_RFCI) /* // background knowledge requires this edge if (knowledge.noEdgeRequired(x.getNode(), y.getNode())) { return; } */ List<Node> empty = Collections.emptyList(); boolean indep; try { indep = independenceTest.isIndependent(x, y, empty); } catch (Exception e) { indep = false; } if (indep) { getSepsets().set(x, y, empty); return; } int sepSetSize = sepSet.size(); for (int i = 1; i <= sepSetSize; i++) { ChoiceGenerator cg = new ChoiceGenerator(sepSetSize, i); int[] combination; while ((combination = cg.next()) != null) { List<Node> condSet = GraphUtils.asList(combination, sepSet); try { indep = independenceTest.isIndependent(x, y, condSet); } catch (Exception e) { indep = false; } if (indep) { getSepsets().set(x, y, condSet); return; } } } }
public static List<Dag> getAllDagsInUndirectedGraph(Graph graph) { Graph undirected = GraphUtils.undirectedGraph(graph); DagIterator iterator = new DagIterator(undirected); List<Dag> dags = new ArrayList<Dag>(); while (iterator.hasNext()) { Graph _graph = iterator.next(); try { Dag dag = new Dag(_graph); dags.add(dag); } catch (IllegalArgumentException e) { // } } return dags; }
/** Tests to see if d separation facts are symmetric. */ public void testDSeparation2() { EdgeListGraphSingleConnections graph = new EdgeListGraphSingleConnections( new Dag(GraphUtils.randomGraph(7, 0, 14, 30, 15, 15, true))); List<Node> nodes = graph.getNodes(); int depth = -1; for (int i = 0; i < nodes.size(); i++) { for (int j = i; j < nodes.size(); j++) { Node x = nodes.get(i); Node y = nodes.get(j); List<Node> theRest = new ArrayList<Node>(nodes); // theRest.remove(x); // theRest.remove(y); DepthChoiceGenerator gen = new DepthChoiceGenerator(theRest.size(), depth); int[] choice; while ((choice = gen.next()) != null) { List<Node> z = new LinkedList<Node>(); for (int k = 0; k < choice.length; k++) { z.add(theRest.get(choice[k])); } boolean dConnectedTo = graph.isDConnectedTo(x, y, z); boolean dConnectedTo1 = graph.isDConnectedTo(y, x, z); if (dConnectedTo != dConnectedTo1) { System.out.println(x + " d connected to " + y + " given " + z); System.out.println(graph); System.out.println("dconnectedto = " + dConnectedTo); System.out.println("dconnecteto1 = " + dConnectedTo1); fail(); } } } } }
// Invalid if then nodes or graph changes. private void calculateArrowsBackward(Node x, Node y, Graph graph) { if (x == y) { return; } if (!graph.isAdjacentTo(x, y)) { return; } if (!knowledgeEmpty()) { if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) { return; } } List<Node> naYX = getNaYX(x, y, graph); clearArrow(x, y); List<Node> _naYX = new ArrayList<Node>(naYX); DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _naYX.size()); int[] choice; while ((choice = gen.next()) != null) { List<Node> H = GraphUtils.asList(choice, _naYX); if (!knowledgeEmpty()) { if (!validSetByKnowledge(y, H)) { continue; } } double bump = deleteEval(x, y, H, naYX, graph); if (bump > 0.0) { Arrow arrow = new Arrow(bump, x, y, H, naYX); sortedArrows.add(arrow); addLookupArrow(x, y, arrow); } } }
/** Tests to see if d separation facts are symmetric. */ public void testDSeparation() { EdgeListGraphSingleConnections graph = new EdgeListGraphSingleConnections( new Dag(GraphUtils.randomGraph(7, 0, 7, 30, 15, 15, true))); System.out.println(graph); List<Node> nodes = graph.getNodes(); int depth = -1; for (int i = 0; i < nodes.size(); i++) { for (int j = i + 1; j < nodes.size(); j++) { Node x = nodes.get(i); Node y = nodes.get(j); List<Node> theRest = new ArrayList<Node>(nodes); theRest.remove(x); theRest.remove(y); DepthChoiceGenerator gen = new DepthChoiceGenerator(theRest.size(), depth); int[] choice; while ((choice = gen.next()) != null) { List<Node> z = new LinkedList<Node>(); for (int k = 0; k < choice.length; k++) { z.add(theRest.get(choice[k])); } if (graph.isDSeparatedFrom(x, y, z) != graph.isDSeparatedFrom(y, x, z)) { fail( SearchLogUtils.independenceFact(x, y, z) + " should have same d-sep result as " + SearchLogUtils.independenceFact(y, x, z)); } } } } }
public Graph search(List<Node> nodes) { long startTime = System.currentTimeMillis(); localScoreCache.clear(); if (!dataSet().getVariables().containsAll(nodes)) { throw new IllegalArgumentException("All of the nodes must be in " + "the supplied data set."); } Graph graph; if (initialGraph == null) { graph = new EdgeListGraphSingleConnections(nodes); } else { initialGraph = GraphUtils.replaceNodes(initialGraph, variables); graph = new EdgeListGraphSingleConnections(initialGraph); } topGraphs.clear(); buildIndexing(graph); addRequiredEdges(graph); score = 0.0; // Do forward search. fes(graph, nodes); // Do backward search. bes(graph); long endTime = System.currentTimeMillis(); this.elapsedTime = endTime - startTime; this.logger.log("graph", "\nReturning this graph: " + graph); this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.flush(); return graph; }
public void rtestDSeparation4() { Graph graph = new Dag(GraphUtils.randomGraph(100, 20, 100, 5, 5, 5, false)); long start, stop; int depth = -1; IndependenceTest test = new IndTestDSep(graph); Rfci fci = new Rfci(test); Fas fas = new Fas(test); start = System.currentTimeMillis(); fci.setDepth(depth); fci.setVerbose(true); fci.search(fas, fas.getNodes()); stop = System.currentTimeMillis(); System.out.println("DSEP RFCI"); System.out.println("# dsep checks = " + fas.getNumIndependenceTests()); System.out.println("Elapsed " + (stop - start)); System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start)); SemPm pm = new SemPm(graph); SemIm im = new SemIm(pm); DataSet data = im.simulateData(1000, false); IndependenceTest test2 = new IndTestFisherZ(data, 0.001); Rfci fci3 = new Rfci(test2); Fas fas2 = new Fas(test2); start = System.currentTimeMillis(); fci3.setDepth(depth); fci3.search(fas2, fas2.getNodes()); stop = System.currentTimeMillis(); System.out.println("FISHER Z RFCI"); System.out.println("# indep checks = " + fas.getNumIndependenceTests()); System.out.println("Elapsed " + (stop - start)); System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start)); }
@Test public void test6() { RandomUtil.getInstance().setSeed(29999483L); int numVars = 5; List<Node> nodes = new ArrayList<>(); for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1))); Graph graph = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true); SemPm spm = new SemPm(graph); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); params.setVarRange(1, 3); SemIm sim = new SemIm(spm, params); GeneralizedSemPm pm = new GeneralizedSemPm(spm); GeneralizedSemIm im = new GeneralizedSemIm(pm, sim); DataSet data = im.simulateData(1000, false); print(im); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.59, aSquaredStar, 0.01); }
@Test public void test5() { RandomUtil.getInstance().setSeed(29999483L); List<Node> nodes = new ArrayList<>(); for (int i1 = 0; i1 < 5; i1++) { nodes.add(new ContinuousVariable("X" + (i1 + 1))); } Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, false)); SemPm semPm = new SemPm(graph); SemIm semIm = new SemIm(semPm); semIm.simulateDataReducedForm(1000, false); GeneralizedSemPm pm = new GeneralizedSemPm(semPm); GeneralizedSemIm im = new GeneralizedSemIm(pm, semIm); TetradVector e = new TetradVector(5); for (int i = 0; i < e.size(); i++) { e.set(i, RandomUtil.getInstance().nextNormal(0, 1)); } TetradVector record1 = semIm.simulateOneRecord(e); TetradVector record2 = im.simulateOneRecord(e); print("XXX1" + e); print("XXX2" + record1); print("XXX3" + record2); for (int i = 0; i < record1.size(); i++) { assertEquals(record1.get(i), record2.get(i), 1e-10); } }
public void rtest4() { System.out.println("SHD\tP"); // System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6"); Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0); Graph mimStructure = structure(mim); SemPm pm = new SemPm(mim); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); NumberFormat nf = new DecimalFormat("0.0000"); int totalError = 0; int errorCount = 0; int maxScore = 0; int maxNumMeasures = 0; double maxP = 0.0; for (int r = 0; r < 1; r++) { SemIm im = new SemIm(pm, params); DataSet data = im.simulateData(1000, false); mim = GraphUtils.replaceNodes(mim, data.getVariables()); List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim); CovarianceMatrix _cov = new CovarianceMatrix(data); ICovarianceMatrix cov = DataUtils.reorderColumns(_cov); String algorithm = "FOFC"; Graph searchGraph; List<List<Node>> partition; if (algorithm.equals("FOFC")) { FindOneFactorClusters fofc = new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001f); searchGraph = fofc.search(); searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariables()); partition = MimUtils.convertToClusters2(searchGraph); } else if (algorithm.equals("BPC")) { TestType testType = TestType.TETRAD_WISHART; TestType purifyType = TestType.TETRAD_BASED2; BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType); searchGraph = bpc.search(); partition = MimUtils.convertToClusters2(searchGraph); } else { throw new IllegalStateException(); } mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariables()); List<String> latentVarList = reidentifyVariables(mim, data, partition, 2); Graph mimbuildStructure; Mimbuild2 mimbuild = new Mimbuild2(); mimbuild.setAlpha(0.001); mimbuild.setMinClusterSize(3); try { mimbuildStructure = mimbuild.search(partition, latentVarList, cov); } catch (Exception e) { e.printStackTrace(); continue; } mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariables()); mimbuildStructure = condense(mimStructure, mimbuildStructure); // Graph mimSubgraph = new EdgeListGraph(mimStructure); // // for (Node node : mimSubgraph.getNodes()) { // if (!mimStructure.getNodes().contains(node)) { // mimSubgraph.removeNode(node); // } // } int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure); boolean impureCluster = containsImpureCluster(partition, trueClusters); double pValue = mimbuild.getpValue(); boolean pBelow05 = pValue < 0.05; boolean numClustersGreaterThan5 = partition.size() != 5; boolean error = false; // boolean condition = impureCluster || numClustersGreaterThan5 || pBelow05; // boolean condition = numClustersGreaterThan5 || pBelow05; boolean condition = numClustered(partition) == 40; if (!condition && (shd > 5)) { error = true; } if (!condition) { totalError += shd; errorCount++; } // if (numClustered(partition) > maxNumMeasures) { // maxNumMeasures = numClustered(partition); // maxP = pValue; // maxScore = shd; // System.out.println("maxNumMeasures = " + maxNumMeasures); // System.out.println("maxScore = " + maxScore); // System.out.println("maxP = " + maxP); // System.out.println("clusters = " + clusterSizes(partition, trueClusters)); // } // else if (pValue > maxP) { maxScore = shd; maxP = mimbuild.getpValue(); maxNumMeasures = numClustered(partition); System.out.println("maxNumMeasures = " + maxNumMeasures); System.out.println("maxScore = " + maxScore); System.out.println("maxP = " + maxP); System.out.println("clusters = " + clusterSizes(partition, trueClusters)); } System.out.print( shd + "\t" + nf.format(pValue) + " " // + (error ? 1 : 0) + " " // + (pBelow05 ? "P < 0.05 " : "") // + (impureCluster ? "Impure cluster " : "") // + (numClustersGreaterThan5 ? "# Clusters != 5 " : "") // + clusterSizes(partition, trueClusters) + numClustered(partition)); System.out.println(); } System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount)); System.out.println("maxNumMeasures = " + maxNumMeasures); System.out.println("maxScore = " + maxScore); System.out.println("maxP = " + maxP); }
/** * Calculates the error variance for the given error node, given all of the coefficient values in * the model. * * @param error An error term in the model--i.e. a variable with NodeType.ERROR. * @return The value of the error variance, or Double.NaN is the value is undefined. */ private double calculateErrorVarianceFromParams(Node error) { error = semGraph.getNode(error.getName()); Node child = semGraph.getChildren(error).get(0); List<Node> parents = semGraph.getParents(child); double otherVariance = 0; for (Node parent : parents) { if (parent == error) continue; double coef = getEdgeCoefficient(parent, child); otherVariance += coef * coef; } if (parents.size() >= 2) { ChoiceGenerator gen = new ChoiceGenerator(parents.size(), 2); int[] indices; while ((indices = gen.next()) != null) { Node node1 = parents.get(indices[0]); Node node2 = parents.get(indices[1]); double coef1, coef2; if (node1.getNodeType() != NodeType.ERROR) { coef1 = getEdgeCoefficient(node1, child); } else { coef1 = 1; } if (node2.getNodeType() != NodeType.ERROR) { coef2 = getEdgeCoefficient(node2, child); } else { coef2 = 1; } List<List<Node>> treks = GraphUtils.treksIncludingBidirected(semGraph, node1, node2); double cov = 0.0; for (List<Node> trek : treks) { double product = 1.0; for (int i = 1; i < trek.size(); i++) { Node _node1 = trek.get(i - 1); Node _node2 = trek.get(i); Edge edge = semGraph.getEdge(_node1, _node2); double factor; if (Edges.isBidirectedEdge(edge)) { factor = edgeParameters.get(edge); } else if (!edgeParameters.containsKey(edge)) { factor = 1; } else if (semGraph.isParentOf(_node1, _node2)) { factor = getEdgeCoefficient(_node1, _node2); } else { factor = getEdgeCoefficient(_node2, _node1); } product *= factor; } cov += product; } otherVariance += 2 * coef1 * coef2 * cov; } } return 1.0 - otherVariance <= 0 ? Double.NaN : 1.0 - otherVariance; }
public static CpcTripleType getCpcTripleType( Node x, Node y, Node z, IndependenceTest test, int depth, Graph graph) { // System.out.println("getCpcTripleType 1"); boolean existsSepsetContainingY = false; boolean existsSepsetNotContainingY = false; Set<Node> __nodes = new HashSet<Node>(graph.getAdjacentNodes(x)); __nodes.remove(z); // System.out.println("getCpcTripleType 2"); List<Node> _nodes = new LinkedList<Node>(__nodes); TetradLogger.getInstance() .log("adjacencies", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes); // System.out.println("getCpcTripleType 3"); int _depth = depth; if (_depth == -1) { _depth = 1000; } _depth = Math.min(_depth, _nodes.size()); // System.out.println("getCpcTripleType 4"); for (int d = 0; d <= _depth; d++) { // System.out.println("getCpcTripleType 5"); ChoiceGenerator cg = new ChoiceGenerator(_nodes.size(), d); int[] choice; while ((choice = cg.next()) != null) { // System.out.println("getCpcTripleType 6"); List<Node> condSet = GraphUtils.asList(choice, _nodes); // System.out.println("getCpcTripleType 7"); if (test.isIndependent(x, z, condSet)) { if (condSet.contains(y)) { existsSepsetContainingY = true; } else { existsSepsetNotContainingY = true; } } } } // System.out.println("getCpcTripleType 8"); __nodes = new HashSet<Node>(graph.getAdjacentNodes(z)); __nodes.remove(x); _nodes = new LinkedList<Node>(__nodes); TetradLogger.getInstance() .log("adjacencies", "Adjacents for " + x + "--" + y + "--" + z + " = " + _nodes); // System.out.println("getCpcTripleType 9"); _depth = depth; if (_depth == -1) { _depth = 1000; } _depth = Math.min(_depth, _nodes.size()); // System.out.println("getCpcTripleType 10"); for (int d = 0; d <= _depth; d++) { // System.out.println("getCpcTripleType 11"); ChoiceGenerator cg = new ChoiceGenerator(_nodes.size(), d); int[] choice; while ((choice = cg.next()) != null) { List<Node> condSet = GraphUtils.asList(choice, _nodes); if (test.isIndependent(x, z, condSet)) { if (condSet.contains(y)) { existsSepsetContainingY = true; } else { existsSepsetNotContainingY = true; } } } } // System.out.println("getCpcTripleType 12"); if (existsSepsetContainingY == existsSepsetNotContainingY) { return CpcTripleType.AMBIGUOUS; } else if (!existsSepsetNotContainingY) { return CpcTripleType.NONCOLLIDER; } else { return CpcTripleType.COLLIDER; } }
private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) { if (RandomUtil.getInstance().nextDouble() > 0.5) { Node temp = x; x = y; y = temp; } TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y); SortedMap<Double, String> scoreReports = new TreeMap<Double, String>(); List<Node> neighborsx = graph.getAdjacentNodes(x); neighborsx.remove(y); double max = Double.NEGATIVE_INFINITY; boolean left = false; boolean right = false; DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size()); int[] choicex; while ((choicex = genx.next()) != null) { List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx); List<Node> condxPlus = new ArrayList<Node>(condxMinus); condxPlus.add(y); double xPlus = score(x, condxPlus); double xMinus = score(x, condxMinus); List<Node> neighborsy = graph.getAdjacentNodes(y); neighborsy.remove(x); DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size()); int[] choicey; while ((choicey = geny.next()) != null) { List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy); // List<Node> parentsY = oldGraph.getParents(y); // parentsY.remove(x); // if (!condyMinus.containsAll(parentsY)) { // continue; // } List<Node> condyPlus = new ArrayList<Node>(condyMinus); condyPlus.add(x); double yPlus = score(y, condyPlus); double yMinus = score(y, condyMinus); // Checking them all at once is expensive but avoids lexical ordering problems in the // algorithm. if (normal(y, condyPlus) || normal(x, condxMinus) || normal(x, condxPlus) || normal(y, condyMinus)) { continue; } double delta = 0.0; if (strong) { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) { StringBuilder builder = new StringBuilder(); builder.append("\nStrong " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else { StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } else { if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(xPlus, yMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + y + "->" + x + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = true; right = false; } } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nWeak " + x + "->" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); if (score > max) { max = score; left = false; right = true; } } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) { double score = combinedScore(yPlus, xMinus); StringBuilder builder = new StringBuilder(); builder.append("\nNo directed edge " + x + "--" + y + " " + score); builder.append("\n Parents(" + x + ") = " + condxMinus); builder.append("\n Parents(" + y + ") = " + condyMinus); scoreReports.put(-score, builder.toString()); } } } } for (double score : scoreReports.keySet()) { TetradLogger.getInstance().log("info", scoreReports.get(score)); } graph.removeEdges(x, y); if (left) { graph.addDirectedEdge(y, x); } if (right) { graph.addDirectedEdge(x, y); } if (!graph.isAdjacentTo(x, y)) { graph.addUndirectedEdge(x, y); } }