Example #1
0
  private void setSuitableParameterDistribution(String parameter) throws ParseException {
    boolean found = false;

    for (String prefix : startsWithParametersTemplates.keySet()) {
      if (parameter.startsWith(prefix)) {
        if (parameterExpressions.get(parameter) == null) {
          setParameterExpression(parameter, startsWithParametersTemplates.get(prefix));
        }
        if (parameterEstimationInitializationExpressions.get(parameter) == null) {
          setParameterEstimationInitializationExpression(
              parameter, startsWithParametersTemplates.get(prefix));
        }
        found = true;
      }
    }

    if (!found) {
      if (parameterExpressions.get(parameter) == null) {
        setParameterExpression(parameter, getParametersTemplate());
      }
      if (parameterEstimationInitializationExpressions.get(parameter) == null) {
        setParameterEstimationInitializationExpression(parameter, getParametersTemplate());
      }
    }
  }
Example #2
0
  public GeneralizedSemPm(SemPm semPm) {
    this(semPm.getGraph());

    // Write down equations.
    try {
      List<Node> variableNodes = getVariableNodes();

      for (int i = 0; i < variableNodes.size(); i++) {
        Node node = variableNodes.get(i);
        List<Node> parents = getVariableParents(node);

        StringBuilder buf = new StringBuilder();

        for (int j = 0; j < parents.size(); j++) {
          if (!(variableNodes.contains(parents.get(j)))) {
            continue;
          }

          Node parent = parents.get(j);

          Parameter _parameter = semPm.getParameter(parent, node);
          String parameter = _parameter.getName();
          Set<Node> nodes = new HashSet<>();
          nodes.add(node);

          referencedParameters.put(parameter, nodes);

          buf.append(parameter);
          buf.append("*");
          buf.append(parents.get(j).getName());

          setParameterExpression(parameter, "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersTemplate(parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersEstimationInitializaationTemplate(
              parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)");

          if (j < parents.size() - 1) {
            buf.append(" + ");
          }
        }

        if (buf.toString().trim().length() != 0) {
          buf.append(" + ");
        }

        buf.append(errorNodes.get(i));
        setNodeExpression(node, buf.toString());
      }

      for (Node node : variableNodes) {
        Parameter _parameter = semPm.getParameter(node, node);
        String parameter = _parameter.getName();

        String distributionFormula = "N(0," + parameter + ")";
        setNodeExpression(getErrorNode(node), distributionFormula);
        setParameterExpression(parameter, "U(0, 1)");
        setStartsWithParametersTemplate(parameter.substring(0, 1), "U(0, 1)");
        setStartsWithParametersEstimationInitializaationTemplate(
            parameter.substring(0, 1), "U(0, 1)");
      }

      variableNames = new ArrayList<>();
      for (Node _node : variableNodes) variableNames.add(_node.getName());
      for (Node _node : errorNodes) variableNames.add(_node.getName());

    } catch (ParseException e) {
      throw new IllegalStateException("Parse error in constructing initial model.", e);
    }
  }
Example #3
0
  /** Constructs a new SemPm from the given SemGraph. */
  public GeneralizedSemPm(SemGraph graph) {
    if (graph == null) {
      throw new NullPointerException("Graph must not be null.");
    }

    //        if (graph.existsDirectedCycle()) {
    //            throw new IllegalArgumentExcneption("Cycles are not supported.");
    //        }

    // Cannot afford to allow error terms on this graph to be shown or hidden from the outside; must
    // make a
    // hidden copy of it and make sure error terms are shown.
    this.graph = new SemGraph(graph);
    this.graph.setShowErrorTerms(true);

    for (Edge edge : this.graph.getEdges()) {
      if (Edges.isBidirectedEdge(edge)) {
        throw new IllegalArgumentException(
            "The generalized SEM PM cannot currently deal with bidirected " + "edges. Sorry.");
      }
    }

    this.nodes = Collections.unmodifiableList(this.graph.getNodes());

    for (Node node : nodes) {
      namesToNodes.put(node.getName(), node);
    }

    this.variableNodes = new ArrayList<>();
    this.measuredNodes = new ArrayList<>();

    for (Node variable : this.nodes) {
      if (variable.getNodeType() == NodeType.MEASURED
          || variable.getNodeType() == NodeType.LATENT) {
        variableNodes.add(variable);
      }

      if (variable.getNodeType() == NodeType.MEASURED) {
        measuredNodes.add(variable);
      }
    }

    this.errorNodes = new ArrayList<>();

    for (Node variable : this.variableNodes) {
      List<Node> parents = this.graph.getParents(variable);
      boolean added = false;

      for (Node _node : parents) {
        if (_node.getNodeType() == NodeType.ERROR) {
          errorNodes.add(_node);
          added = true;
          break;
        }
      }

      if (!added) {
        errorNodes.add(null);
      }
    }

    this.referencedParameters = new HashMap<>();
    this.referencedNodes = new HashMap<>();
    this.nodeExpressions = new HashMap<>();
    this.nodeExpressionStrings = new HashMap<>();
    this.parameterExpressions = new HashMap<>();
    this.parameterExpressionStrings = new HashMap<>();
    this.parameterEstimationInitializationExpressions = new HashMap<>();
    this.parameterEstimationInitializationExpressionStrings = new HashMap<>();
    this.startsWithParametersTemplates = new HashMap<>();
    this.startsWithParametersEstimationInitializationTemplates = new HashMap<>();

    this.variableNames = new ArrayList<>();
    for (Node _node : variableNodes) variableNames.add(_node.getName());
    for (Node _node : errorNodes) variableNames.add(_node.getName());

    try {
      List<Node> variableNodes = getVariableNodes();

      for (Node node : variableNodes) {
        if (!this.graph.isParameterizable(node)) continue;

        if (nodeExpressions.get(node) != null) {
          continue;
        }

        String variablestemplate = getVariablesTemplate();
        String formula =
            TemplateExpander.getInstance().expandTemplate(variablestemplate, this, node);
        setNodeExpression(node, formula);
        Set<String> parameters = getReferencedParameters(node);

        String parametersTemplate = getParametersTemplate();

        for (String parameter : parameters) {
          if (parameterExpressions.get(parameter) == null) {
            if (parametersTemplate != null) {
              setParameterExpression(parameter, parametersTemplate);
            } else if (this.graph.isTimeLagModel()) {
              String expressionString = "Split(-0.9, -.1, .1, 0.9)";
              setParameterExpression(parameter, expressionString);
              setParametersTemplate(expressionString);
            } else {
              String expressionString = "Split(-1.5, -.5, .5, 1.5)";
              setParameterExpression(parameter, expressionString);
              setParametersTemplate(expressionString);
            }
          }
        }

        for (String parameter : parameters) {
          if (parameterEstimationInitializationExpressions.get(parameter) == null) {
            if (parametersTemplate != null) {
              setParameterEstimationInitializationExpression(parameter, parametersTemplate);
            } else if (this.graph.isTimeLagModel()) {
              String expressionString = "Split(-0.9, -.1, .1, 0.9)";
              setParameterEstimationInitializationExpression(parameter, expressionString);
            } else {
              String expressionString = "Split(-1.5, -.5, .5, 1.5)";
              setParameterEstimationInitializationExpression(parameter, expressionString);
            }
          }

          setStartsWithParametersTemplate("s", "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersEstimationInitializaationTemplate(
              "s", "Split(-1.5, -.5, .5, 1.5)");
        }
      }

      for (Node node : errorNodes) {
        if (node == null) continue;

        String template = getErrorsTemplate();
        String formula = TemplateExpander.getInstance().expandTemplate(template, this, node);
        setNodeExpression(node, formula);
        Set<String> parameters = getReferencedParameters(node);

        setStartsWithParametersTemplate("s", "U(1, 3)");
        setStartsWithParametersEstimationInitializaationTemplate("s", "U(1, 3)");

        for (String parameter : parameters) {
          setParameterExpression(parameter, "U(1, 3)");
        }
      }
    } catch (ParseException e) {
      throw new IllegalStateException("Parse error in constructing initial model.", e);
    }
  }
  @Test
  public void test15() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * X1 + E_X2");
      pm.setNodeExpression(x3, "a2 * X2 + E_X3");
      pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)");
      pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)");
      pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)");
      pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)");

      pm.setParameterExpression("c1", "5");
      pm.setParameterExpression("c2", "2");
      pm.setParameterExpression("c3", "10");
      pm.setParameterExpression("c4", "10");
      pm.setParameterExpression("c5", "10");

      pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(.79, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
  @Test
  public void test14() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2");
      pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3");
      pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)");
      pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)");
      pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)");
      pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)");

      pm.setParameterExpression("a1", "1");
      pm.setParameterExpression("a2", "1");
      pm.setParameterExpression("a3", "1");
      pm.setParameterExpression("a4", "1");
      pm.setParameterExpression("c1", "4");
      pm.setParameterExpression("c2", "4");
      pm.setParameterExpression("c3", "4");
      pm.setParameterExpression("c4", "4");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemIm imInit = new GeneralizedSemIm(pm);
      imInit.setParameterValue("c1", 8);
      imInit.setParameterValue("c2", 8);
      imInit.setParameterValue("c3", 8);
      imInit.setParameterValue("c4", 8);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(71.25, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
  @Test
  public void test1() {
    GeneralizedSemPm pm = makeTypicalPm();

    print(pm);

    Node x1 = pm.getNode("X1");
    Node x2 = pm.getNode("X2");
    Node x3 = pm.getNode("X3");
    Node x4 = pm.getNode("X4");
    Node x5 = pm.getNode("X5");

    SemGraph graph = pm.getGraph();

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(B1) + E_X1");
      print(pm);

      String b1 = "B1";
      String b2 = "B2";
      String b3 = "B3";

      Set<Node> nodes = pm.getReferencingNodes(b1);

      assertTrue(nodes.contains(x1));
      assertTrue(!nodes.contains(x2) && !nodes.contains(x2));

      Set<String> referencedParameters = pm.getReferencedParameters(x3);

      print("Parameters referenced by X3 are: " + referencedParameters);

      assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2));
      assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3)));

      Node e_x3 = pm.getNode("E_X3");
      //
      for (Node node : pm.getNodes()) {
        Set<Node> referencingNodes = pm.getReferencingNodes(node);
        print("Nodes referencing " + node + " are: " + referencingNodes);
      }

      for (Node node : pm.getVariableNodes()) {
        Set<Node> referencingNodes = pm.getReferencedNodes(node);
        print("Nodes referenced by " + node + " are: " + referencingNodes);
      }

      Set<Node> referencingX3 = pm.getReferencingNodes(x3);
      assertTrue(referencingX3.contains(x4));
      assertTrue(!referencingX3.contains(x5));

      Set<Node> referencedByX3 = pm.getReferencedNodes(x3);
      assertTrue(
          referencedByX3.contains(x1)
              && referencedByX3.contains(x2)
              && referencedByX3.contains(e_x3)
              && !referencedByX3.contains(x4));

      pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5");

      Node e_x5 = pm.getErrorNode(x5);

      graph.setShowErrorTerms(true);
      assertTrue(e_x5.equals(graph.getExogenous(x5)));

      pm.setNodeExpression(e_x5, "Beta(3, 5)");

      print(pm);

      assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1));
      pm.setParameterExpression(b1, "N(0, 2)");
      assertEquals("N(0, 2)", pm.getParameterExpressionString(b1));

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataAvoidInfinity(10, false);

      print(dataSet);

    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
  @Test
  public void test4() {
    // For X3

    Map<String, String[]> templates = new HashMap<>();

    templates.put(
        "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {});
    templates.put("$", new String[] {});
    templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"});

    for (String template : templates.keySet()) {
      GeneralizedSemPm semPm = makeTypicalPm();
      print(semPm.getGraph().toString());

      Set<Node> shouldWork = new HashSet<>();

      for (String name : templates.get(template)) {
        shouldWork.add(semPm.getNode(name));
      }

      Set<Node> works = new HashSet<>();

      for (int i = 0; i < semPm.getNodes().size(); i++) {
        print("-----------");
        print(semPm.getNodes().get(i).toString());
        print("Trying template: " + template);
        String _template = template;

        Node node = semPm.getNodes().get(i);

        try {
          _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node);
        } catch (Exception e) {
          print("Couldn't expand template: " + template);
          continue;
        }

        try {
          semPm.setNodeExpression(node, _template);
          print("Set formula " + _template + " for " + node);

          if (semPm.getVariableNodes().contains(node)) {
            works.add(node);
          }

        } catch (Exception e) {
          print("Couldn't set formula " + _template + " for " + node);
        }
      }

      for (String parameter : semPm.getParameters()) {
        print("-----------");
        print(parameter);
        print("Trying template: " + template);
        String _template = template;

        try {
          _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null);
        } catch (Exception e) {
          print("Couldn't expand template: " + template);
          continue;
        }

        try {
          semPm.setParameterExpression(parameter, _template);
          print("Set formula " + _template + " for " + parameter);
        } catch (Exception e) {
          print("Couldn't set formula " + _template + " for " + parameter);
        }
      }

      assertEquals(shouldWork, works);
    }
  }