예제 #1
0
  public static Graph weightedRandomGraph(int n, int e) {
    List<Node> nodes = new ArrayList<Node>();
    for (int i = 0; i < n; i++) nodes.add(new GraphNode("X" + i));

    Graph graph = new EdgeListGraph(nodes);

    for (int e0 = 0; e0 < e; e0++) {
      int i1 = weightedRandom(nodes, graph);
      //            int i2 = RandomUtil.getInstance().nextInt(n);
      int i2 = weightedRandom(nodes, graph);

      if (!(shortestPath(nodes.get(i1), nodes.get(i2), graph) < 9)) {
        e0--;
        continue;
      }

      if (i1 == i2) {
        e0--;
        continue;
      }

      Edge edge = Edges.undirectedEdge(nodes.get(i1), nodes.get(i2));

      if (graph.containsEdge(edge)) {
        e0--;
        continue;
      }

      graph.addEdge(edge);
    }

    for (Edge edge : graph.getEdges()) {
      Node n1 = edge.getNode1();
      Node n2 = edge.getNode2();

      if (!graph.isAncestorOf(n2, n1)) {
        graph.removeEdge(edge);
        graph.addDirectedEdge(n1, n2);
      } else {
        graph.removeEdge(edge);
        graph.addDirectedEdge(n2, n1);
      }
    }

    return graph;
  }
  /**
   * Transforms a maximally directed pattern (PDAG) represented in graph <code>g</code> into an
   * arbitrary DAG by modifying <code>g</code> itself. Based on the algorithm described in
   * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine
   * Learning Research. R. Silva, June 2004
   */
  public static void pdagToDag(Graph g) {
    Graph p = new EdgeListGraph(g);
    List<Edge> undirectedEdges = new ArrayList<Edge>();

    for (Edge edge : g.getEdges()) {
      if (edge.getEndpoint1() == Endpoint.TAIL
          && edge.getEndpoint2() == Endpoint.TAIL
          && !undirectedEdges.contains(edge)) {
        undirectedEdges.add(edge);
      }
    }
    g.removeEdges(undirectedEdges);
    List<Node> pNodes = p.getNodes();

    do {
      Node x = null;

      for (Node pNode : pNodes) {
        x = pNode;

        if (p.getChildren(x).size() > 0) {
          continue;
        }

        Set<Node> neighbors = new HashSet<Node>();

        for (Edge edge : p.getEdges()) {
          if (edge.getNode1() == x || edge.getNode2() == x) {
            if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL) {
              if (edge.getNode1() == x) {
                neighbors.add(edge.getNode2());
              } else {
                neighbors.add(edge.getNode1());
              }
            }
          }
        }
        if (neighbors.size() > 0) {
          Collection<Node> parents = p.getParents(x);
          Set<Node> all = new HashSet<Node>(neighbors);
          all.addAll(parents);
          if (!GraphUtils.isClique(all, p)) {
            continue;
          }
        }

        for (Node neighbor : neighbors) {
          Node node1 = g.getNode(neighbor.getName());
          Node node2 = g.getNode(x.getName());

          g.addDirectedEdge(node1, node2);
        }
        p.removeNode(x);
        break;
      }
      pNodes.remove(x);
    } while (pNodes.size() > 0);
  }
  private double getPMulticluster(List<List<Integer>> clusters, int numRestarts) {
    if (false) {
      Graph g = new EdgeListGraph();
      List<Node> latents = new ArrayList<Node>();
      for (int i = 0; i < clusters.size(); i++) {
        GraphNode latent = new GraphNode("L" + i);
        latent.setNodeType(NodeType.LATENT);
        latents.add(latent);
        g.addNode(latent);

        List<Node> cluster = variablesForIndices(clusters.get(i));

        for (int j = 0; j < cluster.size(); j++) {
          g.addNode(cluster.get(j));
          g.addDirectedEdge(latent, cluster.get(j));
        }
      }
      SemPm pm = new SemPm(g);

      //            pm.fixOneLoadingPerLatent();

      SemOptimizerPowell semOptimizer = new SemOptimizerPowell();
      semOptimizer.setNumRestarts(numRestarts);

      SemEstimator est = new SemEstimator(cov, pm, semOptimizer);
      est.setScoreType(SemIm.ScoreType.Fgls);
      est.estimate();
      return est.getEstimatedSem().getPValue();
    } else {
      double max = Double.NEGATIVE_INFINITY;

      for (int i = 0; i < numRestarts; i++) {
        Mimbuild2 mimbuild = new Mimbuild2();

        List<List<Node>> _clusters = new ArrayList<List<Node>>();

        for (List<Integer> _cluster : clusters) {
          _clusters.add(variablesForIndices(_cluster));
        }

        List<String> names = new ArrayList<String>();

        for (int j = 0; j < clusters.size(); j++) {
          names.add("L" + j);
        }

        mimbuild.search(_clusters, names, cov);

        double c = mimbuild.getpValue();
        if (c > max) max = c;
      }

      return max;
    }
  }
  public static Graph bestGuessCycleOrientation(Graph graph, IndependenceTest test) {
    while (true) {
      List<Node> cycle = GraphUtils.directedCycle(graph);

      if (cycle == null) {
        break;
      }

      LinkedList<Node> _cycle = new LinkedList<Node>(cycle);

      Node first = _cycle.getFirst();
      Node last = _cycle.getLast();

      _cycle.addFirst(last);
      _cycle.addLast(first);

      int _j = -1;
      double minP = Double.POSITIVE_INFINITY;

      for (int j = 1; j < _cycle.size() - 1; j++) {
        int i = j - 1;
        int k = j + 1;

        Node x = test.getVariable(_cycle.get(i).getName());
        Node y = test.getVariable(_cycle.get(j).getName());
        Node z = test.getVariable(_cycle.get(k).getName());

        test.isIndependent(x, z, Collections.singletonList(y));

        System.out.println("Testing " + x + " _||_ " + z + " | " + y);

        double p = test.getPValue();

        System.out.println("p = " + p);

        if (p < minP) {
          _j = j;
          minP = p;
        }
      }

      Node x = _cycle.get(_j - 1);
      Node y = _cycle.get(_j);
      Node z = _cycle.get(_j + 1);

      graph.removeEdge(x, y);
      graph.removeEdge(z, y);
      graph.addDirectedEdge(x, y);
      graph.addDirectedEdge(z, y);
    }

    return graph;
  }
  /** Orients according to background knowledge. */
  public static void pcOrientbk(Knowledge bk, Graph graph, List<Node> nodes) {
    TetradLogger.getInstance().log("info", "Staring BK Orientation.");
    for (Iterator<KnowledgeEdge> it = bk.forbiddenEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge edge = it.next();

      // match strings to variables in the graph.
      Node from = translate(edge.getFrom(), nodes);
      Node to = translate(edge.getTo(), nodes);

      if (from == null || to == null) {
        continue;
      }

      if (graph.getEdge(from, to) == null) {
        continue;
      }

      // Orient to-->from
      graph.removeEdge(from, to);
      graph.addDirectedEdge(from, to);
      graph.setEndpoint(from, to, Endpoint.TAIL);
      graph.setEndpoint(to, from, Endpoint.ARROW);

      TetradLogger.getInstance()
          .edgeOriented(SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(to, from)));
    }

    for (Iterator<KnowledgeEdge> it = bk.requiredEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge edge = it.next();

      // match strings to variables in this graph
      Node from = translate(edge.getFrom(), nodes);
      Node to = translate(edge.getTo(), nodes);

      if (from == null || to == null) {
        continue;
      }

      if (graph.getEdge(from, to) == null) {
        continue;
      }

      // Orient from-->to
      graph.setEndpoint(to, from, Endpoint.TAIL);
      graph.setEndpoint(from, to, Endpoint.ARROW);
      TetradLogger.getInstance()
          .edgeOriented(SearchLogUtils.edgeOrientedMsg("Knowledge", graph.getEdge(from, to)));
    }
    TetradLogger.getInstance().log("info", "Finishing BK Orientation.");
  }
예제 #6
0
  @Test
  public void test8() {
    RandomUtil.getInstance().setSeed(29999483L);

    Node x = new GraphNode("X");
    Node y = new GraphNode("Y");

    List<Node> nodes = new ArrayList<>();
    nodes.add(x);
    nodes.add(y);

    Graph graph = new EdgeListGraphSingleConnections(nodes);

    graph.addDirectedEdge(x, y);

    SemPm spm = new SemPm(graph);
    SemIm sim = new SemIm(spm);

    sim.setEdgeCoef(x, y, 20);
    sim.setErrVar(x, 1);
    sim.setErrVar(y, 1);

    GeneralizedSemPm pm = new GeneralizedSemPm(spm);
    GeneralizedSemIm im = new GeneralizedSemIm(pm, sim);

    print(im);

    try {
      pm.setParameterEstimationInitializationExpression("b1", "U(10, 30)");
      pm.setParameterEstimationInitializationExpression("T1", "U(.1, 3)");
      pm.setParameterEstimationInitializationExpression("T2", "U(.1, 3)");
    } catch (ParseException e) {
      e.printStackTrace();
    }

    DataSet data = im.simulateDataRecursive(1000, false);

    GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
    GeneralizedSemIm estIm = estimator.estimate(pm, data);

    print(estIm);
    //        System.out.println(estimator.getReport());

    double aSquaredStar = estimator.getaSquaredStar();

    assertEquals(0.69, aSquaredStar, 0.01);
  }
  private double getP(List<Integer> cluster, int numRestarts) {
    if (true) {
      Node latent = new GraphNode("L");
      latent.setNodeType(NodeType.LATENT);
      Graph g = new EdgeListGraph();
      g.addNode(latent);
      List<Node> measures = variablesForIndices(cluster);
      for (Node node : measures) {
        g.addNode(node);
        g.addDirectedEdge(latent, node);
      }
      SemPm pm = new SemPm(g);

      //            pm.fixOneLoadingPerLatent();

      SemOptimizerPowell semOptimizer = new SemOptimizerPowell();
      semOptimizer.setNumRestarts(numRestarts);

      SemEstimator est = new SemEstimator(cov, pm, semOptimizer);
      est.setScoreType(SemIm.ScoreType.Fgls);
      est.estimate();
      return est.getEstimatedSem().getPValue();
    } else {
      double max = Double.NEGATIVE_INFINITY;

      for (int i = 0; i < numRestarts; i++) {
        Mimbuild2 mimbuild = new Mimbuild2();

        List<List<Node>> clusters1 = new ArrayList<List<Node>>();
        clusters1.add(variablesForIndices(new ArrayList<Integer>(cluster)));

        List<String> names = new ArrayList<String>();
        names.add("L");

        mimbuild.search(clusters1, names, cov);

        double c = mimbuild.getpValue();
        if (c > max) max = c;
      }

      return max;
    }
  }
  private double getClusterP2(List<Node> c) {
    Graph g = new EdgeListGraph(c);
    Node l = new GraphNode("L");
    l.setNodeType(NodeType.LATENT);
    g.addNode(l);

    for (Node n : c) {
      g.addDirectedEdge(l, n);
    }

    SemPm pm = new SemPm(g);
    SemEstimator est;
    if (dataModel instanceof DataSet) {
      est = new SemEstimator((DataSet) dataModel, pm, new SemOptimizerEm());
    } else {
      est = new SemEstimator((CovarianceMatrix) dataModel, pm, new SemOptimizerEm());
    }
    SemIm estIm = est.estimate();
    double pValue = estIm.getPValue();
    return pValue == 1 ? Double.NaN : pValue;
  }
  private Graph convertSearchGraphNodes(Set<Set<Node>> clusters) {
    Graph graph = new EdgeListGraph(variables);

    List<Node> latents = new ArrayList<Node>();
    for (int i = 0; i < clusters.size(); i++) {
      Node latent = new GraphNode(MimBuild.LATENT_PREFIX + (i + 1));
      latent.setNodeType(NodeType.LATENT);
      latents.add(latent);
      graph.addNode(latent);
    }

    List<Set<Node>> _clusters = new ArrayList<Set<Node>>(clusters);

    for (int i = 0; i < latents.size(); i++) {
      for (Node node : _clusters.get(i)) {
        if (!graph.containsNode(node)) graph.addNode(node);
        graph.addDirectedEdge(latents.get(i), node);
      }
    }

    return graph;
  }
예제 #10
0
  public void rtest3() {
    Node x = new GraphNode("X");
    Node y = new GraphNode("Y");
    Node z = new GraphNode("Z");
    Node w = new GraphNode("W");

    List<Node> nodes = new ArrayList<Node>();
    nodes.add(x);
    nodes.add(y);
    nodes.add(z);
    nodes.add(w);

    Graph g = new EdgeListGraph(nodes);
    g.addDirectedEdge(x, y);
    g.addDirectedEdge(x, z);
    g.addDirectedEdge(y, w);
    g.addDirectedEdge(z, w);

    Graph maxGraph = null;
    double maxPValue = -1.0;
    ICovarianceMatrix maxLatentCov = null;

    Graph mim = DataGraphUtils.randomMim(g, 8, 0, 0, 0, true);
    //        Graph mim = DataGraphUtils.randomSingleFactorModel(5, 5, 8, 0, 0, 0);
    Graph mimStructure = structure(mim);
    SemPm pm = new SemPm(mim);

    System.out.println("\n\nTrue graph:");
    System.out.println(mimStructure);

    SemImInitializationParams params = new SemImInitializationParams();
    params.setCoefRange(0.5, 1.5);

    SemIm im = new SemIm(pm, params);

    int N = 1000;

    DataSet data = im.simulateData(N, false);

    CovarianceMatrix cov = new CovarianceMatrix(data);

    for (int i = 0; i < 1; i++) {

      ICovarianceMatrix _cov = DataUtils.reorderColumns(cov);
      List<List<Node>> partition;

      FindOneFactorClusters fofc = new FindOneFactorClusters(_cov, TestType.TETRAD_WISHART, .001);
      fofc.search();
      partition = fofc.getClusters();
      System.out.println(partition);

      List<String> latentVarList = reidentifyVariables(mim, data, partition, 2);

      Mimbuild2 mimbuild = new Mimbuild2();

      mimbuild.setAlpha(0.001);
      //            mimbuild.setMinimumSize(5);

      // To test knowledge.
      //            Knowledge knowledge = new Knowledge2();
      //            knowledge.setEdgeForbidden("L.Y", "L.W", true);
      //            knowledge.setEdgeRequired("L.Y", "L.Z", true);
      //            mimbuild.setKnowledge(knowledge);

      Graph mimbuildStructure = mimbuild.search(partition, latentVarList, _cov);

      double pValue = mimbuild.getpValue();
      System.out.println(mimbuildStructure);
      System.out.println("P = " + pValue);
      System.out.println("Latent Cov = " + mimbuild.getLatentsCov());

      if (pValue > maxPValue) {
        maxPValue = pValue;
        maxGraph = new EdgeListGraph(mimbuildStructure);
        maxLatentCov = mimbuild.getLatentsCov();
      }
    }

    System.out.println("\n\nTrue graph:");
    System.out.println(mimStructure);
    System.out.println("\nBest graph:");
    System.out.println(maxGraph);
    System.out.println("P = " + maxPValue);
    System.out.println("Latent Cov = " + maxLatentCov);
    System.out.println();
  }
예제 #11
0
  private void addRequiredEdges(Graph graph) {
    if (true) return;
    if (knowledgeEmpty()) return;

    for (Iterator<KnowledgeEdge> it = getKnowledge().requiredEdgesIterator(); it.hasNext(); ) {
      KnowledgeEdge next = it.next();

      Node nodeA = graph.getNode(next.getFrom());
      Node nodeB = graph.getNode(next.getTo());

      if (!graph.isAncestorOf(nodeB, nodeA)) {
        graph.removeEdges(nodeA, nodeB);
        graph.addDirectedEdge(nodeA, nodeB);
        TetradLogger.getInstance()
            .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeA, nodeB));
      }
    }
    for (Edge edge : graph.getEdges()) {
      final String A = edge.getNode1().getName();
      final String B = edge.getNode2().getName();

      if (knowledge.isForbidden(A, B)) {
        Node nodeA = edge.getNode1();
        Node nodeB = edge.getNode2();

        if (nodeA != null
            && nodeB != null
            && graph.isAdjacentTo(nodeA, nodeB)
            && !graph.isChildOf(nodeA, nodeB)) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
        if (!graph.isChildOf(nodeA, nodeB)
            && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
      } else if (knowledge.isForbidden(B, A)) {
        Node nodeA = edge.getNode2();
        Node nodeB = edge.getNode1();

        if (nodeA != null
            && nodeB != null
            && graph.isAdjacentTo(nodeA, nodeB)
            && !graph.isChildOf(nodeA, nodeB)) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
        if (!graph.isChildOf(nodeA, nodeB)
            && getKnowledge().isForbidden(nodeA.getName(), nodeB.getName())) {
          if (!graph.isAncestorOf(nodeA, nodeB)) {
            graph.removeEdges(nodeA, nodeB);
            graph.addDirectedEdge(nodeB, nodeA);
            TetradLogger.getInstance()
                .log("insertedEdges", "Adding edge by knowledge: " + graph.getEdge(nodeB, nodeA));
          }
        }
      }
    }
  }
예제 #12
0
  /** Do an actual deletion (Definition 13 from Chickering, 2002). */
  private void delete(Node x, Node y, List<Node> subset, Graph graph, double bump) {

    Edge trueEdge = null;

    if (trueGraph != null) {
      Node _x = trueGraph.getNode(x.getName());
      Node _y = trueGraph.getNode(y.getName());
      trueEdge = trueGraph.getEdge(_x, _y);
    }

    if (log && verbose) {
      Edge oldEdge = graph.getEdge(x, y);

      String label = trueGraph != null && trueEdge != null ? "*" : "";
      TetradLogger.getInstance()
          .log(
              "deletedEdges",
              (graph.getNumEdges() - 1)
                  + ". DELETE "
                  + oldEdge
                  + " "
                  + subset
                  + " ("
                  + bump
                  + ") "
                  + label);
      out.println(
          (graph.getNumEdges() - 1)
              + ". DELETE "
              + oldEdge
              + " "
              + subset
              + " ("
              + bump
              + ") "
              + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (numEdges % 50 == 0) out.println(numEdges);
    }

    graph.removeEdge(x, y);

    for (Node h : subset) {
      Edge oldEdge = graph.getEdge(y, h);

      graph.removeEdge(y, h);
      graph.addDirectedEdge(y, h);

      if (log) {
        TetradLogger.getInstance()
            .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
      }

      if (verbose) {
        out.println("--- Directing " + oldEdge + " to " + graph.getEdge(y, h));
      }

      if (Edges.isUndirectedEdge(graph.getEdge(x, h))) {
        if (!graph.isAdjacentTo(x, h))
          throw new IllegalArgumentException("Not adjacent: " + x + ", " + h);
        oldEdge = graph.getEdge(x, h);

        graph.removeEdge(x, h);
        graph.addDirectedEdge(x, h);

        if (log) {
          TetradLogger.getInstance()
              .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
        }

        if (verbose) {
          out.println("--- Directing " + oldEdge + " to " + graph.getEdge(x, h));
        }
      }
    }
  }
예제 #13
0
  // serial.
  private void insert(Node x, Node y, List<Node> t, Graph graph, double bump) {
    if (graph.isAdjacentTo(x, y)) {
      return; // The initial graph may already have put this edge in the graph.
      //            throw new IllegalArgumentException(x + " and " + y + " are already adjacent in
      // the graph.");
    }

    Edge trueEdge = null;

    if (trueGraph != null) {
      Node _x = trueGraph.getNode(x.getName());
      Node _y = trueGraph.getNode(y.getName());
      trueEdge = trueGraph.getEdge(_x, _y);
    }

    graph.addDirectedEdge(x, y);

    if (log) {
      String label = trueGraph != null && trueEdge != null ? "*" : "";
      TetradLogger.getInstance()
          .log(
              "insertedEdges",
              graph.getNumEdges()
                  + ". INSERT "
                  + graph.getEdge(x, y)
                  + " "
                  + t
                  + " "
                  + bump
                  + " "
                  + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (verbose) {
        if (numEdges % 50 == 0) out.println(numEdges);
      }
    }

    if (verbose) {
      String label = trueGraph != null && trueEdge != null ? "*" : "";
      out.println(
          graph.getNumEdges()
              + ". INSERT "
              + graph.getEdge(x, y)
              + " "
              + t
              + " "
              + bump
              + " "
              + label);
    } else {
      int numEdges = graph.getNumEdges() - 1;
      if (verbose) {
        if (numEdges % 50 == 0) out.println(numEdges);
      }
    }

    for (Node _t : t) {
      Edge oldEdge = graph.getEdge(_t, y);

      if (oldEdge == null) throw new IllegalArgumentException("Not adjacent: " + _t + ", " + y);

      graph.removeEdge(_t, y);
      graph.addDirectedEdge(_t, y);

      if (log && verbose) {
        TetradLogger.getInstance()
            .log("directedEdges", "--- Directing " + oldEdge + " to " + graph.getEdge(_t, y));
        out.println("--- Directing " + oldEdge + " to " + graph.getEdge(_t, y));
      }
    }
  }
예제 #14
0
파일: Lofs.java 프로젝트: jdramsey/tetrad
  private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong, Graph oldGraph) {
    if (RandomUtil.getInstance().nextDouble() > 0.5) {
      Node temp = x;
      x = y;
      y = temp;
    }

    TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y);

    SortedMap<Double, String> scoreReports = new TreeMap<Double, String>();

    List<Node> neighborsx = graph.getAdjacentNodes(x);
    neighborsx.remove(y);

    double max = Double.NEGATIVE_INFINITY;
    boolean left = false;
    boolean right = false;

    DepthChoiceGenerator genx = new DepthChoiceGenerator(neighborsx.size(), neighborsx.size());
    int[] choicex;

    while ((choicex = genx.next()) != null) {
      List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx);

      List<Node> condxPlus = new ArrayList<Node>(condxMinus);
      condxPlus.add(y);

      double xPlus = score(x, condxPlus);
      double xMinus = score(x, condxMinus);

      List<Node> neighborsy = graph.getAdjacentNodes(y);
      neighborsy.remove(x);

      DepthChoiceGenerator geny = new DepthChoiceGenerator(neighborsy.size(), neighborsy.size());
      int[] choicey;

      while ((choicey = geny.next()) != null) {
        List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy);

        //                List<Node> parentsY = oldGraph.getParents(y);
        //                parentsY.remove(x);
        //                if (!condyMinus.containsAll(parentsY)) {
        //                    continue;
        //                }

        List<Node> condyPlus = new ArrayList<Node>(condyMinus);
        condyPlus.add(x);

        double yPlus = score(y, condyPlus);
        double yMinus = score(y, condyMinus);

        // Checking them all at once is expensive but avoids lexical ordering problems in the
        // algorithm.
        if (normal(y, condyPlus)
            || normal(x, condxMinus)
            || normal(x, condxPlus)
            || normal(y, condyMinus)) {
          continue;
        }

        double delta = 0.0;

        if (strong) {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            if (yPlus <= yMinus + delta && xMinus <= xPlus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + y + "->" + x + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = true;
                right = false;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            if (yMinus <= yPlus + delta && xPlus <= xMinus + delta) {
              StringBuilder builder = new StringBuilder();

              builder.append("\nStrong " + x + "->" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());

              if (score > max) {
                max = score;
                left = false;
                right = true;
              }
            } else {
              StringBuilder builder = new StringBuilder();

              builder.append("\nNo directed edge " + x + "--" + y + " " + score);
              builder.append("\n   Parents(" + x + ") = " + condxMinus);
              builder.append("\n   Parents(" + y + ") = " + condyMinus);

              scoreReports.put(-score, builder.toString());
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        } else {
          if (yPlus <= xPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(xPlus, yMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + y + "->" + x + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = true;
              right = false;
            }
          } else if (xPlus <= yPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nWeak " + x + "->" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());

            if (score > max) {
              max = score;
              left = false;
              right = true;
            }
          } else if (yPlus <= xPlus + delta && yMinus <= xMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          } else if (xPlus <= yPlus + delta && xMinus <= yMinus + delta) {
            double score = combinedScore(yPlus, xMinus);

            StringBuilder builder = new StringBuilder();

            builder.append("\nNo directed edge " + x + "--" + y + " " + score);
            builder.append("\n   Parents(" + x + ") = " + condxMinus);
            builder.append("\n   Parents(" + y + ") = " + condyMinus);

            scoreReports.put(-score, builder.toString());
          }
        }
      }
    }

    for (double score : scoreReports.keySet()) {
      TetradLogger.getInstance().log("info", scoreReports.get(score));
    }

    graph.removeEdges(x, y);

    if (left) {
      graph.addDirectedEdge(y, x);
    }

    if (right) {
      graph.addDirectedEdge(x, y);
    }

    if (!graph.isAdjacentTo(x, y)) {
      graph.addUndirectedEdge(x, y);
    }
  }
예제 #15
0
  @Test
  public void test15() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * X1 + E_X2");
      pm.setNodeExpression(x3, "a2 * X2 + E_X3");
      pm.setNodeExpression(x4, "a3 * X1 + a4 * X3 ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "Gamma(c1, c2)");
      pm.setNodeExpression(pm.getErrorNode(x2), "ChiSquare(c3)");
      pm.setNodeExpression(pm.getErrorNode(x3), "ChiSquare(c4)");
      pm.setNodeExpression(pm.getErrorNode(x4), "ChiSquare(c5)");

      pm.setParameterExpression("c1", "5");
      pm.setParameterExpression("c2", "2");
      pm.setParameterExpression("c3", "10");
      pm.setParameterExpression("c4", "10");
      pm.setParameterExpression("c5", "10");

      pm.setParameterEstimationInitializationExpression("c1", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c2", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c3", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c4", "U(1, 5)");
      pm.setParameterEstimationInitializationExpression("c5", "U(1, 5)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(.79, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
예제 #16
0
  @Test
  public void test14() {
    RandomUtil.getInstance().setSeed(29999483L);

    try {
      Node x1 = new GraphNode("X1");
      Node x2 = new GraphNode("X2");
      Node x3 = new GraphNode("X3");
      Node x4 = new GraphNode("X4");

      Graph g = new EdgeListGraphSingleConnections();
      g.addNode(x1);
      g.addNode(x2);
      g.addNode(x3);
      g.addNode(x4);

      g.addDirectedEdge(x1, x2);
      g.addDirectedEdge(x2, x3);
      g.addDirectedEdge(x3, x4);
      g.addDirectedEdge(x1, x4);

      GeneralizedSemPm pm = new GeneralizedSemPm(g);

      pm.setNodeExpression(x1, "E_X1");
      pm.setNodeExpression(x2, "a1 * tan(X1) + E_X2");
      pm.setNodeExpression(x3, "a2 * tan(X2) + E_X3");
      pm.setNodeExpression(x4, "a3 * tan(X1) + a4 * tan(X3) ^ 2 + E_X4");

      pm.setNodeExpression(pm.getErrorNode(x1), "N(0, c1)");
      pm.setNodeExpression(pm.getErrorNode(x2), "N(0, c2)");
      pm.setNodeExpression(pm.getErrorNode(x3), "N(0, c3)");
      pm.setNodeExpression(pm.getErrorNode(x4), "N(0, c4)");

      pm.setParameterExpression("a1", "1");
      pm.setParameterExpression("a2", "1");
      pm.setParameterExpression("a3", "1");
      pm.setParameterExpression("a4", "1");
      pm.setParameterExpression("c1", "4");
      pm.setParameterExpression("c2", "4");
      pm.setParameterExpression("c3", "4");
      pm.setParameterExpression("c4", "4");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print("True model: ");
      print(im);

      DataSet data = im.simulateDataRecursive(1000, false);

      GeneralizedSemIm imInit = new GeneralizedSemIm(pm);
      imInit.setParameterValue("c1", 8);
      imInit.setParameterValue("c2", 8);
      imInit.setParameterValue("c3", 8);
      imInit.setParameterValue("c4", 8);

      GeneralizedSemEstimator estimator = new GeneralizedSemEstimator();
      GeneralizedSemIm estIm = estimator.estimate(pm, data);

      print("\n\n\nEstimated model: ");
      print(estIm);
      print(estimator.getReport());

      double aSquaredStar = estimator.getaSquaredStar();

      assertEquals(71.25, aSquaredStar, 0.01);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }