private boolean existsUnblockedSemiDirectedPath(Node from, Node to, List<Node> cond, Graph G) { Queue<Node> Q = new LinkedList<Node>(); Set<Node> V = new HashSet<Node>(); Q.offer(from); V.add(from); while (!Q.isEmpty()) { Node t = Q.remove(); if (t == to) return true; for (Node u : G.getAdjacentNodes(t)) { Edge edge = G.getEdge(t, u); Node c = Edges.traverseSemiDirected(t, edge); if (c == null) continue; if (cond.contains(c)) continue; if (c == to) return true; if (!V.contains(c)) { V.add(c); Q.offer(c); } } } return false; }
/** * Transforms a maximally directed pattern (PDAG) represented in graph <code>g</code> into an * arbitrary DAG by modifying <code>g</code> itself. Based on the algorithm described in * Chickering (2002) "Optimal structure identification with greedy search" Journal of Machine * Learning Research. R. Silva, June 2004 */ public static void pdagToDag(Graph g) { Graph p = new EdgeListGraph(g); List<Edge> undirectedEdges = new ArrayList<Edge>(); for (Edge edge : g.getEdges()) { if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL && !undirectedEdges.contains(edge)) { undirectedEdges.add(edge); } } g.removeEdges(undirectedEdges); List<Node> pNodes = p.getNodes(); do { Node x = null; for (Node pNode : pNodes) { x = pNode; if (p.getChildren(x).size() > 0) { continue; } Set<Node> neighbors = new HashSet<Node>(); for (Edge edge : p.getEdges()) { if (edge.getNode1() == x || edge.getNode2() == x) { if (edge.getEndpoint1() == Endpoint.TAIL && edge.getEndpoint2() == Endpoint.TAIL) { if (edge.getNode1() == x) { neighbors.add(edge.getNode2()); } else { neighbors.add(edge.getNode1()); } } } } if (neighbors.size() > 0) { Collection<Node> parents = p.getParents(x); Set<Node> all = new HashSet<Node>(neighbors); all.addAll(parents); if (!GraphUtils.isClique(all, p)) { continue; } } for (Node neighbor : neighbors) { Node node1 = g.getNode(neighbor.getName()); Node node2 = g.getNode(x.getName()); g.addDirectedEdge(node1, node2); } p.removeNode(x); break; } pNodes.remove(x); } while (pNodes.size() > 0); }
private Set<Integer> triple(int n1, int n2, int n3) { Set<Integer> triple = new HashSet<Integer>(); triple.add(n1); triple.add(n2); triple.add(n3); if (triple.size() < 3) throw new IllegalArgumentException( "Triple elements must be unique: <" + n1 + ", " + n2 + ", " + n3 + ">"); return triple; }
private Set<Integer> quartet(int x, int y, int z, int w) { Set<Integer> set = new HashSet<Integer>(); set.add(x); set.add(y); set.add(z); set.add(w); if (set.size() < 4) throw new IllegalArgumentException( "Quartet elements must be unique: <" + x + ", " + y + ", " + z + ", " + w + ">"); return set; }
/** Get a graph and direct only the unshielded colliders. */ public static void basicPattern(Graph graph) { Set<Edge> undirectedEdges = new HashSet<Edge>(); NEXT_EDGE: for (Edge edge : graph.getEdges()) { Node head = null, tail = null; if (edge.getEndpoint1() == Endpoint.ARROW && edge.getEndpoint2() == Endpoint.TAIL) { head = edge.getNode1(); tail = edge.getNode2(); } else if (edge.getEndpoint2() == Endpoint.ARROW && edge.getEndpoint1() == Endpoint.TAIL) { head = edge.getNode2(); tail = edge.getNode1(); } if (head != null) { for (Node node : graph.getParents(head)) { if (node != tail && !graph.isAdjacentTo(tail, node)) { continue NEXT_EDGE; } } undirectedEdges.add(edge); } } for (Edge nextUndirected : undirectedEdges) { Node node1 = nextUndirected.getNode1(), node2 = nextUndirected.getNode2(); graph.removeEdge(nextUndirected); graph.addUndirectedEdge(node1, node2); } }
/** Adds the given variable name to knowledge. Duplicates are ignored. */ public void addVariable(String varName) { if (!namesToVars.containsKey(varName) && checkVarName(varName)) { MyNode e = new MyNode(varName); myNodes.add(e); namesToVars.put(varName, e); } }
public Set<Edge> getAdjacencies() { Set<Edge> adjacencies = new HashSet<Edge>(); for (Edge edge : graph.getEdges()) { adjacencies.add(edge); } return adjacencies; }
private Set<List<Set<Integer>>> combineClusters( Set<Set<Integer>> ESeeds, List<Set<Integer>> CSeeds) { Set<Set<Integer>> EClusters = finishESeeds(ESeeds); Set<Integer> Cs = new HashSet(); for (int i = 0; i < variables.size(); i++) Cs.add(i); Set<Integer> Es = new HashSet(); for (Set<Integer> ECluster : EClusters) Es.addAll(ECluster); Cs.removeAll(Es); List<List<Set<Integer>>> Clusters = new ArrayList(); for (Set<Integer> ECluster : EClusters) { List<Set<Integer>> newCluster = new ArrayList<Set<Integer>>(); newCluster.add(1, ECluster); Clusters.add(newCluster); } List<Set<Integer>> EClustersArray = new ArrayList<Set<Integer>>(); for (Set<Integer> ECluster : EClusters) EClustersArray.add(ECluster); for (Integer c : Cs) { int match = -1; int overlap = 0; boolean pass = false; for (int i = 0; i < EClusters.size(); i++) { Set<Integer> ECluster = EClustersArray.get(i); Set<Integer> intersection = ECluster; intersection.retainAll(CSeeds.get(c)); int _overlap = intersection.size(); if (_overlap > overlap) { overlap = _overlap; match = i; if (overlap / ECluster.size() > CIparameter) { pass = true; } } } if (pass) { List<Set<Integer>> modCluster = new ArrayList<Set<Integer>>(); Set<Integer> newCs = Clusters.get(match).get(0); newCs.add(c); modCluster.add(newCs); modCluster.add(EClustersArray.get(match)); Clusters.set(match, modCluster); } } Set<List<Set<Integer>>> ClusterSet = new HashSet<List<Set<Integer>>>(Clusters); return ClusterSet; }
/** Evaluate the Insert(X, Y, T) operator (Definition 12 from Chickering, 2002). */ private double insertEval(Node x, Node y, List<Node> t, List<Node> naYX, Graph graph) { Set<Node> set1 = new HashSet<Node>(naYX); set1.addAll(t); List<Node> paY = graph.getParents(y); set1.addAll(paY); Set<Node> set2 = new HashSet<Node>(set1); set1.add(x); return scoreGraphChange(y, set1, set2); }
/** * @param node the node doing the referencing. * @return the freeParameters referenced by the given variable (variable node or error node). */ public Set<String> getReferencedParameters(Node node) { Set<String> parameters = new HashSet<>(); for (String parameter : this.referencedParameters.keySet()) { if (this.referencedParameters.get(parameter).contains(node)) { parameters.add(parameter); } } return parameters; }
/** * @param node the node doing the referencing. * @return the variables referenced by the expression for the given node (variable node or error * node. */ public Set<Node> getReferencedNodes(Node node) { Set<Node> nodes = new HashSet<>(); for (Node _node : this.referencedNodes.keySet()) { if (this.referencedNodes.get(_node).contains(node)) { nodes.add(_node); } } return nodes; }
private void addLookupArrow(Node i, Node j, Arrow arrow) { OrderedPair<Node> pair = new OrderedPair<Node>(i, j); Set<Arrow> arrows = lookupArrows.get(pair); if (arrows == null) { arrows = new HashSet<Arrow>(); lookupArrows.put(pair, arrows); } arrows.add(arrow); }
private Set<String> split(String spec) { String[] tokens = spec.split(","); Set<String> _tokens = new HashSet<>(); for (String _token : tokens) { if (!_token.trim().equals("")) { _tokens.add(_token); } } return _tokens; }
public static List<Set<Node>> powerSet(List<Node> nodes) { List<Set<Node>> subsets = new ArrayList<Set<Node>>(); int total = (int) Math.pow(2, nodes.size()); for (int i = 0; i < total; i++) { Set<Node> newSet = new HashSet<Node>(); String selection = Integer.toBinaryString(i); for (int j = selection.length() - 1; j >= 0; j--) { if (selection.charAt(j) == '1') { newSet.add(nodes.get(selection.length() - j - 1)); } } subsets.add(newSet); } return subsets; }
private Graph convertToGraph(Set<Set<Integer>> allClusters) { Set<Set<Node>> _clustering = new HashSet<Set<Node>>(); for (Set<Integer> cluster : allClusters) { Set<Node> nodes = new HashSet<Node>(); for (int i : cluster) { nodes.add(variables.get(i)); } _clustering.add(nodes); } return convertSearchGraphNodes(_clustering); }
/** Iterator over the KnowledgeEdge's representing required edges. */ public final Iterator<KnowledgeEdge> requiredEdgesIterator() { Set<KnowledgeEdge> edges = new HashSet<>(); for (OrderedPair<Set<MyNode>> o : requiredRulesSpecs) { final Set<MyNode> first = o.getFirst(); for (MyNode s1 : first) { final Set<MyNode> second = o.getSecond(); for (MyNode s2 : second) { if (!s1.equals(s2)) { edges.add(new KnowledgeEdge(s1.getName(), s2.getName())); } } } } return edges.iterator(); }
private Set<MyNode> getExtent(String spec) { Set<String> split = split(spec); Set<MyNode> matches = new HashSet<>(); for (String _spec : split) { _spec = _spec.replace("*", ".*"); java.util.regex.Pattern pattern = java.util.regex.Pattern.compile(_spec); for (MyNode var : myNodes) { Matcher matcher = pattern.matcher(var.getName()); if (matcher.matches()) { matches.add(var); } } } return matches; }
private boolean pure(Set<Integer> quartet, List<Integer> variables) { if (quartetVanishes(quartet)) { for (int o : variables) { if (quartet.contains(o)) continue; for (int p : quartet) { Set<Integer> _quartet = new HashSet<Integer>(quartet); _quartet.remove(p); _quartet.add(o); if (!quartetVanishes(_quartet)) { return false; } } } return significant(new ArrayList<Integer>(quartet)); } return false; }
/** Iterator over the knowledge's explicitly forbidden edges. */ public final Iterator<KnowledgeEdge> explicitlyForbiddenEdgesIterator() { Set<OrderedPair<Set<MyNode>>> copy = new HashSet<>(forbiddenRulesSpecs); copy.removeAll(forbiddenTierRules()); for (KnowledgeGroup group : knowledgeGroups) { copy.remove(knowledgeGroupRules.get(group)); } Set<KnowledgeEdge> edges = new HashSet<>(); for (OrderedPair<Set<MyNode>> o : copy) { final Set<MyNode> first = o.getFirst(); for (MyNode s1 : first) { final Set<MyNode> second = o.getSecond(); for (MyNode s2 : second) { edges.add(new KnowledgeEdge(s1.getName(), s2.getName())); } } } return edges.iterator(); }
/** * Constructs a new FCI search for the given independence test and background knowledge and a list * of variables to search over. */ public Rfci(IndependenceTest independenceTest, List<Node> searchVars) { if (independenceTest == null || knowledge == null) { throw new NullPointerException(); } this.independenceTest = independenceTest; this.variables.addAll(independenceTest.getVariables()); Set<Node> remVars = new HashSet<Node>(); for (Node node1 : this.variables) { boolean search = false; for (Node node2 : searchVars) { if (node1.getName().equals(node2.getName())) { search = true; } } if (!search) { remVars.add(node1); } } this.variables.removeAll(remVars); }
// Finds clusters of size 3. private Set<Set<Integer>> findMixedClusters( List<Integer> remaining, Set<Integer> unionPure, Map<Node, Set<Node>> adjacencies) { Set<Set<Integer>> threeClusters = new HashSet<Set<Integer>>(); if (unionPure.isEmpty()) { return new HashSet<Set<Integer>>(); } REMAINING: while (true) { if (remaining.size() < 3) break; ChoiceGenerator gen = new ChoiceGenerator(remaining.size(), 3); int[] choice; while ((choice = gen.next()) != null) { int y = remaining.get(choice[0]); int z = remaining.get(choice[1]); int w = remaining.get(choice[2]); Set<Integer> cluster = new HashSet<Integer>(); cluster.add(y); cluster.add(z); cluster.add(w); // if (!allVariablesDependent(cluster)) { // continue; // } if (!clique(cluster, adjacencies)) { continue; } // Check all x as a cross check; really only one should be necessary. boolean allX = true; for (int x : unionPure) { Set<Integer> _cluster = new HashSet<Integer>(cluster); _cluster.add(x); if (!quartetVanishes(_cluster) || !significant(new ArrayList<Integer>(_cluster))) { allX = false; break; } } if (allX) { threeClusters.add(cluster); unionPure.addAll(cluster); remaining.removeAll(cluster); System.out.println( "3-cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster))); continue REMAINING; } } break; } return threeClusters; }
// Trying to optimize the search for 4-cliques a bit. private Set<Set<Integer>> findPureClusters2( List<Integer> _variables, Map<Node, Set<Node>> adjacencies) { System.out.println("Original variables = " + variables); Set<Set<Integer>> clusters = new HashSet<Set<Integer>>(); List<Integer> allVariables = new ArrayList<Integer>(); Set<Node> foundVariables = new HashSet<Node>(); for (int i = 0; i < this.variables.size(); i++) allVariables.add(i); for (int x : _variables) { Node nodeX = variables.get(x); if (foundVariables.contains(nodeX)) continue; List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX)); adjX.removeAll(foundVariables); if (adjX.size() < 3) continue; for (Node nodeY : adjX) { if (foundVariables.contains(nodeY)) continue; List<Node> commonXY = new ArrayList<Node>(adjacencies.get(nodeY)); commonXY.retainAll(adjX); commonXY.removeAll(foundVariables); for (Node nodeZ : commonXY) { if (foundVariables.contains(nodeZ)) continue; List<Node> commonXZ = new ArrayList<Node>(commonXY); commonXZ.retainAll(adjacencies.get(nodeZ)); commonXZ.removeAll(foundVariables); for (Node nodeW : commonXZ) { if (foundVariables.contains(nodeW)) continue; if (!adjacencies.get(nodeY).contains(nodeW)) { continue; } int y = variables.indexOf(nodeY); int w = variables.indexOf(nodeW); int z = variables.indexOf(nodeZ); Set<Integer> cluster = quartet(x, y, z, w); // Note that purity needs to be assessed with respect to all of the variables in order // to // remove all latent-measure impurities between pairs of latents. if (pure(cluster, allVariables)) { O: for (int o : _variables) { if (cluster.contains(o)) continue; cluster.add(o); if (!clique(cluster, adjacencies)) { cluster.remove(o); continue O; } // if (!allVariablesDependent(cluster)) { // cluster.remove(o); // continue O; // } List<Integer> _cluster = new ArrayList<Integer>(cluster); ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4); int[] choice2; int count = 0; while ((choice2 = gen2.next()) != null) { int x2 = _cluster.get(choice2[0]); int y2 = _cluster.get(choice2[1]); int z2 = _cluster.get(choice2[2]); int w2 = _cluster.get(choice2[3]); Set<Integer> quartet = quartet(x2, y2, z2, w2); // Optimizes for large clusters. if (quartet.contains(o)) { if (++count > 2) continue O; } if (quartet.contains(o) && !pure(quartet, allVariables)) { cluster.remove(o); continue O; } } } System.out.println( "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster))); clusters.add(cluster); foundVariables.addAll(variablesForIndices(new ArrayList<Integer>(cluster))); } } } } } return clusters; }
private Set<Set<Integer>> finishESeeds(Set<Set<Integer>> ESeeds) { log("Growing Effect Seeds.", true); Set<Set<Integer>> grown = new HashSet<Set<Integer>>(); List<Integer> _variables = new ArrayList<Integer>(); for (int i = 0; i < variables.size(); i++) _variables.add(i); // Lax grow phase with speedup. if (algType == AlgType.lax) { Set<Integer> t = new HashSet<Integer>(); int count = 0; int total = ESeeds.size(); do { if (!ESeeds.iterator().hasNext()) { break; } Set<Integer> cluster = ESeeds.iterator().next(); Set<Integer> _cluster = new HashSet<Integer>(cluster); if (extraShuffle) { Collections.shuffle(_variables); } for (int o : _variables) { if (_cluster.contains(o)) continue; List<Integer> _cluster2 = new ArrayList<Integer>(_cluster); int rejected = 0; int accepted = 0; ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2); int[] choice; while ((choice = gen.next()) != null) { int n1 = _cluster2.get(choice[0]); int n2 = _cluster2.get(choice[1]); t.clear(); t.add(n1); t.add(n2); t.add(o); if (!ESeeds.contains(t)) { rejected++; } else { accepted++; } } if (rejected > accepted) { continue; } _cluster.add(o); // if (!(avgSumLnP(new ArrayList<Integer>(_cluster)) > -10)) { // _cluster.remove(o); // } } // This takes out all pure clusters that are subsets of _cluster. ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3); int[] choice2; List<Integer> _cluster3 = new ArrayList<Integer>(_cluster); while ((choice2 = gen2.next()) != null) { int n1 = _cluster3.get(choice2[0]); int n2 = _cluster3.get(choice2[1]); int n3 = _cluster3.get(choice2[2]); t.clear(); t.add(n1); t.add(n2); t.add(n3); ESeeds.remove(t); } if (verbose) { System.out.println( "Grown " + (++count) + " of " + total + ": " + variablesForIndices(new ArrayList<Integer>(_cluster))); } grown.add(_cluster); } while (!ESeeds.isEmpty()); } // Lax grow phase without speedup. if (algType == AlgType.laxWithSpeedup) { int count = 0; int total = ESeeds.size(); // Optimized lax version of grow phase. for (Set<Integer> cluster : new HashSet<Set<Integer>>(ESeeds)) { Set<Integer> _cluster = new HashSet<Integer>(cluster); if (extraShuffle) { Collections.shuffle(_variables); } for (int o : _variables) { if (_cluster.contains(o)) continue; List<Integer> _cluster2 = new ArrayList<Integer>(_cluster); int rejected = 0; int accepted = 0; // ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2); int[] choice; while ((choice = gen.next()) != null) { int n1 = _cluster2.get(choice[0]); int n2 = _cluster2.get(choice[1]); Set<Integer> triple = triple(n1, n2, o); if (!ESeeds.contains(triple)) { rejected++; } else { accepted++; } } // if (rejected > accepted) { continue; } // System.out.println("Adding " + o + " to " + cluster); _cluster.add(o); } for (Set<Integer> c : new HashSet<Set<Integer>>(ESeeds)) { if (_cluster.containsAll(c)) { ESeeds.remove(c); } } if (verbose) { System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster); } grown.add(_cluster); } } // Strict grow phase. if (algType == AlgType.strict) { Set<Integer> t = new HashSet<Integer>(); int count = 0; int total = ESeeds.size(); do { if (!ESeeds.iterator().hasNext()) { break; } Set<Integer> cluster = ESeeds.iterator().next(); Set<Integer> _cluster = new HashSet<Integer>(cluster); if (extraShuffle) { Collections.shuffle(_variables); } VARIABLES: for (int o : _variables) { if (_cluster.contains(o)) continue; List<Integer> _cluster2 = new ArrayList<Integer>(_cluster); ChoiceGenerator gen = new ChoiceGenerator(_cluster2.size(), 2); int[] choice; while ((choice = gen.next()) != null) { int n1 = _cluster2.get(choice[0]); int n2 = _cluster2.get(choice[1]); t.clear(); t.add(n1); t.add(n2); t.add(o); if (!ESeeds.contains(t)) { continue VARIABLES; } // if (avgSumLnP(new ArrayList<Integer>(t)) < -10) continue // CLUSTER; } _cluster.add(o); } // This takes out all pure clusters that are subsets of _cluster. ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 3); int[] choice2; List<Integer> _cluster3 = new ArrayList<Integer>(_cluster); while ((choice2 = gen2.next()) != null) { int n1 = _cluster3.get(choice2[0]); int n2 = _cluster3.get(choice2[1]); int n3 = _cluster3.get(choice2[2]); t.clear(); t.add(n1); t.add(n2); t.add(n3); ESeeds.remove(t); } if (verbose) { System.out.println("Grown " + (++count) + " of " + total + ": " + _cluster); } grown.add(_cluster); } while (!ESeeds.isEmpty()); } // Optimized pick phase. log("Choosing among grown Effect Clusters.", true); for (Set<Integer> l : grown) { ArrayList<Integer> _l = new ArrayList<Integer>(l); Collections.sort(_l); if (verbose) { log("Grown: " + variablesForIndices(_l), false); } } Set<Set<Integer>> out = new HashSet<Set<Integer>>(); List<Set<Integer>> list = new ArrayList<Set<Integer>>(grown); // final Map<Set<Integer>, Double> pValues = new HashMap<Set<Integer>, Double>(); // // for (Set<Integer> o : grown) { // pValues.put(o, getP(new ArrayList<Integer>(o))); // } Collections.sort( list, new Comparator<Set<Integer>>() { @Override public int compare(Set<Integer> o1, Set<Integer> o2) { // if (o1.size() == o2.size()) { // double chisq1 = pValues.get(o1); // double chisq2 = pValues.get(o2); // return Double.compare(chisq2, chisq1); // } return o2.size() - o1.size(); } }); // for (Set<Integer> o : list) { // if (pValues.get(o) < alpha) continue; // System.out.println(variablesForIndices(new ArrayList<Integer>(o)) + " p = " + // pValues.get(o)); // } Set<Integer> all = new HashSet<Integer>(); CLUSTER: for (Set<Integer> cluster : list) { // if (pValues.get(cluster) < alpha) continue; for (Integer i : cluster) { if (all.contains(i)) continue CLUSTER; } out.add(cluster); // if (getPMulticluster(out) < alpha) { // out.remove(cluster); // continue; // } all.addAll(cluster); } return out; }
public void setNodeExpression(Node node, String expressionString) throws ParseException { if (node == null) { throw new NullPointerException("Node was null."); } if (expressionString == null) { // return; throw new NullPointerException("Expression string was null."); } // Parse the expression. This could throw an ParseException, but that exception needs to handed // up the // chain, because the interface will need it. ExpressionParser parser = new ExpressionParser(); Expression expression = parser.parseExpression(expressionString); List<String> parameterNames = parser.getParameters(); // Make a list of parent names. List<Node> parents = this.graph.getParents(node); List<String> parentNames = new LinkedList<>(); for (Node parent : parents) { parentNames.add(parent.getName()); } // List<String> _params = new ArrayList<String>(parameterNames); // _params.retainAll(variableNames); // _params.removeAll(parentNames); // // if (!_params.isEmpty()) { // throw new IllegalArgumentException("Conditioning on a variable other than the // parents: " + node); // } // Make a list of parameter names, by removing from the parser's list of freeParameters any that // correspond // to parent variables. If there are any variable names (including error terms) that are not // among the list of // parents, that's a time to throw an exception. We must respect the graph! (We will not // complain if any parents // are missing.) parameterNames.removeAll(variableNames); for (Node variable : nodes) { if (parameterNames.contains(variable.getName())) { parameterNames.remove(variable.getName()); // throw new IllegalArgumentException("The list of parameter names may not // include variables: " + variable.getName()); } } // Remove old parameter references. List<String> parametersToRemove = new LinkedList<>(); for (String parameter : this.referencedParameters.keySet()) { Set<Node> nodes = this.referencedParameters.get(parameter); if (nodes.contains(node)) { nodes.remove(node); } if (nodes.isEmpty()) { parametersToRemove.add(parameter); } } for (String parameter : parametersToRemove) { this.referencedParameters.remove(parameter); this.parameterExpressions.remove(parameter); this.parameterExpressionStrings.remove(parameter); this.parameterEstimationInitializationExpressions.remove(parameter); this.parameterEstimationInitializationExpressionStrings.remove(parameter); } // Add new parameter references. for (String parameter : parameterNames) { if (this.referencedParameters.get(parameter) == null) { this.referencedParameters.put(parameter, new HashSet<Node>()); } Set<Node> nodes = this.referencedParameters.get(parameter); nodes.add(node); setSuitableParameterDistribution(parameter); } // Remove old node references. List<Node> nodesToRemove = new LinkedList<>(); for (Node _node : this.referencedNodes.keySet()) { Set<Node> nodes = this.referencedNodes.get(_node); if (nodes.contains(node)) { nodes.remove(node); } if (nodes.isEmpty()) { nodesToRemove.add(_node); } } for (Node _node : nodesToRemove) { this.referencedNodes.remove(_node); } // Add new freeParameters. for (String variableString : variableNames) { Node _node = getNode(variableString); if (this.referencedNodes.get(_node) == null) { this.referencedNodes.put(_node, new HashSet<Node>()); } for (String s : parentNames) { if (s.equals(variableString)) { Set<Node> nodes = this.referencedNodes.get(_node); nodes.add(node); } } } // Finally, save the parsed expression and the original string that the user entered. No need to // annoy // the user by changing spacing. nodeExpressions.put(node, expression); nodeExpressionStrings.put(node, expressionString); }
public GeneralizedSemPm(SemPm semPm) { this(semPm.getGraph()); // Write down equations. try { List<Node> variableNodes = getVariableNodes(); for (int i = 0; i < variableNodes.size(); i++) { Node node = variableNodes.get(i); List<Node> parents = getVariableParents(node); StringBuilder buf = new StringBuilder(); for (int j = 0; j < parents.size(); j++) { if (!(variableNodes.contains(parents.get(j)))) { continue; } Node parent = parents.get(j); Parameter _parameter = semPm.getParameter(parent, node); String parameter = _parameter.getName(); Set<Node> nodes = new HashSet<>(); nodes.add(node); referencedParameters.put(parameter, nodes); buf.append(parameter); buf.append("*"); buf.append(parents.get(j).getName()); setParameterExpression(parameter, "Split(-1.5, -.5, .5, 1.5)"); setStartsWithParametersTemplate(parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)"); setStartsWithParametersEstimationInitializaationTemplate( parameter.substring(0, 1), "Split(-1.5, -.5, .5, 1.5)"); if (j < parents.size() - 1) { buf.append(" + "); } } if (buf.toString().trim().length() != 0) { buf.append(" + "); } buf.append(errorNodes.get(i)); setNodeExpression(node, buf.toString()); } for (Node node : variableNodes) { Parameter _parameter = semPm.getParameter(node, node); String parameter = _parameter.getName(); String distributionFormula = "N(0," + parameter + ")"; setNodeExpression(getErrorNode(node), distributionFormula); setParameterExpression(parameter, "U(0, 1)"); setStartsWithParametersTemplate(parameter.substring(0, 1), "U(0, 1)"); setStartsWithParametersEstimationInitializaationTemplate( parameter.substring(0, 1), "U(0, 1)"); } variableNames = new ArrayList<>(); for (Node _node : variableNodes) variableNames.add(_node.getName()); for (Node _node : errorNodes) variableNames.add(_node.getName()); } catch (ParseException e) { throw new IllegalStateException("Parse error in constructing initial model.", e); } }
@Test public void test4() { // For X3 Map<String, String[]> templates = new HashMap<>(); templates.put( "NEW(b) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("NEW(X1) + NEW(b) + NEW(c) + NEW(c) + NEW(c)", new String[] {}); templates.put("$", new String[] {}); templates.put("TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + X2", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TPROD($) + TSUM($)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("tan(TSUM(NEW(a)*$))", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(0, 1)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(m, s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), s)", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("Normal(NEW(m), NEW(s)) + m1 + s6", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + a", new String[] {"X1", "X2", "X3", "X4", "X5"}); templates.put("TSUM($) + TSUM($) + TSUM($) + 1", new String[] {"X1", "X2", "X3", "X4", "X5"}); for (String template : templates.keySet()) { GeneralizedSemPm semPm = makeTypicalPm(); print(semPm.getGraph().toString()); Set<Node> shouldWork = new HashSet<>(); for (String name : templates.get(template)) { shouldWork.add(semPm.getNode(name)); } Set<Node> works = new HashSet<>(); for (int i = 0; i < semPm.getNodes().size(); i++) { print("-----------"); print(semPm.getNodes().get(i).toString()); print("Trying template: " + template); String _template = template; Node node = semPm.getNodes().get(i); try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, node); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setNodeExpression(node, _template); print("Set formula " + _template + " for " + node); if (semPm.getVariableNodes().contains(node)) { works.add(node); } } catch (Exception e) { print("Couldn't set formula " + _template + " for " + node); } } for (String parameter : semPm.getParameters()) { print("-----------"); print(parameter); print("Trying template: " + template); String _template = template; try { _template = TemplateExpander.getInstance().expandTemplate(_template, semPm, null); } catch (Exception e) { print("Couldn't expand template: " + template); continue; } try { semPm.setParameterExpression(parameter, _template); print("Set formula " + _template + " for " + parameter); } catch (Exception e) { print("Couldn't set formula " + _template + " for " + parameter); } } assertEquals(shouldWork, works); } }
/** * Transforms a DAG represented in graph <code>graph</code> into a maximally directed pattern * (PDAG) by modifying <code>g</code> itself. Based on the algorithm described in Chickering * (2002) "Optimal structure identification with greedy search" Journal of Machine Learning * Research. It works for both BayesNets and SEMs. R. Silva, June 2004 */ public static void dagToPdag(Graph graph) { // do topological sort on the nodes Graph graphCopy = new EdgeListGraph(graph); Node orderedNodes[] = new Node[graphCopy.getNodes().size()]; int count = 0; while (graphCopy.getNodes().size() > 0) { Set<Node> exogenousNodes = new HashSet<Node>(); for (Node next : graphCopy.getNodes()) { if (graphCopy.isExogenous(next)) { exogenousNodes.add(next); orderedNodes[count++] = graph.getNode(next.getName()); } } graphCopy.removeNodes(new ArrayList<Node>(exogenousNodes)); } // ordered edges - improvised, inefficient implementation count = 0; Edge edges[] = new Edge[graph.getNumEdges()]; boolean edgeOrdered[] = new boolean[graph.getNumEdges()]; Edge orderedEdges[] = new Edge[graph.getNumEdges()]; for (Edge edge : graph.getEdges()) { edges[count++] = edge; } for (int i = 0; i < edges.length; i++) { edgeOrdered[i] = false; } while (count > 0) { for (Node orderedNode : orderedNodes) { for (int k = orderedNodes.length - 1; k >= 0; k--) { for (int q = 0; q < edges.length; q++) { if (!edgeOrdered[q] && edges[q].getNode1() == orderedNodes[k] && edges[q].getNode2() == orderedNode) { edgeOrdered[q] = true; orderedEdges[orderedEdges.length - count] = edges[q]; count--; } } } } } // label edges boolean compelledEdges[] = new boolean[graph.getNumEdges()]; boolean reversibleEdges[] = new boolean[graph.getNumEdges()]; for (int i = 0; i < graph.getNumEdges(); i++) { compelledEdges[i] = false; reversibleEdges[i] = false; } for (int i = 0; i < graph.getNumEdges(); i++) { if (compelledEdges[i] || reversibleEdges[i]) { continue; } Node x = orderedEdges[i].getNode1(); Node y = orderedEdges[i].getNode2(); for (int j = 0; j < orderedEdges.length; j++) { if (orderedEdges[j].getNode2() == x && compelledEdges[j]) { Node w = orderedEdges[j].getNode1(); if (!graph.isParentOf(w, y)) { for (int k = 0; k < orderedEdges.length; k++) { if (orderedEdges[k].getNode2() == y) { compelledEdges[k] = true; break; } } } else { for (int k = 0; k < orderedEdges.length; k++) { if (orderedEdges[k].getNode1() == w && orderedEdges[k].getNode2() == y) { compelledEdges[k] = true; break; } } } } if (compelledEdges[i]) { break; } } if (compelledEdges[i]) { continue; } boolean foundZ = false; for (Edge orderedEdge : orderedEdges) { Node z = orderedEdge.getNode1(); if (z != x && orderedEdge.getNode2() == y && !graph.isParentOf(z, x)) { compelledEdges[i] = true; for (int k = i + 1; k < graph.getNumEdges(); k++) { if (orderedEdges[k].getNode2() == y && !reversibleEdges[k]) { compelledEdges[k] = true; } } foundZ = true; break; } } if (!foundZ) { reversibleEdges[i] = true; for (int j = i + 1; j < orderedEdges.length; j++) { if (!compelledEdges[j] && orderedEdges[j].getNode2() == y) { reversibleEdges[j] = true; } } } } // undirect edges that are reversible for (int i = 0; i < reversibleEdges.length; i++) { if (reversibleEdges[i]) { graph.setEndpoint(orderedEdges[i].getNode1(), orderedEdges[i].getNode2(), Endpoint.TAIL); graph.setEndpoint(orderedEdges[i].getNode2(), orderedEdges[i].getNode1(), Endpoint.TAIL); } } }
// Finds clusters of size 4 or higher. private Set<Set<Integer>> findPureClusters( List<Integer> _variables, Map<Node, Set<Node>> adjacencies) { // System.out.println("Original variables = " + variables); Set<Set<Integer>> clusters = new HashSet<Set<Integer>>(); List<Integer> allVariables = new ArrayList<Integer>(); for (int i = 0; i < this.variables.size(); i++) allVariables.add(i); VARIABLES: while (!_variables.isEmpty()) { if (_variables.size() < 4) break; for (int x : _variables) { Node nodeX = variables.get(x); List<Node> adjX = new ArrayList<Node>(adjacencies.get(nodeX)); adjX.retainAll(variablesForIndices(new ArrayList<Integer>(_variables))); for (Node node : new ArrayList<Node>(adjX)) { if (adjacencies.get(node).size() < 3) { adjX.remove(node); } } if (adjX.size() < 3) { continue; } ChoiceGenerator gen = new ChoiceGenerator(adjX.size(), 3); int[] choice; while ((choice = gen.next()) != null) { Node nodeY = adjX.get(choice[0]); Node nodeZ = adjX.get(choice[1]); Node nodeW = adjX.get(choice[2]); int y = variables.indexOf(nodeY); int w = variables.indexOf(nodeW); int z = variables.indexOf(nodeZ); Set<Integer> cluster = quartet(x, y, z, w); if (!clique(cluster, adjacencies)) { continue; } // Note that purity needs to be assessed with respect to all of the variables in order to // remove all latent-measure impurities between pairs of latents. if (pure(cluster, allVariables)) { // Collections.shuffle(_variables); O: for (int o : _variables) { if (cluster.contains(o)) continue; cluster.add(o); List<Integer> _cluster = new ArrayList<Integer>(cluster); if (!clique(cluster, adjacencies)) { cluster.remove(o); continue O; } // if (!allVariablesDependent(cluster)) { // cluster.remove(o); // continue O; // } ChoiceGenerator gen2 = new ChoiceGenerator(_cluster.size(), 4); int[] choice2; int count = 0; while ((choice2 = gen2.next()) != null) { int x2 = _cluster.get(choice2[0]); int y2 = _cluster.get(choice2[1]); int z2 = _cluster.get(choice2[2]); int w2 = _cluster.get(choice2[3]); Set<Integer> quartet = quartet(x2, y2, z2, w2); // Optimizes for large clusters. if (quartet.contains(o)) { if (++count > 50) continue O; } if (quartet.contains(o) && !pure(quartet, allVariables)) { cluster.remove(o); continue O; } } } System.out.println( "Cluster found: " + variablesForIndices(new ArrayList<Integer>(cluster))); clusters.add(cluster); _variables.removeAll(cluster); continue VARIABLES; } } } break; } return clusters; }