public Alignment alignSentencePair(SentencePair sentencePair) {
		  Alignment alignment = new Alignment();
		  List<String> frenchWords = sentencePair.getFrenchWords();
	      List<String> englishWords = sentencePair.getEnglishWords();     
	      int numFrenchWords = frenchWords.size();
	      int numEnglishWords = englishWords.size();
	      
		  // Model 1 assumes all alignments are equally likely
	      // So we can just take the argMax of t(f|e) to get the englishMaxPosition
	      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
	    	  String f = frenchWords.get(frenchPosition);
	    	  int englishMaxPosition = -1;
	    	  double maxTranslationProb = translationProbs.getCount(f, NULL);
	    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
	    		  String e = englishWords.get(englishPosition);
	    		  double translationProb = translationProbs.getCount(f, e);
	    		  if (translationProb > maxTranslationProb) {
	    			  maxTranslationProb = translationProb;
	    			  englishMaxPosition = englishPosition;
	    		  }
	    	  }
	    	  alignment.addAlignment(englishMaxPosition, frenchPosition, true);
	      }
		  return alignment;
	  }
	  public Alignment alignSentencePair(SentencePair sentencePair) {
		  Alignment alignment = new Alignment();
	      List<String> frenchWords = sentencePair.getFrenchWords();
	      List<String> englishWords = sentencePair.getEnglishWords();     
	      int numFrenchWords = frenchWords.size();
	      int numEnglishWords = englishWords.size();
	      
	      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
	    	  String f = frenchWords.get(frenchPosition);
	    	  int englishMaxPosition = frenchPosition;
	    	  if (englishMaxPosition >= numEnglishWords)
	    		  englishMaxPosition = -1; // map French word to BASELINE if c(f,e) = 0 for all English words
	    	  double maxConditionalProb = 0;
	    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
	    		  String e = englishWords.get(englishPosition);
	    		  double conditionalGivenEnglish = collocationCounts.getCount(f, e) / (eCounts.getCount(e));
	    		  if (conditionalGivenEnglish > maxConditionalProb) {
	    			  maxConditionalProb = conditionalGivenEnglish;
	    			  englishMaxPosition = englishPosition;
	    		  }
	    	  }	
	    	  alignment.addAlignment(englishMaxPosition, frenchPosition, true);
	      }
		  return alignment;
	  }
 private Tree<String> merge(Tree<String> leftTree, Tree<String> rightTree) {
   int span = leftTree.getYield().size() + rightTree.getYield().size();
   String mostFrequentLabel = spanToCategories.getCounter(span).argMax();
   List<Tree<String>> children = new ArrayList<Tree<String>>();
   children.add(leftTree);
   children.add(rightTree);
   return new Tree<String>(mostFrequentLabel, children);
 }
	  private double getDiceCoefficient(String f, String e) {
		  double intersection = collocationCountSentences.getCount(f,e);
		  double cardinalityF = fCountSentences.getCount(f);
		  double cardinalityE = eCountSentences.getCount(e);
		  
		  double dice = 2*intersection / (cardinalityF + cardinalityE);
		  return dice;
	  }
 private void tallyTagging(String word, String tag) {
   if (!isKnown(word)) {
     totalWordTypes += 1.0;
     typeTagCounter.incrementCount(tag, 1.0);
   }
   totalTokens += 1.0;
   tagCounter.incrementCount(tag, 1.0);
   wordCounter.incrementCount(word, 1.0);
   wordToTagCounters.incrementCount(word, tag, 1.0);
 }
示例#6
0
 public void train(List<LabeledLocalTrigramContext> labeledLocalTrigramContexts) {
   // collect word-tag counts
   for (LabeledLocalTrigramContext labeledLocalTrigramContext : labeledLocalTrigramContexts) {
     String word = labeledLocalTrigramContext.getCurrentWord();
     String tag = labeledLocalTrigramContext.getCurrentTag();
     if (!wordsToTags.keySet().contains(word)) {
       // word is currently unknown, so tally its tag in the unknown tag counter
       unknownWordTags.incrementCount(tag, 1.0);
     }
     wordsToTags.incrementCount(word, tag, 1.0);
     seenTagTrigrams.add(
         makeTrigramString(
             labeledLocalTrigramContext.getPreviousPreviousTag(),
             labeledLocalTrigramContext.getPreviousTag(),
             labeledLocalTrigramContext.getCurrentTag()));
   }
   wordsToTags = Counters.conditionalNormalize(wordsToTags);
   unknownWordTags = Counters.normalize(unknownWordTags);
 }
示例#7
0
 private int tallySpans(Tree<String> tree, int start) {
   if (tree.isLeaf() || tree.isPreTerminal()) return 1;
   int end = start;
   for (Tree<String> child : tree.getChildren()) {
     int childSpan = tallySpans(child, end);
     end += childSpan;
   }
   String category = tree.getLabel();
   if (!category.equals("ROOT")) spanToCategories.incrementCount(end - start, category, 1.0);
   return end - start;
 }
 /* Returns a smoothed estimate of P(word|tag) */
 public double scoreTagging(String word, String tag) {
   double p_tag = tagCounter.getCount(tag) / totalTokens;
   double c_word = wordCounter.getCount(word);
   double c_tag_and_word = wordToTagCounters.getCount(word, tag);
   if (c_word < 10) { // rare or unknown
     c_word += 1.0;
     c_tag_and_word += typeTagCounter.getCount(tag) / totalWordTypes;
   }
   double p_word = (1.0 + c_word) / (totalTokens + totalWordTypes);
   double p_tag_given_word = c_tag_and_word / c_word;
   return p_tag_given_word / p_tag * p_word;
 }
示例#9
0
 public Counter<String> getLogScoreCounter(LocalTrigramContext localTrigramContext) {
   int position = localTrigramContext.getPosition();
   String word = localTrigramContext.getWords().get(position);
   Counter<String> tagCounter = unknownWordTags;
   if (wordsToTags.keySet().contains(word)) {
     tagCounter = wordsToTags.getCounter(word);
   }
   Set<String> allowedFollowingTags =
       allowedFollowingTags(
           tagCounter.keySet(),
           localTrigramContext.getPreviousPreviousTag(),
           localTrigramContext.getPreviousTag());
   Counter<String> logScoreCounter = new Counter<String>();
   for (String tag : tagCounter.keySet()) {
     double logScore = Math.log(tagCounter.getCount(tag));
     if (!restrictTrigrams
         || allowedFollowingTags.isEmpty()
         || allowedFollowingTags.contains(tag)) logScoreCounter.setCount(tag, logScore);
   }
   return logScoreCounter;
 }
	  private void trainCounters() {
		  for (SentencePair sentencePair : trainingSentencePairs) {
			  List<String> frenchWords = sentencePair.getFrenchWords();
		      List<String> englishWords = sentencePair.getEnglishWords();
		      
		      //fCounts.incrementAll(frenchWords, 1.0); // won't affect the argMax
		      eCounts.incrementAll(englishWords, 1.0);
		      
		      for (String f: frenchWords) {
		    	  for (String e: englishWords)
		    		  collocationCounts.incrementCount(f, e, 1.0);
		      }
		  }
		  System.out.println("Trained!");
	  }
	  private void trainCounters() {
		  for (SentencePair sentencePair : trainingSentencePairs) {
			  List<String> frenchWords = sentencePair.getFrenchWords();
		      List<String> englishWords = sentencePair.getEnglishWords();
		      Set<String> frenchSet = new HashSet<String>(frenchWords);
		      Set<String> englishSet = new HashSet<String>(englishWords);
		      
		      fCountSentences.incrementAll(frenchSet, 1.0); 
		      eCountSentences.incrementAll(englishSet, 1.0);
		      
		      for (String f: frenchSet) {
		    	  for (String e: englishSet)
		    		  collocationCountSentences.incrementCount(f, e, 1.0);
		      }
		  }
		  System.out.println("Trained!");
	  }
示例#12
0
 public void setTransitionCount(S start, S end, double count) {
   forwardTransitions.setCount(start, end, count);
   backwardTransitions.setCount(end, start, count);
 }
示例#13
0
 /**
  * For a given state, returns a counter over what states can precede it in the markov process,
  * along with the cost of that transition.
  */
 public Counter<S> getBackwardTransitions(S state) {
   return backwardTransitions.getCounter(state);
 }
示例#14
0
 /**
  * For a given state, returns a counter over what states can be next in the markov process,
  * along with the cost of that transition. Caution: a state not in the counter is illegal, and
  * should be considered to have cost Double.NEGATIVE_INFINITY, but Counters score items they
  * don't contain as 0.
  */
 public Counter<S> getForwardTransitions(S state) {
   return forwardTransitions.getCounter(state);
 }
示例#15
0
    public Tree<String> getBestParseOld(List<String> sentence) {
      // TODO: This implements the CKY algorithm

      CounterMap<String, String> parseScores = new CounterMap<String, String>();

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores.setCount(index + " " + (index + span), tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now
      HashMap<String, Triplet<Integer, String, String>> backHash =
          new HashMap<
              String, Triplet<Integer, String, String>>(); // hashmap to store back propation

      // System.out.println("Lexicons found");
      Boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores.getCounter(index + " " + (index + span));
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob =
                  rule.getScore() * parseScores.getCount(index + " " + (index + span), entry);
              if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) {
                parseScores.setCount(index + " " + (index + span), rule.parent, prob);
                backHash.put(
                    index + " " + (index + span) + " " + rule.parent,
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores.getCounter(begin + " " + split);
            Counter<String> countRight = parseScores.getCounter(split + " " + end);
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores.getCount(begin + " " + split, ruleRight.leftChild)
                        * parseScores.getCount(split + " " + end, ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob);
                  backHash.put(
                      begin + " " + end + " " + ruleRight.parent,
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores.getCounter(begin + " " + end);
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry);
                if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) {
                  parseScores.setCount(begin + " " + (end), rule.parent, prob);

                  backHash.put(
                      begin + " " + (end) + " " + rule.parent,
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());
      String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax();
      if (parent == null) {
        System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString());
        System.out.println("THIS IS WEIRD");
      }
      parent = "ROOT";
      parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
	  private CounterMap<String,String> trainEM(int maxIterations) {
		  Set<String> englishVocab = new HashSet<String>();
		  Set<String> frenchVocab = new HashSet<String>();
		  
		  CounterMap<String,String> translations = new CounterMap<String,String>();
		  englishVocab.add(NULL);
		  int iteration = 0;
		  final double thresholdProb = 0.0001;
		  
		  for (SentencePair sentencePair : trainingSentencePairs) {
			  List<String> frenchWords = sentencePair.getFrenchWords();
			  List<String> englishWords = sentencePair.getEnglishWords();
			  // add words from list to vocabulary sets
			  englishVocab.addAll(englishWords);
			  frenchVocab.addAll(frenchWords);
		  }
		  System.out.println("Ready");
		  
		  // We need to initialize translations.getCount(f,e) uniformly
		  // t(f|e) summed over all e in {E + NULL} = 1
		  final double initialCount = 1.0 / englishVocab.size();
		  
		  while(iteration < maxIterations) {
			  CounterMap<String,String> counts = new CounterMap<String,String>(); // set count(f|e) to 0 for all e,f
			  Counter<String> totalEnglish = new Counter<String>(); // set total(e) to 0 for all e
			  
			  // E-step: loop over all sentences and update counts
			  for (SentencePair sentencePair : trainingSentencePairs) {
				  List<String> frenchWords = sentencePair.getFrenchWords();
				  List<String> englishWords = sentencePair.getEnglishWords();
				  
			      int numFrenchWords = frenchWords.size();
			      int numEnglishWords = englishWords.size();
			      Counter<String> sTotalF = new Counter<String>(); 
			      
			      // compute normalization constant sTotalF
			      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
			    	  String f = frenchWords.get(frenchPosition);
			    	  // initialize and compute for English = NULL
			    	  if (!translations.containsKey(f) && initialize)
			    		  translations.setCount(f, NULL, initialCount);
			    	  else if (!translations.containsKey(f))
			    		  translations.setCount(f, NULL, thresholdProb);
			    	  sTotalF.incrementCount(f, translations.getCount(f, NULL)); 
			    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
			    		  String e = englishWords.get(englishPosition);
			    		  if (!(translations.getCounter(f)).containsKey(e) && initialize)
			    			  translations.setCount(f, e, initialCount);
			    		  else if (!(translations.getCounter(f)).containsKey(e))
			    			  translations.setCount(f, e, thresholdProb);
			    		  sTotalF.incrementCount(f, translations.getCount(f, e));
			    	  }
			      }
			      
			      // collect counts in counts and totalEnglish
			      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
			    	  String f = frenchWords.get(frenchPosition);
			    	  
			    	  // collect counts for English = NULL
			    	  double count = translations.getCount(f, NULL) / sTotalF.getCount(f);
			    	  counts.incrementCount(NULL, f, count);
			    	  totalEnglish.incrementCount(NULL, count);
			    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
			    		  String e = englishWords.get(englishPosition);
			    		  count = translations.getCount(f, e) / sTotalF.getCount(f);
			    		  counts.incrementCount(e, f, count);
			    		  totalEnglish.incrementCount(e, count);
			    	  }
			      }
			  } // end of E-step
			  System.out.println("Completed E-step");
			  
			  // M-step: update probabilities with counts from E-step and check for convergence
			  iteration++;
			  for (String e : counts.keySet()) {//englishVocab) {
				  double normalizer = totalEnglish.getCount(e);
				  for (String f : (counts.getCounter(e)).keySet()) {//frenchVocab) {
					  
					  // To speed implementation, we want to update translations only when count / normalizer > threshold
					  double prob = counts.getCount(e, f) / normalizer;
					  if (!initialize) {					  
						  if (prob > thresholdProb)
							  translations.setCount(f, e, prob);
						  else
							  (translations.getCounter(f)).removeKey(e);
					  }
					  else {
						  translations.setCount(f, e, prob);
					  }
				  }
			  }
			  System.out.println("Completed iteration " + iteration);
		  } // end of M-step
		  
		  System.out.println("Trained!");
		  return translations;
	  }