private Set<Integer> unionPure(Set<Set<Integer>> pureClusters) { Set<Integer> unionPure = new HashSet<Integer>(); for (Set<Integer> cluster : pureClusters) { unionPure.addAll(cluster); } return unionPure; }
private Set<Integer> triple(int n1, int n2, int n3) { Set<Integer> triple = new HashSet<Integer>(); triple.add(n1); triple.add(n2); triple.add(n3); if (triple.size() < 3) throw new IllegalArgumentException( "Triple elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ">"); return triple; }
private boolean quartetVanishes(Set<Integer> quartet) { if (quartet.size() != 4) throw new IllegalArgumentException("Expecting a quartet, size = " + quartet.size()); Iterator<Integer> iter = quartet.iterator(); int x = iter.next(); int y = iter.next(); int z = iter.next(); int w = iter.next(); return testVanishing(x, y, z, w); }
private Set<Integer> quartet(int x, int y, int z, int w) { Set<Integer> set = new HashSet<Integer>(); set.add(x); set.add(y); set.add(z); set.add(w); if (set.size() < 4) throw new IllegalArgumentException( "Quartet elements must be unique: <" + x + ", " + y + ", " + z + ", " + w + ">"); return set; }
private Graph convertToGraph(Set<Set<Integer>> allClusters) { Set<Set<Node>> _clustering = new HashSet<Set<Node>>(); for (Set<Integer> cluster : allClusters) { Set<Node> nodes = new HashSet<Node>(); for (int i : cluster) { nodes.add(variables.get(i)); } _clustering.add(nodes); } return convertSearchGraphNodes(_clustering); }
private boolean clique(Set<Integer> cluster, Map<Node, Set<Node>> adjacencies) { List<Integer> _cluster = new ArrayList<Integer>(cluster); for (int i = 0; i < cluster.size(); i++) { for (int j = i + 1; j < cluster.size(); j++) { Node nodei = variables.get(_cluster.get(i)); Node nodej = variables.get(_cluster.get(j)); if (!adjacencies.get(nodei).contains(nodej)) { return false; } } } return true; }
private boolean pure(Set<Integer> quartet, List<Integer> variables) { if (quartetVanishes(quartet)) { for (int o : variables) { if (quartet.contains(o)) continue; for (int p : quartet) { Set<Integer> _quartet = new HashSet<Integer>(quartet); _quartet.remove(p); _quartet.add(o); if (!quartetVanishes(_quartet)) { return false; } } } return significant(new ArrayList<Integer>(quartet)); } return false; }
@Test public void test1() { GeneralizedSemPm pm = makeTypicalPm(); print(pm); Node x1 = pm.getNode("X1"); Node x2 = pm.getNode("X2"); Node x3 = pm.getNode("X3"); Node x4 = pm.getNode("X4"); Node x5 = pm.getNode("X5"); SemGraph graph = pm.getGraph(); List<Node> variablesNodes = pm.getVariableNodes(); print(variablesNodes); List<Node> errorNodes = pm.getErrorNodes(); print(errorNodes); try { pm.setNodeExpression(x1, "cos(B1) + E_X1"); print(pm); String b1 = "B1"; String b2 = "B2"; String b3 = "B3"; Set<Node> nodes = pm.getReferencingNodes(b1); assertTrue(nodes.contains(x1)); assertTrue(!nodes.contains(x2) && !nodes.contains(x2)); Set<String> referencedParameters = pm.getReferencedParameters(x3); print("Parameters referenced by X3 are: " + referencedParameters); assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2)); assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3))); Node e_x3 = pm.getNode("E_X3"); // for (Node node : pm.getNodes()) { Set<Node> referencingNodes = pm.getReferencingNodes(node); print("Nodes referencing " + node + " are: " + referencingNodes); } for (Node node : pm.getVariableNodes()) { Set<Node> referencingNodes = pm.getReferencedNodes(node); print("Nodes referenced by " + node + " are: " + referencingNodes); } Set<Node> referencingX3 = pm.getReferencingNodes(x3); assertTrue(referencingX3.contains(x4)); assertTrue(!referencingX3.contains(x5)); Set<Node> referencedByX3 = pm.getReferencedNodes(x3); assertTrue( referencedByX3.contains(x1) && referencedByX3.contains(x2) && referencedByX3.contains(e_x3) && !referencedByX3.contains(x4)); pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5"); Node e_x5 = pm.getErrorNode(x5); graph.setShowErrorTerms(true); assertTrue(e_x5.equals(graph.getExogenous(x5))); pm.setNodeExpression(e_x5, "Beta(3, 5)"); print(pm); assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1)); pm.setParameterExpression(b1, "N(0, 2)"); assertEquals("N(0, 2)", pm.getParameterExpressionString(b1)); GeneralizedSemIm im = new GeneralizedSemIm(pm); print(im); DataSet dataSet = im.simulateDataAvoidInfinity(10, false); print(dataSet); } catch (ParseException e) { e.printStackTrace(); } }
@Test public void test4() { // For X3 Map<String, String[]> templates = new HashMap<>(); templates.put( "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {}); templates.put("$", new String[] {}); templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"}); for (String template : templates.keySet()) { GeneralizedSemPm semPm = makeTypicalPm(); print(semPm.getGraph().toString()); Set<Node> shouldWork = new HashSet<>(); for (String name : templates.get(template)) { shouldWork.add(semPm.getNode(name)); } Set<Node> works = new HashSet<>(); for (int i = 0; i < semPm.getNodes().size(); i++) { print("-----------"); print(semPm.getNodes().get(i).toString()); print("Trying template: " + template); String _template = template; Node node = semPm.getNodes().get(i); try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setNodeExpression(node, _template); print("Set formula " + _template + " for " + node); if (semPm.getVariableNodes().contains(node)) { works.add(node); } } catch (Exception e) { print("Couldn't set formula " + _template + " for " + node); } } for (String parameter : semPm.getParameters()) { print("-----------"); print(parameter); print("Trying template: " + template); String _template = template; try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setParameterExpression(parameter, _template); print("Set formula " + _template + " for " + parameter); } catch (Exception e) { print("Couldn't set formula " + _template + " for " + parameter); } } assertEquals(shouldWork, works); } }
private Set<Set<Integer>> finishESeeds(Set<Set<Integer>> ESeeds) { log("Growing Effect Seeds.", true); Set<Set<Integer>> grown = new HashSet<Set<Integer>>(); List<Integer> _variables = new ArrayList<Integer>(); for (int i = 0; i < variables.size(); i++) _variables.add(i); // Lax grow phase with speedup. if (algType == AlgType.lax) { Set<Integer> t = new HashSet<Integer>(); int count = 0; int total = ESeeds.size(); do { if (!ESeeds.iterator().hasNext()) { break; } Set<Integer> cluster = ESeeds.iterator().next(); Set<Integer> _cluster = new HashSet<Integer>(cluster); if (extraShuffle) { Collections.shuffle(_variables); } for (int o : _variables) { if (_cluster.contains(o)) continue; List<Integer> _cluster2 = new ArrayList<Integer>(_cluster); int rejected = 0; int accepted = 0; ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2); int[] choice; while ((choice = gen.next()) != null) { int n1 = _cluster2.get(choice[0]); int n2 = _cluster2.get(choice[1]); t.clear(); t.add(n1); t.add(n2); t.add(o); if (!ESeeds.contains(t)) { rejected++; } else { accepted++; } } if (rejected > accepted) { continue; } _cluster.add(o); // if (!(avgSumLnP(new ArrayList<Integer>(_cluster)) > -10)) { // _cluster.remove(o); // } } // This takes out all pure clusters that are subsets of _cluster. ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3); int[] choice2; List<Integer> _cluster3 = new ArrayList<Integer>(_cluster); while ((choice2 = gen2.next()) != null) { int n1 = _cluster3.get(choice2[0]); int n2 = _cluster3.get(choice2[1]); int n3 = _cluster3.get(choice2[2]); t.clear(); t.add(n1); t.add(n2); t.add(n3); ESeeds.remove(t); } if (verbose) { System.out.println( "Grown " + (++count) + " of " + total + ": " + variablesForIndices(new ArrayList<Integer>(_cluster))); } grown.add(_cluster); } while (!ESeeds.isEmpty()); } // Lax grow phase without speedup. if (algType == AlgType.laxWithSpeedup) { int count = 0; int total = ESeeds.size(); // Optimized lax version of grow phase. for (Set<Integer> cluster : new HashSet<Set<Integer>>(ESeeds)) { Set<Integer> _cluster = new HashSet<Integer>(cluster); if (extraShuffle) { Collections.shuffle(_variables); } for (int o : _variables) { if (_cluster.contains(o)) continue; List<Integer> _cluster2 = new ArrayList<Integer>(_cluster); int rejected = 0; int accepted = 0; // ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2); int[] choice; while ((choice = gen.next()) != null) { int n1 = _cluster2.get(choice[0]); int n2 = _cluster2.get(choice[1]); Set<Integer> triple = triple(n1, n2, o); if (!ESeeds.contains(triple)) { rejected++; } else { accepted++; } } // if (rejected > accepted) { continue; } // System.out.println("Adding " + o + " to " + cluster); _cluster.add(o); } for (Set<Integer> c : new HashSet<Set<Integer>>(ESeeds)) { if (_cluster.containsAll(c)) { ESeeds.remove(c); } } if (verbose) { System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster); } grown.add(_cluster); } } // Strict grow phase. if (algType == AlgType.strict) { Set<Integer> t = new HashSet<Integer>(); int count = 0; int total = ESeeds.size(); do { if (!ESeeds.iterator().hasNext()) { break; } Set<Integer> cluster = ESeeds.iterator().next(); Set<Integer> _cluster = new HashSet<Integer>(cluster); if (extraShuffle) { Collections.shuffle(_variables); } VARIABLES: for (int o : _variables) { if (_cluster.contains(o)) continue; List<Integer> _cluster2 = new ArrayList<Integer>(_cluster); ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2); int[] choice; while ((choice = gen.next()) != null) { int n1 = _cluster2.get(choice[0]); int n2 = _cluster2.get(choice[1]); t.clear(); t.add(n1); t.add(n2); t.add(o); if (!ESeeds.contains(t)) { continue VARIABLES; } // if (avgSumLnP(new ArrayList<Integer>(t)) < -10) continue // CLUSTER; } _cluster.add(o); } // This takes out all pure clusters that are subsets of _cluster. ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3); int[] choice2; List<Integer> _cluster3 = new ArrayList<Integer>(_cluster); while ((choice2 = gen2.next()) != null) { int n1 = _cluster3.get(choice2[0]); int n2 = _cluster3.get(choice2[1]); int n3 = _cluster3.get(choice2[2]); t.clear(); t.add(n1); t.add(n2); t.add(n3); ESeeds.remove(t); } if (verbose) { System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster); } grown.add(_cluster); } while (!ESeeds.isEmpty()); } // Optimized pick phase. log("Choosing among grown Effect Clusters.", true); for (Set<Integer> l : grown) { ArrayList<Integer> _l = new ArrayList<Integer>(l); Collections.sort(_l); if (verbose) { log("Grown: " + variablesForIndices(_l), false); } } Set<Set<Integer>> out = new HashSet<Set<Integer>>(); List<Set<Integer>> list = new ArrayList<Set<Integer>>(grown); // final Map<Set<Integer>, Double> pValues = new HashMap<Set<Integer>, Double>(); // // for (Set<Integer> o : grown) { // pValues.put(o, getP(new ArrayList<Integer>(o))); // } Collections.sort( list, new Comparator<Set<Integer>>() { @Override public int compare(Set<Integer> o1, Set<Integer> o2) { // if (o1.size() == o2.size()) { // double chisq1 = pValues.get(o1); // double chisq2 = pValues.get(o2); // return Double.compare(chisq2, chisq1); // } return o2.size() - o1.size(); } }); // for (Set<Integer> o : list) { // if (pValues.get(o) < alpha) continue; // System.out.println(variablesForIndices(new ArrayList<Integer>(o)) + " p = " + // pValues.get(o)); // } Set<Integer> all = new HashSet<Integer>(); CLUSTER: for (Set<Integer> cluster : list) { // if (pValues.get(cluster) < alpha) continue; for (Integer i : cluster) { if (all.contains(i)) continue CLUSTER; } out.add(cluster); // if (getPMulticluster(out) < alpha) { // out.remove(cluster); // continue; // } all.addAll(cluster); } return out; }
private Void findSeeds() { Tetrad tetrad = null; List<Node> empty = new ArrayList(); if (variables.size() < 4) { Set<Set<Integer>> ESeeds = new HashSet<Set<Integer>>(); } Map<Node, Set<Node>> adjacencies; if (depth == -2) { adjacencies = new HashMap<Node, Set<Node>>(); for (Node node : variables) { HashSet<Node> _nodes = new HashSet<Node>(variables); _nodes.remove(node); adjacencies.put(node, _nodes); } } else { // System.out.println("Running PC adjacency search..."); Graph graph = new EdgeListGraph(variables); Fas fas = new Fas(graph, indTest); fas.setVerbose(false); fas.setDepth(depth); // 1? adjacencies = fas.searchMapOnly(); // System.out.println("...done."); } List<Integer> allVariables = new ArrayList<Integer>(); for (int i = 0; i < variables.size(); i++) allVariables.add(i); log("Finding seeds.", true); ChoiceGenerator gen = new ChoiceGenerator(allVariables.size(), 3); int[] choice; CHOICE: while ((choice = gen.next()) != null) { int n1 = allVariables.get(choice[0]); int n2 = allVariables.get(choice[1]); int n3 = allVariables.get(choice[2]); Node v1 = variables.get(choice[0]); Node v2 = variables.get(choice[1]); Node v3 = variables.get(choice[2]); Set<Integer> triple = triple(n1, n2, n3); if (!clique(triple, adjacencies)) { continue; } boolean EPure = true; boolean CPure1 = true; boolean CPure2 = true; boolean CPure3 = true; for (int o : allVariables) { if (triple.contains(o)) { continue; } Node v4 = variables.get(o); tetrad = new Tetrad(v1, v2, v3, v4); if (deltaTest.getPValue(tetrad) > alpha) { EPure = false; if (indTest.isDependent(v1, v4, empty)) { CPure1 = false; } if (indTest.isDependent(v2, v4, empty)) { CPure2 = false; } } tetrad = new Tetrad(v1, v3, v2, v4); if (deltaTest.getPValue(tetrad) > alpha) { EPure = false; if (indTest.isDependent(v3, v4, empty)) { CPure3 = false; } } if (!(EPure || CPure1 || CPure2 || CPure3)) { continue CHOICE; } } HashSet<Integer> _cluster = new HashSet<Integer>(triple); if (verbose) { log("++" + variablesForIndices(new ArrayList<Integer>(triple)), false); } if (EPure) { ESeeds.add(_cluster); } if (!EPure) { if (CPure1) { Set<Integer> _cluster1 = new HashSet<Integer>(n2, n3); _cluster1.addAll(CSeeds.get(n1)); CSeeds.set(n1, _cluster1); } if (CPure2) { Set<Integer> _cluster2 = new HashSet<Integer>(n1, n3); _cluster2.addAll(CSeeds.get(n2)); CSeeds.set(n2, _cluster2); } if (CPure3) { Set<Integer> _cluster3 = new HashSet<Integer>(n1, n2); _cluster3.addAll(CSeeds.get(n3)); CSeeds.set(n3, _cluster3); } } } return null; }
private Set<List<Set<Integer>>> combineClusters( Set<Set<Integer>> ESeeds, List<Set<Integer>> CSeeds) { Set<Set<Integer>> EClusters = finishESeeds(ESeeds); Set<Integer> Cs = new HashSet(); for (int i = 0; i < variables.size(); i++) Cs.add(i); Set<Integer> Es = new HashSet(); for (Set<Integer> ECluster : EClusters) Es.addAll(ECluster); Cs.removeAll(Es); List<List<Set<Integer>>> Clusters = new ArrayList(); for (Set<Integer> ECluster : EClusters) { List<Set<Integer>> newCluster = new ArrayList<Set<Integer>>(); newCluster.add(1, ECluster); Clusters.add(newCluster); } List<Set<Integer>> EClustersArray = new ArrayList<Set<Integer>>(); for (Set<Integer> ECluster : EClusters) EClustersArray.add(ECluster); for (Integer c : Cs) { int match = -1; int overlap = 0; boolean pass = false; for (int i = 0; i < EClusters.size(); i++) { Set<Integer> ECluster = EClustersArray.get(i); Set<Integer> intersection = ECluster; intersection.retainAll(CSeeds.get(c)); int _overlap = intersection.size(); if (_overlap > overlap) { overlap = _overlap; match = i; if (overlap / ECluster.size() > CIparameter) { pass = true; } } } if (pass) { List<Set<Integer>> modCluster = new ArrayList<Set<Integer>>(); Set<Integer> newCs = Clusters.get(match).get(0); newCs.add(c); modCluster.add(newCs); modCluster.add(EClustersArray.get(match)); Clusters.set(match, modCluster); } } Set<List<Set<Integer>>> ClusterSet = new HashSet<List<Set<Integer>>>(Clusters); return ClusterSet; }
// Finds clusters of size 3. private Set<Set<Integer>> findMixedClusters( List<Integer> remaining, Set<Integer> unionPure, Map<Node, Set<Node>> adjacencies) { Set<Set<Integer>> threeClusters = new HashSet<Set<Integer>>(); if (unionPure.isEmpty()) { return new HashSet<Set<Integer>>(); } REMAINING: while (true) { if (remaining.size() < 3) break; ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 3); int[] choice; while ((choice = gen.next()) != null) { int y = remaining.get(choice[0]); int z = remaining.get(choice[1]); int w = remaining.get(choice[2]); Set<Integer> cluster = new HashSet<Integer>(); cluster.add(y); cluster.add(z); cluster.add(w); // if (!allVariablesDependent(cluster)) { // continue; // } if (!clique(cluster, adjacencies)) { continue; } // Check all x as a cross check; really only one should be necessary. boolean allX = true; for (int x : unionPure) { Set<Integer> _cluster = new HashSet<Integer>(cluster); _cluster.add(x); if (!quartetVanishes(_cluster) || !significant(new ArrayList<Integer>(_cluster))) { allX = false; break; } } if (allX) { threeClusters.add(cluster); unionPure.addAll(cluster); remaining.removeAll(cluster); System.out.println( "3-cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster))); continue REMAINING; } } break; } return threeClusters; }
// Trying to optimize the search for 4-cliques a bit. private Set<Set<Integer>> findPureClusters2( List<Integer> _variables, Map<Node, Set<Node>> adjacencies) { System.out.println("Original variables = " + variables); Set<Set<Integer>> clusters = new HashSet<Set<Integer>>(); List<Integer> allVariables = new ArrayList<Integer>(); Set<Node> foundVariables = new HashSet<Node>(); for (int i = 0; i < this.variables.size(); i++) allVariables.add(i); for (int x : _variables) { Node nodeX = variables.get(x); if (foundVariables.contains(nodeX)) continue; List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX)); adjX.removeAll(foundVariables); if (adjX.size() < 3) continue; for (Node nodeY : adjX) { if (foundVariables.contains(nodeY)) continue; List<Node> commonXY = new ArrayList<Node>(adjacencies.get(nodeY)); commonXY.retainAll(adjX); commonXY.removeAll(foundVariables); for (Node nodeZ : commonXY) { if (foundVariables.contains(nodeZ)) continue; List<Node> commonXZ = new ArrayList<Node>(commonXY); commonXZ.retainAll(adjacencies.get(nodeZ)); commonXZ.removeAll(foundVariables); for (Node nodeW : commonXZ) { if (foundVariables.contains(nodeW)) continue; if (!adjacencies.get(nodeY).contains(nodeW)) { continue; } int y = variables.indexOf(nodeY); int w = variables.indexOf(nodeW); int z = variables.indexOf(nodeZ); Set<Integer> cluster = quartet(x, y, z, w); // Note that purity needs to be assessed with respect to all of the variables in order // to // remove all latent-measure impurities between pairs of latents. if (pure(cluster, allVariables)) { O: for (int o : _variables) { if (cluster.contains(o)) continue; cluster.add(o); if (!clique(cluster, adjacencies)) { cluster.remove(o); continue O; } // if (!allVariablesDependent(cluster)) { // cluster.remove(o); // continue O; // } List<Integer> _cluster = new ArrayList<Integer>(cluster); ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4); int[] choice2; int count = 0; while ((choice2 = gen2.next()) != null) { int x2 = _cluster.get(choice2[0]); int y2 = _cluster.get(choice2[1]); int z2 = _cluster.get(choice2[2]); int w2 = _cluster.get(choice2[3]); Set<Integer> quartet = quartet(x2, y2, z2, w2); // Optimizes for large clusters. if (quartet.contains(o)) { if (++count > 2) continue O; } if (quartet.contains(o) && !pure(quartet, allVariables)) { cluster.remove(o); continue O; } } } System.out.println( "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster))); clusters.add(cluster); foundVariables.addAll(variablesForIndices(new ArrayList<Integer>(cluster))); } } } } } return clusters; }
// Finds clusters of size 4 or higher. private Set<Set<Integer>> findPureClusters( List<Integer> _variables, Map<Node, Set<Node>> adjacencies) { // System.out.println("Original variables = " + variables); Set<Set<Integer>> clusters = new HashSet<Set<Integer>>(); List<Integer> allVariables = new ArrayList<Integer>(); for (int i = 0; i < this.variables.size(); i++) allVariables.add(i); VARIABLES: while (!_variables.isEmpty()) { if (_variables.size() < 4) break; for (int x : _variables) { Node nodeX = variables.get(x); List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX)); adjX.retainAll(variablesForIndices(new ArrayList<Integer>(_variables))); for (Node node : new ArrayList<Node>(adjX)) { if (adjacencies.get(node).size() < 3) { adjX.remove(node); } } if (adjX.size() < 3) { continue; } ChoiceGenerator gen = new ChoiceGenerator(adjX.size(), 3); int[] choice; while ((choice = gen.next()) != null) { Node nodeY = adjX.get(choice[0]); Node nodeZ = adjX.get(choice[1]); Node nodeW = adjX.get(choice[2]); int y = variables.indexOf(nodeY); int w = variables.indexOf(nodeW); int z = variables.indexOf(nodeZ); Set<Integer> cluster = quartet(x, y, z, w); if (!clique(cluster, adjacencies)) { continue; } // Note that purity needs to be assessed with respect to all of the variables in order to // remove all latent-measure impurities between pairs of latents. if (pure(cluster, allVariables)) { // Collections.shuffle(_variables); O: for (int o : _variables) { if (cluster.contains(o)) continue; cluster.add(o); List<Integer> _cluster = new ArrayList<Integer>(cluster); if (!clique(cluster, adjacencies)) { cluster.remove(o); continue O; } // if (!allVariablesDependent(cluster)) { // cluster.remove(o); // continue O; // } ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4); int[] choice2; int count = 0; while ((choice2 = gen2.next()) != null) { int x2 = _cluster.get(choice2[0]); int y2 = _cluster.get(choice2[1]); int z2 = _cluster.get(choice2[2]); int w2 = _cluster.get(choice2[3]); Set<Integer> quartet = quartet(x2, y2, z2, w2); // Optimizes for large clusters. if (quartet.contains(o)) { if (++count > 50) continue O; } if (quartet.contains(o) && !pure(quartet, allVariables)) { cluster.remove(o); continue O; } } } System.out.println( "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster))); clusters.add(cluster); _variables.removeAll(cluster); continue VARIABLES; } } } break; } return clusters; }