Exemplo n.º 1
0
  public GeneralizedSemPm(SemPm semPm) {
    this(semPm.getGraph());

    // Write down equations.
    try {
      List<Node> variableNodes = getVariableNodes();

      for (int i = 0; i < variableNodes.size(); i++) {
        Node node = variableNodes.get(i);
        List<Node> parents = getVariableParents(node);

        StringBuilder buf = new StringBuilder();

        for (int j = 0; j < parents.size(); j++) {
          if (!(variableNodes.contains(parents.get(j)))) {
            continue;
          }

          Node parent = parents.get(j);

          Parameter _parameter = semPm.getParameter(parent, node);
          String parameter = _parameter.getName();
          Set<Node> nodes = new HashSet<>();
          nodes.add(node);

          referencedParameters.put(parameter, nodes);

          buf.append(parameter);
          buf.append("*");
          buf.append(parents.get(j).getName());

          setParameterExpression(parameter, "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersTemplate(parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersEstimationInitializaationTemplate(
              parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)");

          if (j < parents.size() - 1) {
            buf.append(" + ");
          }
        }

        if (buf.toString().trim().length() != 0) {
          buf.append(" + ");
        }

        buf.append(errorNodes.get(i));
        setNodeExpression(node, buf.toString());
      }

      for (Node node : variableNodes) {
        Parameter _parameter = semPm.getParameter(node, node);
        String parameter = _parameter.getName();

        String distributionFormula = "N(0," + parameter + ")";
        setNodeExpression(getErrorNode(node), distributionFormula);
        setParameterExpression(parameter, "U(0, 1)");
        setStartsWithParametersTemplate(parameter.substring(0, 1), "U(0, 1)");
        setStartsWithParametersEstimationInitializaationTemplate(
            parameter.substring(0, 1), "U(0, 1)");
      }

      variableNames = new ArrayList<>();
      for (Node _node : variableNodes) variableNames.add(_node.getName());
      for (Node _node : errorNodes) variableNames.add(_node.getName());

    } catch (ParseException e) {
      throw new IllegalStateException("Parse error in constructing initial model.", e);
    }
  }
Exemplo n.º 2
0
  /** Constructs a new SemPm from the given SemGraph. */
  public GeneralizedSemPm(SemGraph graph) {
    if (graph == null) {
      throw new NullPointerException("Graph must not be null.");
    }

    //        if (graph.existsDirectedCycle()) {
    //            throw new IllegalArgumentExcneption("Cycles are not supported.");
    //        }

    // Cannot afford to allow error terms on this graph to be shown or hidden from the outside; must
    // make a
    // hidden copy of it and make sure error terms are shown.
    this.graph = new SemGraph(graph);
    this.graph.setShowErrorTerms(true);

    for (Edge edge : this.graph.getEdges()) {
      if (Edges.isBidirectedEdge(edge)) {
        throw new IllegalArgumentException(
            "The generalized SEM PM cannot currently deal with bidirected " + "edges. Sorry.");
      }
    }

    this.nodes = Collections.unmodifiableList(this.graph.getNodes());

    for (Node node : nodes) {
      namesToNodes.put(node.getName(), node);
    }

    this.variableNodes = new ArrayList<>();
    this.measuredNodes = new ArrayList<>();

    for (Node variable : this.nodes) {
      if (variable.getNodeType() == NodeType.MEASURED
          || variable.getNodeType() == NodeType.LATENT) {
        variableNodes.add(variable);
      }

      if (variable.getNodeType() == NodeType.MEASURED) {
        measuredNodes.add(variable);
      }
    }

    this.errorNodes = new ArrayList<>();

    for (Node variable : this.variableNodes) {
      List<Node> parents = this.graph.getParents(variable);
      boolean added = false;

      for (Node _node : parents) {
        if (_node.getNodeType() == NodeType.ERROR) {
          errorNodes.add(_node);
          added = true;
          break;
        }
      }

      if (!added) {
        errorNodes.add(null);
      }
    }

    this.referencedParameters = new HashMap<>();
    this.referencedNodes = new HashMap<>();
    this.nodeExpressions = new HashMap<>();
    this.nodeExpressionStrings = new HashMap<>();
    this.parameterExpressions = new HashMap<>();
    this.parameterExpressionStrings = new HashMap<>();
    this.parameterEstimationInitializationExpressions = new HashMap<>();
    this.parameterEstimationInitializationExpressionStrings = new HashMap<>();
    this.startsWithParametersTemplates = new HashMap<>();
    this.startsWithParametersEstimationInitializationTemplates = new HashMap<>();

    this.variableNames = new ArrayList<>();
    for (Node _node : variableNodes) variableNames.add(_node.getName());
    for (Node _node : errorNodes) variableNames.add(_node.getName());

    try {
      List<Node> variableNodes = getVariableNodes();

      for (Node node : variableNodes) {
        if (!this.graph.isParameterizable(node)) continue;

        if (nodeExpressions.get(node) != null) {
          continue;
        }

        String variablestemplate = getVariablesTemplate();
        String formula =
            TemplateExpander.getInstance().expandTemplate(variablestemplate, this, node);
        setNodeExpression(node, formula);
        Set<String> parameters = getReferencedParameters(node);

        String parametersTemplate = getParametersTemplate();

        for (String parameter : parameters) {
          if (parameterExpressions.get(parameter) == null) {
            if (parametersTemplate != null) {
              setParameterExpression(parameter, parametersTemplate);
            } else if (this.graph.isTimeLagModel()) {
              String expressionString = "Split(-0.9, -.1, .1, 0.9)";
              setParameterExpression(parameter, expressionString);
              setParametersTemplate(expressionString);
            } else {
              String expressionString = "Split(-1.5, -.5, .5, 1.5)";
              setParameterExpression(parameter, expressionString);
              setParametersTemplate(expressionString);
            }
          }
        }

        for (String parameter : parameters) {
          if (parameterEstimationInitializationExpressions.get(parameter) == null) {
            if (parametersTemplate != null) {
              setParameterEstimationInitializationExpression(parameter, parametersTemplate);
            } else if (this.graph.isTimeLagModel()) {
              String expressionString = "Split(-0.9, -.1, .1, 0.9)";
              setParameterEstimationInitializationExpression(parameter, expressionString);
            } else {
              String expressionString = "Split(-1.5, -.5, .5, 1.5)";
              setParameterEstimationInitializationExpression(parameter, expressionString);
            }
          }

          setStartsWithParametersTemplate("s", "Split(-1.5, -.5, .5, 1.5)");
          setStartsWithParametersEstimationInitializaationTemplate(
              "s", "Split(-1.5, -.5, .5, 1.5)");
        }
      }

      for (Node node : errorNodes) {
        if (node == null) continue;

        String template = getErrorsTemplate();
        String formula = TemplateExpander.getInstance().expandTemplate(template, this, node);
        setNodeExpression(node, formula);
        Set<String> parameters = getReferencedParameters(node);

        setStartsWithParametersTemplate("s", "U(1, 3)");
        setStartsWithParametersEstimationInitializaationTemplate("s", "U(1, 3)");

        for (String parameter : parameters) {
          setParameterExpression(parameter, "U(1, 3)");
        }
      }
    } catch (ParseException e) {
      throw new IllegalStateException("Parse error in constructing initial model.", e);
    }
  }