@Test public void test7() { RandomUtil.getInstance().setSeed(29999483L); List<Node> nodes = new ArrayList<>(); int numVars = 10; for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1))); Graph graph = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true); GeneralizedSemPm pm = new GeneralizedSemPm(graph); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.67, aSquaredStar, 0.01); }
@Test public void test6() { RandomUtil.getInstance().setSeed(29999483L); int numVars = 5; List<Node> nodes = new ArrayList<>(); for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1))); Graph graph = GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true); SemPm spm = new SemPm(graph); SemImInitializationParams params = new SemImInitializationParams(); params.setCoefRange(0.5, 1.5); params.setVarRange(1, 3); SemIm sim = new SemIm(spm, params); GeneralizedSemPm pm = new GeneralizedSemPm(spm); GeneralizedSemIm im = new GeneralizedSemIm(pm, sim); DataSet data = im.simulateData(1000, false); print(im); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.59, aSquaredStar, 0.01); }
@Test public void test15() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * X1 + E_X2"); pm.setNodeExpression(x3, "a2 * X2 + E_X3"); pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)"); pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)"); pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)"); pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)"); pm.setParameterExpression("c1", "5"); pm.setParameterExpression("c2", "2"); pm.setParameterExpression("c3", "10"); pm.setParameterExpression("c4", "10"); pm.setParameterExpression("c5", "10"); pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(.79, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test14() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2"); pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3"); pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)"); pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)"); pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)"); pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)"); pm.setParameterExpression("a1", "1"); pm.setParameterExpression("a2", "1"); pm.setParameterExpression("a3", "1"); pm.setParameterExpression("a4", "1"); pm.setParameterExpression("c1", "4"); pm.setParameterExpression("c2", "4"); pm.setParameterExpression("c3", "4"); pm.setParameterExpression("c4", "4"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemIm imInit = new GeneralizedSemIm(pm); imInit.setParameterValue("c1", 8); imInit.setParameterValue("c2", 8); imInit.setParameterValue("c3", 8); imInit.setParameterValue("c4", 8); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(71.25, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }