예제 #1
0
  /**
   * Completes a pattern that was modified by an insertion/deletion operator Based on the algorithm
   * described on Appendix C of (Chickering, 2002).
   */
  private void rebuildPattern(Graph graph) {
    SearchGraphUtils.basicPattern(graph, false);
    addRequiredEdges(graph);
    meekOrient(graph, getKnowledge());

    if (TetradLogger.getInstance().isEventActive("rebuiltPatterns")) {
      TetradLogger.getInstance().log("rebuiltPatterns", "Rebuilt pattern = " + graph);
    }
  }
예제 #2
0
파일: Lofs.java 프로젝트: jdramsey/tetrad
  private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) {
    for (Node node : nodes) {
      SortedMap<Double, String> scoreReports = new TreeMap<Double, String>();

      List<Node> adj = skeleton.getAdjacentNodes(node);

      DepthChoiceGenerator gen = new DepthChoiceGenerator(adj.size(), adj.size());
      int[] choice;
      double maxScore = Double.NEGATIVE_INFINITY;
      List<Node> parents = null;

      while ((choice = gen.next()) != null) {
        List<Node> _parents = GraphUtils.asList(choice, adj);

        double score = score(node, _parents);
        scoreReports.put(-score, _parents.toString());

        if (score > maxScore) {
          maxScore = score;
          parents = _parents;
        }
      }

      for (double score : scoreReports.keySet()) {
        TetradLogger.getInstance()
            .log(
                "score",
                "For " + node + " parents = " + scoreReports.get(score) + " score = " + -score);
      }

      TetradLogger.getInstance().log("score", "");

      if (parents == null) {
        continue;
      }

      if (normal(node, parents)) continue;

      for (Node _node : adj) {
        if (parents.contains(_node)) {
          Edge parentEdge = Edges.directedEdge(_node, node);

          if (!graph.containsEdge(parentEdge)) {
            graph.addEdge(parentEdge);
          }
        }
      }
    }
  }
예제 #3
0
  private void bes(Graph graph) {
    TetradLogger.getInstance().log("info", "** BACKWARD EQUIVALENCE SEARCH");

    initializeArrowsBackward(graph);

    while (!sortedArrows.isEmpty()) {
      Arrow arrow = sortedArrows.first();
      sortedArrows.remove(arrow);

      Node x = arrow.getX();
      Node y = arrow.getY();

      clearArrow(x, y);

      if (!validDelete(arrow.getHOrT(), arrow.getNaYX(), graph)) {
        continue;
      }

      List<Node> h = arrow.getHOrT();
      double bump = arrow.getBump();

      delete(x, y, h, graph, bump);
      score += bump;
      rebuildPattern(graph);

      storeGraph(graph);

      initializeArrowsBackward(
          graph); // Rebuilds Arrows from scratch each time. Fast enough for backwards.
    }
  }
예제 #4
0
  /**
   * Forward equivalence search.
   *
   * @param graph The graph in the state prior to the forward equivalence search.
   */
  private void fes(Graph graph, List<Node> nodes) {
    TetradLogger.getInstance().log("info", "** FORWARD EQUIVALENCE SEARCH");

    lookupArrows = new HashMap<OrderedPair, Set<Arrow>>();

    initializeArrowsForward(nodes);

    while (!sortedArrows.isEmpty()) {
      Arrow arrow = sortedArrows.first();
      sortedArrows.remove(arrow);

      Node x = arrow.getX();
      Node y = arrow.getY();

      clearArrow(x, y);

      if (graph.isAdjacentTo(x, y)) {
        continue;
      }

      if (!validInsert(x, y, arrow.getHOrT(), arrow.getNaYX(), graph)) {
        continue;
      }

      List<Node> t = arrow.getHOrT();
      double bump = arrow.getBump();

      Set<Edge> edges = graph.getEdges();

      insert(x, y, t, graph, bump);
      score += bump;
      rebuildPattern(graph);

      // Try to avoid duplicating scoring calls. First clear out all of the edges that need to be
      // changed,
      // then change them, checking to see if they're already been changed. I know, roundabout, but
      // there's
      // a performance boost.
      for (Edge edge : graph.getEdges()) {
        if (!edges.contains(edge)) {
          reevaluateForward(graph, nodes, edge.getNode1(), edge.getNode2());
        }
      }

      storeGraph(graph);
    }
  }
예제 #5
0
  private void addRequiredEdges(Graph graph) {
    if (true) return;
    if (knowledgeEmpty()) return;

    for (Iterator<KnowledgeEdge> it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge next = it.next();

      Node nodeA = graph.getNode(next.getFrom());
      Node nodeB = graph.getNode(next.getTo());

      if (!graph.isAncestorOf(nodeB, nodeA)) {
        graph.removeEdges(nodeA, nodeB);
        graph.addDirectedEdge(nodeA, nodeB);
        TetradLogger.getInstance()
            .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
      }
    }
    for (Edge edge : graph.getEdges()) {
      final String A = edge.getNode1().getName();
      final String B = edge.getNode2().getName();

      if (knowledge.isForbidden(A, B)) {
        Node nodeA = edge.getNode1();
        Node nodeB = edge.getNode2();

        if (nodeA != null
            && nodeB != null
            && graph.isAdjacentTo(nodeA, nodeB)
            && !graph.isChildOf(nodeA, nodeB)) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
        if (!graph.isChildOf(nodeA, nodeB)
            && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
      } else if (knowledge.isForbidden(B, A)) {
        Node nodeA = edge.getNode2();
        Node nodeB = edge.getNode1();

        if (nodeA != null
            && nodeB != null
            && graph.isAdjacentTo(nodeA, nodeB)
            && !graph.isChildOf(nodeA, nodeB)) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
        if (!graph.isChildOf(nodeA, nodeB)
            && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
      }
    }
  }
예제 #6
0
  /** Do an actual deletion (Definition 13 from Chickering, 2002). */
  private void delete(Node x, Node y, List<Node> subset, Graph graph, double bump) {

    Edge trueEdge = null;

    if (trueGraph != null) {
      Node _x = trueGraph.getNode(x.getName());
      Node _y = trueGraph.getNode(y.getName());
      trueEdge = trueGraph.getEdge(_x, _y);
    }

    if (log && verbose) {
      Edge oldEdge = graph.getEdge(x, y);

      String label = trueGraph != null && trueEdge != null ? "*" : "";
      TetradLogger.getInstance()
          .log(
              "deletedEdges",
              (graph.getNumEdges() - 1)
                  + ". DELETE "
                  + oldEdge
                  + " "
                  + subset
                  + " ("
                  + bump
                  + ") "
                  + label);
      out.println(
          (graph.getNumEdges() - 1)
              + ". DELETE "
              + oldEdge
              + " "
              + subset
              + " ("
              + bump
              + ") "
              + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (numEdges % 50 == 0) out.println(numEdges);
    }

    graph.removeEdge(x, y);

    for (Node h : subset) {
      Edge oldEdge = graph.getEdge(y, h);

      graph.removeEdge(y, h);
      graph.addDirectedEdge(y, h);

      if (log) {
        TetradLogger.getInstance()
            .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
      }

      if (verbose) {
        out.println("--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
      }

      if (Edges.isUndirectedEdge(graph.getEdge(x, h))) {
        if (!graph.isAdjacentTo(x, h))
          throw new IllegalArgumentException("Not adjacent: " + x + ", " + h);
        oldEdge = graph.getEdge(x, h);

        graph.removeEdge(x, h);
        graph.addDirectedEdge(x, h);

        if (log) {
          TetradLogger.getInstance()
              .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
        }

        if (verbose) {
          out.println("--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
        }
      }
    }
  }
예제 #7
0
  // serial.
  private void insert(Node x, Node y, List<Node> t, Graph graph, double bump) {
    if (graph.isAdjacentTo(x, y)) {
      return; // The initial graph may already have put this edge in the graph.
      //            throw new IllegalArgumentException(x + " and " + y + " are already adjacent in
      // the graph.");
    }

    Edge trueEdge = null;

    if (trueGraph != null) {
      Node _x = trueGraph.getNode(x.getName());
      Node _y = trueGraph.getNode(y.getName());
      trueEdge = trueGraph.getEdge(_x, _y);
    }

    graph.addDirectedEdge(x, y);

    if (log) {
      String label = trueGraph != null && trueEdge != null ? "*" : "";
      TetradLogger.getInstance()
          .log(
              "insertedEdges",
              graph.getNumEdges()
                  + ". INSERT "
                  + graph.getEdge(x, y)
                  + " "
                  + t
                  + " "
                  + bump
                  + " "
                  + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (verbose) {
        if (numEdges % 50 == 0) out.println(numEdges);
      }
    }

    if (verbose) {
      String label = trueGraph != null && trueEdge != null ? "*" : "";
      out.println(
          graph.getNumEdges()
              + ". INSERT "
              + graph.getEdge(x, y)
              + " "
              + t
              + " "
              + bump
              + " "
              + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (verbose) {
        if (numEdges % 50 == 0) out.println(numEdges);
      }
    }

    for (Node _t : t) {
      Edge oldEdge = graph.getEdge(_t, y);

      if (oldEdge == null) throw new IllegalArgumentException("Not adjacent: " + _t + ", " + y);

      graph.removeEdge(_t, y);
      graph.addDirectedEdge(_t, y);

      if (log && verbose) {
        TetradLogger.getInstance()
            .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(_t, y));
        out.println("--- Directing " + oldEdge + " to " + graph.getEdge(_t, y));
      }
    }
  }
예제 #8
0
/**
 * GesSearch is an implentation of the GES algorithm, as specified in Chickering (2002) "Optimal
 * structure identification with greedy search" Journal of Machine Learning Research. It works for
 * both BayesNets and SEMs.
 *
 * <p>Some code optimization could be done for the scoring part of the graph for discrete models
 * (method scoreGraphChange). Some of Andrew Moore's approaches for caching sufficient statistics,
 * for instance.
 *
 * @author Ricardo Silva, Summer 2003
 * @author Joseph Ramsey, Revisions 10/2005
 */
public final class GesConcurrent implements GraphSearch, GraphScorer {

  /** The data set, various variable subsets of which are to be scored. */
  private DataSet dataSet;

  /** The covariance matrix for the data set. */
  private ICovarianceMatrix covariances;

  /** Sample size, either from the data set or from the variances. */
  private int sampleSize;

  /** Specification of forbidden and required edges. */
  private IKnowledge knowledge = new Knowledge2();

  /** Map from variables to their column indices in the data set. */
  private HashMap<Node, Integer> hashIndices;

  /** List of variables in the data set, in order. */
  private List<Node> variables;

  /** True iff the data set is discrete. */
  private boolean discrete;

  /**
   * The true graph, if known. If this is provided, asterisks will be printed out next to false
   * positive added edges (that is, edges added that aren't adjacencies in the true graph).
   */
  private Graph trueGraph;

  /** An initial graph to start from. */
  private Graph initialGraph;

  /** Caches scores for discrete search. */
  private final LocalScoreCache localScoreCache = new LocalScoreCache();

  /** Elapsed time of the most recent search. */
  private long elapsedTime;

  /**
   * True if cycles are to be aggressively prevented. May be expensive for large graphs (but also
   * useful for large graphs).
   */
  private boolean aggressivelyPreventCycles = false;

  /** Listeners for graph change events. */
  private transient List<PropertyChangeListener> listeners;

  /** Penalty discount--the BIC penalty is multiplied by this (for continuous variables). */
  private double penaltyDiscount = 1.0;

  /** The score for discrete searches. */
  private LocalDiscreteScore discreteScore;

  /** The logger for this class. The config needs to be set. */
  private TetradLogger logger = TetradLogger.getInstance();

  /** The top n graphs found by the algorithm, where n is numPatternsToStore. */
  private SortedSet<ScoredGraph> topGraphs = new TreeSet<ScoredGraph>();

  /** The number of top patterns to store. */
  private int numPatternsToStore = 0;

  // Potential arrows sorted by bump high to low. The first one is a candidate for adding to the
  // graph.
  private SortedSet<Arrow> sortedArrows = new ConcurrentSkipListSet<Arrow>();

  // Arrows added to sortedArrows for each <i, j>.
  private Map<OrderedPair, Set<Arrow>> lookupArrows;

  /** True if graphs should be stored. */
  private boolean log = true;

  private boolean verbose = false;

  private int NTHREADS = Runtime.getRuntime().availableProcessors() * 5;
  //    private boolean checkedKnowledgeEmpty = false;
  //    private boolean knowledgeEmpty = false;
  private ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
  private double score;
  private PrintStream out;

  // ===========================CONSTRUCTORS=============================//

  public GesConcurrent(DataSet dataSet) {
    setDataSet(dataSet);
    if (dataSet.isDiscrete()) {
      BDeuScore score = new BDeuScore(dataSet);
      score.setSamplePrior(10);
      score.setStructurePrior(0.001);
    }
    setStructurePrior(0.001);
    setSamplePrior(10.);
  }

  public GesConcurrent(ICovarianceMatrix covMatrix) {
    setCovMatrix(covMatrix);
    setStructurePrior(0.001);
    setSamplePrior(10.);
  }

  // ==========================PUBLIC METHODS==========================//

  public boolean isAggressivelyPreventCycles() {
    return this.aggressivelyPreventCycles;
  }

  public void setAggressivelyPreventCycles(boolean aggressivelyPreventCycles) {
    this.aggressivelyPreventCycles = aggressivelyPreventCycles;
  }

  /**
   * Greedy equivalence search: Start from the empty graph, add edges till model is significant.
   * Then start deleting edges till a minimum is achieved.
   *
   * @return the resulting Pattern.
   */
  public Graph search() {

    Graph graph;

    if (initialGraph == null) {
      graph = new EdgeListGraphSingleConnections(getVariables());
    } else {
      graph = new EdgeListGraphSingleConnections(initialGraph);
    }

    fireGraphChange(graph);
    buildIndexing(graph);
    addRequiredEdges(graph);

    topGraphs.clear();

    storeGraph(graph);

    List<Node> nodes = graph.getNodes();

    long start = System.currentTimeMillis();
    score = 0.0;

    // Do forward search.
    fes(graph, nodes);

    // Do backward search.
    bes(graph);

    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - start;
    this.logger.log("graph", "\nReturning this graph: " + graph);

    this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    this.logger.flush();

    return graph;
  }

  public Graph search(List<Node> nodes) {
    long startTime = System.currentTimeMillis();
    localScoreCache.clear();

    if (!dataSet().getVariables().containsAll(nodes)) {
      throw new IllegalArgumentException("All of the nodes must be in " + "the supplied data set.");
    }

    Graph graph;

    if (initialGraph == null) {
      graph = new EdgeListGraphSingleConnections(nodes);
    } else {
      initialGraph = GraphUtils.replaceNodes(initialGraph, variables);
      graph = new EdgeListGraphSingleConnections(initialGraph);
    }

    topGraphs.clear();

    buildIndexing(graph);
    addRequiredEdges(graph);
    score = 0.0;

    // Do forward search.
    fes(graph, nodes);

    // Do backward search.
    bes(graph);

    long endTime = System.currentTimeMillis();
    this.elapsedTime = endTime - startTime;
    this.logger.log("graph", "\nReturning this graph: " + graph);

    this.logger.log("info", "Elapsed time = " + (elapsedTime) / 1000. + " s");
    this.logger.flush();

    return graph;
  }

  public IKnowledge getKnowledge() {
    return knowledge;
  }

  /**
   * Sets the background knowledge.
   *
   * @param knowledge the knowledge object, specifying forbidden and required edges.
   */
  public void setKnowledge(IKnowledge knowledge) {
    if (knowledge == null) throw new NullPointerException();
    this.knowledge = knowledge;
  }

  public void setStructurePrior(double structurePrior) {
    if (getDiscreteScore() != null) {
      getDiscreteScore().setStructurePrior(structurePrior);
    }
  }

  public void setSamplePrior(double samplePrior) {
    if (getDiscreteScore() != null) {
      getDiscreteScore().setSamplePrior(samplePrior);
    }
  }

  public long getElapsedTime() {
    return elapsedTime;
  }

  public void addPropertyChangeListener(PropertyChangeListener l) {
    getListeners().add(l);
  }

  public double getPenaltyDiscount() {
    return penaltyDiscount;
  }

  public void setPenaltyDiscount(double penaltyDiscount) {
    if (penaltyDiscount < 0) {
      throw new IllegalArgumentException("Penalty discount must be >= 0: " + penaltyDiscount);
    }

    this.penaltyDiscount = penaltyDiscount;
  }

  public void setTrueGraph(Graph trueGraph) {
    this.trueGraph = trueGraph;
  }

  public double getScore(Graph dag) {
    return scoreDag(dag);
  }

  public SortedSet<ScoredGraph> getTopGraphs() {
    return topGraphs;
  }

  public int getNumPatternsToStore() {
    return numPatternsToStore;
  }

  public void setNumPatternsToStore(int numPatternsToStore) {
    if (numPatternsToStore < 0) {
      throw new IllegalArgumentException(
          "# graphs to store must at least 0: " + numPatternsToStore);
    }

    this.numPatternsToStore = numPatternsToStore;
  }

  public LocalDiscreteScore getDiscreteScore() {
    return discreteScore;
  }

  public void setDiscreteScore(LocalDiscreteScore discreteScore) {
    if (discreteScore.getDataSet() != dataSet) {
      throw new IllegalArgumentException("Must use the same data set.");
    }
    this.discreteScore = discreteScore;
  }

  // ===========================PRIVATE METHODS========================//

  /**
   * Forward equivalence search.
   *
   * @param graph The graph in the state prior to the forward equivalence search.
   */
  private void fes(Graph graph, List<Node> nodes) {
    TetradLogger.getInstance().log("info", "** FORWARD EQUIVALENCE SEARCH");

    lookupArrows = new HashMap<OrderedPair, Set<Arrow>>();

    initializeArrowsForward(nodes);

    while (!sortedArrows.isEmpty()) {
      Arrow arrow = sortedArrows.first();
      sortedArrows.remove(arrow);

      Node x = arrow.getX();
      Node y = arrow.getY();

      clearArrow(x, y);

      if (graph.isAdjacentTo(x, y)) {
        continue;
      }

      if (!validInsert(x, y, arrow.getHOrT(), arrow.getNaYX(), graph)) {
        continue;
      }

      List<Node> t = arrow.getHOrT();
      double bump = arrow.getBump();

      Set<Edge> edges = graph.getEdges();

      insert(x, y, t, graph, bump);
      score += bump;
      rebuildPattern(graph);

      // Try to avoid duplicating scoring calls. First clear out all of the edges that need to be
      // changed,
      // then change them, checking to see if they're already been changed. I know, roundabout, but
      // there's
      // a performance boost.
      for (Edge edge : graph.getEdges()) {
        if (!edges.contains(edge)) {
          reevaluateForward(graph, nodes, edge.getNode1(), edge.getNode2());
        }
      }

      storeGraph(graph);
    }
  }

  private void bes(Graph graph) {
    TetradLogger.getInstance().log("info", "** BACKWARD EQUIVALENCE SEARCH");

    initializeArrowsBackward(graph);

    while (!sortedArrows.isEmpty()) {
      Arrow arrow = sortedArrows.first();
      sortedArrows.remove(arrow);

      Node x = arrow.getX();
      Node y = arrow.getY();

      clearArrow(x, y);

      if (!validDelete(arrow.getHOrT(), arrow.getNaYX(), graph)) {
        continue;
      }

      List<Node> h = arrow.getHOrT();
      double bump = arrow.getBump();

      delete(x, y, h, graph, bump);
      score += bump;
      rebuildPattern(graph);

      storeGraph(graph);

      initializeArrowsBackward(
          graph); // Rebuilds Arrows from scratch each time. Fast enough for backwards.
    }
  }

  // Expensive
  // Concurrent.
  private void initializeArrowsForward(final List<Node> nodes) {
    final List<Node> emptyList = new ArrayList<Node>(0);
    final Set<Node> emptySet = new HashSet<Node>(0);
    List<Callable<Boolean>> callables = new ArrayList<Callable<Boolean>>();

    for (int t = 0; t < NTHREADS; t++) {
      final int _t = t;

      Callable<Boolean> worker =
          new Callable<Boolean>() {
            @Override
            public Boolean call() throws Exception {
              int chunk = nodes.size() / NTHREADS + 1;

              for (int j = _t * chunk; j < Math.min((_t + 1) * chunk, nodes.size()); j++) {
                if (log && verbose) {
                  if ((j + 1) % 10 == 0) out.println("Initializing arrows forward: " + (j + 1));
                }

                for (int i = j + 1; i < nodes.size(); i++) {
                  Node x = nodes.get(i);
                  Node y = nodes.get(j);

                  if (!knowledgeEmpty()) {
                    if (getKnowledge().isForbidden(x.getName(), y.getName())) {
                      continue;
                    }

                    if (!validSetByKnowledge(y, emptyList)) {
                      continue;
                    }
                  }

                  double bump = scoreGraphChange(y, Collections.singleton(x), emptySet);

                  if (bump > 0.0) {
                    Arrow arrow = new Arrow(bump, x, y, emptyList, emptyList);
                    sortedArrows.add(arrow);
                    addLookupArrow(x, y, arrow);

                    Arrow arrow2 = new Arrow(bump, y, x, emptyList, emptyList);
                    sortedArrows.add(arrow2);
                    addLookupArrow(y, x, arrow);
                  }
                }
              }

              return true;
            }
          };

      callables.add(worker);
    }

    pool.invokeAll(callables);
  }

  private boolean knowledgeEmpty() {
    //        if (!checkedKnowledgeEmpty) {
    //            knowledgeEmpty = knowledge.isEmpty();
    //            checkedKnowledgeEmpty = true;
    //        }
    //
    //        return knowledgeEmpty;

    //        return knowledge.isEmpty();
    return false;
  }

  private void initializeArrowsBackward(Graph graph) {
    sortedArrows.clear();
    lookupArrows.clear();

    for (Edge edge : graph.getEdges()) {
      Node x = edge.getNode1();
      Node y = edge.getNode2();

      if (!knowledgeEmpty()) {
        if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) {
          continue;
        }
      }

      if (Edges.isDirectedEdge(edge)) {
        calculateArrowsBackward(x, y, graph);
      } else {
        calculateArrowsBackward(x, y, graph);
        calculateArrowsBackward(y, x, graph);
      }
    }
  }

  private void reevaluateForward(
      final Graph graph, final List<Node> nodes, final Node x, final Node y) {
    List<Callable<Boolean>> callables = new ArrayList<Callable<Boolean>>();

    for (int t = 0; t < NTHREADS; t++) {
      final int _t = t;

      Callable<Boolean> worker =
          new Callable<Boolean>() {
            @Override
            public Boolean call() {
              int chunk = nodes.size() / NTHREADS + 1;
              for (int _w = _t * chunk; _w < Math.min((_t + 1) * chunk, nodes.size()); _w++) {
                final Node w = nodes.get(_w);

                if (w == x) continue;
                if (w == y) continue;

                if (!graph.isAdjacentTo(w, x)) {
                  calculateArrowsForward(w, x, graph);

                  if (graph.isAdjacentTo(w, y)) {
                    calculateArrowsForward(x, w, graph);
                  }
                }

                if (!graph.isAdjacentTo(w, y)) {
                  calculateArrowsForward(w, y, graph);

                  if (graph.isAdjacentTo(w, x)) {
                    calculateArrowsForward(y, w, graph);
                  }
                }
              }

              return true;
            }
          };

      callables.add(worker);
    }

    pool.invokeAll(callables);
  }

  private void calculateArrowsForward(Node x, Node y, Graph graph) {
    clearArrow(x, y);

    if (!knowledgeEmpty()) {
      if (getKnowledge().isForbidden(x.getName(), y.getName())) {
        return;
      }
    }

    List<Node> naYX = getNaYX(x, y, graph);
    List<Node> t = getTNeighbors(x, y, graph);

    DepthChoiceGenerator gen = new DepthChoiceGenerator(t.size(), t.size());
    int[] choice;

    while ((choice = gen.next()) != null) {
      List<Node> s = GraphUtils.asList(choice, t);

      if (!knowledgeEmpty()) {
        if (!validSetByKnowledge(y, s)) {
          continue;
        }
      }

      double bump = insertEval(x, y, s, naYX, graph);

      if (bump > 0.0) {
        Arrow arrow = new Arrow(bump, x, y, s, naYX);
        sortedArrows.add(arrow);
        addLookupArrow(x, y, arrow);
      }
    }
  }

  // Invalid if then nodes or graph changes.
  private void calculateArrowsBackward(Node x, Node y, Graph graph) {
    if (x == y) {
      return;
    }

    if (!graph.isAdjacentTo(x, y)) {
      return;
    }

    if (!knowledgeEmpty()) {
      if (!getKnowledge().noEdgeRequired(x.getName(), y.getName())) {
        return;
      }
    }

    List<Node> naYX = getNaYX(x, y, graph);

    clearArrow(x, y);

    List<Node> _naYX = new ArrayList<Node>(naYX);
    DepthChoiceGenerator gen = new DepthChoiceGenerator(_naYX.size(), _naYX.size());
    int[] choice;

    while ((choice = gen.next()) != null) {
      List<Node> H = GraphUtils.asList(choice, _naYX);

      if (!knowledgeEmpty()) {
        if (!validSetByKnowledge(y, H)) {
          continue;
        }
      }

      double bump = deleteEval(x, y, H, naYX, graph);

      if (bump > 0.0) {
        Arrow arrow = new Arrow(bump, x, y, H, naYX);
        sortedArrows.add(arrow);
        addLookupArrow(x, y, arrow);
      }
    }
  }

  /** True iff log output should be produced. */
  public boolean isLog() {
    return log;
  }

  public void setLog(boolean log) {
    this.log = log;
  }

  public Graph getInitialGraph() {
    return initialGraph;
  }

  // Cannot be done if the graph changes.
  public void setInitialGraph(Graph initialGraph) {
    initialGraph = GraphUtils.replaceNodes(initialGraph, variables);

    out.println("Initial graph variables: " + initialGraph.getNodes());
    out.println("Data set variables: " + variables);

    if (!new HashSet<Node>(initialGraph.getNodes()).equals(new HashSet<Node>(variables))) {
      throw new IllegalArgumentException("Variables aren't the same.");
    }

    this.initialGraph = initialGraph;
  }

  public void setVerbose(boolean verbose) {
    this.verbose = verbose;
  }

  public void setOut(PrintStream out) {
    this.out = out;
  }

  public PrintStream getOut() {
    return out;
  }

  // Concurrent OK.
  private static class Arrow implements Comparable {
    private double bump;
    private Node x;
    private Node y;
    private List<Node> hOrT;
    private List<Node> naYX;

    public Arrow(double bump, Node x, Node y, List<Node> hOrT, List<Node> naYX) {
      this.bump = bump;
      this.x = x;
      this.y = y;
      this.hOrT = hOrT;
      this.naYX = naYX;
    }

    public double getBump() {
      return bump;
    }

    public Node getX() {
      return x;
    }

    public Node getY() {
      return y;
    }

    public List<Node> getHOrT() {
      return hOrT;
    }

    public List<Node> getNaYX() {
      return naYX;
    }

    // Sorting is by bump, high to low.
    public int compareTo(Object o) {
      Arrow arrow = (Arrow) o;
      return Double.compare(arrow.getBump(), getBump());
    }

    public String toString() {
      return "Arrow<"
          + x
          + "->"
          + y
          + " bump = "
          + bump
          + " t/h = "
          + hOrT
          + " naYX = "
          + naYX
          + ">";
    }
  }

  /** Get all nodes that are connected to Y by an undirected edge and not adjacent to X. */
  private static List<Node> getTNeighbors(Node x, Node y, Graph graph) {
    List<Edge> yEdges = graph.getEdges(y);
    List<Node> tNeighbors = new ArrayList<Node>();

    for (Edge edge : yEdges) {
      if (!Edges.isUndirectedEdge(edge)) {
        continue;
      }

      Node z = edge.getDistalNode(y);

      if (graph.isAdjacentTo(z, x)) {
        continue;
      }

      tNeighbors.add(z);
    }

    return tNeighbors;
  }

  /** Evaluate the Insert(X, Y, T) operator (Definition 12 from Chickering, 2002). */
  private double insertEval(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) {
    Set<Node> set1 = new HashSet<Node>(naYX);
    set1.addAll(t);
    List<Node> paY = graph.getParents(y);
    set1.addAll(paY);
    Set<Node> set2 = new HashSet<Node>(set1);
    set1.add(x);

    return scoreGraphChange(y, set1, set2);
  }

  /** Evaluate the Delete(X, Y, T) operator (Definition 12 from Chickering, 2002). */
  // Can be done concurrently.
  private double deleteEval(Node x, Node y, List<Node> h, List<Node> naYX, Graph graph) {
    List<Node> paY = graph.getParents(y);
    Set<Node> paYMinuxX = new HashSet<Node>(paY);
    paYMinuxX.remove(x);

    Set<Node> set1 = new HashSet<Node>(naYX);
    set1.removeAll(h);
    set1.addAll(paYMinuxX);

    Set<Node> set2 = new HashSet<Node>(naYX);
    set2.removeAll(h);
    set2.addAll(paY);

    return scoreGraphChange(y, set1, set2);
  }

  /*
   * Do an actual insertion
   * (Definition 12 from Chickering, 2002).
   **/
  // serial.
  private void insert(Node x, Node y, List<Node> t, Graph graph, double bump) {
    if (graph.isAdjacentTo(x, y)) {
      return; // The initial graph may already have put this edge in the graph.
      //            throw new IllegalArgumentException(x + " and " + y + " are already adjacent in
      // the graph.");
    }

    Edge trueEdge = null;

    if (trueGraph != null) {
      Node _x = trueGraph.getNode(x.getName());
      Node _y = trueGraph.getNode(y.getName());
      trueEdge = trueGraph.getEdge(_x, _y);
    }

    graph.addDirectedEdge(x, y);

    if (log) {
      String label = trueGraph != null && trueEdge != null ? "*" : "";
      TetradLogger.getInstance()
          .log(
              "insertedEdges",
              graph.getNumEdges()
                  + ". INSERT "
                  + graph.getEdge(x, y)
                  + " "
                  + t
                  + " "
                  + bump
                  + " "
                  + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (verbose) {
        if (numEdges % 50 == 0) out.println(numEdges);
      }
    }

    if (verbose) {
      String label = trueGraph != null && trueEdge != null ? "*" : "";
      out.println(
          graph.getNumEdges()
              + ". INSERT "
              + graph.getEdge(x, y)
              + " "
              + t
              + " "
              + bump
              + " "
              + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (verbose) {
        if (numEdges % 50 == 0) out.println(numEdges);
      }
    }

    for (Node _t : t) {
      Edge oldEdge = graph.getEdge(_t, y);

      if (oldEdge == null) throw new IllegalArgumentException("Not adjacent: " + _t + ", " + y);

      graph.removeEdge(_t, y);
      graph.addDirectedEdge(_t, y);

      if (log && verbose) {
        TetradLogger.getInstance()
            .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(_t, y));
        out.println("--- Directing " + oldEdge + " to " + graph.getEdge(_t, y));
      }
    }
  }

  /** Do an actual deletion (Definition 13 from Chickering, 2002). */
  private void delete(Node x, Node y, List<Node> subset, Graph graph, double bump) {

    Edge trueEdge = null;

    if (trueGraph != null) {
      Node _x = trueGraph.getNode(x.getName());
      Node _y = trueGraph.getNode(y.getName());
      trueEdge = trueGraph.getEdge(_x, _y);
    }

    if (log && verbose) {
      Edge oldEdge = graph.getEdge(x, y);

      String label = trueGraph != null && trueEdge != null ? "*" : "";
      TetradLogger.getInstance()
          .log(
              "deletedEdges",
              (graph.getNumEdges() - 1)
                  + ". DELETE "
                  + oldEdge
                  + " "
                  + subset
                  + " ("
                  + bump
                  + ") "
                  + label);
      out.println(
          (graph.getNumEdges() - 1)
              + ". DELETE "
              + oldEdge
              + " "
              + subset
              + " ("
              + bump
              + ") "
              + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (numEdges % 50 == 0) out.println(numEdges);
    }

    graph.removeEdge(x, y);

    for (Node h : subset) {
      Edge oldEdge = graph.getEdge(y, h);

      graph.removeEdge(y, h);
      graph.addDirectedEdge(y, h);

      if (log) {
        TetradLogger.getInstance()
            .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
      }

      if (verbose) {
        out.println("--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
      }

      if (Edges.isUndirectedEdge(graph.getEdge(x, h))) {
        if (!graph.isAdjacentTo(x, h))
          throw new IllegalArgumentException("Not adjacent: " + x + ", " + h);
        oldEdge = graph.getEdge(x, h);

        graph.removeEdge(x, h);
        graph.addDirectedEdge(x, h);

        if (log) {
          TetradLogger.getInstance()
              .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
        }

        if (verbose) {
          out.println("--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
        }
      }
    }
  }

  /*
   * Test if the candidate insertion is a valid operation
   * (Theorem 15 from Chickering, 2002).
   **/

  private boolean validInsert(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) {
    List<Node> union = new ArrayList<Node>(t); // t and nayx are disjoint
    union.addAll(naYX);

    return isClique(union, graph) && !existsUnblockedSemiDirectedPath(y, x, union, graph);
  }

  /** Test if the candidate deletion is a valid operation (Theorem 17 from Chickering, 2002). */
  private static boolean validDelete(List<Node> h, List<Node> naXY, Graph graph) {
    List<Node> list = new ArrayList<Node>(naXY);
    list.removeAll(h);
    return isClique(list, graph);
  }

  // ---Background knowledge methods.

  private void addRequiredEdges(Graph graph) {
    if (true) return;
    if (knowledgeEmpty()) return;

    for (Iterator<KnowledgeEdge> it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge next = it.next();

      Node nodeA = graph.getNode(next.getFrom());
      Node nodeB = graph.getNode(next.getTo());

      if (!graph.isAncestorOf(nodeB, nodeA)) {
        graph.removeEdges(nodeA, nodeB);
        graph.addDirectedEdge(nodeA, nodeB);
        TetradLogger.getInstance()
            .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
      }
    }
    for (Edge edge : graph.getEdges()) {
      final String A = edge.getNode1().getName();
      final String B = edge.getNode2().getName();

      if (knowledge.isForbidden(A, B)) {
        Node nodeA = edge.getNode1();
        Node nodeB = edge.getNode2();

        if (nodeA != null
            && nodeB != null
            && graph.isAdjacentTo(nodeA, nodeB)
            && !graph.isChildOf(nodeA, nodeB)) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
        if (!graph.isChildOf(nodeA, nodeB)
            && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
      } else if (knowledge.isForbidden(B, A)) {
        Node nodeA = edge.getNode2();
        Node nodeB = edge.getNode1();

        if (nodeA != null
            && nodeB != null
            && graph.isAdjacentTo(nodeA, nodeB)
            && !graph.isChildOf(nodeA, nodeB)) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
        if (!graph.isChildOf(nodeA, nodeB)
            && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
      }
    }
  }

  private String getString(KnowledgeEdge next) {
    return next.getTo();
  }

  /**
   * Use background knowledge to decide if an insert or delete operation does not orient edges in a
   * forbidden direction according to prior knowledge. If some orientation is forbidden in the
   * subset, the whole subset is forbidden.
   */
  private boolean validSetByKnowledge(Node y, List<Node> subset) {
    for (Node node : subset) {
      if (getKnowledge().isForbidden(node.getName(), y.getName())) {
        return false;
      }
    }
    return true;
  }

  // --Auxiliary methods.

  /**
   * Find all nodes that are connected to Y by an undirected edge that are adjacent to X (that is,
   * by undirected or directed edge).
   */
  private static List<Node> getNaYX(Node x, Node y, Graph graph) {
    List<Edge> yEdges = graph.getEdges(y);
    List<Node> nayx = new ArrayList<Node>();

    for (Edge edge : yEdges) {
      if (!Edges.isUndirectedEdge(edge)) {
        continue;
      }

      Node z = edge.getDistalNode(y);

      if (!graph.isAdjacentTo(z, x)) {
        continue;
      }

      nayx.add(z);
    }

    return nayx;
  }

  /** Returns true iif the given set forms a clique in the given graph. */
  private static boolean isClique(List<Node> nodes, Graph graph) {
    for (int i = 0; i < nodes.size() - 1; i++) {
      for (int j = i + 1; j < nodes.size(); j++) {
        if (!graph.isAdjacentTo(nodes.get(i), nodes.get(j))) {
          return false;
        }
      }
    }

    return true;
  }

  private boolean existsUnblockedSemiDirectedPath(Node from, Node to, List<Node> cond, Graph G) {
    Queue<Node> Q = new LinkedList<Node>();
    Set<Node> V = new HashSet<Node>();
    Q.offer(from);
    V.add(from);

    while (!Q.isEmpty()) {
      Node t = Q.remove();
      if (t == to) return true;

      for (Node u : G.getAdjacentNodes(t)) {
        Edge edge = G.getEdge(t, u);
        Node c = Edges.traverseSemiDirected(t, edge);
        if (c == null) continue;
        if (cond.contains(c)) continue;
        if (c == to) return true;

        if (!V.contains(c)) {
          V.add(c);
          Q.offer(c);
        }
      }
    }

    return false;
  }

  /**
   * Completes a pattern that was modified by an insertion/deletion operator Based on the algorithm
   * described on Appendix C of (Chickering, 2002).
   */
  private void rebuildPattern(Graph graph) {
    SearchGraphUtils.basicPattern(graph, false);
    addRequiredEdges(graph);
    meekOrient(graph, getKnowledge());

    if (TetradLogger.getInstance().isEventActive("rebuiltPatterns")) {
      TetradLogger.getInstance().log("rebuiltPatterns", "Rebuilt pattern = " + graph);
    }
  }

  private Graph pickDag(Graph graph) {
    SearchGraphUtils.basicPattern(graph, false);
    addRequiredEdges(graph);
    boolean containsUndirected;

    do {
      containsUndirected = false;

      for (Edge edge : graph.getEdges()) {
        if (Edges.isUndirectedEdge(edge)) {
          containsUndirected = true;
          graph.removeEdge(edge);
          Edge _edge = Edges.directedEdge(edge.getNode1(), edge.getNode2());
          graph.addEdge(_edge);
        }
      }

      meekOrient(graph, getKnowledge());
    } while (containsUndirected);

    return graph;
  }

  /**
   * Fully direct a graph with background knowledge. I am not sure how to adapt Chickering's
   * suggested algorithm above (dagToPdag) to incorporate background knowledge, so I am also
   * implementing this algorithm based on Meek's 1995 UAI paper. Notice it is the same implemented
   * in PcSearch. *IMPORTANT!* *It assumes all colliders are oriented, as well as arrows dictated by
   * time order.*
   */
  private void meekOrient(Graph graph, IKnowledge knowledge) {
    MeekRules rules = new MeekRules();
    rules.setOrientInPlace(false);
    rules.setKnowledge(knowledge);
    rules.orientImplied(graph);
  }

  private void setDataSet(DataSet dataSet) {
    List<String> _varNames = dataSet.getVariableNames();

    this.variables = dataSet.getVariables();
    this.dataSet = dataSet;
    this.discrete = dataSet.isDiscrete();

    if (!isDiscrete()) {
      this.covariances = new CovarianceMatrix(dataSet);
    }

    this.sampleSize = dataSet.getNumRows();
  }

  private void setCovMatrix(ICovarianceMatrix covarianceMatrix) {
    this.covariances = covarianceMatrix;
    this.variables = covarianceMatrix.getVariables();
    this.sampleSize = covarianceMatrix.getSampleSize();
  }

  private void buildIndexing(Graph graph) {
    this.hashIndices = new HashMap<Node, Integer>();
    for (Node node : graph.getNodes()) {
      this.hashIndices.put(node, variables.indexOf(node));
    }
  }

  private void clearArrow(Node x, Node y) {
    final OrderedPair<Node> pair = new OrderedPair<Node>(x, y);
    final Set<Arrow> lookupArrows = this.lookupArrows.get(pair);

    if (lookupArrows != null) {
      sortedArrows.removeAll(lookupArrows);
    }

    this.lookupArrows.remove(pair);
  }

  private void addLookupArrow(Node i, Node j, Arrow arrow) {
    OrderedPair<Node> pair = new OrderedPair<Node>(i, j);
    Set<Arrow> arrows = lookupArrows.get(pair);

    if (arrows == null) {
      arrows = new HashSet<Arrow>();
      lookupArrows.put(pair, arrows);
    }

    arrows.add(arrow);
  }

  // ===========================SCORING METHODS===================//
  public double scoreDag(Graph graph) {
    Graph dag = new EdgeListGraphSingleConnections(graph);
    buildIndexing(graph);

    double score = 0.0;

    for (Node y : dag.getNodes()) {
      Set<Node> parents = new HashSet<Node>(dag.getParents(y));
      int nextIndex = -1;
      for (int i = 0; i < getVariables().size(); i++) {
        nextIndex = hashIndices.get(variables.get(i));
      }
      int parentIndices[] = new int[parents.size()];
      Iterator<Node> pi = parents.iterator();
      int count = 0;
      while (pi.hasNext()) {
        Node nextParent = pi.next();
        parentIndices[count++] = hashIndices.get(nextParent);
      }

      if (this.isDiscrete()) {
        score += localDiscreteScore(nextIndex, parentIndices);
      } else {
        score += localSemScore(nextIndex, parentIndices);
      }
    }
    return score;
  }

  private double scoreGraphChange(Node y, Set<Node> parents1, Set<Node> parents2) {
    int yIndex = hashIndices.get(y);

    double score1, score2;

    int[] parentIndices1 = new int[parents1.size()];

    int count = -1;
    for (Node parent : parents1) {
      parentIndices1[++count] = hashIndices.get(parent);
    }

    if (isDiscrete()) {
      score1 = localDiscreteScore(yIndex, parentIndices1);
    } else {
      score1 = localSemScore(yIndex, parentIndices1);
    }

    int[] parentIndices2 = new int[parents2.size()];

    int count2 = -1;
    for (Node parent : parents2) {
      parentIndices2[++count2] = hashIndices.get(parent);
    }

    if (isDiscrete()) {
      score2 = localDiscreteScore(yIndex, parentIndices2);
    } else {
      score2 = localSemScore(yIndex, parentIndices2);
    }

    return score1 - score2;
  }

  /** Compute the local BDeu score of (i, parents(i)). See (Chickering, 2002). */
  private double localDiscreteScore(int i, int parents[]) {
    return getDiscreteScore().localScore(i, parents);
  }

  /**
   * Calculates the sample likelihood and BIC score for i given its parents in a simple SEM model.
   */
  private double localSemScore(int i, int[] parents) {
    try {
      ICovarianceMatrix cov = getCovMatrix();
      double varianceY = cov.getValue(i, i);
      double residualVariance = varianceY;
      int n = sampleSize();
      int p = parents.length;
      int k = (p * (p + 1)) / 2 + p;
      //            int k = (p + 1) * (p + 1);
      //            int k = p + 1;
      TetradMatrix covxx = cov.getSelection(parents, parents);
      TetradMatrix covxxInv = covxx.inverse();
      TetradVector covxy = cov.getSelection(parents, new int[] {i}).getColumn(0);
      TetradVector b = covxxInv.times(covxy);
      residualVariance -= covxy.dotProduct(b);

      if (residualVariance <= 0 && verbose) {
        out.println(
            "Nonpositive residual varianceY: resVar / varianceY = "
                + (residualVariance / varianceY));
        return Double.NaN;
      }

      double c = getPenaltyDiscount();

      //            return -n * log(residualVariance) - 2 * k; //AIC
      return -n * Math.log(residualVariance) - c * k * Math.log(n);
      //            return -n * log(residualVariance) - c * k * (log(n) - log(2 * PI));
    } catch (Exception e) {
      e.printStackTrace();
      throw new RuntimeException(e);
      //            throwMinimalLinearDependentSet(parents, cov);
    }
  }

  //    private void throwMinimalLinearDependentSet(int[] parents, TetradMatrix cov) {
  //        List<Node> _parents = new ArrayList<Node>();
  //        for (int p : parents) _parents.add(variables.get(p));
  //
  //        DepthChoiceGenerator gen = new DepthChoiceGenerator(_parents.size(), _parents.size());
  //        int[] choice;
  //
  //        while ((choice = gen.next()) != null) {
  //            int[] sel = new int[choice.length];
  //            List<Node> _sel = new ArrayList<Node>();
  //            for (int m = 0; m < choice.length; m++) {
  //                sel[m] = parents[m];
  //                _sel.add(variables.get(sel[m]));
  //            }
  //
  //            TetradMatrix m = cov.getSelection(sel, sel);
  //
  //            try {
  //                m.inverse();
  //            } catch (Exception e2) {
  //                throw new RuntimeException("Linear dependence among variables: " + _sel);
  //            }
  //        }
  //    }

  private int sampleSize() {
    return this.sampleSize;
  }

  private List<Node> getVariables() {
    return variables;
  }

  private ICovarianceMatrix getCovMatrix() {
    return covariances;
  }

  private DataSet dataSet() {
    return dataSet;
  }

  private boolean isDiscrete() {
    return discrete;
  }

  private void fireGraphChange(Graph graph) {
    for (PropertyChangeListener l : getListeners()) {
      l.propertyChange(new PropertyChangeEvent(this, "graph", null, graph));
    }
  }

  private List<PropertyChangeListener> getListeners() {
    if (listeners == null) {
      listeners = new ArrayList<PropertyChangeListener>();
    }
    return listeners;
  }

  private void storeGraph(Graph graph) {
    if (numPatternsToStore < 1) return;

    if (topGraphs.isEmpty() || score > topGraphs.first().getScore()) {
      Graph graphCopy = new EdgeListGraphSingleConnections(graph);

      topGraphs.add(new ScoredGraph(graphCopy, score));

      if (topGraphs.size() > getNumPatternsToStore()) {
        topGraphs.remove(topGraphs.first());
      }
    }
  }
}
예제 #9
0
파일: Lofs.java 프로젝트: jdramsey/tetrad
  private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) {
    if (RandomUtil.getInstance().nextDouble() > 0.5) {
      Node temp = x;
      x = y;
      y = temp;
    }

    TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y);

    SortedMap<Double, String> scoreReports = new TreeMap<Double, String>();

    List<Node> neighborsx = graph.getAdjacentNodes(x);
    neighborsx.remove(y);

    double max = Double.NEGATIVE_INFINITY;
    boolean left = false;
    boolean right = false;

    DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size());
    int[] choicex;

    while ((choicex = genx.next()) != null) {
      List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx);

      List<Node> condxPlus = new ArrayList<Node>(condxMinus);
      condxPlus.add(y);

      double xPlus = score(x, condxPlus);
      double xMinus = score(x, condxMinus);

      List<Node> neighborsy = graph.getAdjacentNodes(y);
      neighborsy.remove(x);

      DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size());
      int[] choicey;

      while ((choicey = geny.next()) != null) {
        List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy);

        //                List<Node> parentsY = oldGraph.getParents(y);
        //                parentsY.remove(x);
        //                if (!condyMinus.containsAll(parentsY)) {
        //                    continue;
        //                }

        List<Node> condyPlus = new ArrayList<Node>(condyMinus);
        condyPlus.add(x);

        double yPlus = score(y, condyPlus);
        double yMinus = score(y, condyMinus);

        // Checking them all at once is expensive but avoids lexical ordering problems in the
        // algorithm.
        if (normal(y, condyPlus)
            || normal(x, condxMinus)
            || normal(x, condxPlus)
            || normal(y, condyMinus)) {
          continue;
        }

        double delta = 0.0;

        if (strong) {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + y + "->" + x + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = true;
                right = false;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + x + "->" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = false;
                right = true;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        } else {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + y + "->" + x + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = true;
              right = false;
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + x + "->" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = false;
              right = true;
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        }
      }
    }

    for (double score : scoreReports.keySet()) {
      TetradLogger.getInstance().log("info", scoreReports.get(score));
    }

    graph.removeEdges(x, y);

    if (left) {
      graph.addDirectedEdge(y, x);
    }

    if (right) {
      graph.addDirectedEdge(x, y);
    }

    if (!graph.isAdjacentTo(x, y)) {
      graph.addUndirectedEdge(x, y);
    }
  }