Exemplo n.º 1
0
    public void train(List<Tree<String>> trainTrees) {
      // TODO: before you generate your grammar, the training trees
      // need to be binarized so that rules are at most binary
      // Binarize the tree.
      List<Tree<String>> binarizedTrees = new ArrayList<Tree<String>>();

      for (Tree<String> trainTree : trainTrees) {
        Tree<String> newTree = TreeAnnotations.annotateTree(trainTree);
        binarizedTrees.add(newTree);
      }
      lexicon = new Lexicon(binarizedTrees);
      grammar = new Grammar(binarizedTrees);
      System.out.println("trained!!");
    }
Exemplo n.º 2
0
    public void train(List<Tree<String>> trainTrees) {
      // TODO: before you generate your grammar, the training trees
      // need to be binarized so that rules are at most binary

      List<Tree<String>> annotatedTrees = new ArrayList<Tree<String>>();
      for (Tree<String> tree : trainTrees) {
        annotatedTrees.add(TreeAnnotations.annotateTree(tree));
      }

      /*
      System.out.println("trainTrees: " );
      for (Tree<String> tree: trainTrees)  System.out.println("  " + tree);
      System.out.println("annotatedTrees: " );
      for (Tree<String> tree: annotatedTrees)  System.out.println("  " + tree);
       */
      lexicon = new Lexicon(annotatedTrees);
      grammar = new Grammar(annotatedTrees);

      /*
      System.out.println("lexicon: " + lexicon.getAllTags());
      System.out.println("grammar: " + grammar);
      */
      // System.exit(-1);
    }
Exemplo n.º 3
0
    public Tree<String> getBestParse(List<String> sentence) {
      // TODO: implement this method
      int n = sentence.size();

      // System.out.println("getBestParse: n=" + n);

      List<List<Map<Object, Double>>> scores = new ArrayList<List<Map<Object, Double>>>(n + 1);
      for (int i = 0; i < n + 1; i++) {
        List<Map<Object, Double>> row = new ArrayList<Map<Object, Double>>(n + 1);
        for (int j = 0; j < n + 1; j++) {
          row.add(new HashMap<Object, Double>());
        }
        scores.add(row);
      }
      List<List<Map<Object, Triplet<Integer, Object, Object>>>> backs =
          new ArrayList<List<Map<Object, Triplet<Integer, Object, Object>>>>(n + 1);
      for (int i = 0; i < n + 1; i++) {
        List<Map<Object, Triplet<Integer, Object, Object>>> row =
            new ArrayList<Map<Object, Triplet<Integer, Object, Object>>>(n + 1);
        for (int j = 0; j < n + 1; j++) {
          row.add(new HashMap<Object, Triplet<Integer, Object, Object>>());
        }
        backs.add(row);
      }

      /*
      System.out.println("scores=" + scores.size() + "x" + scores.get(0).size());
      System.out.println("backs=" + backs.size() + "x" + backs.get(0).size());
      printChart(scores, backs, "scores");
      */
      // First the Lexicon

      for (int i = 0; i < n; i++) {
        String word = sentence.get(i);
        for (String tag : lexicon.getAllTags()) {
          UnaryRule A = new UnaryRule(tag, word);
          A.setScore(Math.log(lexicon.scoreTagging(word, tag)));
          scores.get(i).get(i + 1).put(A, A.getScore());
          backs.get(i).get(i + 1).put(A, null);
        }

        // System.out.println("Starting unaries: i=" + i + ",n=" + n );
        // Handle unaries
        boolean added = true;
        while (added) {
          added = false;
          Map<Object, Double> A_scores = scores.get(i).get(i + 1);
          // Don't modify the dict we are iterating
          List<Object> A_keys = copyKeys(A_scores);
          // for (int j = 0; j < 5 && j < A_keys.size(); j++) {
          //	System.out.print("," + j + "=" + A_scores.get(A_keys.get(j)));
          // }

          for (Object oB : A_keys) {
            UnaryRule B = (UnaryRule) oB;
            for (UnaryRule A : grammar.getUnaryRulesByChild(B.getParent())) {
              double prob = Math.log(A.getScore()) + A_scores.get(B);
              if (prob > -1000.0) {

                if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                  // System.out.print(" *A=" + A + ", B=" + B);
                  // System.out.print(",  prob=" +  prob);
                  // System.out.println(",  A_scores.get(A)=" +  A_scores.get(A));
                  A_scores.put(A, prob);
                  backs.get(i).get(i + 1).put(A, new Triplet<Integer, Object, Object>(-1, B, null));
                  added = true;
                }
                // System.out.println(", added=" + added);
              }
            }
          }
          // System.out.println(", A_scores=" + A_scores.size() + ", added=" + added);
        }
      }

      // printChart(scores, backs, "scores with Lexicon");

      // Do higher layers
      // Naming is based on rules: A -> B,C

      long startTime = new Date().getTime();
      for (int span = 2; span < n + 1; span++) {

        for (int begin = 0; begin < n + 1 - span; begin++) {
          int end = begin + span;

          Map<Object, Double> A_scores = scores.get(begin).get(end);
          Map<Object, Triplet<Integer, Object, Object>> A_backs = backs.get(begin).get(end);

          for (int split = begin + 1; split < end; split++) {

            Map<Object, Double> B_scores = scores.get(begin).get(split);
            Map<Object, Double> C_scores = scores.get(split).get(end);

            List<Object> B_list = new ArrayList<Object>(B_scores.keySet());
            List<Object> C_list = new ArrayList<Object>(C_scores.keySet());

            // This is a key optimization. !@#$
            // It avoids a B_list.size() x C_list.size() search in the for (Object B : B_list) loop
            Map<String, List<Object>> C_map = new HashMap<String, List<Object>>();
            for (Object C : C_list) {
              String parent = getParent(C);
              if (!C_map.containsKey(parent)) {
                C_map.put(parent, new ArrayList<Object>());
              }
              C_map.get(parent).add(C);
            }

            for (Object B : B_list) {
              for (BinaryRule A : grammar.getBinaryRulesByLeftChild(getParent(B))) {
                if (C_map.containsKey(A.getRightChild())) {
                  for (Object C : C_map.get(A.getRightChild())) {
                    // We now have A which has B as left child and C as right child
                    double prob = Math.log(A.getScore()) + B_scores.get(B) + C_scores.get(C);
                    if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                      A_scores.put(A, prob);
                      A_backs.put(A, new Triplet<Integer, Object, Object>(split, B, C));
                    }
                  }
                }
              }
            }
          }

          // Handle unaries: A -> B
          boolean added = true;
          while (added) {
            added = false;
            // Don't modify the dict we are iterating
            List<Object> A_keys = copyKeys(A_scores);
            for (Object oB : A_keys) {
              for (UnaryRule A : grammar.getUnaryRulesByChild(getParent(oB))) {
                double prob = Math.log(A.getScore()) + A_scores.get(oB);
                if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                  A_scores.put(A, prob);
                  A_backs.put(A, new Triplet<Integer, Object, Object>(-1, oB, null));
                  added = true;
                }
              }
            }
          }
        }
      }

      // printChart(scores, backs, "scores with Lexicon and Grammar");

      Map<Object, Double> topOfChart = scores.get(0).get(n);

      System.out.println("topOfChart: " + topOfChart.size());
      /*
      for (Object o: topOfChart.keySet()) {
          System.out.println("o=" + o + ", score=" + topOfChart.getCount(o));
      }
      */

      // All parses have "ROOT" at top of tree
      Object bestKey = null;
      Object secondBestKey = null;
      double bestScore = Double.NEGATIVE_INFINITY;
      double secondBestScore = Double.NEGATIVE_INFINITY;
      for (Object key : topOfChart.keySet()) {
        double score = topOfChart.get(key);
        if (score >= secondBestScore || secondBestKey == null) {
          secondBestKey = key;
          secondBestScore = score;
        }
        if ("ROOT".equals(getParent(key)) && (score >= bestScore || bestKey == null)) {
          bestKey = key;
          bestScore = score;
        }
      }

      if (bestKey == null) {
        bestKey = secondBestKey;
        System.out.println("secondBestKey=" + secondBestKey);
      }
      if (bestKey == null) {
        for (Object key : topOfChart.keySet()) {
          System.out.println("val=" + topOfChart.get(key) + ", key=" + key);
        }
      }
      System.out.println("bestKey=" + bestKey + ", log(prob)=" + topOfChart.get(bestKey));

      Tree<String> result = makeTree(backs, 0, n, bestKey);
      if (!"ROOT".equals(result.getLabel())) {
        List<Tree<String>> children = new ArrayList<Tree<String>>();
        children.add(result);
        result = new Tree<String>("ROOT", children); // !@#$
      }

      /*
      System.out.println("==================================================");
      System.out.println(result);
      System.out.println("====================^^^^^^========================");
      */
      return TreeAnnotations.unAnnotateTree(result);
    }
Exemplo n.º 4
0
    public Tree<String> getBestParseOld(List<String> sentence) {
      // TODO: This implements the CKY algorithm

      CounterMap<String, String> parseScores = new CounterMap<String, String>();

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores.setCount(index + " " + (index + span), tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now
      HashMap<String, Triplet<Integer, String, String>> backHash =
          new HashMap<
              String, Triplet<Integer, String, String>>(); // hashmap to store back propation

      // System.out.println("Lexicons found");
      Boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores.getCounter(index + " " + (index + span));
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob =
                  rule.getScore() * parseScores.getCount(index + " " + (index + span), entry);
              if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) {
                parseScores.setCount(index + " " + (index + span), rule.parent, prob);
                backHash.put(
                    index + " " + (index + span) + " " + rule.parent,
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores.getCounter(begin + " " + split);
            Counter<String> countRight = parseScores.getCounter(split + " " + end);
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores.getCount(begin + " " + split, ruleRight.leftChild)
                        * parseScores.getCount(split + " " + end, ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob);
                  backHash.put(
                      begin + " " + end + " " + ruleRight.parent,
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores.getCounter(begin + " " + end);
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry);
                if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) {
                  parseScores.setCount(begin + " " + (end), rule.parent, prob);

                  backHash.put(
                      begin + " " + (end) + " " + rule.parent,
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());
      String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax();
      if (parent == null) {
        System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString());
        System.out.println("THIS IS WEIRD");
      }
      parent = "ROOT";
      parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
Exemplo n.º 5
0
    public Tree<String> getBestParse(List<String> sentence) {
      // This implements the CKY algorithm
      int nEntries = sentence.size();

      // hashmap to store back rules
      HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash =
          new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>();

      // more efficient access with arrays, but must cast each time :(
      @SuppressWarnings("unchecked")
      Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]);

      for (int i = 0; i < nEntries; i++) {
        for (int j = 0; j < nEntries; j++) {
          parseScores[i][j] = new Counter<String>();
        }
      }

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores[index][index + span - 1].setCount(tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now

      // System.out.println("Lexicons found");
      boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores[index][index + span - 1];
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry);
              if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) {
                parseScores[index][index + span - 1].setCount(rule.parent, prob);
                backHash.put(
                    new Triplet<Integer, Integer, String>(index, index + span, rule.parent),
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores[begin][split - 1];
            Counter<String> countRight = parseScores[split][end - 1];
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores[begin][split - 1].getCount(ruleRight.leftChild)
                        * parseScores[split][end - 1].getCount(ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob);
                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent),
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores[begin][end - 1];
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry);
                if (prob > parseScores[begin][end - 1].getCount(rule.parent)) {
                  parseScores[begin][end - 1].setCount(rule.parent, prob);

                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, rule.parent),
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());

      // Pick the argmax
      String parent = parseScores[0][nEntries - 1].argMax();

      // Or pick root. This second one is preferred since sentences are meant to have ROOT as their
      // root node.
      parent = "ROOT";
      parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }