@Test public void test7() { RandomUtil.getInstance().setSeed(29999483L); List<Node> nodes = new ArrayList<>(); int numVars = 10; for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1))); Graph graph = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true); GeneralizedSemPm pm = new GeneralizedSemPm(graph); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.67, aSquaredStar, 0.01); }
@Test public void test2() { RandomUtil.getInstance().setSeed(2999983L); int sampleSize = 1000; List<Node> variableNodes = new ArrayList<>(); ContinuousVariable x1 = new ContinuousVariable("X1"); ContinuousVariable x2 = new ContinuousVariable("X2"); ContinuousVariable x3 = new ContinuousVariable("X3"); ContinuousVariable x4 = new ContinuousVariable("X4"); ContinuousVariable x5 = new ContinuousVariable("X5"); variableNodes.add(x1); variableNodes.add(x2); variableNodes.add(x3); variableNodes.add(x4); variableNodes.add(x5); Graph _graph = new EdgeListGraph(variableNodes); SemGraph graph = new SemGraph(_graph); graph.addDirectedEdge(x1, x3); graph.addDirectedEdge(x2, x3); graph.addDirectedEdge(x3, x4); graph.addDirectedEdge(x2, x4); graph.addDirectedEdge(x4, x5); graph.addDirectedEdge(x2, x5); SemPm semPm = new SemPm(graph); SemIm semIm = new SemIm(semPm); DataSet dataSet = semIm.simulateData(sampleSize, false); print(semPm); GeneralizedSemPm _semPm = new GeneralizedSemPm(semPm); GeneralizedSemIm _semIm = new GeneralizedSemIm(_semPm, semIm); DataSet _dataSet = _semIm.simulateDataMinimizeSurface(sampleSize, false); print(_semPm); // System.out.println(_dataSet); for (int j = 0; j < dataSet.getNumColumns(); j++) { double[] col = dataSet.getDoubleData().getColumn(j).toArray(); double[] _col = _dataSet.getDoubleData().getColumn(j).toArray(); double mean = StatUtils.mean(col); double _mean = StatUtils.mean(_col); double variance = StatUtils.variance(col); double _variance = StatUtils.variance(_col); assertEquals(mean, _mean, 0.3); assertEquals(1.0, variance / _variance, .2); } }
@Test public void test8() { RandomUtil.getInstance().setSeed(29999483L); Node x = new GraphNode("X"); Node y = new GraphNode("Y"); List<Node> nodes = new ArrayList<>(); nodes.add(x); nodes.add(y); Graph graph = new EdgeListGraphSingleConnections(nodes); graph.addDirectedEdge(x, y); SemPm spm = new SemPm(graph); SemIm sim = new SemIm(spm); sim.setEdgeCoef(x, y, 20); sim.setErrVar(x, 1); sim.setErrVar(y, 1); GeneralizedSemPm pm = new GeneralizedSemPm(spm); GeneralizedSemIm im = new GeneralizedSemIm(pm, sim); print(im); try { pm.setParameterEstimationInitializationExpression("b1", "U(10, 30)"); pm.setParameterEstimationInitializationExpression("T1", "U(.1, 3)"); pm.setParameterEstimationInitializationExpression("T2", "U(.1, 3)"); } catch (ParseException e) { e.printStackTrace(); } DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); // System.out.println(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.69, aSquaredStar, 0.01); }
@Test public void test6() { RandomUtil.getInstance().setSeed(29999483L); int numVars = 5; List<Node> nodes = new ArrayList<>(); for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1))); Graph graph = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true); SemPm spm = new SemPm(graph); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); params.setVarRange(1, 3); SemIm sim = new SemIm(spm, params); GeneralizedSemPm pm = new GeneralizedSemPm(spm); GeneralizedSemIm im = new GeneralizedSemIm(pm, sim); DataSet data = im.simulateData(1000, false); print(im); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.59, aSquaredStar, 0.01); }
@Test public void test5() { RandomUtil.getInstance().setSeed(29999483L); List<Node> nodes = new ArrayList<>(); for (int i1 = 0; i1 < 5; i1++) { nodes.add(new ContinuousVariable("X" + (i1 + 1))); } Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, false)); SemPm semPm = new SemPm(graph); SemIm semIm = new SemIm(semPm); semIm.simulateDataReducedForm(1000, false); GeneralizedSemPm pm = new GeneralizedSemPm(semPm); GeneralizedSemIm im = new GeneralizedSemIm(pm, semIm); TetradVector e = new TetradVector(5); for (int i = 0; i < e.size(); i++) { e.set(i, RandomUtil.getInstance().nextNormal(0, 1)); } TetradVector record1 = semIm.simulateOneRecord(e); TetradVector record2 = im.simulateOneRecord(e); print("XXX1" + e); print("XXX2" + record1); print("XXX3" + record2); for (int i = 0; i < record1.size(); i++) { assertEquals(record1.get(i), record2.get(i), 1e-10); } }
@Test public void test15() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * X1 + E_X2"); pm.setNodeExpression(x3, "a2 * X2 + E_X3"); pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)"); pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)"); pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)"); pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)"); pm.setParameterExpression("c1", "5"); pm.setParameterExpression("c2", "2"); pm.setParameterExpression("c3", "10"); pm.setParameterExpression("c4", "10"); pm.setParameterExpression("c5", "10"); pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(.79, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test14() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2"); pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3"); pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)"); pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)"); pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)"); pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)"); pm.setParameterExpression("a1", "1"); pm.setParameterExpression("a2", "1"); pm.setParameterExpression("a3", "1"); pm.setParameterExpression("a4", "1"); pm.setParameterExpression("c1", "4"); pm.setParameterExpression("c2", "4"); pm.setParameterExpression("c3", "4"); pm.setParameterExpression("c4", "4"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemIm imInit = new GeneralizedSemIm(pm); imInit.setParameterValue("c1", 8); imInit.setParameterValue("c2", 8); imInit.setParameterValue("c3", 8); imInit.setParameterValue("c4", 8); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(71.25, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test1() { GeneralizedSemPm pm = makeTypicalPm(); print(pm); Node x1 = pm.getNode("X1"); Node x2 = pm.getNode("X2"); Node x3 = pm.getNode("X3"); Node x4 = pm.getNode("X4"); Node x5 = pm.getNode("X5"); SemGraph graph = pm.getGraph(); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(B1) + E_X1"); print(pm); String b1 = "B1"; String b2 = "B2"; String b3 = "B3"; Set<Node> nodes = pm.getReferencingNodes(b1); assertTrue(nodes.contains(x1)); assertTrue(!nodes.contains(x2) && !nodes.contains(x2)); Set<String> referencedParameters = pm.getReferencedParameters(x3); print("Parameters referenced by X3 are: " + referencedParameters); assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2)); assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3))); Node e_x3 = pm.getNode("E_X3"); // for (Node node : pm.getNodes()) { Set<Node> referencingNodes = pm.getReferencingNodes(node); print("Nodes referencing " + node + " are: " + referencingNodes); } for (Node node : pm.getVariableNodes()) { Set<Node> referencingNodes = pm.getReferencedNodes(node); print("Nodes referenced by " + node + " are: " + referencingNodes); } Set<Node> referencingX3 = pm.getReferencingNodes(x3); assertTrue(referencingX3.contains(x4)); assertTrue(!referencingX3.contains(x5)); Set<Node> referencedByX3 = pm.getReferencedNodes(x3); assertTrue( referencedByX3.contains(x1) && referencedByX3.contains(x2) && referencedByX3.contains(e_x3) && !referencedByX3.contains(x4)); pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5"); Node e_x5 = pm.getErrorNode(x5); graph.setShowErrorTerms(true); assertTrue(e_x5.equals(graph.getExogenous(x5))); pm.setNodeExpression(e_x5, "Beta(3, 5)"); print(pm); assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1)); pm.setParameterExpression(b1, "N(0, 2)"); assertEquals("N(0, 2)", pm.getParameterExpressionString(b1)); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataAvoidInfinity(10, false); print(dataSet); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test3() { RandomUtil.getInstance().setSeed(49293843L); List<Node> variableNodes = new ArrayList<>(); ContinuousVariable x1 = new ContinuousVariable("X1"); ContinuousVariable x2 = new ContinuousVariable("X2"); ContinuousVariable x3 = new ContinuousVariable("X3"); ContinuousVariable x4 = new ContinuousVariable("X4"); ContinuousVariable x5 = new ContinuousVariable("X5"); variableNodes.add(x1); variableNodes.add(x2); variableNodes.add(x3); variableNodes.add(x4); variableNodes.add(x5); Graph _graph = new EdgeListGraph(variableNodes); SemGraph graph = new SemGraph(_graph); graph.setShowErrorTerms(true); Node e1 = graph.getExogenous(x1); Node e2 = graph.getExogenous(x2); Node e3 = graph.getExogenous(x3); Node e4 = graph.getExogenous(x4); Node e5 = graph.getExogenous(x5); graph.addDirectedEdge(x1, x3); graph.addDirectedEdge(x1, x2); graph.addDirectedEdge(x2, x3); graph.addDirectedEdge(x3, x4); graph.addDirectedEdge(x2, x4); graph.addDirectedEdge(x4, x5); graph.addDirectedEdge(x2, x5); graph.addDirectedEdge(x5, x1); GeneralizedSemPm pm = new GeneralizedSemPm(graph); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(b1) + a1 * X5 + E_X1"); pm.setNodeExpression(x2, "a2 * X1 + E_X2"); pm.setNodeExpression(x3, "tan(a3*X2 + a4*X1) + E_X3"); pm.setNodeExpression(x4, "0.1 * E^X2 + X3 + E_X4"); pm.setNodeExpression(x5, "0.1 * E^X4 + a6* X2 + E_X5"); pm.setNodeExpression(e1, "U(0, 1)"); pm.setNodeExpression(e2, "U(0, 1)"); pm.setNodeExpression(e3, "U(0, 1)"); pm.setNodeExpression(e4, "U(0, 1)"); pm.setNodeExpression(e5, "U(0, 1)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataNSteps(1000, false); // System.out.println(dataSet); double[] d1 = dataSet.getDoubleData().getColumn(0).toArray(); double[] d2 = dataSet.getDoubleData().getColumn(1).toArray(); double cov = StatUtils.covariance(d1, d2); assertEquals(-0.002, cov, 0.001); } catch (ParseException e) { e.printStackTrace(); } }