Ejemplo n.º 1
0
 private static <L> Tree<L> deepCopy(Tree<L> tree) {
   List<Tree<L>> childrenCopies = new ArrayList<Tree<L>>();
   for (Tree<L> child : tree.getChildren()) {
     childrenCopies.add(deepCopy(child));
   }
   return new Tree<L>(tree.getLabel(), childrenCopies);
 }
Ejemplo n.º 2
0
 private Tree<String> merge(Tree<String> leftTree, Tree<String> rightTree) {
   int span = leftTree.getYield().size() + rightTree.getYield().size();
   String mostFrequentLabel = spanToCategories.getCounter(span).argMax();
   List<Tree<String>> children = new ArrayList<Tree<String>>();
   children.add(leftTree);
   children.add(rightTree);
   return new Tree<String>(mostFrequentLabel, children);
 }
Ejemplo n.º 3
0
 private static <L> void appendPreTerminalYield(Tree<L> tree, List<L> yield) {
   if (tree.isPreTerminal()) {
     yield.add(tree.getLabel());
     return;
   }
   for (Tree<L> child : tree.getChildren()) {
     appendPreTerminalYield(child, yield);
   }
 }
Ejemplo n.º 4
0
 private int setWordsHelper(List<L> words, int wordNum) {
   if (isLeaf()) {
     label = words.get(wordNum);
     return wordNum + 1;
   } else {
     for (Tree<L> child : getChildren()) wordNum = child.setWordsHelper(words, wordNum);
     return wordNum;
   }
 }
Ejemplo n.º 5
0
 private static <L> int toConstituentCollectionHelper(
     Tree<L> tree, int start, List<Constituent<L>> constituents) {
   if (tree.isLeaf() || tree.isPreTerminal()) return 1;
   int span = 0;
   for (Tree<L> child : tree.getChildren()) {
     span += toConstituentCollectionHelper(child, start + span, constituents);
   }
   constituents.add(new Constituent<L>(tree.getLabel(), start, start + span));
   return span;
 }
Ejemplo n.º 6
0
 public void train(List<Tree<String>> trainTrees) {
   lexicon = new Lexicon(trainTrees);
   knownParses = new CounterMap<List<String>, Tree<String>>();
   spanToCategories = new CounterMap<Integer, String>();
   for (Tree<String> trainTree : trainTrees) {
     List<String> tags = trainTree.getPreTerminalYield();
     knownParses.incrementCount(tags, trainTree, 1.0);
     tallySpans(trainTree, 0);
   }
 }
Ejemplo n.º 7
0
 private void tallyTree(
     Tree<String> tree,
     Counter<String> symbolCounter,
     Counter<UnaryRule> unaryRuleCounter,
     Counter<BinaryRule> binaryRuleCounter) {
   if (tree.isLeaf()) return;
   if (tree.isPreTerminal()) return;
   if (tree.getChildren().size() == 1) {
     UnaryRule unaryRule = makeUnaryRule(tree);
     symbolCounter.incrementCount(tree.getLabel(), 1.0);
     unaryRuleCounter.incrementCount(unaryRule, 1.0);
   }
   if (tree.getChildren().size() == 2) {
     BinaryRule binaryRule = makeBinaryRule(tree);
     symbolCounter.incrementCount(tree.getLabel(), 1.0);
     binaryRuleCounter.incrementCount(binaryRule, 1.0);
   }
   if (tree.getChildren().size() < 1 || tree.getChildren().size() > 2) {
     throw new RuntimeException(
         "Attempted to construct a Grammar with an illegal tree: " + tree);
   }
   for (Tree<String> child : tree.getChildren()) {
     tallyTree(child, symbolCounter, unaryRuleCounter, binaryRuleCounter);
   }
 }
Ejemplo n.º 8
0
 private int tallySpans(Tree<String> tree, int start) {
   if (tree.isLeaf() || tree.isPreTerminal()) return 1;
   int end = start;
   for (Tree<String> child : tree.getChildren()) {
     int childSpan = tallySpans(child, end);
     end += childSpan;
   }
   String category = tree.getLabel();
   if (!category.equals("ROOT")) spanToCategories.incrementCount(end - start, category, 1.0);
   return end - start;
 }
Ejemplo n.º 9
0
 /* Builds a lexicon from the observed tags in a list of training trees. */
 public Lexicon(List<Tree<String>> trainTrees) {
   for (Tree<String> trainTree : trainTrees) {
     List<String> words = trainTree.getYield();
     List<String> tags = trainTree.getPreTerminalYield();
     for (int position = 0; position < words.size(); position++) {
       String word = words.get(position);
       String tag = tags.get(position);
       tallyTagging(word, tag);
     }
   }
 }
Ejemplo n.º 10
0
 public void toStringBuilder(StringBuilder sb) {
   if (!isLeaf()) sb.append('(');
   if (getLabel() != null) {
     sb.append(getLabel());
   }
   if (!isLeaf()) {
     for (Tree<L> child : getChildren()) {
       sb.append(' ');
       child.toStringBuilder(sb);
     }
     sb.append(')');
   }
 }
Ejemplo n.º 11
0
 private static Tree<String> binarizeTreeHelper(
     Tree<String> tree, int numChildrenGenerated, String intermediateLabel) {
   Tree<String> leftTree = tree.getChildren().get(numChildrenGenerated);
   List<Tree<String>> children = new ArrayList<Tree<String>>();
   children.add(binarizeTree(leftTree));
   if (numChildrenGenerated < tree.getChildren().size() - 1) {
     Tree<String> rightTree =
         binarizeTreeHelper(
             tree, numChildrenGenerated + 1, intermediateLabel + "_" + leftTree.getLabel());
     children.add(rightTree);
   }
   return new Tree<String>(intermediateLabel, children);
 }
Ejemplo n.º 12
0
 private static <L> void traversalHelper(Tree<L> tree, List<Tree<L>> traversal, boolean preOrder) {
   if (preOrder) traversal.add(tree);
   for (Tree<L> child : tree.getChildren()) {
     traversalHelper(child, traversal, preOrder);
   }
   if (!preOrder) traversal.add(tree);
 }
Ejemplo n.º 13
0
  private static double testParser(Parser parser, List<Tree<String>> testTrees) {
    EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String> eval =
        new EnglishPennTreebankParseEvaluator.LabeledConstituentEval<String>(
            Collections.singleton("ROOT"),
            new HashSet<String>(Arrays.asList(new String[] {"''", "``", ".", ":", ","})));
    for (Tree<String> testTree : testTrees) {
      List<String> testSentence = testTree.getYield();

      if (testSentence.size() > MAX_LENGTH) continue;
      Tree<String> guessedTree = parser.getBestParse(testSentence);
      System.out.println("Guess:\n" + Trees.PennTreeRenderer.render(guessedTree));
      System.out.println("Gold:\n" + Trees.PennTreeRenderer.render(testTree));
      eval.evaluate(guessedTree, testTree);
    }
    System.out.println();
    return eval.display(true);
  }
Ejemplo n.º 14
0
    private static Tree<String> binarizeTree(Tree<String> tree) {
      String label = tree.getLabel();
      if (tree.isLeaf()) {
        return new Tree<String>(label);
      }

      /*
        // [Average]    P: 48.53   R: 47.67   F1: 48.10   EX:  4.85
        if (tree.getChildren().size() == 1) {
            return new Tree<String>(label, Collections.singletonList(binarizeTree(tree.getChildren().get(0))));
        }
      */

      //  [Average]    P: 48.53   R: 47.67   F1: 48.10   EX:  4.85
      if (tree.getChildren().size() <= 2) {
        List<Tree<String>> children = new ArrayList<Tree<String>>(2);
        for (Tree<String> child : tree.getChildren()) {
          children.add(binarizeTree(child));
        }
        return new Tree<String>(label, children);
      }

      // otherwise, it's a binary-or-more local tree,
      // so decompose it into a sequence of binary and unary trees.
      String intermediateLabel = "@" + label + "->";
      Tree<String> intermediateTree = binarizeTreeHelper(tree, 0, intermediateLabel);
      return new Tree<String>(label, intermediateTree.getChildren());
    }
Ejemplo n.º 15
0
 private BinaryRule makeBinaryRule(Tree<String> tree) {
   return new BinaryRule(
       tree.getLabel(),
       tree.getChildren().get(0).getLabel(),
       tree.getChildren().get(1).getLabel());
 }
Ejemplo n.º 16
0
 private Tree<String> getBestKnownParse(List<String> tags, List<String> sentence) {
   Tree<String> parse = knownParses.getCounter(tags).argMax().deepCopy();
   parse.setWords(sentence);
   return parse;
 }
Ejemplo n.º 17
0
    public Tree<String> getBestParse(List<String> sentence) {
      // TODO: implement this method
      int n = sentence.size();

      // System.out.println("getBestParse: n=" + n);

      List<List<Map<Object, Double>>> scores = new ArrayList<List<Map<Object, Double>>>(n + 1);
      for (int i = 0; i < n + 1; i++) {
        List<Map<Object, Double>> row = new ArrayList<Map<Object, Double>>(n + 1);
        for (int j = 0; j < n + 1; j++) {
          row.add(new HashMap<Object, Double>());
        }
        scores.add(row);
      }
      List<List<Map<Object, Triplet<Integer, Object, Object>>>> backs =
          new ArrayList<List<Map<Object, Triplet<Integer, Object, Object>>>>(n + 1);
      for (int i = 0; i < n + 1; i++) {
        List<Map<Object, Triplet<Integer, Object, Object>>> row =
            new ArrayList<Map<Object, Triplet<Integer, Object, Object>>>(n + 1);
        for (int j = 0; j < n + 1; j++) {
          row.add(new HashMap<Object, Triplet<Integer, Object, Object>>());
        }
        backs.add(row);
      }

      /*
      System.out.println("scores=" + scores.size() + "x" + scores.get(0).size());
      System.out.println("backs=" + backs.size() + "x" + backs.get(0).size());
      printChart(scores, backs, "scores");
      */
      // First the Lexicon

      for (int i = 0; i < n; i++) {
        String word = sentence.get(i);
        for (String tag : lexicon.getAllTags()) {
          UnaryRule A = new UnaryRule(tag, word);
          A.setScore(Math.log(lexicon.scoreTagging(word, tag)));
          scores.get(i).get(i + 1).put(A, A.getScore());
          backs.get(i).get(i + 1).put(A, null);
        }

        // System.out.println("Starting unaries: i=" + i + ",n=" + n );
        // Handle unaries
        boolean added = true;
        while (added) {
          added = false;
          Map<Object, Double> A_scores = scores.get(i).get(i + 1);
          // Don't modify the dict we are iterating
          List<Object> A_keys = copyKeys(A_scores);
          // for (int j = 0; j < 5 && j < A_keys.size(); j++) {
          //	System.out.print("," + j + "=" + A_scores.get(A_keys.get(j)));
          // }

          for (Object oB : A_keys) {
            UnaryRule B = (UnaryRule) oB;
            for (UnaryRule A : grammar.getUnaryRulesByChild(B.getParent())) {
              double prob = Math.log(A.getScore()) + A_scores.get(B);
              if (prob > -1000.0) {

                if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                  // System.out.print(" *A=" + A + ", B=" + B);
                  // System.out.print(",  prob=" +  prob);
                  // System.out.println(",  A_scores.get(A)=" +  A_scores.get(A));
                  A_scores.put(A, prob);
                  backs.get(i).get(i + 1).put(A, new Triplet<Integer, Object, Object>(-1, B, null));
                  added = true;
                }
                // System.out.println(", added=" + added);
              }
            }
          }
          // System.out.println(", A_scores=" + A_scores.size() + ", added=" + added);
        }
      }

      // printChart(scores, backs, "scores with Lexicon");

      // Do higher layers
      // Naming is based on rules: A -> B,C

      long startTime = new Date().getTime();
      for (int span = 2; span < n + 1; span++) {

        for (int begin = 0; begin < n + 1 - span; begin++) {
          int end = begin + span;

          Map<Object, Double> A_scores = scores.get(begin).get(end);
          Map<Object, Triplet<Integer, Object, Object>> A_backs = backs.get(begin).get(end);

          for (int split = begin + 1; split < end; split++) {

            Map<Object, Double> B_scores = scores.get(begin).get(split);
            Map<Object, Double> C_scores = scores.get(split).get(end);

            List<Object> B_list = new ArrayList<Object>(B_scores.keySet());
            List<Object> C_list = new ArrayList<Object>(C_scores.keySet());

            // This is a key optimization. !@#$
            // It avoids a B_list.size() x C_list.size() search in the for (Object B : B_list) loop
            Map<String, List<Object>> C_map = new HashMap<String, List<Object>>();
            for (Object C : C_list) {
              String parent = getParent(C);
              if (!C_map.containsKey(parent)) {
                C_map.put(parent, new ArrayList<Object>());
              }
              C_map.get(parent).add(C);
            }

            for (Object B : B_list) {
              for (BinaryRule A : grammar.getBinaryRulesByLeftChild(getParent(B))) {
                if (C_map.containsKey(A.getRightChild())) {
                  for (Object C : C_map.get(A.getRightChild())) {
                    // We now have A which has B as left child and C as right child
                    double prob = Math.log(A.getScore()) + B_scores.get(B) + C_scores.get(C);
                    if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                      A_scores.put(A, prob);
                      A_backs.put(A, new Triplet<Integer, Object, Object>(split, B, C));
                    }
                  }
                }
              }
            }
          }

          // Handle unaries: A -> B
          boolean added = true;
          while (added) {
            added = false;
            // Don't modify the dict we are iterating
            List<Object> A_keys = copyKeys(A_scores);
            for (Object oB : A_keys) {
              for (UnaryRule A : grammar.getUnaryRulesByChild(getParent(oB))) {
                double prob = Math.log(A.getScore()) + A_scores.get(oB);
                if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                  A_scores.put(A, prob);
                  A_backs.put(A, new Triplet<Integer, Object, Object>(-1, oB, null));
                  added = true;
                }
              }
            }
          }
        }
      }

      // printChart(scores, backs, "scores with Lexicon and Grammar");

      Map<Object, Double> topOfChart = scores.get(0).get(n);

      System.out.println("topOfChart: " + topOfChart.size());
      /*
      for (Object o: topOfChart.keySet()) {
          System.out.println("o=" + o + ", score=" + topOfChart.getCount(o));
      }
      */

      // All parses have "ROOT" at top of tree
      Object bestKey = null;
      Object secondBestKey = null;
      double bestScore = Double.NEGATIVE_INFINITY;
      double secondBestScore = Double.NEGATIVE_INFINITY;
      for (Object key : topOfChart.keySet()) {
        double score = topOfChart.get(key);
        if (score >= secondBestScore || secondBestKey == null) {
          secondBestKey = key;
          secondBestScore = score;
        }
        if ("ROOT".equals(getParent(key)) && (score >= bestScore || bestKey == null)) {
          bestKey = key;
          bestScore = score;
        }
      }

      if (bestKey == null) {
        bestKey = secondBestKey;
        System.out.println("secondBestKey=" + secondBestKey);
      }
      if (bestKey == null) {
        for (Object key : topOfChart.keySet()) {
          System.out.println("val=" + topOfChart.get(key) + ", key=" + key);
        }
      }
      System.out.println("bestKey=" + bestKey + ", log(prob)=" + topOfChart.get(bestKey));

      Tree<String> result = makeTree(backs, 0, n, bestKey);
      if (!"ROOT".equals(result.getLabel())) {
        List<Tree<String>> children = new ArrayList<Tree<String>>();
        children.add(result);
        result = new Tree<String>("ROOT", children); // !@#$
      }

      /*
      System.out.println("==================================================");
      System.out.println(result);
      System.out.println("====================^^^^^^========================");
      */
      return TreeAnnotations.unAnnotateTree(result);
    }