コード例 #1
0
ファイル: GdistanceApply.java プロジェクト: bd2kccd/r-causal
  public static void main(String... args) {
    long timestart = System.nanoTime();
    System.out.println("Loading first graph");
    Graph graph1 = GraphUtils.loadGraphTxt(new File("images_graph_10sub_pd40_group1.txt"));
    long timegraph1 = System.nanoTime();
    // System.out.println(graph1);
    System.out.println(
        "Done loading first graph. Elapsed time: " + (timegraph1 - timestart) / 1000000000 + "s");
    System.out.println("Loading second graph");
    Graph graph2 = GraphUtils.loadGraphTxt(new File("images_graph_10sub_pd40_group2.txt"));
    long timegraph2 = System.nanoTime();
    System.out.println(
        "Done loading second graph. Elapsed time: " + (timegraph2 - timegraph1) / 1000000000 + "s");

    // load the location map
    String workingDirectory = System.getProperty("user.dir");
    System.out.println(workingDirectory);
    Path mapPath = Paths.get("erich_coordinates.txt");
    System.out.println(mapPath);
    edu.cmu.tetrad.io.DataReader dataReaderMap = new TabularContinuousDataReader(mapPath, ',');
    try {
      DataSet locationMap = dataReaderMap.readInData();
      long timegraph3 = System.nanoTime();
      System.out.println(
          "Done loading location map. Elapsed time: "
              + (timegraph3 - timegraph2) / 1000000000
              + "s");

      System.out.println("Running Gdistance");
      // Make this either Gdistance or GdistanceVic
      List<Double> distance = GdistanceVic.distances(graph1, graph2, locationMap);
      System.out.println(distance);
      System.out.println(
          "Done running Distance. Elapsed time: "
              + (System.nanoTime() - timegraph3) / 1000000000
              + "s");
      System.out.println(
          "Total elapsed time: " + (System.nanoTime() - timestart) / 1000000000 + "s");

      PrintWriter writer = new PrintWriter("Gdistances.txt", "UTF-8");
      writer.println(distance);
      writer.close();
    } catch (Exception IOException) {
      IOException.printStackTrace();
    }
  }
コード例 #2
0
  public static void main(String[] args) {
    //        Graph g = new EdgeListGraph();
    //        g.addNode(new ContinuousVariable("X1"));
    //        g.addNode(new ContinuousVariable("X2"));
    //        g.addNode(new DiscreteVariable("X3", 4));
    //        g.addNode(new DiscreteVariable("X4", 4));
    //        g.addNode(new ContinuousVariable("X5"));
    //
    //        g.addDirectedEdge(g.getNode("X1"), g.getNode("X2"));
    //        g.addDirectedEdge(g.getNode("X2"), g.getNode("X3"));
    //        g.addDirectedEdge(g.getNode("X3"), g.getNode("X4"));
    //        g.addDirectedEdge(g.getNode("X4"), g.getNode("X5"));
    //
    //        GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
    ////        System.out.println(pm);
    //
    //        GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
    ////        System.out.println(im);
    //
    //        int samps = 200;
    //        DataSet ds = im.simulateDataAvoidInfinity(samps, false);
    //        ds = MixedUtils.makeMixedData(ds, MixedUtils.getNodeDists(g));
    //        //System.out.println(ds);
    //        System.out.println(ds.isMixed());
    try {
      String path = ExampleMixedSearch.class.getResource("test_data").getPath();
      Graph trueGraph =
          SearchGraphUtils.patternFromDag(
              GraphUtils.loadGraphTxt(new File(path, "DAG_0_graph.txt")));
      DataSet ds = MixedUtils.loadDataSet(path, "DAG_0_data.txt");

      IndTestMultinomialLogisticRegression indMix =
          new IndTestMultinomialLogisticRegression(ds, .05);
      IndTestMultinomialLogisticRegressionWald indWalLin =
          new IndTestMultinomialLogisticRegressionWald(ds, .05, true);
      IndTestMultinomialLogisticRegressionWald indWalLog =
          new IndTestMultinomialLogisticRegressionWald(ds, .05, false);

      PcStable s1 = new PcStable(indMix);
      PcStable s2 = new PcStable(indWalLin);
      PcStable s3 = new PcStable(indWalLog);

      long time = System.currentTimeMillis();
      Graph g1 = SearchGraphUtils.patternFromDag(s1.search());
      System.out.println("Mix Time " + ((System.currentTimeMillis() - time) / 1000.0));

      time = System.currentTimeMillis();
      Graph g2 = SearchGraphUtils.patternFromDag(s2.search());
      System.out.println("Wald lin Time " + ((System.currentTimeMillis() - time) / 1000.0));

      time = System.currentTimeMillis();
      Graph g3 = SearchGraphUtils.patternFromDag(s3.search());
      System.out.println("Wald log Time " + ((System.currentTimeMillis() - time) / 1000.0));

      //            System.out.println(g);
      //            System.out.println("IndMix: " + s1.search());
      //            System.out.println("IndWalLin: " + s2.search());
      //            System.out.println("IndWalLog: " + s3.search());

      System.out.println(MixedUtils.EdgeStatHeader);
      System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g1)));
      System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g2)));
      System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g3)));
    } catch (Throwable t) {
      t.printStackTrace();
    }
  }