@Test public void test1() { GeneralizedSemPm pm = makeTypicalPm(); print(pm); Node x1 = pm.getNode("X1"); Node x2 = pm.getNode("X2"); Node x3 = pm.getNode("X3"); Node x4 = pm.getNode("X4"); Node x5 = pm.getNode("X5"); SemGraph graph = pm.getGraph(); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(B1) + E_X1"); print(pm); String b1 = "B1"; String b2 = "B2"; String b3 = "B3"; Set<Node> nodes = pm.getReferencingNodes(b1); assertTrue(nodes.contains(x1)); assertTrue(!nodes.contains(x2) && !nodes.contains(x2)); Set<String> referencedParameters = pm.getReferencedParameters(x3); print("Parameters referenced by X3 are: " + referencedParameters); assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2)); assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3))); Node e_x3 = pm.getNode("E_X3"); // for (Node node : pm.getNodes()) { Set<Node> referencingNodes = pm.getReferencingNodes(node); print("Nodes referencing " + node + " are: " + referencingNodes); } for (Node node : pm.getVariableNodes()) { Set<Node> referencingNodes = pm.getReferencedNodes(node); print("Nodes referenced by " + node + " are: " + referencingNodes); } Set<Node> referencingX3 = pm.getReferencingNodes(x3); assertTrue(referencingX3.contains(x4)); assertTrue(!referencingX3.contains(x5)); Set<Node> referencedByX3 = pm.getReferencedNodes(x3); assertTrue( referencedByX3.contains(x1) && referencedByX3.contains(x2) && referencedByX3.contains(e_x3) && !referencedByX3.contains(x4)); pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5"); Node e_x5 = pm.getErrorNode(x5); graph.setShowErrorTerms(true); assertTrue(e_x5.equals(graph.getExogenous(x5))); pm.setNodeExpression(e_x5, "Beta(3, 5)"); print(pm); assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1)); pm.setParameterExpression(b1, "N(0, 2)"); assertEquals("N(0, 2)", pm.getParameterExpressionString(b1)); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataAvoidInfinity(10, false); print(dataSet); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test4() { // For X3 Map<String, String[]> templates = new HashMap<>(); templates.put( "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {}); templates.put("$", new String[] {}); templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"}); for (String template : templates.keySet()) { GeneralizedSemPm semPm = makeTypicalPm(); print(semPm.getGraph().toString()); Set<Node> shouldWork = new HashSet<>(); for (String name : templates.get(template)) { shouldWork.add(semPm.getNode(name)); } Set<Node> works = new HashSet<>(); for (int i = 0; i < semPm.getNodes().size(); i++) { print("-----------"); print(semPm.getNodes().get(i).toString()); print("Trying template: " + template); String _template = template; Node node = semPm.getNodes().get(i); try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setNodeExpression(node, _template); print("Set formula " + _template + " for " + node); if (semPm.getVariableNodes().contains(node)) { works.add(node); } } catch (Exception e) { print("Couldn't set formula " + _template + " for " + node); } } for (String parameter : semPm.getParameters()) { print("-----------"); print(parameter); print("Trying template: " + template); String _template = template; try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setParameterExpression(parameter, _template); print("Set formula " + _template + " for " + parameter); } catch (Exception e) { print("Couldn't set formula " + _template + " for " + parameter); } } assertEquals(shouldWork, works); } }