// Causes a package cycle. public void testManualDiscretize2() { Graph graph = new Dag(GraphUtils.randomGraph(5, 0, 5, 3, 3, 3, false)); SemPm pm = new SemPm(graph); SemIm im = new SemIm(pm); DataSet data = im.simulateData(100, false); List<Node> nodes = data.getVariables(); Discretizer discretizer = new Discretizer(data); // discretizer.setVariablesCopied(true); discretizer.equalCounts(nodes.get(0), 3); discretizer.equalIntervals(nodes.get(1), 2); discretizer.equalCounts(nodes.get(2), 5); discretizer.equalIntervals(nodes.get(3), 8); discretizer.equalCounts(nodes.get(4), 4); DataSet discretized = discretizer.discretize(); System.out.println(discretized); assertEquals(2, maxInColumn(discretized, 0)); assertEquals(1, maxInColumn(discretized, 1)); assertEquals(4, maxInColumn(discretized, 2)); assertEquals(7, maxInColumn(discretized, 3)); assertEquals(3, maxInColumn(discretized, 4)); }
@Test public void test2() { RandomUtil.getInstance().setSeed(2999983L); int sampleSize = 1000; List<Node> variableNodes = new ArrayList<>(); ContinuousVariable x1 = new ContinuousVariable("X1"); ContinuousVariable x2 = new ContinuousVariable("X2"); ContinuousVariable x3 = new ContinuousVariable("X3"); ContinuousVariable x4 = new ContinuousVariable("X4"); ContinuousVariable x5 = new ContinuousVariable("X5"); variableNodes.add(x1); variableNodes.add(x2); variableNodes.add(x3); variableNodes.add(x4); variableNodes.add(x5); Graph _graph = new EdgeListGraph(variableNodes); SemGraph graph = new SemGraph(_graph); graph.addDirectedEdge(x1, x3); graph.addDirectedEdge(x2, x3); graph.addDirectedEdge(x3, x4); graph.addDirectedEdge(x2, x4); graph.addDirectedEdge(x4, x5); graph.addDirectedEdge(x2, x5); SemPm semPm = new SemPm(graph); SemIm semIm = new SemIm(semPm); DataSet dataSet = semIm.simulateData(sampleSize, false); print(semPm); GeneralizedSemPm _semPm = new GeneralizedSemPm(semPm); GeneralizedSemIm _semIm = new GeneralizedSemIm(_semPm, semIm); DataSet _dataSet = _semIm.simulateDataMinimizeSurface(sampleSize, false); print(_semPm); // System.out.println(_dataSet); for (int j = 0; j < dataSet.getNumColumns(); j++) { double[] col = dataSet.getDoubleData().getColumn(j).toArray(); double[] _col = _dataSet.getDoubleData().getColumn(j).toArray(); double mean = StatUtils.mean(col); double _mean = StatUtils.mean(_col); double variance = StatUtils.variance(col); double _variance = StatUtils.variance(_col); assertEquals(mean, _mean, 0.3); assertEquals(1.0, variance / _variance, .2); } }
public void rtestDSeparation4() { Graph graph = new Dag(GraphUtils.randomGraph(100, 20, 100, 5, 5, 5, false)); long start, stop; int depth = -1; IndependenceTest test = new IndTestDSep(graph); Rfci fci = new Rfci(test); Fas fas = new Fas(test); start = System.currentTimeMillis(); fci.setDepth(depth); fci.setVerbose(true); fci.search(fas, fas.getNodes()); stop = System.currentTimeMillis(); System.out.println("DSEP RFCI"); System.out.println("# dsep checks = " + fas.getNumIndependenceTests()); System.out.println("Elapsed " + (stop - start)); System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start)); SemPm pm = new SemPm(graph); SemIm im = new SemIm(pm); DataSet data = im.simulateData(1000, false); IndependenceTest test2 = new IndTestFisherZ(data, 0.001); Rfci fci3 = new Rfci(test2); Fas fas2 = new Fas(test2); start = System.currentTimeMillis(); fci3.setDepth(depth); fci3.search(fas2, fas2.getNodes()); stop = System.currentTimeMillis(); System.out.println("FISHER Z RFCI"); System.out.println("# indep checks = " + fas.getNumIndependenceTests()); System.out.println("Elapsed " + (stop - start)); System.out.println("Per " + fas.getNumIndependenceTests() / (double) (stop - start)); }
public void testManualDiscretize3() { Graph graph = new Dag(GraphUtils.randomGraph(5, 0, 5, 3, 3, 3, false)); SemPm pm = new SemPm(graph); SemIm im = new SemIm(pm); DataSet data = im.simulateData(100, false); List<Node> nodes = data.getVariables(); Discretizer discretizer = new Discretizer(data); discretizer.setVariablesCopied(true); discretizer.setVariablesCopied(true); discretizer.equalCounts(nodes.get(0), 3); DataSet discretized = discretizer.discretize(); System.out.println(discretized); assertTrue(discretized.getVariable(0) instanceof DiscreteVariable); assertTrue(discretized.getVariable(1) instanceof ContinuousVariable); assertTrue(discretized.getVariable(2) instanceof ContinuousVariable); assertTrue(discretized.getVariable(3) instanceof ContinuousVariable); assertTrue(discretized.getVariable(4) instanceof ContinuousVariable); }
public void test1() { for (int r = 0; r < 1; r++) { Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0); Graph mimStructure = structure(mim); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(.5, 1.5); SemPm pm = new SemPm(mim); SemIm im = new SemIm(pm, params); DataSet data = im.simulateData(300, false); String algorithm = "FOFC"; Graph searchGraph; List<List<Node>> partition; if (algorithm.equals("FOFC")) { FindOneFactorClusters fofc = new FindOneFactorClusters(data, TestType.TETRAD_WISHART, 0.001); searchGraph = fofc.search(); partition = fofc.getClusters(); } else if (algorithm.equals("BPC")) { TestType testType = TestType.TETRAD_WISHART; TestType purifyType = TestType.TETRAD_BASED2; BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType); searchGraph = bpc.search(); partition = MimUtils.convertToClusters2(searchGraph); } else { throw new IllegalStateException(); } List<String> latentVarList = reidentifyVariables(mim, data, partition, 2); System.out.println(partition); System.out.println(latentVarList); System.out.println("True\n" + mimStructure); Graph mimbuildStructure; for (int mimbuildMethod : new int[] {3, 4}) { if (mimbuildMethod == 1) { System.out.println("Mimbuild 1\n"); Clusters measurements = ClusterUtils.mimClusters(searchGraph); IndTestMimBuild test = new IndTestMimBuild(data, 0.001, measurements); MimBuild mimbuild = new MimBuild(test, new Knowledge2()); Graph full = mimbuild.search(); full = changeLatentNames(full, measurements, latentVarList); mimbuildStructure = structure(full); System.out.println( "SHD = " + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); System.out.println("Estimated\n" + mimbuildStructure); System.out.println(); } // else if (mimbuildMethod == 2) { // System.out.println("Mimbuild 2\n"); // Mimbuild2 mimbuild = new Mimbuild2(); // mimbuild.setAlpha(0.001); // mimbuildStructure = mimbuild.search(partition, latentVarList, data); // TetradMatrix latentcov = mimbuild.getLatentsCov(); // List<String> latentnames = mimbuild.getLatentNames(); // System.out.println("\nCovariance over the latents"); // System.out.println(MatrixUtils.toStringSquare(latentcov.toArray(), // latentnames)); // System.out.println("Estimated\n" + mimbuildStructure); // System.out.println("SHD = " + // SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); // System.out.println(); // } else if (mimbuildMethod == 3) { System.out.println("Mimbuild 3\n"); Mimbuild2 mimbuild = new Mimbuild2(); mimbuild.setAlpha(0.001); mimbuild.setMinClusterSize(3); mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data)); ICovarianceMatrix latentcov = mimbuild.getLatentsCov(); System.out.println("\nCovariance over the latents"); System.out.println(latentcov); System.out.println("Estimated\n" + mimbuildStructure); System.out.println( "SHD = " + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); System.out.println(); } else if (mimbuildMethod == 4) { System.out.println("Mimbuild Trek\n"); MimbuildTrek mimbuild = new MimbuildTrek(); mimbuild.setAlpha(0.1); mimbuild.setMinClusterSize(3); mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data)); ICovarianceMatrix latentcov = mimbuild.getLatentsCov(); System.out.println("\nCovariance over the latents"); System.out.println(latentcov); System.out.println("Estimated\n" + mimbuildStructure); System.out.println( "SHD = " + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure)); System.out.println(); } else { throw new IllegalStateException(); } } } }
public void rtest4() { System.out.println("SHD\tP"); // System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6"); Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0); Graph mimStructure = structure(mim); SemPm pm = new SemPm(mim); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); NumberFormat nf = new DecimalFormat("0.0000"); int totalError = 0; int errorCount = 0; int maxScore = 0; int maxNumMeasures = 0; double maxP = 0.0; for (int r = 0; r < 1; r++) { SemIm im = new SemIm(pm, params); DataSet data = im.simulateData(1000, false); mim = GraphUtils.replaceNodes(mim, data.getVariables()); List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim); CovarianceMatrix _cov = new CovarianceMatrix(data); ICovarianceMatrix cov = DataUtils.reorderColumns(_cov); String algorithm = "FOFC"; Graph searchGraph; List<List<Node>> partition; if (algorithm.equals("FOFC")) { FindOneFactorClusters fofc = new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001f); searchGraph = fofc.search(); searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariables()); partition = MimUtils.convertToClusters2(searchGraph); } else if (algorithm.equals("BPC")) { TestType testType = TestType.TETRAD_WISHART; TestType purifyType = TestType.TETRAD_BASED2; BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType); searchGraph = bpc.search(); partition = MimUtils.convertToClusters2(searchGraph); } else { throw new IllegalStateException(); } mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariables()); List<String> latentVarList = reidentifyVariables(mim, data, partition, 2); Graph mimbuildStructure; Mimbuild2 mimbuild = new Mimbuild2(); mimbuild.setAlpha(0.001); mimbuild.setMinClusterSize(3); try { mimbuildStructure = mimbuild.search(partition, latentVarList, cov); } catch (Exception e) { e.printStackTrace(); continue; } mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariables()); mimbuildStructure = condense(mimStructure, mimbuildStructure); // Graph mimSubgraph = new EdgeListGraph(mimStructure); // // for (Node node : mimSubgraph.getNodes()) { // if (!mimStructure.getNodes().contains(node)) { // mimSubgraph.removeNode(node); // } // } int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure); boolean impureCluster = containsImpureCluster(partition, trueClusters); double pValue = mimbuild.getpValue(); boolean pBelow05 = pValue < 0.05; boolean numClustersGreaterThan5 = partition.size() != 5; boolean error = false; // boolean condition = impureCluster || numClustersGreaterThan5 || pBelow05; // boolean condition = numClustersGreaterThan5 || pBelow05; boolean condition = numClustered(partition) == 40; if (!condition && (shd > 5)) { error = true; } if (!condition) { totalError += shd; errorCount++; } // if (numClustered(partition) > maxNumMeasures) { // maxNumMeasures = numClustered(partition); // maxP = pValue; // maxScore = shd; // System.out.println("maxNumMeasures = " + maxNumMeasures); // System.out.println("maxScore = " + maxScore); // System.out.println("maxP = " + maxP); // System.out.println("clusters = " + clusterSizes(partition, trueClusters)); // } // else if (pValue > maxP) { maxScore = shd; maxP = mimbuild.getpValue(); maxNumMeasures = numClustered(partition); System.out.println("maxNumMeasures = " + maxNumMeasures); System.out.println("maxScore = " + maxScore); System.out.println("maxP = " + maxP); System.out.println("clusters = " + clusterSizes(partition, trueClusters)); } System.out.print( shd + "\t" + nf.format(pValue) + " " // + (error ? 1 : 0) + " " // + (pBelow05 ? "P < 0.05 " : "") // + (impureCluster ? "Impure cluster " : "") // + (numClustersGreaterThan5 ? "# Clusters != 5 " : "") // + clusterSizes(partition, trueClusters) + numClustered(partition)); System.out.println(); } System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount)); System.out.println("maxNumMeasures = " + maxNumMeasures); System.out.println("maxScore = " + maxScore); System.out.println("maxP = " + maxP); }
public void rtest3() { Node x = new GraphNode("X"); Node y = new GraphNode("Y"); Node z = new GraphNode("Z"); Node w = new GraphNode("W"); List<Node> nodes = new ArrayList<Node>(); nodes.add(x); nodes.add(y); nodes.add(z); nodes.add(w); Graph g = new EdgeListGraph(nodes); g.addDirectedEdge(x, y); g.addDirectedEdge(x, z); g.addDirectedEdge(y, w); g.addDirectedEdge(z, w); Graph maxGraph = null; double maxPValue = -1.0; ICovarianceMatrix maxLatentCov = null; Graph mim = DataGraphUtils.randomMim(g, 8, 0, 0, 0, true); // Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0); Graph mimStructure = structure(mim); SemPm pm = new SemPm(mim); System.out.println("\n\nTrue graph:"); System.out.println(mimStructure); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); SemIm im = new SemIm(pm, params); int N = 1000; DataSet data = im.simulateData(N, false); CovarianceMatrix cov = new CovarianceMatrix(data); for (int i = 0; i < 1; i++) { ICovarianceMatrix _cov = DataUtils.reorderColumns(cov); List<List<Node>> partition; FindOneFactorClusters fofc = new FindOneFactorClusters(_cov, TestType.TETRAD_WISHART, .001); fofc.search(); partition = fofc.getClusters(); System.out.println(partition); List<String> latentVarList = reidentifyVariables(mim, data, partition, 2); Mimbuild2 mimbuild = new Mimbuild2(); mimbuild.setAlpha(0.001); // mimbuild.setMinimumSize(5); // To test knowledge. // Knowledge knowledge = new Knowledge2(); // knowledge.setEdgeForbidden("L.Y", "L.W", true); // knowledge.setEdgeRequired("L.Y", "L.Z", true); // mimbuild.setKnowledge(knowledge); Graph mimbuildStructure = mimbuild.search(partition, latentVarList, _cov); double pValue = mimbuild.getpValue(); System.out.println(mimbuildStructure); System.out.println("P = " + pValue); System.out.println("Latent Cov = " + mimbuild.getLatentsCov()); if (pValue > maxPValue) { maxPValue = pValue; maxGraph = new EdgeListGraph(mimbuildStructure); maxLatentCov = mimbuild.getLatentsCov(); } } System.out.println("\n\nTrue graph:"); System.out.println(mimStructure); System.out.println("\nBest graph:"); System.out.println(maxGraph); System.out.println("P = " + maxPValue); System.out.println("Latent Cov = " + maxLatentCov); System.out.println(); }
@Test public void testHistogram() { RandomUtil.getInstance().setSeed(4829384L); List<Node> nodes = new ArrayList<Node>(); for (int i = 0; i < 5; i++) { nodes.add(new ContinuousVariable("X" + (i + 1))); } Dag trueGraph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, false)); int sampleSize = 1000; // Continuous SemPm semPm = new SemPm(trueGraph); SemIm semIm = new SemIm(semPm); DataSet data = semIm.simulateData(sampleSize, false); Histogram histogram = new Histogram(data); histogram.setTarget("X1"); histogram.setNumBins(20); assertEquals(3.76, histogram.getMax(), 0.01); assertEquals(-3.83, histogram.getMin(), 0.01); assertEquals(1000, histogram.getN()); histogram.setTarget("X1"); histogram.setNumBins(10); histogram.addConditioningVariable("X3", 0, 1); histogram.addConditioningVariable("X4", 0, 1); histogram.removeConditioningVariable("X3"); assertEquals(3.76, histogram.getMax(), 0.01); assertEquals(-3.83, histogram.getMin(), 0.01); assertEquals(188, histogram.getN()); double[] arr = histogram.getContinuousData("X2"); histogram.addConditioningVariable("X2", StatUtils.min(arr), StatUtils.mean(arr)); // Discrete BayesPm bayesPm = new BayesPm(trueGraph); BayesIm bayesIm = new MlBayesIm(bayesPm, MlBayesIm.RANDOM); DataSet data2 = bayesIm.simulateData(sampleSize, false); // For some reason these are giving different // values when all of the unit tests are run are // once. TODO They produce stable values when // this particular test is run repeatedly. Histogram histogram2 = new Histogram(data2); histogram2.setTarget("X1"); int[] frequencies1 = histogram2.getFrequencies(); // assertEquals(928, frequencies1[0]); // assertEquals(72, frequencies1[1]); histogram2.setTarget("X1"); histogram2.addConditioningVariable("X2", 0); histogram2.addConditioningVariable("X3", 1); int[] frequencies = histogram2.getFrequencies(); // assertEquals(377, frequencies[0]); // assertEquals(28, frequencies[1]); }