예제 #1
0
  @Test
  public void test7() {
    RandomUtil.getInstance().setSeed(29999483L);

    List<Node> nodes = new ArrayList<>();
    int numVars = 10;

    for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1)));

    Graph graph =
        GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true);

    GeneralizedSemPm pm = new GeneralizedSemPm(graph);
    GeneralizedSemIm im = new GeneralizedSemIm(pm);

    print(im);

    DataSet data = im.simulateDataRecursive(1000, false);

    GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
    GeneralizedSemIm estIm = estimator.estimate(pm, data);

    print(estIm);
    print(estimator.getReport());

    double aSquaredStar = estimator.getaSquaredStar();

    assertEquals(0.67, aSquaredStar, 0.01);
  }
예제 #2
0
  @Test
  public void test6() {
    RandomUtil.getInstance().setSeed(29999483L);

    int numVars = 5;

    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1)));

    Graph graph =
        GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true);

    SemPm spm = new SemPm(graph);

    SemImInitializationParams params = new SemImInitializationParams();
    params.setCoefRange(0.5, 1.5);
    params.setVarRange(1, 3);

    SemIm sim = new SemIm(spm, params);

    GeneralizedSemPm pm = new GeneralizedSemPm(spm);
    GeneralizedSemIm im = new GeneralizedSemIm(pm, sim);

    DataSet data = im.simulateData(1000, false);

    print(im);

    GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
    GeneralizedSemIm estIm = estimator.estimate(pm, data);

    print(estIm);
    print(estimator.getReport());

    double aSquaredStar = estimator.getaSquaredStar();

    assertEquals(0.59, aSquaredStar, 0.01);
  }
예제 #3
0
  @Test
  public void test15() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * X1 + E_X2");
      pm.setNodeExpression(x3, "a2 * X2 + E_X3");
      pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)");
      pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)");
      pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)");
      pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)");

      pm.setParameterExpression("c1", "5");
      pm.setParameterExpression("c2", "2");
      pm.setParameterExpression("c3", "10");
      pm.setParameterExpression("c4", "10");
      pm.setParameterExpression("c5", "10");

      pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(.79, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
예제 #4
0
  @Test
  public void test14() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2");
      pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3");
      pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)");
      pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)");
      pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)");
      pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)");

      pm.setParameterExpression("a1", "1");
      pm.setParameterExpression("a2", "1");
      pm.setParameterExpression("a3", "1");
      pm.setParameterExpression("a4", "1");
      pm.setParameterExpression("c1", "4");
      pm.setParameterExpression("c2", "4");
      pm.setParameterExpression("c3", "4");
      pm.setParameterExpression("c4", "4");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemIm imInit = new GeneralizedSemIm(pm);
      imInit.setParameterValue("c1", 8);
      imInit.setParameterValue("c2", 8);
      imInit.setParameterValue("c3", 8);
      imInit.setParameterValue("c4", 8);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(71.25, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }