コード例 #1
0
ファイル: Knowledge2.java プロジェクト: ps7z/tetrad
  /** Sets the variable in a given tier to the specified list. */
  public void setTier(int tier, List<String> vars) {
    ensureTiers(tier);
    Set<MyNode> _tier = tierSpecs.get(tier);
    if (_tier != null) _tier.clear();

    for (String var : vars) {
      addToTier(tier, var);
    }
  }
コード例 #2
0
  private Set<Set<Integer>> finishESeeds(Set<Set<Integer>> ESeeds) {
    log("Growing Effect Seeds.", true);
    Set<Set<Integer>> grown = new HashSet<Set<Integer>>();

    List<Integer> _variables = new ArrayList<Integer>();
    for (int i = 0; i < variables.size(); i++) _variables.add(i);

    // Lax grow phase with speedup.
    if (algType == AlgType.lax) {
      Set<Integer> t = new HashSet<Integer>();
      int count = 0;
      int total = ESeeds.size();

      do {
        if (!ESeeds.iterator().hasNext()) {
          break;
        }

        Set<Integer> cluster = ESeeds.iterator().next();
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
          int rejected = 0;
          int accepted = 0;

          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            t.clear();
            t.add(n1);
            t.add(n2);
            t.add(o);

            if (!ESeeds.contains(t)) {
              rejected++;
            } else {
              accepted++;
            }
          }

          if (rejected > accepted) {
            continue;
          }

          _cluster.add(o);

          //                    if (!(avgSumLnP(new ArrayList<Integer>(_cluster)) > -10)) {
          //                        _cluster.remove(o);
          //                    }
        }

        // This takes out all pure clusters that are subsets of _cluster.
        ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
        int[] choice2;
        List<Integer> _cluster3 = new ArrayList<Integer>(_cluster);

        while ((choice2 = gen2.next()) != null) {
          int n1 = _cluster3.get(choice2[0]);
          int n2 = _cluster3.get(choice2[1]);
          int n3 = _cluster3.get(choice2[2]);

          t.clear();
          t.add(n1);
          t.add(n2);
          t.add(n3);

          ESeeds.remove(t);
        }

        if (verbose) {
          System.out.println(
              "Grown "
                  + (++count)
                  + " of "
                  + total
                  + ": "
                  + variablesForIndices(new ArrayList<Integer>(_cluster)));
        }
        grown.add(_cluster);
      } while (!ESeeds.isEmpty());
    }

    // Lax grow phase without speedup.
    if (algType == AlgType.laxWithSpeedup) {
      int count = 0;
      int total = ESeeds.size();

      // Optimized lax version of grow phase.
      for (Set<Integer> cluster : new HashSet<Set<Integer>>(ESeeds)) {
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);
          int rejected = 0;
          int accepted = 0;
          //
          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            Set<Integer> triple = triple(n1, n2, o);

            if (!ESeeds.contains(triple)) {
              rejected++;
            } else {
              accepted++;
            }
          }
          //
          if (rejected > accepted) {
            continue;
          }

          //                    System.out.println("Adding " + o  + " to " + cluster);
          _cluster.add(o);
        }

        for (Set<Integer> c : new HashSet<Set<Integer>>(ESeeds)) {
          if (_cluster.containsAll(c)) {
            ESeeds.remove(c);
          }
        }

        if (verbose) {
          System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
        }

        grown.add(_cluster);
      }
    }

    // Strict grow phase.
    if (algType == AlgType.strict) {
      Set<Integer> t = new HashSet<Integer>();
      int count = 0;
      int total = ESeeds.size();

      do {
        if (!ESeeds.iterator().hasNext()) {
          break;
        }

        Set<Integer> cluster = ESeeds.iterator().next();
        Set<Integer> _cluster = new HashSet<Integer>(cluster);

        if (extraShuffle) {
          Collections.shuffle(_variables);
        }

        VARIABLES:
        for (int o : _variables) {
          if (_cluster.contains(o)) continue;

          List<Integer> _cluster2 = new ArrayList<Integer>(_cluster);

          ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2);
          int[] choice;

          while ((choice = gen.next()) != null) {
            int n1 = _cluster2.get(choice[0]);
            int n2 = _cluster2.get(choice[1]);

            t.clear();
            t.add(n1);
            t.add(n2);
            t.add(o);

            if (!ESeeds.contains(t)) {
              continue VARIABLES;
            }

            //                        if (avgSumLnP(new ArrayList<Integer>(t)) < -10) continue
            // CLUSTER;
          }

          _cluster.add(o);
        }

        // This takes out all pure clusters that are subsets of _cluster.
        ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3);
        int[] choice2;
        List<Integer> _cluster3 = new ArrayList<Integer>(_cluster);

        while ((choice2 = gen2.next()) != null) {
          int n1 = _cluster3.get(choice2[0]);
          int n2 = _cluster3.get(choice2[1]);
          int n3 = _cluster3.get(choice2[2]);

          t.clear();
          t.add(n1);
          t.add(n2);
          t.add(n3);

          ESeeds.remove(t);
        }

        if (verbose) {
          System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster);
        }
        grown.add(_cluster);
      } while (!ESeeds.isEmpty());
    }

    // Optimized pick phase.
    log("Choosing among grown Effect Clusters.", true);

    for (Set<Integer> l : grown) {
      ArrayList<Integer> _l = new ArrayList<Integer>(l);
      Collections.sort(_l);
      if (verbose) {
        log("Grown: " + variablesForIndices(_l), false);
      }
    }

    Set<Set<Integer>> out = new HashSet<Set<Integer>>();

    List<Set<Integer>> list = new ArrayList<Set<Integer>>(grown);

    //        final Map<Set<Integer>, Double> pValues = new HashMap<Set<Integer>, Double>();
    //
    //        for (Set<Integer> o : grown) {
    //            pValues.put(o, getP(new ArrayList<Integer>(o)));
    //        }

    Collections.sort(
        list,
        new Comparator<Set<Integer>>() {
          @Override
          public int compare(Set<Integer> o1, Set<Integer> o2) {
            //                if (o1.size() == o2.size()) {
            //                    double chisq1 = pValues.get(o1);
            //                    double chisq2 = pValues.get(o2);
            //                    return Double.compare(chisq2, chisq1);
            //                }

            return o2.size() - o1.size();
          }
        });

    //        for (Set<Integer> o : list) {
    //            if (pValues.get(o) < alpha) continue;
    //            System.out.println(variablesForIndices(new ArrayList<Integer>(o)) + "  p = " +
    // pValues.get(o));
    //        }

    Set<Integer> all = new HashSet<Integer>();

    CLUSTER:
    for (Set<Integer> cluster : list) {
      //            if (pValues.get(cluster) < alpha) continue;

      for (Integer i : cluster) {
        if (all.contains(i)) continue CLUSTER;
      }

      out.add(cluster);

      //            if (getPMulticluster(out) < alpha) {
      //                out.remove(cluster);
      //                continue;
      //            }

      all.addAll(cluster);
    }

    return out;
  }