コード例 #1
0
ファイル: TestMimbuild3.java プロジェクト: jdramsey/tetrad
  public void test1() {
    for (int r = 0; r < 1; r++) {
      Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0);

      Graph mimStructure = structure(mim);

      SemImInitializationParams params = new SemImInitializationParams();
      params.setCoefRange(.5, 1.5);

      SemPm pm = new SemPm(mim);
      SemIm im = new SemIm(pm, params);
      DataSet data = im.simulateData(300, false);

      String algorithm = "FOFC";
      Graph searchGraph;
      List<List<Node>> partition;

      if (algorithm.equals("FOFC")) {
        FindOneFactorClusters fofc =
            new FindOneFactorClusters(data, TestType.TETRAD_WISHART, 0.001);
        searchGraph = fofc.search();
        partition = fofc.getClusters();
      } else if (algorithm.equals("BPC")) {
        TestType testType = TestType.TETRAD_WISHART;
        TestType purifyType = TestType.TETRAD_BASED2;

        BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
        searchGraph = bpc.search();

        partition = MimUtils.convertToClusters2(searchGraph);
      } else {
        throw new IllegalStateException();
      }

      List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);

      System.out.println(partition);
      System.out.println(latentVarList);

      System.out.println("True\n" + mimStructure);

      Graph mimbuildStructure;

      for (int mimbuildMethod : new int[] {3, 4}) {
        if (mimbuildMethod == 1) {
          System.out.println("Mimbuild 1\n");
          Clusters measurements = ClusterUtils.mimClusters(searchGraph);
          IndTestMimBuild test = new IndTestMimBuild(data, 0.001, measurements);
          MimBuild mimbuild = new MimBuild(test, new Knowledge2());
          Graph full = mimbuild.search();
          full = changeLatentNames(full, measurements, latentVarList);
          mimbuildStructure = structure(full);
          System.out.println(
              "SHD = "
                  + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
          System.out.println("Estimated\n" + mimbuildStructure);
          System.out.println();
        }
        //                else if (mimbuildMethod == 2) {
        //                    System.out.println("Mimbuild 2\n");
        //                    Mimbuild2 mimbuild = new Mimbuild2();
        //                    mimbuild.setAlpha(0.001);
        //                    mimbuildStructure = mimbuild.search(partition, latentVarList, data);
        //                    TetradMatrix latentcov = mimbuild.getLatentsCov();
        //                    List<String> latentnames = mimbuild.getLatentNames();
        //                    System.out.println("\nCovariance over the latents");
        //                    System.out.println(MatrixUtils.toStringSquare(latentcov.toArray(),
        // latentnames));
        //                    System.out.println("Estimated\n" + mimbuildStructure);
        //                    System.out.println("SHD = " +
        // SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
        //                    System.out.println();
        //                }
        else if (mimbuildMethod == 3) {
          System.out.println("Mimbuild 3\n");
          Mimbuild2 mimbuild = new Mimbuild2();
          mimbuild.setAlpha(0.001);
          mimbuild.setMinClusterSize(3);
          mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
          ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
          System.out.println("\nCovariance over the latents");
          System.out.println(latentcov);
          System.out.println("Estimated\n" + mimbuildStructure);
          System.out.println(
              "SHD = "
                  + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
          System.out.println();
        } else if (mimbuildMethod == 4) {
          System.out.println("Mimbuild Trek\n");
          MimbuildTrek mimbuild = new MimbuildTrek();
          mimbuild.setAlpha(0.1);
          mimbuild.setMinClusterSize(3);
          mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
          ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
          System.out.println("\nCovariance over the latents");
          System.out.println(latentcov);
          System.out.println("Estimated\n" + mimbuildStructure);
          System.out.println(
              "SHD = "
                  + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
          System.out.println();
        } else {
          throw new IllegalStateException();
        }
      }
    }
  }
コード例 #2
0
ファイル: TestMimbuild3.java プロジェクト: jdramsey/tetrad
  public void rtest4() {
    System.out.println("SHD\tP");
    //        System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6");

    Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0);

    Graph mimStructure = structure(mim);

    SemPm pm = new SemPm(mim);
    SemImInitializationParams params = new SemImInitializationParams();
    params.setCoefRange(0.5, 1.5);

    NumberFormat nf = new DecimalFormat("0.0000");

    int totalError = 0;
    int errorCount = 0;

    int maxScore = 0;
    int maxNumMeasures = 0;
    double maxP = 0.0;

    for (int r = 0; r < 1; r++) {
      SemIm im = new SemIm(pm, params);

      DataSet data = im.simulateData(1000, false);

      mim = GraphUtils.replaceNodes(mim, data.getVariables());
      List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim);

      CovarianceMatrix _cov = new CovarianceMatrix(data);

      ICovarianceMatrix cov = DataUtils.reorderColumns(_cov);

      String algorithm = "FOFC";
      Graph searchGraph;
      List<List<Node>> partition;

      if (algorithm.equals("FOFC")) {
        FindOneFactorClusters fofc =
            new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001f);
        searchGraph = fofc.search();
        searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariables());
        partition = MimUtils.convertToClusters2(searchGraph);
      } else if (algorithm.equals("BPC")) {
        TestType testType = TestType.TETRAD_WISHART;
        TestType purifyType = TestType.TETRAD_BASED2;

        BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
        searchGraph = bpc.search();

        partition = MimUtils.convertToClusters2(searchGraph);
      } else {
        throw new IllegalStateException();
      }

      mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariables());

      List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);

      Graph mimbuildStructure;

      Mimbuild2 mimbuild = new Mimbuild2();
      mimbuild.setAlpha(0.001);
      mimbuild.setMinClusterSize(3);

      try {
        mimbuildStructure = mimbuild.search(partition, latentVarList, cov);
      } catch (Exception e) {
        e.printStackTrace();
        continue;
      }

      mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariables());
      mimbuildStructure = condense(mimStructure, mimbuildStructure);

      //            Graph mimSubgraph = new EdgeListGraph(mimStructure);
      //
      //            for (Node node : mimSubgraph.getNodes()) {
      //                if (!mimStructure.getNodes().contains(node)) {
      //                    mimSubgraph.removeNode(node);
      //                }
      //            }

      int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
      boolean impureCluster = containsImpureCluster(partition, trueClusters);
      double pValue = mimbuild.getpValue();
      boolean pBelow05 = pValue < 0.05;
      boolean numClustersGreaterThan5 = partition.size() != 5;
      boolean error = false;

      //            boolean condition = impureCluster || numClustersGreaterThan5 || pBelow05;
      //            boolean condition = numClustersGreaterThan5 || pBelow05;
      boolean condition = numClustered(partition) == 40;

      if (!condition && (shd > 5)) {
        error = true;
      }

      if (!condition) {
        totalError += shd;
        errorCount++;
      }

      //            if (numClustered(partition) > maxNumMeasures) {
      //                maxNumMeasures = numClustered(partition);
      //                maxP = pValue;
      //                maxScore = shd;
      //                System.out.println("maxNumMeasures = " + maxNumMeasures);
      //                System.out.println("maxScore = " + maxScore);
      //                System.out.println("maxP = " + maxP);
      //                System.out.println("clusters = " + clusterSizes(partition, trueClusters));
      //            }
      //            else
      if (pValue > maxP) {
        maxScore = shd;
        maxP = mimbuild.getpValue();
        maxNumMeasures = numClustered(partition);
        System.out.println("maxNumMeasures = " + maxNumMeasures);
        System.out.println("maxScore = " + maxScore);
        System.out.println("maxP = " + maxP);
        System.out.println("clusters = " + clusterSizes(partition, trueClusters));
      }

      System.out.print(
          shd
              + "\t"
              + nf.format(pValue)
              + " "
              //                            + (error ? 1 : 0) + " "
              //                            + (pBelow05 ? "P < 0.05 " : "")
              //                            + (impureCluster ? "Impure cluster " : "")
              //                            + (numClustersGreaterThan5 ? "# Clusters != 5 " : "")
              //                            + clusterSizes(partition, trueClusters)
              + numClustered(partition));

      System.out.println();
    }

    System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount));

    System.out.println("maxNumMeasures = " + maxNumMeasures);
    System.out.println("maxScore = " + maxScore);
    System.out.println("maxP = " + maxP);
  }