예제 #1
0
  @Test
  public void test4() {
    // For X3

    Map<String, String[]> templates = new HashMap<>();

    templates.put(
        "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {});
    templates.put("$", new String[] {});
    templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"});
    templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"});

    for (String template : templates.keySet()) {
      GeneralizedSemPm semPm = makeTypicalPm();
      print(semPm.getGraph().toString());

      Set<Node> shouldWork = new HashSet<>();

      for (String name : templates.get(template)) {
        shouldWork.add(semPm.getNode(name));
      }

      Set<Node> works = new HashSet<>();

      for (int i = 0; i < semPm.getNodes().size(); i++) {
        print("-----------");
        print(semPm.getNodes().get(i).toString());
        print("Trying template: " + template);
        String _template = template;

        Node node = semPm.getNodes().get(i);

        try {
          _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node);
        } catch (Exception e) {
          print("Couldn't expand template: " + template);
          continue;
        }

        try {
          semPm.setNodeExpression(node, _template);
          print("Set formula " + _template + " for " + node);

          if (semPm.getVariableNodes().contains(node)) {
            works.add(node);
          }

        } catch (Exception e) {
          print("Couldn't set formula " + _template + " for " + node);
        }
      }

      for (String parameter : semPm.getParameters()) {
        print("-----------");
        print(parameter);
        print("Trying template: " + template);
        String _template = template;

        try {
          _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null);
        } catch (Exception e) {
          print("Couldn't expand template: " + template);
          continue;
        }

        try {
          semPm.setParameterExpression(parameter, _template);
          print("Set formula " + _template + " for " + parameter);
        } catch (Exception e) {
          print("Couldn't set formula " + _template + " for " + parameter);
        }
      }

      assertEquals(shouldWork, works);
    }
  }
예제 #2
0
  @Test
  public void test1() {
    GeneralizedSemPm pm = makeTypicalPm();

    print(pm);

    Node x1 = pm.getNode("X1");
    Node x2 = pm.getNode("X2");
    Node x3 = pm.getNode("X3");
    Node x4 = pm.getNode("X4");
    Node x5 = pm.getNode("X5");

    SemGraph graph = pm.getGraph();

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(B1) + E_X1");
      print(pm);

      String b1 = "B1";
      String b2 = "B2";
      String b3 = "B3";

      Set<Node> nodes = pm.getReferencingNodes(b1);

      assertTrue(nodes.contains(x1));
      assertTrue(!nodes.contains(x2) && !nodes.contains(x2));

      Set<String> referencedParameters = pm.getReferencedParameters(x3);

      print("Parameters referenced by X3 are: " + referencedParameters);

      assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2));
      assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3)));

      Node e_x3 = pm.getNode("E_X3");
      //
      for (Node node : pm.getNodes()) {
        Set<Node> referencingNodes = pm.getReferencingNodes(node);
        print("Nodes referencing " + node + " are: " + referencingNodes);
      }

      for (Node node : pm.getVariableNodes()) {
        Set<Node> referencingNodes = pm.getReferencedNodes(node);
        print("Nodes referenced by " + node + " are: " + referencingNodes);
      }

      Set<Node> referencingX3 = pm.getReferencingNodes(x3);
      assertTrue(referencingX3.contains(x4));
      assertTrue(!referencingX3.contains(x5));

      Set<Node> referencedByX3 = pm.getReferencedNodes(x3);
      assertTrue(
          referencedByX3.contains(x1)
              && referencedByX3.contains(x2)
              && referencedByX3.contains(e_x3)
              && !referencedByX3.contains(x4));

      pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5");

      Node e_x5 = pm.getErrorNode(x5);

      graph.setShowErrorTerms(true);
      assertTrue(e_x5.equals(graph.getExogenous(x5)));

      pm.setNodeExpression(e_x5, "Beta(3, 5)");

      print(pm);

      assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1));
      pm.setParameterExpression(b1, "N(0, 2)");
      assertEquals("N(0, 2)", pm.getParameterExpressionString(b1));

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataAvoidInfinity(10, false);

      print(dataSet);

    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
예제 #3
0
  @Test
  public void test3() {
    RandomUtil.getInstance().setSeed(49293843L);

    List<Node> variableNodes = new ArrayList<>();
    ContinuousVariable x1 = new ContinuousVariable("X1");
    ContinuousVariable x2 = new ContinuousVariable("X2");
    ContinuousVariable x3 = new ContinuousVariable("X3");
    ContinuousVariable x4 = new ContinuousVariable("X4");
    ContinuousVariable x5 = new ContinuousVariable("X5");

    variableNodes.add(x1);
    variableNodes.add(x2);
    variableNodes.add(x3);
    variableNodes.add(x4);
    variableNodes.add(x5);

    Graph _graph = new EdgeListGraph(variableNodes);
    SemGraph graph = new SemGraph(_graph);
    graph.setShowErrorTerms(true);

    Node e1 = graph.getExogenous(x1);
    Node e2 = graph.getExogenous(x2);
    Node e3 = graph.getExogenous(x3);
    Node e4 = graph.getExogenous(x4);
    Node e5 = graph.getExogenous(x5);

    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x1, x2);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x2, x4);
    graph.addDirectedEdge(x4, x5);
    graph.addDirectedEdge(x2, x5);
    graph.addDirectedEdge(x5, x1);

    GeneralizedSemPm pm = new GeneralizedSemPm(graph);

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(b1) + a1 * X5 + E_X1");
      pm.setNodeExpression(x2, "a2 * X1 + E_X2");
      pm.setNodeExpression(x3, "tan(a3*X2 + a4*X1) + E_X3");
      pm.setNodeExpression(x4, "0.1 * E^X2 + X3 + E_X4");
      pm.setNodeExpression(x5, "0.1 * E^X4 + a6* X2 + E_X5");
      pm.setNodeExpression(e1, "U(0, 1)");
      pm.setNodeExpression(e2, "U(0, 1)");
      pm.setNodeExpression(e3, "U(0, 1)");
      pm.setNodeExpression(e4, "U(0, 1)");
      pm.setNodeExpression(e5, "U(0, 1)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataNSteps(1000, false);

      //            System.out.println(dataSet);

      double[] d1 = dataSet.getDoubleData().getColumn(0).toArray();
      double[] d2 = dataSet.getDoubleData().getColumn(1).toArray();

      double cov = StatUtils.covariance(d1, d2);

      assertEquals(-0.002, cov, 0.001);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }