@Test public void test8() { RandomUtil.getInstance().setSeed(29999483L); Node x = new GraphNode("X"); Node y = new GraphNode("Y"); List<Node> nodes = new ArrayList<>(); nodes.add(x); nodes.add(y); Graph graph = new EdgeListGraphSingleConnections(nodes); graph.addDirectedEdge(x, y); SemPm spm = new SemPm(graph); SemIm sim = new SemIm(spm); sim.setEdgeCoef(x, y, 20); sim.setErrVar(x, 1); sim.setErrVar(y, 1); GeneralizedSemPm pm = new GeneralizedSemPm(spm); GeneralizedSemIm im = new GeneralizedSemIm(pm, sim); print(im); try { pm.setParameterEstimationInitializationExpression("b1", "U(10, 30)"); pm.setParameterEstimationInitializationExpression("T1", "U(.1, 3)"); pm.setParameterEstimationInitializationExpression("T2", "U(.1, 3)"); } catch (ParseException e) { e.printStackTrace(); } DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); // System.out.println(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.69, aSquaredStar, 0.01); }
@Test public void test15() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * X1 + E_X2"); pm.setNodeExpression(x3, "a2 * X2 + E_X3"); pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)"); pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)"); pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)"); pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)"); pm.setParameterExpression("c1", "5"); pm.setParameterExpression("c2", "2"); pm.setParameterExpression("c3", "10"); pm.setParameterExpression("c4", "10"); pm.setParameterExpression("c5", "10"); pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(.79, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test14() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2"); pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3"); pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)"); pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)"); pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)"); pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)"); pm.setParameterExpression("a1", "1"); pm.setParameterExpression("a2", "1"); pm.setParameterExpression("a3", "1"); pm.setParameterExpression("a4", "1"); pm.setParameterExpression("c1", "4"); pm.setParameterExpression("c2", "4"); pm.setParameterExpression("c3", "4"); pm.setParameterExpression("c4", "4"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemIm imInit = new GeneralizedSemIm(pm); imInit.setParameterValue("c1", 8); imInit.setParameterValue("c2", 8); imInit.setParameterValue("c3", 8); imInit.setParameterValue("c4", 8); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(71.25, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test1() { GeneralizedSemPm pm = makeTypicalPm(); print(pm); Node x1 = pm.getNode("X1"); Node x2 = pm.getNode("X2"); Node x3 = pm.getNode("X3"); Node x4 = pm.getNode("X4"); Node x5 = pm.getNode("X5"); SemGraph graph = pm.getGraph(); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(B1) + E_X1"); print(pm); String b1 = "B1"; String b2 = "B2"; String b3 = "B3"; Set<Node> nodes = pm.getReferencingNodes(b1); assertTrue(nodes.contains(x1)); assertTrue(!nodes.contains(x2) && !nodes.contains(x2)); Set<String> referencedParameters = pm.getReferencedParameters(x3); print("Parameters referenced by X3 are: " + referencedParameters); assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2)); assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3))); Node e_x3 = pm.getNode("E_X3"); // for (Node node : pm.getNodes()) { Set<Node> referencingNodes = pm.getReferencingNodes(node); print("Nodes referencing " + node + " are: " + referencingNodes); } for (Node node : pm.getVariableNodes()) { Set<Node> referencingNodes = pm.getReferencedNodes(node); print("Nodes referenced by " + node + " are: " + referencingNodes); } Set<Node> referencingX3 = pm.getReferencingNodes(x3); assertTrue(referencingX3.contains(x4)); assertTrue(!referencingX3.contains(x5)); Set<Node> referencedByX3 = pm.getReferencedNodes(x3); assertTrue( referencedByX3.contains(x1) && referencedByX3.contains(x2) && referencedByX3.contains(e_x3) && !referencedByX3.contains(x4)); pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5"); Node e_x5 = pm.getErrorNode(x5); graph.setShowErrorTerms(true); assertTrue(e_x5.equals(graph.getExogenous(x5))); pm.setNodeExpression(e_x5, "Beta(3, 5)"); print(pm); assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1)); pm.setParameterExpression(b1, "N(0, 2)"); assertEquals("N(0, 2)", pm.getParameterExpressionString(b1)); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataAvoidInfinity(10, false); print(dataSet); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test4() { // For X3 Map<String, String[]> templates = new HashMap<>(); templates.put( "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {}); templates.put("$", new String[] {}); templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"}); for (String template : templates.keySet()) { GeneralizedSemPm semPm = makeTypicalPm(); print(semPm.getGraph().toString()); Set<Node> shouldWork = new HashSet<>(); for (String name : templates.get(template)) { shouldWork.add(semPm.getNode(name)); } Set<Node> works = new HashSet<>(); for (int i = 0; i < semPm.getNodes().size(); i++) { print("-----------"); print(semPm.getNodes().get(i).toString()); print("Trying template: " + template); String _template = template; Node node = semPm.getNodes().get(i); try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setNodeExpression(node, _template); print("Set formula " + _template + " for " + node); if (semPm.getVariableNodes().contains(node)) { works.add(node); } } catch (Exception e) { print("Couldn't set formula " + _template + " for " + node); } } for (String parameter : semPm.getParameters()) { print("-----------"); print(parameter); print("Trying template: " + template); String _template = template; try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setParameterExpression(parameter, _template); print("Set formula " + _template + " for " + parameter); } catch (Exception e) { print("Couldn't set formula " + _template + " for " + parameter); } } assertEquals(shouldWork, works); } }
@Test public void test3() { RandomUtil.getInstance().setSeed(49293843L); List<Node> variableNodes = new ArrayList<>(); ContinuousVariable x1 = new ContinuousVariable("X1"); ContinuousVariable x2 = new ContinuousVariable("X2"); ContinuousVariable x3 = new ContinuousVariable("X3"); ContinuousVariable x4 = new ContinuousVariable("X4"); ContinuousVariable x5 = new ContinuousVariable("X5"); variableNodes.add(x1); variableNodes.add(x2); variableNodes.add(x3); variableNodes.add(x4); variableNodes.add(x5); Graph _graph = new EdgeListGraph(variableNodes); SemGraph graph = new SemGraph(_graph); graph.setShowErrorTerms(true); Node e1 = graph.getExogenous(x1); Node e2 = graph.getExogenous(x2); Node e3 = graph.getExogenous(x3); Node e4 = graph.getExogenous(x4); Node e5 = graph.getExogenous(x5); graph.addDirectedEdge(x1, x3); graph.addDirectedEdge(x1, x2); graph.addDirectedEdge(x2, x3); graph.addDirectedEdge(x3, x4); graph.addDirectedEdge(x2, x4); graph.addDirectedEdge(x4, x5); graph.addDirectedEdge(x2, x5); graph.addDirectedEdge(x5, x1); GeneralizedSemPm pm = new GeneralizedSemPm(graph); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(b1) + a1 * X5 + E_X1"); pm.setNodeExpression(x2, "a2 * X1 + E_X2"); pm.setNodeExpression(x3, "tan(a3*X2 + a4*X1) + E_X3"); pm.setNodeExpression(x4, "0.1 * E^X2 + X3 + E_X4"); pm.setNodeExpression(x5, "0.1 * E^X4 + a6* X2 + E_X5"); pm.setNodeExpression(e1, "U(0, 1)"); pm.setNodeExpression(e2, "U(0, 1)"); pm.setNodeExpression(e3, "U(0, 1)"); pm.setNodeExpression(e4, "U(0, 1)"); pm.setNodeExpression(e5, "U(0, 1)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataNSteps(1000, false); // System.out.println(dataSet); double[] d1 = dataSet.getDoubleData().getColumn(0).toArray(); double[] d2 = dataSet.getDoubleData().getColumn(1).toArray(); double cov = StatUtils.covariance(d1, d2); assertEquals(-0.002, cov, 0.001); } catch (ParseException e) { e.printStackTrace(); } }