예제 #1
0
 /*
  * Matrix mult but with min-plus, and iterative. Each min-plus operation
  * that changes the path inserts it into a new queue
  */
 public SparseMatrix apsp() {
   SparseMatrix shortestPaths = new SparseMatrix(this);
   SparseMatrix currentPairs = new SparseMatrix(this.rowDim, this.colDim);
   SparseMatrix newPairs = new SparseMatrix(this.rowDim, this.colDim);
   newPairs = new SparseMatrix(this);
   for (int d = 0; d < this.rowDim; d++) {
     shortestPaths.set(d, d, 0.0);
   }
   for (int d = 0; d < this.rowDim; d++) {
     newPairs.set(d, d, 0.0);
   }
   while (!newPairs.isEmpty()) {
     currentPairs = new SparseMatrix(newPairs);
     newPairs = new SparseMatrix(this.rowDim, this.colDim);
     for (int r : currentPairs.rows) {
       Counter row = currentPairs.getRow(r);
       for (int c : row.keySet()) {
         Counter oRow = this.getRow(c);
         for (int oc : oRow.keySet()) {
           double pathLength = currentPairs.get(r, c) + oRow.get(oc);
           if (pathLength < shortestPaths.getPath(r, oc)) {
             newPairs.set(r, oc, pathLength);
             shortestPaths.set(r, oc, pathLength);
           }
         }
       }
     }
   }
   return shortestPaths;
 }
예제 #2
0
 /*
  * Should return path from initial node to terminal;
  * Means getting new path-backtracker
  */
 public Counter bfs(int initNode, int terminalNode) {
   Counter shortestPaths = new Counter();
   Counter currentNodes = new Counter();
   Counter newNodes = new Counter();
   shortestPaths.add(initNode, 0);
   newNodes.add(initNode, 0);
   while (!newNodes.isEmpty()) {
     currentNodes = new Counter(newNodes);
     newNodes = new Counter();
     for (int r : currentNodes.keySet()) {
       Counter row = this.getRow(r);
       for (int c : row.keySet()) {
         // Weighted case: later paths can be shorter than first path
         double pathLength = currentNodes.get(r) + row.get(c);
         if (pathLength < shortestPaths.getPath(c)) {
           newNodes.put(c, pathLength);
           shortestPaths.put(c, pathLength);
         }
       }
     }
     if (shortestPaths.getPath(terminalNode) < Double.MAX_VALUE / 2.0) {
       break;
     }
   }
   return shortestPaths;
 }
예제 #3
0
파일: UnigramLM.java 프로젝트: hltcoe/parma
  /**
   * GT smoothing with least squares interpolation. This follows the procedure in Jurafsky and
   * Martin sect. 4.5.3.
   */
  public void smoothAndNormalize() {
    Counter<Integer> cntCounter = new Counter<Integer>();
    for (K tok : lm.keySet()) {
      int cnt = (int) lm.getCount(tok);
      cntCounter.incrementCount(cnt);
    }

    final double[] coeffs = runLogSpaceRegression(cntCounter);

    UNK_PROB = cntCounter.getCount(1) / lm.totalCount();

    for (K tok : lm.keySet()) {
      double tokCnt = lm.getCount(tok);
      if (tokCnt <= unkCutoff) // Treat as unknown
      unkTokens.add(tok);
      if (tokCnt <= kCutoff) { // Smooth
        double cSmooth = katzEstimate(cntCounter, tokCnt, coeffs);
        lm.setCount(tok, cSmooth);
      }
    }

    // Normalize
    // Counters.normalize(lm);
    // MY COUNTER IS ALWAYS NORMALIZED AND AWESOME
  }
예제 #4
0
파일: UnigramLM.java 프로젝트: hltcoe/parma
 public double getCount(K token) {
   if (!lm.keySet().contains(token)) {
     System.err.println(lm.keySet().size());
     throw new RuntimeException("token not in keyset");
   }
   return lm.getCount(token);
 }
예제 #5
0
 public static Counter PersonalizedPageRank(Counter seedVector, SparseMatrix transitionMat) {
   double trimEps = 0.0;
   int iterLimit = 200;
   transitionMat = transitionMat.stochasticizeRows();
   double beta = 0.85;
   long start;
   long end;
   Counter vector = new Counter(seedVector);
   double svCard = 0;
   for (int i : seedVector.keySet()) {
     svCard += seedVector.get(i);
   }
   Counter oldVector = new Counter(vector);
   vector = transitionMat.multiply(vector);
   double diff = 1;
   int t = 0;
   start = System.currentTimeMillis();
   while (diff > Math.pow(10.0, -10.0) && (diff < 3.99999 || diff > 4.00001) && t < iterLimit) {
     t += 1;
     vector.trimKeys(trimEps);
     vector = transitionMat.multiply(vector);
     Set<Integer> vecSeedUnion = vector.concreteKeySet();
     vecSeedUnion.addAll(seedVector.concreteKeySet());
     for (int i : vecSeedUnion) {
       //				vector.set(i,
       // beta*vector.get(i)/norm+(1-beta)*seedVector.get(i));///Math.max(norm,1.0));
       vector.set(
           i,
           beta * vector.get(i)
               + (1 - beta) * seedVector.get(i) / svCard); // /Math.max(norm,1.0));
     }
     double norm = 0;
     for (int i : vector.keySet()) {
       //				norm += vector.get(i)*vector.get(i);
       //				norm += Math.abs(vector.get(i));
       norm += vector.get(i);
     }
     //			norm = Math.sqrt(norm);
     diff = 0;
     Set<Integer> vecOldUnion = vector.concreteKeySet();
     vecOldUnion.addAll(oldVector.concreteKeySet());
     for (int i : vecOldUnion) {
       diff += (oldVector.get(i) - vector.get(i)) * (oldVector.get(i) - vector.get(i));
     }
     System.out.println(diff + " " + norm);
     //			System.out.println(vector.toString());
     // System.out.println(oldVector.toString());
     oldVector = new Counter(vector);
   }
   //		System.out.println(transitionMat.toStringValues());
   end = System.currentTimeMillis();
   System.out.println("Time: " + (end - start) + " iterations: " + t);
   return vector;
 }
예제 #6
0
 /**
  * Builds a Trellis over a sentence, by starting at the state State, and advancing through all
  * legal extensions of each state already in the trellis. You should not have to modify this
  * code (or even read it, really).
  */
 private Trellis<State> buildTrellis(List<String> sentence) {
   Trellis<State> trellis = new Trellis<State>();
   trellis.setStartState(State.getStartState());
   State stopState = State.getStopState(sentence.size() + 2);
   trellis.setStopState(stopState);
   Set<State> states = Collections.singleton(State.getStartState());
   for (int position = 0; position <= sentence.size() + 1; position++) {
     Set<State> nextStates = new HashSet<State>();
     for (State state : states) {
       if (state.equals(stopState)) continue;
       LocalTrigramContext localTrigramContext =
           new LocalTrigramContext(
               sentence, position, state.getPreviousPreviousTag(), state.getPreviousTag());
       Counter<String> tagScores = localTrigramScorer.getLogScoreCounter(localTrigramContext);
       for (String tag : tagScores.keySet()) {
         double score = tagScores.getCount(tag);
         State nextState = state.getNextState(tag);
         trellis.setTransitionCount(state, nextState, score);
         nextStates.add(nextState);
       }
     }
     //        System.out.println("States: "+nextStates);
     states = nextStates;
   }
   return trellis;
 }
예제 #7
0
 /*
  * Takes a set of sketch nodes, and returns an ArrayList<Integer> such that
  * arr.get(i) gives the index of the sketch node that node i is closest too.
  *
  * Need to work the return values a little bit. Make a proper data
  * structure.
  */
 public ArrayList<ArrayList<Integer>> distSketch(int len, Counter sketchNodes) {
   ArrayList<Integer> closestIndex = new ArrayList<Integer>();
   for (int i = 0; i < len; i++) closestIndex.set(i, -1);
   ArrayList<Double> closestDist = new ArrayList<Double>();
   for (int i = 0; i < len; i++) closestDist.set(i, Double.MAX_VALUE);
   ArrayList<ArrayList<Integer>> sketchReverseIndex = new ArrayList<ArrayList<Integer>>();
   for (int index : sketchNodes.keySet()) {
     Counter distances = this.bfs(index);
     for (int j = 0; j < len; j++) {
       double curDist = closestDist.get(j);
       double dist = distances.getPath(index);
       if (dist < curDist) {
         closestIndex.set(j, index);
       }
     }
     sketchReverseIndex.add(new ArrayList<Integer>());
   }
   for (int j = 0; j < len; j++) {
     int closest = closestIndex.get(j);
     sketchReverseIndex.get(closest).add(j);
   }
   // Return sketchReverseIndex, closestIndex forward index, and index
   // correspondence bimap
   return sketchReverseIndex;
 }
예제 #8
0
 public void removeEntries(SparseMatrix redundant) {
   for (int r : redundant.getRows()) {
     Counter row = redundant.getRow(r);
     for (int c : row.keySet()) {
       this.remove(r, c);
     }
   }
 }
예제 #9
0
 public SparseMatrix transpose() {
   SparseMatrix transp = new SparseMatrix(this.rowDim, this.colDim);
   for (int r : rows) {
     Counter row = this.getRow(r);
     for (int c : row.keySet()) {
       double v = row.get(c);
       transp.set(c, r, v);
     }
   }
   return transp;
 }
예제 #10
0
 public SparseMatrix makeLaplacian() {
   SparseMatrix laplacian = new SparseMatrix(this.rowDim, this.colDim);
   for (int r : this.getRows()) {
     Counter row = this.getRow(r);
     laplacian.set(r, r, row.sum());
     for (int c : row.keySet()) {
       laplacian.set(r, c, -1 * row.get(c));
     }
   }
   return laplacian;
 }
예제 #11
0
 public SparseMatrix multiply(SparseMatrix other) {
   SparseMatrix mult = new SparseMatrix(this.rowDim, other.colDim);
   for (int r : rows) {
     // System.out.println("multiplying row: "+ r);
     Counter row = this.getRow(r);
     // for(int c: other.cols){
     // Counter col = other.getCol(c);
     // double dotProd = row.dot(col);
     // System.out.println(row.toString()+" "+col.toString()+" "+dotProd);
     // mult.set(r, c, dotProd);
     // }
     for (int c : row.keySet()) {
       Counter oRow = other.getRow(c);
       for (int oc : oRow.keySet()) {
         mult.add(r, oc, row.get(c) * oRow.get(oc));
       }
     }
   }
   return mult;
 }
예제 #12
0
 /* A builds PCFG using the observed counts of binary and unary
  * productions in the training trees to estimate the probabilities
  * for those rules.
  */
 public Grammar(List<Tree<String>> trainTrees) {
   Counter<UnaryRule> unaryRuleCounter = new Counter<UnaryRule>();
   Counter<BinaryRule> binaryRuleCounter = new Counter<BinaryRule>();
   Counter<String> symbolCounter = new Counter<String>();
   for (Tree<String> trainTree : trainTrees) {
     tallyTree(trainTree, symbolCounter, unaryRuleCounter, binaryRuleCounter);
   }
   for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
     double unaryProbability =
         unaryRuleCounter.getCount(unaryRule) / symbolCounter.getCount(unaryRule.getParent());
     unaryRule.setScore(unaryProbability);
     addUnary(unaryRule);
   }
   for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
     double binaryProbability =
         binaryRuleCounter.getCount(binaryRule) / symbolCounter.getCount(binaryRule.getParent());
     binaryRule.setScore(binaryProbability);
     addBinary(binaryRule);
   }
 }
예제 #13
0
 public void addRow(int r, Counter other) {
   // System.out.println("MSG: added row "+r);
   Counter row = this.getRow(r);
   if (row.isEmpty()) {
     mat.put(r, row);
     rows.add(r);
   }
   for (int c : other.keySet()) {
     cols.add(c);
   }
   row.addAll(other);
 }
예제 #14
0
 public Counter<String> getLogScoreCounter(LocalTrigramContext localTrigramContext) {
   int position = localTrigramContext.getPosition();
   String word = localTrigramContext.getWords().get(position);
   Counter<String> tagCounter = unknownWordTags;
   if (wordsToTags.keySet().contains(word)) {
     tagCounter = wordsToTags.getCounter(word);
   }
   Set<String> allowedFollowingTags =
       allowedFollowingTags(
           tagCounter.keySet(),
           localTrigramContext.getPreviousPreviousTag(),
           localTrigramContext.getPreviousTag());
   Counter<String> logScoreCounter = new Counter<String>();
   for (String tag : tagCounter.keySet()) {
     double logScore = Math.log(tagCounter.getCount(tag));
     if (!restrictTrigrams
         || allowedFollowingTags.isEmpty()
         || allowedFollowingTags.contains(tag)) logScoreCounter.setCount(tag, logScore);
   }
   return logScoreCounter;
 }
예제 #15
0
 public SparseMatrix stochasticizeRows() {
   SparseMatrix stochasticMat = new SparseMatrix(this.rowDim, this.colDim);
   double[] rowSums = new double[this.rowDim];
   for (int r : this.rows) {
     Counter row = this.getRow(r);
     for (int c : row.keySet()) {
       rowSums[r] += row.get(c);
     }
   }
   for (int r : this.rows) {
     Counter row = this.getRow(r);
     for (int c : row.keySet()) {
       double value = 0;
       if (rowSums[r] != 0) {
         //				if(true){
         value = this.get(r, c) / rowSums[r];
       }
       stochasticMat.set(r, c, value);
     }
   }
   return stochasticMat;
 }
예제 #16
0
파일: UnigramLM.java 프로젝트: hltcoe/parma
  private double[] runLogSpaceRegression(Counter<Integer> cntCounter) {
    SimpleRegression reg = new SimpleRegression();

    for (int cnt : cntCounter.keySet()) {
      reg.addData(cnt, Math.log(cntCounter.getCount(cnt)));
    }

    // System.out.println(reg.getIntercept());
    // System.out.println(reg.getSlope());
    // System.out.println(regression.getSlopeStdErr());

    double[] coeffs = new double[] {reg.getIntercept(), reg.getSlope()};

    return coeffs;
  }
예제 #17
0
 public static Counter TopEig(SparseMatrix mat) {
   double trimEps = 0.0;
   int iterLimit = 1000;
   long start;
   long end;
   Counter vector = new Counter();
   int vecLen = mat.colDim;
   //		for(int i = 0; i < vecLen; i++){
   //			vector.add(i, Math.random()-0.5);
   //		}
   for (int i = 0; i < vecLen; i++) {
     vector.add(i, 1.0 / (double) vecLen);
   }
   Counter oldVector = new Counter(vector);
   vector = mat.multiply(vector);
   double diff = 1;
   int t = 0;
   start = System.currentTimeMillis();
   while (diff > Math.pow(10.0, -10.0) && (diff < 3.99999 || diff > 4.00001) && t < iterLimit) {
     t += 1;
     vector.trimKeys(trimEps);
     vector = mat.multiply(vector);
     double norm = 0;
     for (int i : vector.keySet()) {
       norm += vector.get(i) * vector.get(i);
       //				norm += Math.abs(vector.get(i));
     }
     norm = Math.sqrt(norm);
     vector.multiply(1.0 / norm);
     diff = 0;
     Set<Integer> vecOldUnion = vector.concreteKeySet();
     vecOldUnion.addAll(oldVector.concreteKeySet());
     for (int i : vecOldUnion) {
       diff += (oldVector.get(i) - vector.get(i)) * (oldVector.get(i) - vector.get(i));
     }
     System.out.println(diff + " " + norm);
     //			System.out.println(vector.toString());
     // System.out.println(oldVector.toString());
     oldVector = new Counter(vector);
   }
   //		System.out.println(mat.toStringValues());
   end = System.currentTimeMillis();
   System.out.println("Time: " + (end - start) + " iterations: " + t);
   return vector;
 }
예제 #18
0
 public boolean isKnown(String word) {
   return wordCounter.keySet().contains(word);
 }
예제 #19
0
 public Set<String> getAllTags() {
   return tagCounter.keySet();
 }
예제 #20
0
파일: UnigramLM.java 프로젝트: hltcoe/parma
 public Set<K> getVocab() {
   return Collections.unmodifiableSet(lm.keySet());
 }
예제 #21
0
    public Tree<String> getBestParse(List<String> sentence) {
      // This implements the CKY algorithm
      int nEntries = sentence.size();

      // hashmap to store back rules
      HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash =
          new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>();

      // more efficient access with arrays, but must cast each time :(
      @SuppressWarnings("unchecked")
      Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]);

      for (int i = 0; i < nEntries; i++) {
        for (int j = 0; j < nEntries; j++) {
          parseScores[i][j] = new Counter<String>();
        }
      }

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores[index][index + span - 1].setCount(tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now

      // System.out.println("Lexicons found");
      boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores[index][index + span - 1];
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry);
              if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) {
                parseScores[index][index + span - 1].setCount(rule.parent, prob);
                backHash.put(
                    new Triplet<Integer, Integer, String>(index, index + span, rule.parent),
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores[begin][split - 1];
            Counter<String> countRight = parseScores[split][end - 1];
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores[begin][split - 1].getCount(ruleRight.leftChild)
                        * parseScores[split][end - 1].getCount(ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob);
                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent),
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores[begin][end - 1];
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry);
                if (prob > parseScores[begin][end - 1].getCount(rule.parent)) {
                  parseScores[begin][end - 1].setCount(rule.parent, prob);

                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, rule.parent),
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());

      // Pick the argmax
      String parent = parseScores[0][nEntries - 1].argMax();

      // Or pick root. This second one is preferred since sentences are meant to have ROOT as their
      // root node.
      parent = "ROOT";
      parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
예제 #22
0
    public Tree<String> getBestParseOld(List<String> sentence) {
      // TODO: This implements the CKY algorithm

      CounterMap<String, String> parseScores = new CounterMap<String, String>();

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores.setCount(index + " " + (index + span), tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now
      HashMap<String, Triplet<Integer, String, String>> backHash =
          new HashMap<
              String, Triplet<Integer, String, String>>(); // hashmap to store back propation

      // System.out.println("Lexicons found");
      Boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores.getCounter(index + " " + (index + span));
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob =
                  rule.getScore() * parseScores.getCount(index + " " + (index + span), entry);
              if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) {
                parseScores.setCount(index + " " + (index + span), rule.parent, prob);
                backHash.put(
                    index + " " + (index + span) + " " + rule.parent,
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores.getCounter(begin + " " + split);
            Counter<String> countRight = parseScores.getCounter(split + " " + end);
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores.getCount(begin + " " + split, ruleRight.leftChild)
                        * parseScores.getCount(split + " " + end, ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob);
                  backHash.put(
                      begin + " " + end + " " + ruleRight.parent,
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores.getCounter(begin + " " + end);
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry);
                if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) {
                  parseScores.setCount(begin + " " + (end), rule.parent, prob);

                  backHash.put(
                      begin + " " + (end) + " " + rule.parent,
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());
      String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax();
      if (parent == null) {
        System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString());
        System.out.println("THIS IS WEIRD");
      }
      parent = "ROOT";
      parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
예제 #23
0
파일: UnigramLM.java 프로젝트: hltcoe/parma
 public int vocabSize() {
   return lm.keySet().size();
 }