private void setSuitableParameterDistribution(String parameter) throws ParseException { boolean found = false; for (String prefix : startsWithParametersTemplates.keySet()) { if (parameter.startsWith(prefix)) { if (parameterExpressions.get(parameter) == null) { setParameterExpression(parameter, startsWithParametersTemplates.get(prefix)); } if (parameterEstimationInitializationExpressions.get(parameter) == null) { setParameterEstimationInitializationExpression( parameter, startsWithParametersTemplates.get(prefix)); } found = true; } } if (!found) { if (parameterExpressions.get(parameter) == null) { setParameterExpression(parameter, getParametersTemplate()); } if (parameterEstimationInitializationExpressions.get(parameter) == null) { setParameterEstimationInitializationExpression(parameter, getParametersTemplate()); } } }
public GeneralizedSemPm(SemPm semPm) { this(semPm.getGraph()); // Write down equations. try { List<Node> variableNodes = getVariableNodes(); for (int i = 0; i < variableNodes.size(); i++) { Node node = variableNodes.get(i); List<Node> parents = getVariableParents(node); StringBuilder buf = new StringBuilder(); for (int j = 0; j < parents.size(); j++) { if (!(variableNodes.contains(parents.get(j)))) { continue; } Node parent = parents.get(j); Parameter _parameter = semPm.getParameter(parent, node); String parameter = _parameter.getName(); Set<Node> nodes = new HashSet<>(); nodes.add(node); referencedParameters.put(parameter, nodes); buf.append(parameter); buf.append("*"); buf.append(parents.get(j).getName()); setParameterExpression(parameter, "Split(-1.5, -.5, .5, 1.5)"); setStartsWithParametersTemplate(parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)"); setStartsWithParametersEstimationInitializaationTemplate( parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)"); if (j < parents.size() - 1) { buf.append(" + "); } } if (buf.toString().trim().length() != 0) { buf.append(" + "); } buf.append(errorNodes.get(i)); setNodeExpression(node, buf.toString()); } for (Node node : variableNodes) { Parameter _parameter = semPm.getParameter(node, node); String parameter = _parameter.getName(); String distributionFormula = "N(0," + parameter + ")"; setNodeExpression(getErrorNode(node), distributionFormula); setParameterExpression(parameter, "U(0, 1)"); setStartsWithParametersTemplate(parameter.substring(0, 1), "U(0, 1)"); setStartsWithParametersEstimationInitializaationTemplate( parameter.substring(0, 1), "U(0, 1)"); } variableNames = new ArrayList<>(); for (Node _node : variableNodes) variableNames.add(_node.getName()); for (Node _node : errorNodes) variableNames.add(_node.getName()); } catch (ParseException e) { throw new IllegalStateException("Parse error in constructing initial model.", e); } }
/** Constructs a new SemPm from the given SemGraph. */ public GeneralizedSemPm(SemGraph graph) { if (graph == null) { throw new NullPointerException("Graph must not be null."); } // if (graph.existsDirectedCycle()) { // throw new IllegalArgumentExcneption("Cycles are not supported."); // } // Cannot afford to allow error terms on this graph to be shown or hidden from the outside; must // make a // hidden copy of it and make sure error terms are shown. this.graph = new SemGraph(graph); this.graph.setShowErrorTerms(true); for (Edge edge : this.graph.getEdges()) { if (Edges.isBidirectedEdge(edge)) { throw new IllegalArgumentException( "The generalized SEM PM cannot currently deal with bidirected " + "edges. Sorry."); } } this.nodes = Collections.unmodifiableList(this.graph.getNodes()); for (Node node : nodes) { namesToNodes.put(node.getName(), node); } this.variableNodes = new ArrayList<>(); this.measuredNodes = new ArrayList<>(); for (Node variable : this.nodes) { if (variable.getNodeType() == NodeType.MEASURED || variable.getNodeType() == NodeType.LATENT) { variableNodes.add(variable); } if (variable.getNodeType() == NodeType.MEASURED) { measuredNodes.add(variable); } } this.errorNodes = new ArrayList<>(); for (Node variable : this.variableNodes) { List<Node> parents = this.graph.getParents(variable); boolean added = false; for (Node _node : parents) { if (_node.getNodeType() == NodeType.ERROR) { errorNodes.add(_node); added = true; break; } } if (!added) { errorNodes.add(null); } } this.referencedParameters = new HashMap<>(); this.referencedNodes = new HashMap<>(); this.nodeExpressions = new HashMap<>(); this.nodeExpressionStrings = new HashMap<>(); this.parameterExpressions = new HashMap<>(); this.parameterExpressionStrings = new HashMap<>(); this.parameterEstimationInitializationExpressions = new HashMap<>(); this.parameterEstimationInitializationExpressionStrings = new HashMap<>(); this.startsWithParametersTemplates = new HashMap<>(); this.startsWithParametersEstimationInitializationTemplates = new HashMap<>(); this.variableNames = new ArrayList<>(); for (Node _node : variableNodes) variableNames.add(_node.getName()); for (Node _node : errorNodes) variableNames.add(_node.getName()); try { List<Node> variableNodes = getVariableNodes(); for (Node node : variableNodes) { if (!this.graph.isParameterizable(node)) continue; if (nodeExpressions.get(node) != null) { continue; } String variablestemplate = getVariablesTemplate(); String formula = TemplateExpander.getInstance().expandTemplate(variablestemplate, this, node); setNodeExpression(node, formula); Set<String> parameters = getReferencedParameters(node); String parametersTemplate = getParametersTemplate(); for (String parameter : parameters) { if (parameterExpressions.get(parameter) == null) { if (parametersTemplate != null) { setParameterExpression(parameter, parametersTemplate); } else if (this.graph.isTimeLagModel()) { String expressionString = "Split(-0.9, -.1, .1, 0.9)"; setParameterExpression(parameter, expressionString); setParametersTemplate(expressionString); } else { String expressionString = "Split(-1.5, -.5, .5, 1.5)"; setParameterExpression(parameter, expressionString); setParametersTemplate(expressionString); } } } for (String parameter : parameters) { if (parameterEstimationInitializationExpressions.get(parameter) == null) { if (parametersTemplate != null) { setParameterEstimationInitializationExpression(parameter, parametersTemplate); } else if (this.graph.isTimeLagModel()) { String expressionString = "Split(-0.9, -.1, .1, 0.9)"; setParameterEstimationInitializationExpression(parameter, expressionString); } else { String expressionString = "Split(-1.5, -.5, .5, 1.5)"; setParameterEstimationInitializationExpression(parameter, expressionString); } } setStartsWithParametersTemplate("s", "Split(-1.5, -.5, .5, 1.5)"); setStartsWithParametersEstimationInitializaationTemplate( "s", "Split(-1.5, -.5, .5, 1.5)"); } } for (Node node : errorNodes) { if (node == null) continue; String template = getErrorsTemplate(); String formula = TemplateExpander.getInstance().expandTemplate(template, this, node); setNodeExpression(node, formula); Set<String> parameters = getReferencedParameters(node); setStartsWithParametersTemplate("s", "U(1, 3)"); setStartsWithParametersEstimationInitializaationTemplate("s", "U(1, 3)"); for (String parameter : parameters) { setParameterExpression(parameter, "U(1, 3)"); } } } catch (ParseException e) { throw new IllegalStateException("Parse error in constructing initial model.", e); } }
@Test public void test15() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * X1 + E_X2"); pm.setNodeExpression(x3, "a2 * X2 + E_X3"); pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)"); pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)"); pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)"); pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)"); pm.setParameterExpression("c1", "5"); pm.setParameterExpression("c2", "2"); pm.setParameterExpression("c3", "10"); pm.setParameterExpression("c4", "10"); pm.setParameterExpression("c5", "10"); pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(.79, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test14() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2"); pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3"); pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)"); pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)"); pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)"); pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)"); pm.setParameterExpression("a1", "1"); pm.setParameterExpression("a2", "1"); pm.setParameterExpression("a3", "1"); pm.setParameterExpression("a4", "1"); pm.setParameterExpression("c1", "4"); pm.setParameterExpression("c2", "4"); pm.setParameterExpression("c3", "4"); pm.setParameterExpression("c4", "4"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemIm imInit = new GeneralizedSemIm(pm); imInit.setParameterValue("c1", 8); imInit.setParameterValue("c2", 8); imInit.setParameterValue("c3", 8); imInit.setParameterValue("c4", 8); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(71.25, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test1() { GeneralizedSemPm pm = makeTypicalPm(); print(pm); Node x1 = pm.getNode("X1"); Node x2 = pm.getNode("X2"); Node x3 = pm.getNode("X3"); Node x4 = pm.getNode("X4"); Node x5 = pm.getNode("X5"); SemGraph graph = pm.getGraph(); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(B1) + E_X1"); print(pm); String b1 = "B1"; String b2 = "B2"; String b3 = "B3"; Set<Node> nodes = pm.getReferencingNodes(b1); assertTrue(nodes.contains(x1)); assertTrue(!nodes.contains(x2) && !nodes.contains(x2)); Set<String> referencedParameters = pm.getReferencedParameters(x3); print("Parameters referenced by X3 are: " + referencedParameters); assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2)); assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3))); Node e_x3 = pm.getNode("E_X3"); // for (Node node : pm.getNodes()) { Set<Node> referencingNodes = pm.getReferencingNodes(node); print("Nodes referencing " + node + " are: " + referencingNodes); } for (Node node : pm.getVariableNodes()) { Set<Node> referencingNodes = pm.getReferencedNodes(node); print("Nodes referenced by " + node + " are: " + referencingNodes); } Set<Node> referencingX3 = pm.getReferencingNodes(x3); assertTrue(referencingX3.contains(x4)); assertTrue(!referencingX3.contains(x5)); Set<Node> referencedByX3 = pm.getReferencedNodes(x3); assertTrue( referencedByX3.contains(x1) && referencedByX3.contains(x2) && referencedByX3.contains(e_x3) && !referencedByX3.contains(x4)); pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5"); Node e_x5 = pm.getErrorNode(x5); graph.setShowErrorTerms(true); assertTrue(e_x5.equals(graph.getExogenous(x5))); pm.setNodeExpression(e_x5, "Beta(3, 5)"); print(pm); assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1)); pm.setParameterExpression(b1, "N(0, 2)"); assertEquals("N(0, 2)", pm.getParameterExpressionString(b1)); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataAvoidInfinity(10, false); print(dataSet); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test4() { // For X3 Map<String, String[]> templates = new HashMap<>(); templates.put( "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {}); templates.put("$", new String[] {}); templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"}); for (String template : templates.keySet()) { GeneralizedSemPm semPm = makeTypicalPm(); print(semPm.getGraph().toString()); Set<Node> shouldWork = new HashSet<>(); for (String name : templates.get(template)) { shouldWork.add(semPm.getNode(name)); } Set<Node> works = new HashSet<>(); for (int i = 0; i < semPm.getNodes().size(); i++) { print("-----------"); print(semPm.getNodes().get(i).toString()); print("Trying template: " + template); String _template = template; Node node = semPm.getNodes().get(i); try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setNodeExpression(node, _template); print("Set formula " + _template + " for " + node); if (semPm.getVariableNodes().contains(node)) { works.add(node); } } catch (Exception e) { print("Couldn't set formula " + _template + " for " + node); } } for (String parameter : semPm.getParameters()) { print("-----------"); print(parameter); print("Trying template: " + template); String _template = template; try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setParameterExpression(parameter, _template); print("Set formula " + _template + " for " + parameter); } catch (Exception e) { print("Couldn't set formula " + _template + " for " + parameter); } } assertEquals(shouldWork, works); } }