Пример #1
0
  /** Orients according to background knowledge */
  private void fciOrientbk(IKnowledge bk, Graph graph, List<Node> variables) {
    logger.log("info", "Starting BK Orientation.");

    for (Iterator<KnowledgeEdge> it = bk.forbiddenEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge edge = it.next();

      // match strings to variables in the graph.
      Node from = SearchGraphUtils.translate(edge.getFrom(), variables);
      Node to = SearchGraphUtils.translate(edge.getTo(), variables);

      if (from == null || to == null) {
        continue;
      }

      if (graph.getEdge(from, to) == null) {
        continue;
      }

      // Orient to*->from
      graph.setEndpoint(to, from, Endpoint.ARROW);
      graph.setEndpoint(from, to, Endpoint.CIRCLE);
      changeFlag = true;
      logger.log(
          "knowledgeOrientation",
          SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to)));
    }

    for (Iterator<KnowledgeEdge> it = bk.requiredEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge edge = it.next();

      // match strings to variables in this graph
      Node from = SearchGraphUtils.translate(edge.getFrom(), variables);
      Node to = SearchGraphUtils.translate(edge.getTo(), variables);

      if (from == null || to == null) {
        continue;
      }

      if (graph.getEdge(from, to) == null) {
        continue;
      }

      graph.setEndpoint(to, from, Endpoint.TAIL);
      graph.setEndpoint(from, to, Endpoint.ARROW);
      changeFlag = true;
      logger.log(
          "knowledgeOrientation",
          SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to)));
    }

    logger.log("info", "Finishing BK Orientation.");
  }
Пример #2
0
  /**
   * Runs PC starting with a commplete graph over the given list of nodes, using the given
   * independence test and knowledge and returns the resultant graph. The returned graph will be a
   * pattern if the independence information is consistent with the hypothesis that there are no
   * latent common causes. It may, however, contain cycles or bidirected edges if this assumption is
   * not born out, either due to the actual presence of latent common causes, or due to statistical
   * errors in conditional independence judgments.
   *
   * <p>All of the given nodes must be in the domain of the given conditional independence test.
   */
  public Graph search(List<Node> nodes) {
    this.logger.log("info", "Starting PC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");

    //        this.logger.log("info", "Variables " + independenceTest.getVariables());

    long startTime = System.currentTimeMillis();

    if (getIndependenceTest() == null) {
      throw new NullPointerException();
    }

    List allNodes = getIndependenceTest().getVariables();
    if (!allNodes.containsAll(nodes)) {
      throw new IllegalArgumentException(
          "All of the given nodes must " + "be in the domain of the independence test provided.");
    }

    graph = new EdgeListGraph(nodes);

    IFas fas = new FasStableConcurrent(initialGraph, getIndependenceTest());
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);

    graph = fas.search();
    sepsets = fas.getSepsets();

    SearchGraphUtils.pcOrientbk(knowledge, graph, nodes);
    //        SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph,
    // initialGraph, verbose);
    //        SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose);
    //        SearchGraphUtils.orientColeelidersLocally(knowledge, graph, independenceTest, depth);
    SearchGraphUtils.orientCollidersUsingSepsets(this.sepsets, knowledge, graph, verbose);

    MeekRules rules = new MeekRules();
    rules.setAggressivelyPreventCycles(this.aggressivelyPreventCycles);
    rules.setKnowledge(knowledge);
    rules.orientImplied(graph);

    this.logger.log("graph", "\nReturning this graph: " + graph);

    this.elapsedTime = System.currentTimeMillis() - startTime;

    this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    this.logger.log("info", "Finishing PC Algorithm.");
    this.logger.flush();

    return graph;
  }
 /** Returns the pattern to which the given DAG belongs. */
 public static Graph patternFromDag(Graph dag) {
   Graph graph = new EdgeListGraph(dag);
   SearchGraphUtils.basicPattern(graph);
   MeekRules rules = new MeekRules();
   rules.orientImplied(graph);
   return graph;
 }
Пример #4
0
 public void layoutByKnowledge() {
   GraphWorkbench resultWorkbench = getWorkbench();
   Graph graph = resultWorkbench.getGraph();
   IKnowledge knowledge = getAlgorithmRunner().getParams().getKnowledge();
   SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
   //        resultWorkbench.setGraph(graph);
 }
Пример #5
0
  /**
   * Completes a pattern that was modified by an insertion/deletion operator Based on the algorithm
   * described on Appendix C of (Chickering, 2002).
   */
  private void rebuildPattern(Graph graph) {
    SearchGraphUtils.basicPattern(graph, false);
    addRequiredEdges(graph);
    meekOrient(graph, getKnowledge());

    if (TetradLogger.getInstance().isEventActive("rebuiltPatterns")) {
      TetradLogger.getInstance().log("rebuiltPatterns", "Rebuilt pattern = " + graph);
    }
  }
Пример #6
0
 protected void doDefaultArrangement(Graph resultGraph) {
   if (getLatestWorkbenchGraph() != null) { // (alreadyLaidOut) {
     GraphUtils.arrangeBySourceGraph(resultGraph, getLatestWorkbenchGraph());
   } else if (getKnowledge().isDefaultToKnowledgeLayout()) {
     SearchGraphUtils.arrangeByKnowledgeTiers(resultGraph, getKnowledge());
     //            alreadyLaidOut = true;
   } else {
     GraphUtils.circleLayout(resultGraph, 200, 200, 150);
     //            alreadyLaidOut = true;
   }
 }
Пример #7
0
  /**
   * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the
   * extending class.
   */
  public void execute() {
    IKnowledge knowledge = getParams().getKnowledge();
    SearchParams searchParams = getParams();

    FciIndTestParams indTestParams = (FciIndTestParams) searchParams.getIndTestParams();

    //            Cfci fciSearch =
    //                    new Cfci(getIndependenceTest(), knowledge);
    //            fciSearch.setMaxIndegree(indTestParams.depth());
    //            Graph graph = fciSearch.search();
    //
    //            if (knowledge.isDefaultToKnowledgeLayout()) {
    //                SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    //            }
    //
    //            setResultGraph(graph);
    Graph graph;

    if (indTestParams.isRFCI_Used()) {
      Rfci fci = new Rfci(getIndependenceTest());
      fci.setKnowledge(knowledge);
      fci.setCompleteRuleSetUsed(indTestParams.isCompleteRuleSetUsed());
      fci.setMaxPathLength(indTestParams.getMaxReachablePathLength());
      fci.setDepth(indTestParams.getDepth());
      graph = fci.search();
    } else {
      Fci fci = new Fci(getIndependenceTest());
      fci.setKnowledge(knowledge);
      fci.setCompleteRuleSetUsed(indTestParams.isCompleteRuleSetUsed());
      fci.setPossibleDsepSearchDone(indTestParams.isPossibleDsepDone());
      fci.setMaxPathLength(indTestParams.getMaxReachablePathLength());
      fci.setDepth(indTestParams.getDepth());
      graph = fci.search();
    }

    if (getSourceGraph() != null) {
      GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else if (knowledge.isDefaultToKnowledgeLayout()) {
      SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    } else {
      GraphUtils.circleLayout(graph, 200, 200, 150);
    }

    setResultGraph(graph);
  }
Пример #8
0
  /**
   * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the
   * extending class.
   */
  public void execute() {
    IKnowledge knowledge = getParams().getKnowledge();
    SearchParams searchParams = getParams();

    FciGesIndTestParams indTestParams = (FciGesIndTestParams) searchParams.getIndTestParams();

    //            Cfci fciSearch =
    //                    new Cfci(getIndependenceTest(), knowledge);
    //            fciSearch.setDepth(indTestParams.depth());
    //            Graph graph = fciSearch.search();
    //
    //            if (knowledge.isDefaultToKnowledgeLayout()) {
    //                SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    //            }
    //
    //            setResultGraph(graph);
    Graph graph;

    TFciGes fci = new TFciGes(getIndependenceTest());
    fci.setKnowledge(knowledge);
    fci.setCompleteRuleSetUsed(indTestParams.isCompleteRuleSetUsed());
    fci.setPossibleDsepSearchDone(indTestParams.isPossibleDsepDone());
    fci.setMaxPathLength(indTestParams.getMaxReachablePathLength());
    fci.setDepth(indTestParams.getDepth());
    fci.setPenaltyDiscount(indTestParams.getPenaltyDiscount());
    fci.setSamplePrior(indTestParams.getSamplePrior());
    fci.setStructurePrior(indTestParams.getStructurePrior());
    fci.setCompleteRuleSetUsed(false);
    fci.setPenaltyDiscount(indTestParams.getPenaltyDiscount());
    fci.setFaithfulnessAssumed(indTestParams.isFaithfulnessAssumed());
    graph = fci.search();

    if (getSourceGraph() != null) {
      GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else if (knowledge.isDefaultToKnowledgeLayout()) {
      SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    } else {
      GraphUtils.circleLayout(graph, 200, 200, 150);
    }

    setResultGraph(graph);
  }
Пример #9
0
  private Graph pickDag(Graph graph) {
    SearchGraphUtils.basicPattern(graph, false);
    addRequiredEdges(graph);
    boolean containsUndirected;

    do {
      containsUndirected = false;

      for (Edge edge : graph.getEdges()) {
        if (Edges.isUndirectedEdge(edge)) {
          containsUndirected = true;
          graph.removeEdge(edge);
          Edge _edge = Edges.directedEdge(edge.getNode1(), edge.getNode2());
          graph.addEdge(_edge);
        }
      }

      meekOrient(graph, getKnowledge());
    } while (containsUndirected);

    return graph;
  }
Пример #10
0
  public void execute() {
    IKnowledge knowledge = getParams().getKnowledge();
    PcSearchParams searchParams = (PcSearchParams) getParams();

    PcIndTestParams indTestParams = (PcIndTestParams) searchParams.getIndTestParams();

    VcpcAlt VcpcAlt = new VcpcAlt(getIndependenceTest());
    VcpcAlt.setKnowledge(knowledge);
    VcpcAlt.setAggressivelyPreventCycles(this.isAggressivelyPreventCycles());
    VcpcAlt.setDepth(indTestParams.getDepth());
    Graph graph = VcpcAlt.search();

    if (getSourceGraph() != null) {
      GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else if (knowledge.isDefaultToKnowledgeLayout()) {
      SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    } else {
      GraphUtils.circleLayout(graph, 200, 200, 150);
    }

    setResultGraph(graph);
  }
Пример #11
0
  /**
   * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the
   * extending class.
   */
  public void execute() {
    IKnowledge knowledge = (IKnowledge) getParams().get("knowledge", new Knowledge2());
    Parameters searchParams = getParams();

    Parameters params = searchParams;

    Graph graph;

    IndependenceTest independenceTest = getIndependenceTest();
    Score score = new ScoredIndTest(independenceTest);

    if (independenceTest instanceof IndTestDSep) {
      final DagToPag dagToPag = new DagToPag(((IndTestDSep) independenceTest).getGraph());
      dagToPag.setCompleteRuleSetUsed(params.getBoolean("completeRuleSetUsed", false));
      graph = dagToPag.convert();
    } else {
      GFci fci = new GFci(independenceTest, score);
      fci.setKnowledge(knowledge);
      fci.setCompleteRuleSetUsed(params.getBoolean("completeRuleSetUsed", false));
      fci.setMaxPathLength(params.getInt("maxReachablePathLength", -1));
      fci.setMaxDegree(params.getInt("maxIndegree"));
      fci.setCompleteRuleSetUsed(false);
      fci.setFaithfulnessAssumed(params.getBoolean("faithfulnessAssumed", true));
      graph = fci.search();
    }

    if (getSourceGraph() != null) {
      GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else if (knowledge.isDefaultToKnowledgeLayout()) {
      SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    } else {
      GraphUtils.circleLayout(graph, 200, 200, 150);
    }

    setResultGraph(graph);
  }
Пример #12
0
  public void test1() {
    for (int r = 0; r < 1; r++) {
      Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 6, 0, 0, 0);

      Graph mimStructure = structure(mim);

      SemImInitializationParams params = new SemImInitializationParams();
      params.setCoefRange(.5, 1.5);

      SemPm pm = new SemPm(mim);
      SemIm im = new SemIm(pm, params);
      DataSet data = im.simulateData(300, false);

      String algorithm = "FOFC";
      Graph searchGraph;
      List<List<Node>> partition;

      if (algorithm.equals("FOFC")) {
        FindOneFactorClusters fofc =
            new FindOneFactorClusters(data, TestType.TETRAD_WISHART, 0.001);
        searchGraph = fofc.search();
        partition = fofc.getClusters();
      } else if (algorithm.equals("BPC")) {
        TestType testType = TestType.TETRAD_WISHART;
        TestType purifyType = TestType.TETRAD_BASED2;

        BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
        searchGraph = bpc.search();

        partition = MimUtils.convertToClusters2(searchGraph);
      } else {
        throw new IllegalStateException();
      }

      List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);

      System.out.println(partition);
      System.out.println(latentVarList);

      System.out.println("True\n" + mimStructure);

      Graph mimbuildStructure;

      for (int mimbuildMethod : new int[] {3, 4}) {
        if (mimbuildMethod == 1) {
          System.out.println("Mimbuild 1\n");
          Clusters measurements = ClusterUtils.mimClusters(searchGraph);
          IndTestMimBuild test = new IndTestMimBuild(data, 0.001, measurements);
          MimBuild mimbuild = new MimBuild(test, new Knowledge2());
          Graph full = mimbuild.search();
          full = changeLatentNames(full, measurements, latentVarList);
          mimbuildStructure = structure(full);
          System.out.println(
              "SHD = "
                  + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
          System.out.println("Estimated\n" + mimbuildStructure);
          System.out.println();
        }
        //                else if (mimbuildMethod == 2) {
        //                    System.out.println("Mimbuild 2\n");
        //                    Mimbuild2 mimbuild = new Mimbuild2();
        //                    mimbuild.setAlpha(0.001);
        //                    mimbuildStructure = mimbuild.search(partition, latentVarList, data);
        //                    TetradMatrix latentcov = mimbuild.getLatentsCov();
        //                    List<String> latentnames = mimbuild.getLatentNames();
        //                    System.out.println("\nCovariance over the latents");
        //                    System.out.println(MatrixUtils.toStringSquare(latentcov.toArray(),
        // latentnames));
        //                    System.out.println("Estimated\n" + mimbuildStructure);
        //                    System.out.println("SHD = " +
        // SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
        //                    System.out.println();
        //                }
        else if (mimbuildMethod == 3) {
          System.out.println("Mimbuild 3\n");
          Mimbuild2 mimbuild = new Mimbuild2();
          mimbuild.setAlpha(0.001);
          mimbuild.setMinClusterSize(3);
          mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
          ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
          System.out.println("\nCovariance over the latents");
          System.out.println(latentcov);
          System.out.println("Estimated\n" + mimbuildStructure);
          System.out.println(
              "SHD = "
                  + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
          System.out.println();
        } else if (mimbuildMethod == 4) {
          System.out.println("Mimbuild Trek\n");
          MimbuildTrek mimbuild = new MimbuildTrek();
          mimbuild.setAlpha(0.1);
          mimbuild.setMinClusterSize(3);
          mimbuildStructure = mimbuild.search(partition, latentVarList, new CovarianceMatrix(data));
          ICovarianceMatrix latentcov = mimbuild.getLatentsCov();
          System.out.println("\nCovariance over the latents");
          System.out.println(latentcov);
          System.out.println("Estimated\n" + mimbuildStructure);
          System.out.println(
              "SHD = "
                  + SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure));
          System.out.println();
        } else {
          throw new IllegalStateException();
        }
      }
    }
  }
Пример #13
0
  public void rtest4() {
    System.out.println("SHD\tP");
    //        System.out.println("MB1\tMB2\tMB3\tMB4\tMB5\tMB6");

    Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0);

    Graph mimStructure = structure(mim);

    SemPm pm = new SemPm(mim);
    SemImInitializationParams params = new SemImInitializationParams();
    params.setCoefRange(0.5, 1.5);

    NumberFormat nf = new DecimalFormat("0.0000");

    int totalError = 0;
    int errorCount = 0;

    int maxScore = 0;
    int maxNumMeasures = 0;
    double maxP = 0.0;

    for (int r = 0; r < 1; r++) {
      SemIm im = new SemIm(pm, params);

      DataSet data = im.simulateData(1000, false);

      mim = GraphUtils.replaceNodes(mim, data.getVariables());
      List<List<Node>> trueClusters = MimUtils.convertToClusters2(mim);

      CovarianceMatrix _cov = new CovarianceMatrix(data);

      ICovarianceMatrix cov = DataUtils.reorderColumns(_cov);

      String algorithm = "FOFC";
      Graph searchGraph;
      List<List<Node>> partition;

      if (algorithm.equals("FOFC")) {
        FindOneFactorClusters fofc =
            new FindOneFactorClusters(cov, TestType.TETRAD_WISHART, 0.001f);
        searchGraph = fofc.search();
        searchGraph = GraphUtils.replaceNodes(searchGraph, data.getVariables());
        partition = MimUtils.convertToClusters2(searchGraph);
      } else if (algorithm.equals("BPC")) {
        TestType testType = TestType.TETRAD_WISHART;
        TestType purifyType = TestType.TETRAD_BASED2;

        BuildPureClusters bpc = new BuildPureClusters(data, 0.001, testType, purifyType);
        searchGraph = bpc.search();

        partition = MimUtils.convertToClusters2(searchGraph);
      } else {
        throw new IllegalStateException();
      }

      mimStructure = GraphUtils.replaceNodes(mimStructure, data.getVariables());

      List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);

      Graph mimbuildStructure;

      Mimbuild2 mimbuild = new Mimbuild2();
      mimbuild.setAlpha(0.001);
      mimbuild.setMinClusterSize(3);

      try {
        mimbuildStructure = mimbuild.search(partition, latentVarList, cov);
      } catch (Exception e) {
        e.printStackTrace();
        continue;
      }

      mimbuildStructure = GraphUtils.replaceNodes(mimbuildStructure, data.getVariables());
      mimbuildStructure = condense(mimStructure, mimbuildStructure);

      //            Graph mimSubgraph = new EdgeListGraph(mimStructure);
      //
      //            for (Node node : mimSubgraph.getNodes()) {
      //                if (!mimStructure.getNodes().contains(node)) {
      //                    mimSubgraph.removeNode(node);
      //                }
      //            }

      int shd = SearchGraphUtils.structuralHammingDistance(mimStructure, mimbuildStructure);
      boolean impureCluster = containsImpureCluster(partition, trueClusters);
      double pValue = mimbuild.getpValue();
      boolean pBelow05 = pValue < 0.05;
      boolean numClustersGreaterThan5 = partition.size() != 5;
      boolean error = false;

      //            boolean condition = impureCluster || numClustersGreaterThan5 || pBelow05;
      //            boolean condition = numClustersGreaterThan5 || pBelow05;
      boolean condition = numClustered(partition) == 40;

      if (!condition && (shd > 5)) {
        error = true;
      }

      if (!condition) {
        totalError += shd;
        errorCount++;
      }

      //            if (numClustered(partition) > maxNumMeasures) {
      //                maxNumMeasures = numClustered(partition);
      //                maxP = pValue;
      //                maxScore = shd;
      //                System.out.println("maxNumMeasures = " + maxNumMeasures);
      //                System.out.println("maxScore = " + maxScore);
      //                System.out.println("maxP = " + maxP);
      //                System.out.println("clusters = " + clusterSizes(partition, trueClusters));
      //            }
      //            else
      if (pValue > maxP) {
        maxScore = shd;
        maxP = mimbuild.getpValue();
        maxNumMeasures = numClustered(partition);
        System.out.println("maxNumMeasures = " + maxNumMeasures);
        System.out.println("maxScore = " + maxScore);
        System.out.println("maxP = " + maxP);
        System.out.println("clusters = " + clusterSizes(partition, trueClusters));
      }

      System.out.print(
          shd
              + "\t"
              + nf.format(pValue)
              + " "
              //                            + (error ? 1 : 0) + " "
              //                            + (pBelow05 ? "P < 0.05 " : "")
              //                            + (impureCluster ? "Impure cluster " : "")
              //                            + (numClustersGreaterThan5 ? "# Clusters != 5 " : "")
              //                            + clusterSizes(partition, trueClusters)
              + numClustered(partition));

      System.out.println();
    }

    System.out.println("\nAvg SHD for not-flagged cases = " + (totalError / (double) errorCount));

    System.out.println("maxNumMeasures = " + maxNumMeasures);
    System.out.println("maxScore = " + maxScore);
    System.out.println("maxP = " + maxP);
  }
Пример #14
0
  /**
   * Runs PC starting with a commplete graph over the given list of nodes, using the given
   * independence test and knowledge and returns the resultant graph. The returned graph will be a
   * pattern if the independence information is consistent with the hypothesis that there are no
   * latent common causes. It may, however, contain cycles or bidirected edges if this assumption is
   * not born out, either due to the actual presence of latent common causes, or due to statistical
   * errors in conditional independence judgments.
   *
   * <p>All of the given nodes must be in the domain of the given conditional independence test.
   */
  public Graph search(List<Node> nodes) {
    this.logger.log("info", "Starting PC algorithm");
    this.logger.log("info", "Independence test = " + getIndependenceTest() + ".");

    if (trueDag != null) {
      this.dsep = new IndTestDSep(trueDag);
    }

    long startTime = System.currentTimeMillis();

    if (getIndependenceTest() == null) {
      throw new NullPointerException();
    }

    List<Node> allNodes = getIndependenceTest().getVariables();
    if (!allNodes.containsAll(nodes)) {
      throw new IllegalArgumentException(
          "All of the given nodes must " + "be in the domain of the independence test provided.");
    }

    IFas fas = new Fas2(getIndependenceTest());
    fas.setInitialGraph(initialGraph);
    fas.setKnowledge(getKnowledge());
    fas.setDepth(getDepth());
    fas.setVerbose(verbose);
    graph = fas.search();

    SearchGraphUtils.pcOrientbk(knowledge, graph, nodes);

    //        independenceTest = new ProbabilisticMAPIndependence((DataSet)
    // independenceTest.getData());

    SepsetsMaxPValue sepsetProducer =
        new SepsetsMaxPValue(graph, independenceTest, null, getDepth());
    sepsetProducer.setDsep(dsep);

    addColliders(graph, sepsetProducer, knowledge);

    MeekRules rules = new MeekRules();
    rules.setKnowledge(knowledge);
    rules.orientImplied(graph);

    //        Graph pattern = new EdgeListGraphSingleConnections(graph);
    //
    //        for (Node x : getNodes()) {
    //            for (Node y : getNodes()) {
    //                if (x == y) continue;
    //
    //                if (!localMarkovIndep(x, y, pattern, independenceTest)) {
    //                    graph.addUndirectedEdge(x, y);
    //                }
    //            }
    //        }
    //
    //        fas = new FasStableConcurrent(getIndependenceTest());
    //        fas.setInitialGraph(new EdgeListGraphSingleConnections(graph));
    //        fas.setKnowledge(getKnowledge());
    //        fas.setDepth(getDepth());
    //        fas.setVerbose(verbose);
    //        graph = fas.search();
    //
    //        sepsetProducer = new SepsetsMaxPValue(graph, independenceTest, null, getDepth());
    //
    //        addColliders(graph, sepsetProducer, knowledge);
    //
    //        rules = new MeekRules();
    //        rules.setKnowledge(knowledge);
    //        rules.orientImplied(graph);

    this.logger.log("graph", "\nReturning this graph: " + graph);

    this.elapsedTime = System.currentTimeMillis() - startTime;

    this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    this.logger.log("info", "Finishing PC Algorithm.");
    this.logger.flush();

    return graph;
  }
Пример #15
0
  /**
   * Executes the algorithm, producing (at least) a result workbench. Must be implemented in the
   * extending class.
   */
  public void execute() {
    Object source = dataWrapper.getSelectedDataModel();

    DataModel dataModel = (DataModel) source;

    IKnowledge knowledge = params2.getKnowledge();

    if (initialGraph == null) {
      initialGraph = new EdgeListGraph(dataModel.getVariables());
    }

    Graph graph2 = new EdgeListGraph(initialGraph);
    graph2 = GraphUtils.replaceNodes(graph2, dataModel.getVariables());

    Bff search;

    if (dataModel instanceof DataSet) {
      DataSet dataSet = (DataSet) dataModel;

      if (getAlgorithmType() == AlgorithmType.BEAM) {
        search = new BffBeam(graph2, dataSet, knowledge);
      } else if (getAlgorithmType() == AlgorithmType.GES) {
        search = new BffGes(graph2, dataSet);
        search.setKnowledge(knowledge);
      } else {
        throw new IllegalStateException();
      }
    } else if (dataModel instanceof CovarianceMatrix) {
      CovarianceMatrix covarianceMatrix = (CovarianceMatrix) dataModel;

      if (getAlgorithmType() == AlgorithmType.BEAM) {
        search = new BffBeam(graph2, covarianceMatrix, knowledge);
      } else if (getAlgorithmType() == AlgorithmType.GES) {
        throw new IllegalArgumentException(
            "GES method requires a dataset; a covariance matrix was provided.");
        //                search = new BffGes(graph2, covarianceMatrix);
        //                search.setKnowledge(knowledge);
      } else {
        throw new IllegalStateException();
      }
    } else {
      throw new IllegalStateException();
    }

    PcIndTestParams indTestParams = (PcIndTestParams) getParams().getIndTestParams();

    search.setAlpha(indTestParams.getAlpha());
    search.setBeamWidth(indTestParams.getBeamWidth());
    search.setHighPValueAlpha(indTestParams.getZeroEdgeP());
    this.graph = search.search();

    //        this.graph = search.getNewSemIm().getSemPm().getGraph();

    setOriginalSemIm(search.getOriginalSemIm());
    this.newSemIm = search.getNewSemIm();
    fireGraphChange(graph);

    if (getSourceGraph() != null) {
      GraphUtils.arrangeBySourceGraph(graph, getSourceGraph());
    } else if (knowledge.isDefaultToKnowledgeLayout()) {
      SearchGraphUtils.arrangeByKnowledgeTiers(graph, knowledge);
    } else {
      GraphUtils.circleLayout(graph, 200, 200, 150);
    }

    setResultGraph(SearchGraphUtils.patternForDag(graph, knowledge));
  }
Пример #16
0
  public static void main(String[] args) {
    //        Graph g = new EdgeListGraph();
    //        g.addNode(new ContinuousVariable("X1"));
    //        g.addNode(new ContinuousVariable("X2"));
    //        g.addNode(new DiscreteVariable("X3", 4));
    //        g.addNode(new DiscreteVariable("X4", 4));
    //        g.addNode(new ContinuousVariable("X5"));
    //
    //        g.addDirectedEdge(g.getNode("X1"), g.getNode("X2"));
    //        g.addDirectedEdge(g.getNode("X2"), g.getNode("X3"));
    //        g.addDirectedEdge(g.getNode("X3"), g.getNode("X4"));
    //        g.addDirectedEdge(g.getNode("X4"), g.getNode("X5"));
    //
    //        GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
    ////        System.out.println(pm);
    //
    //        GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
    ////        System.out.println(im);
    //
    //        int samps = 200;
    //        DataSet ds = im.simulateDataAvoidInfinity(samps, false);
    //        ds = MixedUtils.makeMixedData(ds, MixedUtils.getNodeDists(g));
    //        //System.out.println(ds);
    //        System.out.println(ds.isMixed());
    try {
      String path = ExampleMixedSearch.class.getResource("test_data").getPath();
      Graph trueGraph =
          SearchGraphUtils.patternFromDag(
              GraphUtils.loadGraphTxt(new File(path, "DAG_0_graph.txt")));
      DataSet ds = MixedUtils.loadDataSet(path, "DAG_0_data.txt");

      IndTestMultinomialLogisticRegression indMix =
          new IndTestMultinomialLogisticRegression(ds, .05);
      IndTestMultinomialLogisticRegressionWald indWalLin =
          new IndTestMultinomialLogisticRegressionWald(ds, .05, true);
      IndTestMultinomialLogisticRegressionWald indWalLog =
          new IndTestMultinomialLogisticRegressionWald(ds, .05, false);

      PcStable s1 = new PcStable(indMix);
      PcStable s2 = new PcStable(indWalLin);
      PcStable s3 = new PcStable(indWalLog);

      long time = System.currentTimeMillis();
      Graph g1 = SearchGraphUtils.patternFromDag(s1.search());
      System.out.println("Mix Time " + ((System.currentTimeMillis() - time) / 1000.0));

      time = System.currentTimeMillis();
      Graph g2 = SearchGraphUtils.patternFromDag(s2.search());
      System.out.println("Wald lin Time " + ((System.currentTimeMillis() - time) / 1000.0));

      time = System.currentTimeMillis();
      Graph g3 = SearchGraphUtils.patternFromDag(s3.search());
      System.out.println("Wald log Time " + ((System.currentTimeMillis() - time) / 1000.0));

      //            System.out.println(g);
      //            System.out.println("IndMix: " + s1.search());
      //            System.out.println("IndWalLin: " + s2.search());
      //            System.out.println("IndWalLog: " + s3.search());

      System.out.println(MixedUtils.EdgeStatHeader);
      System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g1)));
      System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g2)));
      System.out.println(MixedUtils.stringFrom2dArray(MixedUtils.allEdgeStats(trueGraph, g3)));
    } catch (Throwable t) {
      t.printStackTrace();
    }
  }