/** Orients according to background knowledge */ private void fciOrientbk(IKnowledge bk, Graph graph, List<Node> variables) { logger.log("info", "Starting BK Orientation."); for (Iterator<KnowledgeEdge> it = bk.forbiddenEdgesIterator(); it.hasNext(); ) { KnowledgeEdge edge = it.next(); // match strings to variables in the graph. Node from = SearchGraphUtils.translate(edge.getFrom(), variables); Node to = SearchGraphUtils.translate(edge.getTo(), variables); if (from == null || to == null) { continue; } if (graph.getEdge(from, to) == null) { continue; } // Orient to*->from graph.setEndpoint(to, from, Endpoint.ARROW); graph.setEndpoint(from, to, Endpoint.CIRCLE); changeFlag = true; logger.log( "knowledgeOrientation", SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); } for (Iterator<KnowledgeEdge> it = bk.requiredEdgesIterator(); it.hasNext(); ) { KnowledgeEdge edge = it.next(); // match strings to variables in this graph Node from = SearchGraphUtils.translate(edge.getFrom(), variables); Node to = SearchGraphUtils.translate(edge.getTo(), variables); if (from == null || to == null) { continue; } if (graph.getEdge(from, to) == null) { continue; } graph.setEndpoint(to, from, Endpoint.TAIL); graph.setEndpoint(from, to, Endpoint.ARROW); changeFlag = true; logger.log( "knowledgeOrientation", SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to))); } logger.log("info", "Finishing BK Orientation."); }
/** * Runs PC starting with a commplete graph over the given list of nodes, using the given * independence test and knowledge and returns the resultant graph. The returned graph will be a * pattern if the independence information is consistent with the hypothesis that there are no * latent common causes. It may, however, contain cycles or bidirected edges if this assumption is * not born out, either due to the actual presence of latent common causes, or due to statistical * errors in conditional independence judgments. * * <p>All of the given nodes must be in the domain of the given conditional independence test. */ public Graph search(List<Node> nodes) { this.logger.log("info", "Starting PC algorithm"); this.logger.log("info", "Independence test = " + getIndependenceTest() + "."); // this.logger.log("info", "Variables " + independenceTest.getVariables()); long startTime = System.currentTimeMillis(); if (getIndependenceTest() == null) { throw new NullPointerException(); } List allNodes = getIndependenceTest().getVariables(); if (!allNodes.containsAll(nodes)) { throw new IllegalArgumentException( "All of the given nodes must " + "be in the domain of the independence test provided."); } graph = new EdgeListGraph(nodes); IFas fas = new FasStableConcurrent(initialGraph, getIndependenceTest()); fas.setKnowledge(getKnowledge()); fas.setDepth(getDepth()); fas.setVerbose(verbose); graph = fas.search(); sepsets = fas.getSepsets(); SearchGraphUtils.pcOrientbk(knowledge, graph, nodes); // SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, // initialGraph, verbose); // SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose); // SearchGraphUtils.orientColeelidersLocally(knowledge, graph, independenceTest, depth); SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose); MeekRules rules = new MeekRules(); rules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles); rules.setKnowledge(knowledge); rules.orientImplied(graph); this.logger.log("graph", "\nReturning this graph: " + graph); this.elapsedTime = System.currentTimeMillis() - startTime; this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.log("info", "Finishing PC Algorithm."); this.logger.flush(); return graph; }
/** Returns the pattern to which the given DAG belongs. */ public static Graph patternFromDag(Graph dag) { Graph graph = new EdgeListGraph(dag); SearchGraphUtils.basicPattern(graph); MeekRules rules = new MeekRules(); rules.orientImplied(graph); return graph; }
public void layoutByKnowledge() { GraphWorkbench resultWorkbench = getWorkbench(); Graph graph = resultWorkbench.getGraph(); IKnowledge knowledge = getAlgorithmRunner().getParams().getKnowledge(); SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); // resultWorkbench.setGraph(graph); }
/** * Completes a pattern that was modified by an insertion/deletion operator Based on the algorithm * described on Appendix C of (Chickering, 2002). */ private void rebuildPattern(Graph graph) { SearchGraphUtils.basicPattern(graph, false); addRequiredEdges(graph); meekOrient(graph, getKnowledge()); if (TetradLogger.getInstance().isEventActive("rebuiltPatterns")) { TetradLogger.getInstance().log("rebuiltPatterns", "Rebuilt pattern = " + graph); } }
protected void doDefaultArrangement(Graph resultGraph) { if (getLatestWorkbenchGraph() != null) { // (alreadyLaidOut) { GraphUtils.arrangeBySourceGraph(resultGraph, getLatestWorkbenchGraph()); } else if (getKnowledge().isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(resultGraph, getKnowledge()); // alreadyLaidOut = true; } else { GraphUtils.circleLayout(resultGraph, 200, 200, 150); // alreadyLaidOut = true; } }
/** * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the * extending class. */ public void execute() { IKnowledge knowledge = getParams().getKnowledge(); SearchParams searchParams = getParams(); FciIndTestParams indTestParams = (FciIndTestParams) searchParams.getIndTestParams(); // Cfci fciSearch = // new Cfci(getIndependenceTest(), knowledge); // fciSearch.setMaxIndegree(indTestParams.depth()); // Graph graph = fciSearch.search(); // // if (knowledge.isDefaultToKnowledgeLayout()) { // SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); // } // // setResultGraph(graph); Graph graph; if (indTestParams.isRFCI_Used()) { Rfci fci = new Rfci(getIndependenceTest()); fci.setKnowledge(knowledge); fci.setCompleteRuleSetUsed(indTestParams.isCompleteRuleSetUsed()); fci.setMaxPathLength(indTestParams.getMaxReachablePathLength()); fci.setDepth(indTestParams.getDepth()); graph = fci.search(); } else { Fci fci = new Fci(getIndependenceTest()); fci.setKnowledge(knowledge); fci.setCompleteRuleSetUsed(indTestParams.isCompleteRuleSetUsed()); fci.setPossibleDsepSearchDone(indTestParams.isPossibleDsepDone()); fci.setMaxPathLength(indTestParams.getMaxReachablePathLength()); fci.setDepth(indTestParams.getDepth()); graph = fci.search(); } if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); } else if (knowledge.isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); } else { GraphUtils.circleLayout(graph, 200, 200, 150); } setResultGraph(graph); }
/** * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the * extending class. */ public void execute() { IKnowledge knowledge = getParams().getKnowledge(); SearchParams searchParams = getParams(); FciGesIndTestParams indTestParams = (FciGesIndTestParams) searchParams.getIndTestParams(); // Cfci fciSearch = // new Cfci(getIndependenceTest(), knowledge); // fciSearch.setDepth(indTestParams.depth()); // Graph graph = fciSearch.search(); // // if (knowledge.isDefaultToKnowledgeLayout()) { // SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); // } // // setResultGraph(graph); Graph graph; TFciGes fci = new TFciGes(getIndependenceTest()); fci.setKnowledge(knowledge); fci.setCompleteRuleSetUsed(indTestParams.isCompleteRuleSetUsed()); fci.setPossibleDsepSearchDone(indTestParams.isPossibleDsepDone()); fci.setMaxPathLength(indTestParams.getMaxReachablePathLength()); fci.setDepth(indTestParams.getDepth()); fci.setPenaltyDiscount(indTestParams.getPenaltyDiscount()); fci.setSamplePrior(indTestParams.getSamplePrior()); fci.setStructurePrior(indTestParams.getStructurePrior()); fci.setCompleteRuleSetUsed(false); fci.setPenaltyDiscount(indTestParams.getPenaltyDiscount()); fci.setFaithfulnessAssumed(indTestParams.isFaithfulnessAssumed()); graph = fci.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); } else if (knowledge.isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); } else { GraphUtils.circleLayout(graph, 200, 200, 150); } setResultGraph(graph); }
private Graph pickDag(Graph graph) { SearchGraphUtils.basicPattern(graph, false); addRequiredEdges(graph); boolean containsUndirected; do { containsUndirected = false; for (Edge edge : graph.getEdges()) { if (Edges.isUndirectedEdge(edge)) { containsUndirected = true; graph.removeEdge(edge); Edge _edge = Edges.directedEdge(edge.getNode1(), edge.getNode2()); graph.addEdge(_edge); } } meekOrient(graph, getKnowledge()); } while (containsUndirected); return graph; }
public void execute() { IKnowledge knowledge = getParams().getKnowledge(); PcSearchParams searchParams = (PcSearchParams) getParams(); PcIndTestParams indTestParams = (PcIndTestParams) searchParams.getIndTestParams(); VcpcAlt VcpcAlt = new VcpcAlt(getIndependenceTest()); VcpcAlt.setKnowledge(knowledge); VcpcAlt.setAggressivelyPreventCycles(this.isAggressivelyPreventCycles()); VcpcAlt.setDepth(indTestParams.getDepth()); Graph graph = VcpcAlt.search(); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); } else if (knowledge.isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); } else { GraphUtils.circleLayout(graph, 200, 200, 150); } setResultGraph(graph); }
/** * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the * extending class. */ public void execute() { IKnowledge knowledge = (IKnowledge) getParams().get("knowledge", new Knowledge2()); Parameters searchParams = getParams(); Parameters params = searchParams; Graph graph; IndependenceTest independenceTest = getIndependenceTest(); Score score = new ScoredIndTest(independenceTest); if (independenceTest instanceof IndTestDSep) { final DagToPag dagToPag = new DagToPag(((IndTestDSep) independenceTest).getGraph()); dagToPag.setCompleteRuleSetUsed(params.getBoolean("completeRuleSetUsed", false)); graph = dagToPag.convert(); } else { GFci fci = new GFci(independenceTest, score); fci.setKnowledge(knowledge); fci.setCompleteRuleSetUsed(params.getBoolean("completeRuleSetUsed", false)); fci.setMaxPathLength(params.getInt("maxReachablePathLength", -1)); fci.setMaxDegree(params.getInt("maxIndegree")); fci.setCompleteRuleSetUsed(false); fci.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true)); graph = fci.search(); } if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); } else if (knowledge.isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); } else { GraphUtils.circleLayout(graph, 200, 200, 150); } setResultGraph(graph); }
public void test1() { for (int r = 0; r < 1; r++) { Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0); Graph mimStructure = structure(mim); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(.5, 1.5); SemPm pm = new SemPm(mim); SemIm im = new SemIm(pm, params); DataSet data = im.simulateData(300, false); String algorithm = "FOFC"; Graph searchGraph; List<List<Node>> partition; if (algorithm.equals("FOFC")) { FindOneFactorClusters fofc = new FindOneFactorClusters(data, TestType.TETRAD_WISHART, 0.001); searchGraph = fofc.search(); partition = fofc.getClusters(); } else if (algorithm.equals("BPC")) { TestType testType = TestType.TETRAD_WISHART; TestType purifyType = TestType.TETRAD_BASED2; BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType); searchGraph = bpc.search(); partition = MimUtils.convertToClusters2(searchGraph); } else { throw new IllegalStateException(); } List<String> latentVarList = reidentifyVariables(mim, data, partition, 2); System.out.println(partition); System.out.println(latentVarList); System.out.println("True\n" + mimStructure); Graph mimbuildStructure; for (int mimbuildMethod : new int[] {3, 4}) { if (mimbuildMethod == 1) { System.out.println("Mimbuild 1\n"); Clusters measurements = ClusterUtils.mimClusters(searchGraph); IndTestMimBuild test = new IndTestMimBuild(data, 0.001, measurements); MimBuild mimbuild = new MimBuild(test, new Knowledge2()); Graph full = mimbuild.search(); full = changeLatentNames(full, measurements, latentVarList); mimbuildStructure = structure(full); System.out.println( "SHD = " + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); System.out.println("Estimated\n" + mimbuildStructure); System.out.println(); } // else if (mimbuildMethod == 2) { // System.out.println("Mimbuild 2\n"); // Mimbuild2 mimbuild = new Mimbuild2(); // mimbuild.setAlpha(0.001); // mimbuildStructure = mimbuild.search(partition, latentVarList, data); // TetradMatrix latentcov = mimbuild.getLatentsCov(); // List<String> latentnames = mimbuild.getLatentNames(); // System.out.println("\nCovariance over the latents"); // System.out.println(MatrixUtils.toStringSquare(latentcov.toArray(), // latentnames)); // System.out.println("Estimated\n" + mimbuildStructure); // System.out.println("SHD = " + // SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); // System.out.println(); // } else if (mimbuildMethod == 3) { System.out.println("Mimbuild 3\n"); Mimbuild2 mimbuild = new Mimbuild2(); mimbuild.setAlpha(0.001); mimbuild.setMinClusterSize(3); mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data)); ICovarianceMatrix latentcov = mimbuild.getLatentsCov(); System.out.println("\nCovariance over the latents"); System.out.println(latentcov); System.out.println("Estimated\n" + mimbuildStructure); System.out.println( "SHD = " + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); System.out.println(); } else if (mimbuildMethod == 4) { System.out.println("Mimbuild Trek\n"); MimbuildTrek mimbuild = new MimbuildTrek(); mimbuild.setAlpha(0.1); mimbuild.setMinClusterSize(3); mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data)); ICovarianceMatrix latentcov = mimbuild.getLatentsCov(); System.out.println("\nCovariance over the latents"); System.out.println(latentcov); System.out.println("Estimated\n" + mimbuildStructure); System.out.println( "SHD = " + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); System.out.println(); } else { throw new IllegalStateException(); } } } }
public void rtest4() { System.out.println("SHD\tP"); // System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6"); Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0); Graph mimStructure = structure(mim); SemPm pm = new SemPm(mim); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); NumberFormat nf = new DecimalFormat("0.0000"); int totalError = 0; int errorCount = 0; int maxScore = 0; int maxNumMeasures = 0; double maxP = 0.0; for (int r = 0; r < 1; r++) { SemIm im = new SemIm(pm, params); DataSet data = im.simulateData(1000, false); mim = GraphUtils.replaceNodes(mim, data.getVariables()); List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim); CovarianceMatrix _cov = new CovarianceMatrix(data); ICovarianceMatrix cov = DataUtils.reorderColumns(_cov); String algorithm = "FOFC"; Graph searchGraph; List<List<Node>> partition; if (algorithm.equals("FOFC")) { FindOneFactorClusters fofc = new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001f); searchGraph = fofc.search(); searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariables()); partition = MimUtils.convertToClusters2(searchGraph); } else if (algorithm.equals("BPC")) { TestType testType = TestType.TETRAD_WISHART; TestType purifyType = TestType.TETRAD_BASED2; BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType); searchGraph = bpc.search(); partition = MimUtils.convertToClusters2(searchGraph); } else { throw new IllegalStateException(); } mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariables()); List<String> latentVarList = reidentifyVariables(mim, data, partition, 2); Graph mimbuildStructure; Mimbuild2 mimbuild = new Mimbuild2(); mimbuild.setAlpha(0.001); mimbuild.setMinClusterSize(3); try { mimbuildStructure = mimbuild.search(partition, latentVarList, cov); } catch (Exception e) { e.printStackTrace(); continue; } mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariables()); mimbuildStructure = condense(mimStructure, mimbuildStructure); // Graph mimSubgraph = new EdgeListGraph(mimStructure); // // for (Node node : mimSubgraph.getNodes()) { // if (!mimStructure.getNodes().contains(node)) { // mimSubgraph.removeNode(node); // } // } int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure); boolean impureCluster = containsImpureCluster(partition, trueClusters); double pValue = mimbuild.getpValue(); boolean pBelow05 = pValue < 0.05; boolean numClustersGreaterThan5 = partition.size() != 5; boolean error = false; // boolean condition = impureCluster || numClustersGreaterThan5 || pBelow05; // boolean condition = numClustersGreaterThan5 || pBelow05; boolean condition = numClustered(partition) == 40; if (!condition && (shd > 5)) { error = true; } if (!condition) { totalError += shd; errorCount++; } // if (numClustered(partition) > maxNumMeasures) { // maxNumMeasures = numClustered(partition); // maxP = pValue; // maxScore = shd; // System.out.println("maxNumMeasures = " + maxNumMeasures); // System.out.println("maxScore = " + maxScore); // System.out.println("maxP = " + maxP); // System.out.println("clusters = " + clusterSizes(partition, trueClusters)); // } // else if (pValue > maxP) { maxScore = shd; maxP = mimbuild.getpValue(); maxNumMeasures = numClustered(partition); System.out.println("maxNumMeasures = " + maxNumMeasures); System.out.println("maxScore = " + maxScore); System.out.println("maxP = " + maxP); System.out.println("clusters = " + clusterSizes(partition, trueClusters)); } System.out.print( shd + "\t" + nf.format(pValue) + " " // + (error ? 1 : 0) + " " // + (pBelow05 ? "P < 0.05 " : "") // + (impureCluster ? "Impure cluster " : "") // + (numClustersGreaterThan5 ? "# Clusters != 5 " : "") // + clusterSizes(partition, trueClusters) + numClustered(partition)); System.out.println(); } System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount)); System.out.println("maxNumMeasures = " + maxNumMeasures); System.out.println("maxScore = " + maxScore); System.out.println("maxP = " + maxP); }
/** * Runs PC starting with a commplete graph over the given list of nodes, using the given * independence test and knowledge and returns the resultant graph. The returned graph will be a * pattern if the independence information is consistent with the hypothesis that there are no * latent common causes. It may, however, contain cycles or bidirected edges if this assumption is * not born out, either due to the actual presence of latent common causes, or due to statistical * errors in conditional independence judgments. * * <p>All of the given nodes must be in the domain of the given conditional independence test. */ public Graph search(List<Node> nodes) { this.logger.log("info", "Starting PC algorithm"); this.logger.log("info", "Independence test = " + getIndependenceTest() + "."); if (trueDag != null) { this.dsep = new IndTestDSep(trueDag); } long startTime = System.currentTimeMillis(); if (getIndependenceTest() == null) { throw new NullPointerException(); } List<Node> allNodes = getIndependenceTest().getVariables(); if (!allNodes.containsAll(nodes)) { throw new IllegalArgumentException( "All of the given nodes must " + "be in the domain of the independence test provided."); } IFas fas = new Fas2(getIndependenceTest()); fas.setInitialGraph(initialGraph); fas.setKnowledge(getKnowledge()); fas.setDepth(getDepth()); fas.setVerbose(verbose); graph = fas.search(); SearchGraphUtils.pcOrientbk(knowledge, graph, nodes); // independenceTest = new ProbabilisticMAPIndependence((DataSet) // independenceTest.getData()); SepsetsMaxPValue sepsetProducer = new SepsetsMaxPValue(graph, independenceTest, null, getDepth()); sepsetProducer.setDsep(dsep); addColliders(graph, sepsetProducer, knowledge); MeekRules rules = new MeekRules(); rules.setKnowledge(knowledge); rules.orientImplied(graph); // Graph pattern = new EdgeListGraphSingleConnections(graph); // // for (Node x : getNodes()) { // for (Node y : getNodes()) { // if (x == y) continue; // // if (!localMarkovIndep(x, y, pattern, independenceTest)) { // graph.addUndirectedEdge(x, y); // } // } // } // // fas = new FasStableConcurrent(getIndependenceTest()); // fas.setInitialGraph(new EdgeListGraphSingleConnections(graph)); // fas.setKnowledge(getKnowledge()); // fas.setDepth(getDepth()); // fas.setVerbose(verbose); // graph = fas.search(); // // sepsetProducer = new SepsetsMaxPValue(graph, independenceTest, null, getDepth()); // // addColliders(graph, sepsetProducer, knowledge); // // rules = new MeekRules(); // rules.setKnowledge(knowledge); // rules.orientImplied(graph); this.logger.log("graph", "\nReturning this graph: " + graph); this.elapsedTime = System.currentTimeMillis() - startTime; this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s"); this.logger.log("info", "Finishing PC Algorithm."); this.logger.flush(); return graph; }
/** * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the * extending class. */ public void execute() { Object source = dataWrapper.getSelectedDataModel(); DataModel dataModel = (DataModel) source; IKnowledge knowledge = params2.getKnowledge(); if (initialGraph == null) { initialGraph = new EdgeListGraph(dataModel.getVariables()); } Graph graph2 = new EdgeListGraph(initialGraph); graph2 = GraphUtils.replaceNodes(graph2, dataModel.getVariables()); Bff search; if (dataModel instanceof DataSet) { DataSet dataSet = (DataSet) dataModel; if (getAlgorithmType() == AlgorithmType.BEAM) { search = new BffBeam(graph2, dataSet, knowledge); } else if (getAlgorithmType() == AlgorithmType.GES) { search = new BffGes(graph2, dataSet); search.setKnowledge(knowledge); } else { throw new IllegalStateException(); } } else if (dataModel instanceof CovarianceMatrix) { CovarianceMatrix covarianceMatrix = (CovarianceMatrix) dataModel; if (getAlgorithmType() == AlgorithmType.BEAM) { search = new BffBeam(graph2, covarianceMatrix, knowledge); } else if (getAlgorithmType() == AlgorithmType.GES) { throw new IllegalArgumentException( "GES method requires a dataset; a covariance matrix was provided."); // search = new BffGes(graph2, covarianceMatrix); // search.setKnowledge(knowledge); } else { throw new IllegalStateException(); } } else { throw new IllegalStateException(); } PcIndTestParams indTestParams = (PcIndTestParams) getParams().getIndTestParams(); search.setAlpha(indTestParams.getAlpha()); search.setBeamWidth(indTestParams.getBeamWidth()); search.setHighPValueAlpha(indTestParams.getZeroEdgeP()); this.graph = search.search(); // this.graph = search.getNewSemIm().getSemPm().getGraph(); setOriginalSemIm(search.getOriginalSemIm()); this.newSemIm = search.getNewSemIm(); fireGraphChange(graph); if (getSourceGraph() != null) { GraphUtils.arrangeBySourceGraph(graph, getSourceGraph()); } else if (knowledge.isDefaultToKnowledgeLayout()) { SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge); } else { GraphUtils.circleLayout(graph, 200, 200, 150); } setResultGraph(SearchGraphUtils.patternForDag(graph, knowledge)); }
public static void main(String[] args) { // Graph g = new EdgeListGraph(); // g.addNode(new ContinuousVariable("X1")); // g.addNode(new ContinuousVariable("X2")); // g.addNode(new DiscreteVariable("X3", 4)); // g.addNode(new DiscreteVariable("X4", 4)); // g.addNode(new ContinuousVariable("X5")); // // g.addDirectedEdge(g.getNode("X1"), g.getNode("X2")); // g.addDirectedEdge(g.getNode("X2"), g.getNode("X3")); // g.addDirectedEdge(g.getNode("X3"), g.getNode("X4")); // g.addDirectedEdge(g.getNode("X4"), g.getNode("X5")); // // GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)"); //// System.out.println(pm); // // GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm); //// System.out.println(im); // // int samps = 200; // DataSet ds = im.simulateDataAvoidInfinity(samps, false); // ds = MixedUtils.makeMixedData(ds, MixedUtils.getNodeDists(g)); // //System.out.println(ds); // System.out.println(ds.isMixed()); try { String path = ExampleMixedSearch.class.getResource("test_data").getPath(); Graph trueGraph = SearchGraphUtils.patternFromDag( GraphUtils.loadGraphTxt(new File(path, "DAG_0_graph.txt"))); DataSet ds = MixedUtils.loadDataSet(path, "DAG_0_data.txt"); IndTestMultinomialLogisticRegression indMix = new IndTestMultinomialLogisticRegression(ds, .05); IndTestMultinomialLogisticRegressionWald indWalLin = new IndTestMultinomialLogisticRegressionWald(ds, .05, true); IndTestMultinomialLogisticRegressionWald indWalLog = new IndTestMultinomialLogisticRegressionWald(ds, .05, false); PcStable s1 = new PcStable(indMix); PcStable s2 = new PcStable(indWalLin); PcStable s3 = new PcStable(indWalLog); long time = System.currentTimeMillis(); Graph g1 = SearchGraphUtils.patternFromDag(s1.search()); System.out.println("Mix Time " + ((System.currentTimeMillis() - time) / 1000.0)); time = System.currentTimeMillis(); Graph g2 = SearchGraphUtils.patternFromDag(s2.search()); System.out.println("Wald lin Time " + ((System.currentTimeMillis() - time) / 1000.0)); time = System.currentTimeMillis(); Graph g3 = SearchGraphUtils.patternFromDag(s3.search()); System.out.println("Wald log Time " + ((System.currentTimeMillis() - time) / 1000.0)); // System.out.println(g); // System.out.println("IndMix: " + s1.search()); // System.out.println("IndWalLin: " + s2.search()); // System.out.println("IndWalLog: " + s3.search()); System.out.println(MixedUtils.EdgeStatHeader); System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g1))); System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g2))); System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g3))); } catch (Throwable t) { t.printStackTrace(); } }