コード例 #1
0
  @Test
  public void test7() {
    RandomUtil.getInstance().setSeed(29999483L);

    List<Node> nodes = new ArrayList<>();
    int numVars = 10;

    for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1)));

    Graph graph =
        GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true);

    GeneralizedSemPm pm = new GeneralizedSemPm(graph);
    GeneralizedSemIm im = new GeneralizedSemIm(pm);

    print(im);

    DataSet data = im.simulateDataRecursive(1000, false);

    GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
    GeneralizedSemIm estIm = estimator.estimate(pm, data);

    print(estIm);
    print(estimator.getReport());

    double aSquaredStar = estimator.getaSquaredStar();

    assertEquals(0.67, aSquaredStar, 0.01);
  }
コード例 #2
0
  @Test
  public void test2() {
    RandomUtil.getInstance().setSeed(2999983L);

    int sampleSize = 1000;

    List<Node> variableNodes = new ArrayList<>();
    ContinuousVariable x1 = new ContinuousVariable("X1");
    ContinuousVariable x2 = new ContinuousVariable("X2");
    ContinuousVariable x3 = new ContinuousVariable("X3");
    ContinuousVariable x4 = new ContinuousVariable("X4");
    ContinuousVariable x5 = new ContinuousVariable("X5");

    variableNodes.add(x1);
    variableNodes.add(x2);
    variableNodes.add(x3);
    variableNodes.add(x4);
    variableNodes.add(x5);

    Graph _graph = new EdgeListGraph(variableNodes);
    SemGraph graph = new SemGraph(_graph);
    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x2, x4);
    graph.addDirectedEdge(x4, x5);
    graph.addDirectedEdge(x2, x5);

    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);
    DataSet dataSet = semIm.simulateData(sampleSize, false);

    print(semPm);

    GeneralizedSemPm _semPm = new GeneralizedSemPm(semPm);
    GeneralizedSemIm _semIm = new GeneralizedSemIm(_semPm, semIm);
    DataSet _dataSet = _semIm.simulateDataMinimizeSurface(sampleSize, false);

    print(_semPm);

    //        System.out.println(_dataSet);

    for (int j = 0; j < dataSet.getNumColumns(); j++) {
      double[] col = dataSet.getDoubleData().getColumn(j).toArray();
      double[] _col = _dataSet.getDoubleData().getColumn(j).toArray();

      double mean = StatUtils.mean(col);
      double _mean = StatUtils.mean(_col);

      double variance = StatUtils.variance(col);
      double _variance = StatUtils.variance(_col);

      assertEquals(mean, _mean, 0.3);
      assertEquals(1.0, variance / _variance, .2);
    }
  }
コード例 #3
0
  @Test
  public void test8() {
    RandomUtil.getInstance().setSeed(29999483L);

    Node x = new GraphNode("X");
    Node y = new GraphNode("Y");

    List<Node> nodes = new ArrayList<>();
    nodes.add(x);
    nodes.add(y);

    Graph graph = new EdgeListGraphSingleConnections(nodes);

    graph.addDirectedEdge(x, y);

    SemPm spm = new SemPm(graph);
    SemIm sim = new SemIm(spm);

    sim.setEdgeCoef(x, y, 20);
    sim.setErrVar(x, 1);
    sim.setErrVar(y, 1);

    GeneralizedSemPm pm = new GeneralizedSemPm(spm);
    GeneralizedSemIm im = new GeneralizedSemIm(pm, sim);

    print(im);

    try {
      pm.setParameterEstimationInitializationExpression("b1", "U(10, 30)");
      pm.setParameterEstimationInitializationExpression("T1", "U(.1, 3)");
      pm.setParameterEstimationInitializationExpression("T2", "U(.1, 3)");
    } catch (ParseException e) {
      e.printStackTrace();
    }

    DataSet data = im.simulateDataRecursive(1000, false);

    GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
    GeneralizedSemIm estIm = estimator.estimate(pm, data);

    print(estIm);
    //        System.out.println(estimator.getReport());

    double aSquaredStar = estimator.getaSquaredStar();

    assertEquals(0.69, aSquaredStar, 0.01);
  }
コード例 #4
0
  @Test
  public void test6() {
    RandomUtil.getInstance().setSeed(29999483L);

    int numVars = 5;

    List<Node> nodes = new ArrayList<>();
    for (int i = 0; i < numVars; i++) nodes.add(new ContinuousVariable("X" + (i + 1)));

    Graph graph =
        GraphUtils.randomGraphRandomForwardEdges(nodes, 0, numVars, 30, 15, 15, false, true);

    SemPm spm = new SemPm(graph);

    SemImInitializationParams params = new SemImInitializationParams();
    params.setCoefRange(0.5, 1.5);
    params.setVarRange(1, 3);

    SemIm sim = new SemIm(spm, params);

    GeneralizedSemPm pm = new GeneralizedSemPm(spm);
    GeneralizedSemIm im = new GeneralizedSemIm(pm, sim);

    DataSet data = im.simulateData(1000, false);

    print(im);

    GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
    GeneralizedSemIm estIm = estimator.estimate(pm, data);

    print(estIm);
    print(estimator.getReport());

    double aSquaredStar = estimator.getaSquaredStar();

    assertEquals(0.59, aSquaredStar, 0.01);
  }
コード例 #5
0
  @Test
  public void test5() {
    RandomUtil.getInstance().setSeed(29999483L);

    List<Node> nodes = new ArrayList<>();

    for (int i1 = 0; i1 < 5; i1++) {
      nodes.add(new ContinuousVariable("X" + (i1 + 1)));
    }

    Graph graph = new Dag(GraphUtils.randomGraph(nodes, 0, 5, 30, 15, 15, false));
    SemPm semPm = new SemPm(graph);
    SemIm semIm = new SemIm(semPm);

    semIm.simulateDataReducedForm(1000, false);

    GeneralizedSemPm pm = new GeneralizedSemPm(semPm);
    GeneralizedSemIm im = new GeneralizedSemIm(pm, semIm);

    TetradVector e = new TetradVector(5);

    for (int i = 0; i < e.size(); i++) {
      e.set(i, RandomUtil.getInstance().nextNormal(0, 1));
    }

    TetradVector record1 = semIm.simulateOneRecord(e);
    TetradVector record2 = im.simulateOneRecord(e);

    print("XXX1" + e);
    print("XXX2" + record1);
    print("XXX3" + record2);

    for (int i = 0; i < record1.size(); i++) {
      assertEquals(record1.get(i), record2.get(i), 1e-10);
    }
  }
コード例 #6
0
  @Test
  public void test15() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * X1 + E_X2");
      pm.setNodeExpression(x3, "a2 * X2 + E_X3");
      pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)");
      pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)");
      pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)");
      pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)");

      pm.setParameterExpression("c1", "5");
      pm.setParameterExpression("c2", "2");
      pm.setParameterExpression("c3", "10");
      pm.setParameterExpression("c4", "10");
      pm.setParameterExpression("c5", "10");

      pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(.79, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
コード例 #7
0
  @Test
  public void test14() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2");
      pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3");
      pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)");
      pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)");
      pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)");
      pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)");

      pm.setParameterExpression("a1", "1");
      pm.setParameterExpression("a2", "1");
      pm.setParameterExpression("a3", "1");
      pm.setParameterExpression("a4", "1");
      pm.setParameterExpression("c1", "4");
      pm.setParameterExpression("c2", "4");
      pm.setParameterExpression("c3", "4");
      pm.setParameterExpression("c4", "4");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemIm imInit = new GeneralizedSemIm(pm);
      imInit.setParameterValue("c1", 8);
      imInit.setParameterValue("c2", 8);
      imInit.setParameterValue("c3", 8);
      imInit.setParameterValue("c4", 8);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(71.25, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
コード例 #8
0
  @Test
  public void test1() {
    GeneralizedSemPm pm = makeTypicalPm();

    print(pm);

    Node x1 = pm.getNode("X1");
    Node x2 = pm.getNode("X2");
    Node x3 = pm.getNode("X3");
    Node x4 = pm.getNode("X4");
    Node x5 = pm.getNode("X5");

    SemGraph graph = pm.getGraph();

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(B1) + E_X1");
      print(pm);

      String b1 = "B1";
      String b2 = "B2";
      String b3 = "B3";

      Set<Node> nodes = pm.getReferencingNodes(b1);

      assertTrue(nodes.contains(x1));
      assertTrue(!nodes.contains(x2) && !nodes.contains(x2));

      Set<String> referencedParameters = pm.getReferencedParameters(x3);

      print("Parameters referenced by X3 are: " + referencedParameters);

      assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2));
      assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3)));

      Node e_x3 = pm.getNode("E_X3");
      //
      for (Node node : pm.getNodes()) {
        Set<Node> referencingNodes = pm.getReferencingNodes(node);
        print("Nodes referencing " + node + " are: " + referencingNodes);
      }

      for (Node node : pm.getVariableNodes()) {
        Set<Node> referencingNodes = pm.getReferencedNodes(node);
        print("Nodes referenced by " + node + " are: " + referencingNodes);
      }

      Set<Node> referencingX3 = pm.getReferencingNodes(x3);
      assertTrue(referencingX3.contains(x4));
      assertTrue(!referencingX3.contains(x5));

      Set<Node> referencedByX3 = pm.getReferencedNodes(x3);
      assertTrue(
          referencedByX3.contains(x1)
              && referencedByX3.contains(x2)
              && referencedByX3.contains(e_x3)
              && !referencedByX3.contains(x4));

      pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5");

      Node e_x5 = pm.getErrorNode(x5);

      graph.setShowErrorTerms(true);
      assertTrue(e_x5.equals(graph.getExogenous(x5)));

      pm.setNodeExpression(e_x5, "Beta(3, 5)");

      print(pm);

      assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1));
      pm.setParameterExpression(b1, "N(0, 2)");
      assertEquals("N(0, 2)", pm.getParameterExpressionString(b1));

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataAvoidInfinity(10, false);

      print(dataSet);

    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
コード例 #9
0
  @Test
  public void test3() {
    RandomUtil.getInstance().setSeed(49293843L);

    List<Node> variableNodes = new ArrayList<>();
    ContinuousVariable x1 = new ContinuousVariable("X1");
    ContinuousVariable x2 = new ContinuousVariable("X2");
    ContinuousVariable x3 = new ContinuousVariable("X3");
    ContinuousVariable x4 = new ContinuousVariable("X4");
    ContinuousVariable x5 = new ContinuousVariable("X5");

    variableNodes.add(x1);
    variableNodes.add(x2);
    variableNodes.add(x3);
    variableNodes.add(x4);
    variableNodes.add(x5);

    Graph _graph = new EdgeListGraph(variableNodes);
    SemGraph graph = new SemGraph(_graph);
    graph.setShowErrorTerms(true);

    Node e1 = graph.getExogenous(x1);
    Node e2 = graph.getExogenous(x2);
    Node e3 = graph.getExogenous(x3);
    Node e4 = graph.getExogenous(x4);
    Node e5 = graph.getExogenous(x5);

    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x1, x2);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x2, x4);
    graph.addDirectedEdge(x4, x5);
    graph.addDirectedEdge(x2, x5);
    graph.addDirectedEdge(x5, x1);

    GeneralizedSemPm pm = new GeneralizedSemPm(graph);

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(b1) + a1 * X5 + E_X1");
      pm.setNodeExpression(x2, "a2 * X1 + E_X2");
      pm.setNodeExpression(x3, "tan(a3*X2 + a4*X1) + E_X3");
      pm.setNodeExpression(x4, "0.1 * E^X2 + X3 + E_X4");
      pm.setNodeExpression(x5, "0.1 * E^X4 + a6* X2 + E_X5");
      pm.setNodeExpression(e1, "U(0, 1)");
      pm.setNodeExpression(e2, "U(0, 1)");
      pm.setNodeExpression(e3, "U(0, 1)");
      pm.setNodeExpression(e4, "U(0, 1)");
      pm.setNodeExpression(e5, "U(0, 1)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataNSteps(1000, false);

      //            System.out.println(dataSet);

      double[] d1 = dataSet.getDoubleData().getColumn(0).toArray();
      double[] d2 = dataSet.getDoubleData().getColumn(1).toArray();

      double cov = StatUtils.covariance(d1, d2);

      assertEquals(-0.002, cov, 0.001);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }