private void setSuitableParameterDistribution(String parameter) throws ParseException { boolean found = false; for (String prefix : startsWithParametersTemplates.keySet()) { if (parameter.startsWith(prefix)) { if (parameterExpressions.get(parameter) == null) { setParameterExpression(parameter, startsWithParametersTemplates.get(prefix)); } if (parameterEstimationInitializationExpressions.get(parameter) == null) { setParameterEstimationInitializationExpression( parameter, startsWithParametersTemplates.get(prefix)); } found = true; } } if (!found) { if (parameterExpressions.get(parameter) == null) { setParameterExpression(parameter, getParametersTemplate()); } if (parameterEstimationInitializationExpressions.get(parameter) == null) { setParameterEstimationInitializationExpression(parameter, getParametersTemplate()); } } }
@Test public void test8() { RandomUtil.getInstance().setSeed(29999483L); Node x = new GraphNode("X"); Node y = new GraphNode("Y"); List<Node> nodes = new ArrayList<>(); nodes.add(x); nodes.add(y); Graph graph = new EdgeListGraphSingleConnections(nodes); graph.addDirectedEdge(x, y); SemPm spm = new SemPm(graph); SemIm sim = new SemIm(spm); sim.setEdgeCoef(x, y, 20); sim.setErrVar(x, 1); sim.setErrVar(y, 1); GeneralizedSemPm pm = new GeneralizedSemPm(spm); GeneralizedSemIm im = new GeneralizedSemIm(pm, sim); print(im); try { pm.setParameterEstimationInitializationExpression("b1", "U(10, 30)"); pm.setParameterEstimationInitializationExpression("T1", "U(.1, 3)"); pm.setParameterEstimationInitializationExpression("T2", "U(.1, 3)"); } catch (ParseException e) { e.printStackTrace(); } DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print(estIm); // System.out.println(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(0.69, aSquaredStar, 0.01); }
/** Constructs a new SemPm from the given SemGraph. */ public GeneralizedSemPm(SemGraph graph) { if (graph == null) { throw new NullPointerException("Graph must not be null."); } // if (graph.existsDirectedCycle()) { // throw new IllegalArgumentExcneption("Cycles are not supported."); // } // Cannot afford to allow error terms on this graph to be shown or hidden from the outside; must // make a // hidden copy of it and make sure error terms are shown. this.graph = new SemGraph(graph); this.graph.setShowErrorTerms(true); for (Edge edge : this.graph.getEdges()) { if (Edges.isBidirectedEdge(edge)) { throw new IllegalArgumentException( "The generalized SEM PM cannot currently deal with bidirected " + "edges. Sorry."); } } this.nodes = Collections.unmodifiableList(this.graph.getNodes()); for (Node node : nodes) { namesToNodes.put(node.getName(), node); } this.variableNodes = new ArrayList<>(); this.measuredNodes = new ArrayList<>(); for (Node variable : this.nodes) { if (variable.getNodeType() == NodeType.MEASURED || variable.getNodeType() == NodeType.LATENT) { variableNodes.add(variable); } if (variable.getNodeType() == NodeType.MEASURED) { measuredNodes.add(variable); } } this.errorNodes = new ArrayList<>(); for (Node variable : this.variableNodes) { List<Node> parents = this.graph.getParents(variable); boolean added = false; for (Node _node : parents) { if (_node.getNodeType() == NodeType.ERROR) { errorNodes.add(_node); added = true; break; } } if (!added) { errorNodes.add(null); } } this.referencedParameters = new HashMap<>(); this.referencedNodes = new HashMap<>(); this.nodeExpressions = new HashMap<>(); this.nodeExpressionStrings = new HashMap<>(); this.parameterExpressions = new HashMap<>(); this.parameterExpressionStrings = new HashMap<>(); this.parameterEstimationInitializationExpressions = new HashMap<>(); this.parameterEstimationInitializationExpressionStrings = new HashMap<>(); this.startsWithParametersTemplates = new HashMap<>(); this.startsWithParametersEstimationInitializationTemplates = new HashMap<>(); this.variableNames = new ArrayList<>(); for (Node _node : variableNodes) variableNames.add(_node.getName()); for (Node _node : errorNodes) variableNames.add(_node.getName()); try { List<Node> variableNodes = getVariableNodes(); for (Node node : variableNodes) { if (!this.graph.isParameterizable(node)) continue; if (nodeExpressions.get(node) != null) { continue; } String variablestemplate = getVariablesTemplate(); String formula = TemplateExpander.getInstance().expandTemplate(variablestemplate, this, node); setNodeExpression(node, formula); Set<String> parameters = getReferencedParameters(node); String parametersTemplate = getParametersTemplate(); for (String parameter : parameters) { if (parameterExpressions.get(parameter) == null) { if (parametersTemplate != null) { setParameterExpression(parameter, parametersTemplate); } else if (this.graph.isTimeLagModel()) { String expressionString = "Split(-0.9, -.1, .1, 0.9)"; setParameterExpression(parameter, expressionString); setParametersTemplate(expressionString); } else { String expressionString = "Split(-1.5, -.5, .5, 1.5)"; setParameterExpression(parameter, expressionString); setParametersTemplate(expressionString); } } } for (String parameter : parameters) { if (parameterEstimationInitializationExpressions.get(parameter) == null) { if (parametersTemplate != null) { setParameterEstimationInitializationExpression(parameter, parametersTemplate); } else if (this.graph.isTimeLagModel()) { String expressionString = "Split(-0.9, -.1, .1, 0.9)"; setParameterEstimationInitializationExpression(parameter, expressionString); } else { String expressionString = "Split(-1.5, -.5, .5, 1.5)"; setParameterEstimationInitializationExpression(parameter, expressionString); } } setStartsWithParametersTemplate("s", "Split(-1.5, -.5, .5, 1.5)"); setStartsWithParametersEstimationInitializaationTemplate( "s", "Split(-1.5, -.5, .5, 1.5)"); } } for (Node node : errorNodes) { if (node == null) continue; String template = getErrorsTemplate(); String formula = TemplateExpander.getInstance().expandTemplate(template, this, node); setNodeExpression(node, formula); Set<String> parameters = getReferencedParameters(node); setStartsWithParametersTemplate("s", "U(1, 3)"); setStartsWithParametersEstimationInitializaationTemplate("s", "U(1, 3)"); for (String parameter : parameters) { setParameterExpression(parameter, "U(1, 3)"); } } } catch (ParseException e) { throw new IllegalStateException("Parse error in constructing initial model.", e); } }
@Test public void test15() { RandomUtil.getInstance().setSeed(29999483L); try { Node x1 = new GraphNode("X1"); Node x2 = new GraphNode("X2"); Node x3 = new GraphNode("X3"); Node x4 = new GraphNode("X4"); Graph g = new EdgeListGraphSingleConnections(); g.addNode(x1); g.addNode(x2); g.addNode(x3); g.addNode(x4); g.addDirectedEdge(x1, x2); g.addDirectedEdge(x2, x3); g.addDirectedEdge(x3, x4); g.addDirectedEdge(x1, x4); GeneralizedSemPm pm = new GeneralizedSemPm(g); pm.setNodeExpression(x1, "E_X1"); pm.setNodeExpression(x2, "a1 * X1 + E_X2"); pm.setNodeExpression(x3, "a2 * X2 + E_X3"); pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4"); pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)"); pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)"); pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)"); pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)"); pm.setParameterExpression("c1", "5"); pm.setParameterExpression("c2", "2"); pm.setParameterExpression("c3", "10"); pm.setParameterExpression("c4", "10"); pm.setParameterExpression("c5", "10"); pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)"); pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)"); GeneralizedSemIm im = new GeneralizedSemIm(pm); print("True model: "); print(im); DataSet data = im.simulateDataRecursive(1000, false); GeneralizedSemEstimator estimator = new GeneralizedSemEstimator(); GeneralizedSemIm estIm = estimator.estimate(pm, data); print("\n\n\nEstimated model: "); print(estIm); print(estimator.getReport()); double aSquaredStar = estimator.getaSquaredStar(); assertEquals(.79, aSquaredStar, 0.01); } catch (ParseException e) { e.printStackTrace(); } }