private static <L> Tree<L> deepCopy(Tree<L> tree) { List<Tree<L>> childrenCopies = new ArrayList<Tree<L>>(); for (Tree<L> child : tree.getChildren()) { childrenCopies.add(deepCopy(child)); } return new Tree<L>(tree.getLabel(), childrenCopies); }
private Tree<String> merge(Tree<String> leftTree, Tree<String> rightTree) { int span = leftTree.getYield().size() + rightTree.getYield().size(); String mostFrequentLabel = spanToCategories.getCounter(span).argMax(); List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(leftTree); children.add(rightTree); return new Tree<String>(mostFrequentLabel, children); }
private static <L> void appendPreTerminalYield(Tree<L> tree, List<L> yield) { if (tree.isPreTerminal()) { yield.add(tree.getLabel()); return; } for (Tree<L> child : tree.getChildren()) { appendPreTerminalYield(child, yield); } }
private int setWordsHelper(List<L> words, int wordNum) { if (isLeaf()) { label = words.get(wordNum); return wordNum + 1; } else { for (Tree<L> child : getChildren()) wordNum = child.setWordsHelper(words, wordNum); return wordNum; } }
private static <L> int toConstituentCollectionHelper( Tree<L> tree, int start, List<Constituent<L>> constituents) { if (tree.isLeaf() || tree.isPreTerminal()) return 1; int span = 0; for (Tree<L> child : tree.getChildren()) { span += toConstituentCollectionHelper(child, start + span, constituents); } constituents.add(new Constituent<L>(tree.getLabel(), start, start + span)); return span; }
public void train(List<Tree<String>> trainTrees) { lexicon = new Lexicon(trainTrees); knownParses = new CounterMap<List<String>, Tree<String>>(); spanToCategories = new CounterMap<Integer, String>(); for (Tree<String> trainTree : trainTrees) { List<String> tags = trainTree.getPreTerminalYield(); knownParses.incrementCount(tags, trainTree, 1.0); tallySpans(trainTree, 0); } }
private void tallyTree( Tree<String> tree, Counter<String> symbolCounter, Counter<UnaryRule> unaryRuleCounter, Counter<BinaryRule> binaryRuleCounter) { if (tree.isLeaf()) return; if (tree.isPreTerminal()) return; if (tree.getChildren().size() == 1) { UnaryRule unaryRule = makeUnaryRule(tree); symbolCounter.incrementCount(tree.getLabel(), 1.0); unaryRuleCounter.incrementCount(unaryRule, 1.0); } if (tree.getChildren().size() == 2) { BinaryRule binaryRule = makeBinaryRule(tree); symbolCounter.incrementCount(tree.getLabel(), 1.0); binaryRuleCounter.incrementCount(binaryRule, 1.0); } if (tree.getChildren().size() < 1 || tree.getChildren().size() > 2) { throw new RuntimeException( "Attempted to construct a Grammar with an illegal tree: " + tree); } for (Tree<String> child : tree.getChildren()) { tallyTree(child, symbolCounter, unaryRuleCounter, binaryRuleCounter); } }
private int tallySpans(Tree<String> tree, int start) { if (tree.isLeaf() || tree.isPreTerminal()) return 1; int end = start; for (Tree<String> child : tree.getChildren()) { int childSpan = tallySpans(child, end); end += childSpan; } String category = tree.getLabel(); if (!category.equals("ROOT")) spanToCategories.incrementCount(end - start, category, 1.0); return end - start; }
/* Builds a lexicon from the observed tags in a list of training trees. */ public Lexicon(List<Tree<String>> trainTrees) { for (Tree<String> trainTree : trainTrees) { List<String> words = trainTree.getYield(); List<String> tags = trainTree.getPreTerminalYield(); for (int position = 0; position < words.size(); position++) { String word = words.get(position); String tag = tags.get(position); tallyTagging(word, tag); } } }
public void toStringBuilder(StringBuilder sb) { if (!isLeaf()) sb.append('('); if (getLabel() != null) { sb.append(getLabel()); } if (!isLeaf()) { for (Tree<L> child : getChildren()) { sb.append(' '); child.toStringBuilder(sb); } sb.append(')'); } }
private static Tree<String> binarizeTreeHelper( Tree<String> tree, int numChildrenGenerated, String intermediateLabel) { Tree<String> leftTree = tree.getChildren().get(numChildrenGenerated); List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(binarizeTree(leftTree)); if (numChildrenGenerated < tree.getChildren().size() - 1) { Tree<String> rightTree = binarizeTreeHelper( tree, numChildrenGenerated + 1, intermediateLabel + "_" + leftTree.getLabel()); children.add(rightTree); } return new Tree<String>(intermediateLabel, children); }
private static <L> void traversalHelper(Tree<L> tree, List<Tree<L>> traversal, boolean preOrder) { if (preOrder) traversal.add(tree); for (Tree<L> child : tree.getChildren()) { traversalHelper(child, traversal, preOrder); } if (!preOrder) traversal.add(tree); }
private static double testParser(Parser parser, List<Tree<String>> testTrees) { EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> eval = new EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String>( Collections.singleton("ROOT"), new HashSet<String>(Arrays.asList(new String[] {"''", "``", ".", ":", ","}))); for (Tree<String> testTree : testTrees) { List<String> testSentence = testTree.getYield(); if (testSentence.size() > MAX_LENGTH) continue; Tree<String> guessedTree = parser.getBestParse(testSentence); System.out.println("Guess:\n" + Trees.PennTreeRenderer.render(guessedTree)); System.out.println("Gold:\n" + Trees.PennTreeRenderer.render(testTree)); eval.evaluate(guessedTree, testTree); } System.out.println(); return eval.display(true); }
private static Tree<String> binarizeTree(Tree<String> tree) { String label = tree.getLabel(); if (tree.isLeaf()) { return new Tree<String>(label); } /* // [Average] P: 48.53 R: 47.67 F1: 48.10 EX: 4.85 if (tree.getChildren().size() == 1) { return new Tree<String>(label, Collections.singletonList(binarizeTree(tree.getChildren().get(0)))); } */ // [Average] P: 48.53 R: 47.67 F1: 48.10 EX: 4.85 if (tree.getChildren().size() <= 2) { List<Tree<String>> children = new ArrayList<Tree<String>>(2); for (Tree<String> child : tree.getChildren()) { children.add(binarizeTree(child)); } return new Tree<String>(label, children); } // otherwise, it's a binary-or-more local tree, // so decompose it into a sequence of binary and unary trees. String intermediateLabel = "@" + label + "->"; Tree<String> intermediateTree = binarizeTreeHelper(tree, 0, intermediateLabel); return new Tree<String>(label, intermediateTree.getChildren()); }
private BinaryRule makeBinaryRule(Tree<String> tree) { return new BinaryRule( tree.getLabel(), tree.getChildren().get(0).getLabel(), tree.getChildren().get(1).getLabel()); }
private Tree<String> getBestKnownParse(List<String> tags, List<String> sentence) { Tree<String> parse = knownParses.getCounter(tags).argMax().deepCopy(); parse.setWords(sentence); return parse; }
public Tree<String> getBestParse(List<String> sentence) { // TODO: implement this method int n = sentence.size(); // System.out.println("getBestParse: n=" + n); List<List<Map<Object, Double>>> scores = new ArrayList<List<Map<Object, Double>>>(n + 1); for (int i = 0; i < n + 1; i++) { List<Map<Object, Double>> row = new ArrayList<Map<Object, Double>>(n + 1); for (int j = 0; j < n + 1; j++) { row.add(new HashMap<Object, Double>()); } scores.add(row); } List<List<Map<Object, Triplet<Integer, Object, Object>>>> backs = new ArrayList<List<Map<Object, Triplet<Integer, Object, Object>>>>(n + 1); for (int i = 0; i < n + 1; i++) { List<Map<Object, Triplet<Integer, Object, Object>>> row = new ArrayList<Map<Object, Triplet<Integer, Object, Object>>>(n + 1); for (int j = 0; j < n + 1; j++) { row.add(new HashMap<Object, Triplet<Integer, Object, Object>>()); } backs.add(row); } /* System.out.println("scores=" + scores.size() + "x" + scores.get(0).size()); System.out.println("backs=" + backs.size() + "x" + backs.get(0).size()); printChart(scores, backs, "scores"); */ // First the Lexicon for (int i = 0; i < n; i++) { String word = sentence.get(i); for (String tag : lexicon.getAllTags()) { UnaryRule A = new UnaryRule(tag, word); A.setScore(Math.log(lexicon.scoreTagging(word, tag))); scores.get(i).get(i + 1).put(A, A.getScore()); backs.get(i).get(i + 1).put(A, null); } // System.out.println("Starting unaries: i=" + i + ",n=" + n ); // Handle unaries boolean added = true; while (added) { added = false; Map<Object, Double> A_scores = scores.get(i).get(i + 1); // Don't modify the dict we are iterating List<Object> A_keys = copyKeys(A_scores); // for (int j = 0; j < 5 && j < A_keys.size(); j++) { // System.out.print("," + j + "=" + A_scores.get(A_keys.get(j))); // } for (Object oB : A_keys) { UnaryRule B = (UnaryRule) oB; for (UnaryRule A : grammar.getUnaryRulesByChild(B.getParent())) { double prob = Math.log(A.getScore()) + A_scores.get(B); if (prob > -1000.0) { if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { // System.out.print(" *A=" + A + ", B=" + B); // System.out.print(", prob=" + prob); // System.out.println(", A_scores.get(A)=" + A_scores.get(A)); A_scores.put(A, prob); backs.get(i).get(i + 1).put(A, new Triplet<Integer, Object, Object>(-1, B, null)); added = true; } // System.out.println(", added=" + added); } } } // System.out.println(", A_scores=" + A_scores.size() + ", added=" + added); } } // printChart(scores, backs, "scores with Lexicon"); // Do higher layers // Naming is based on rules: A -> B,C long startTime = new Date().getTime(); for (int span = 2; span < n + 1; span++) { for (int begin = 0; begin < n + 1 - span; begin++) { int end = begin + span; Map<Object, Double> A_scores = scores.get(begin).get(end); Map<Object, Triplet<Integer, Object, Object>> A_backs = backs.get(begin).get(end); for (int split = begin + 1; split < end; split++) { Map<Object, Double> B_scores = scores.get(begin).get(split); Map<Object, Double> C_scores = scores.get(split).get(end); List<Object> B_list = new ArrayList<Object>(B_scores.keySet()); List<Object> C_list = new ArrayList<Object>(C_scores.keySet()); // This is a key optimization. !@#$ // It avoids a B_list.size() x C_list.size() search in the for (Object B : B_list) loop Map<String, List<Object>> C_map = new HashMap<String, List<Object>>(); for (Object C : C_list) { String parent = getParent(C); if (!C_map.containsKey(parent)) { C_map.put(parent, new ArrayList<Object>()); } C_map.get(parent).add(C); } for (Object B : B_list) { for (BinaryRule A : grammar.getBinaryRulesByLeftChild(getParent(B))) { if (C_map.containsKey(A.getRightChild())) { for (Object C : C_map.get(A.getRightChild())) { // We now have A which has B as left child and C as right child double prob = Math.log(A.getScore()) + B_scores.get(B) + C_scores.get(C); if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { A_scores.put(A, prob); A_backs.put(A, new Triplet<Integer, Object, Object>(split, B, C)); } } } } } } // Handle unaries: A -> B boolean added = true; while (added) { added = false; // Don't modify the dict we are iterating List<Object> A_keys = copyKeys(A_scores); for (Object oB : A_keys) { for (UnaryRule A : grammar.getUnaryRulesByChild(getParent(oB))) { double prob = Math.log(A.getScore()) + A_scores.get(oB); if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { A_scores.put(A, prob); A_backs.put(A, new Triplet<Integer, Object, Object>(-1, oB, null)); added = true; } } } } } } // printChart(scores, backs, "scores with Lexicon and Grammar"); Map<Object, Double> topOfChart = scores.get(0).get(n); System.out.println("topOfChart: " + topOfChart.size()); /* for (Object o: topOfChart.keySet()) { System.out.println("o=" + o + ", score=" + topOfChart.getCount(o)); } */ // All parses have "ROOT" at top of tree Object bestKey = null; Object secondBestKey = null; double bestScore = Double.NEGATIVE_INFINITY; double secondBestScore = Double.NEGATIVE_INFINITY; for (Object key : topOfChart.keySet()) { double score = topOfChart.get(key); if (score >= secondBestScore || secondBestKey == null) { secondBestKey = key; secondBestScore = score; } if ("ROOT".equals(getParent(key)) && (score >= bestScore || bestKey == null)) { bestKey = key; bestScore = score; } } if (bestKey == null) { bestKey = secondBestKey; System.out.println("secondBestKey=" + secondBestKey); } if (bestKey == null) { for (Object key : topOfChart.keySet()) { System.out.println("val=" + topOfChart.get(key) + ", key=" + key); } } System.out.println("bestKey=" + bestKey + ", log(prob)=" + topOfChart.get(bestKey)); Tree<String> result = makeTree(backs, 0, n, bestKey); if (!"ROOT".equals(result.getLabel())) { List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(result); result = new Tree<String>("ROOT", children); // !@#$ } /* System.out.println("=================================================="); System.out.println(result); System.out.println("====================^^^^^^========================"); */ return TreeAnnotations.unAnnotateTree(result); }