예제 #1
0
 /**
  * Builds a Trellis over a sentence, by starting at the state State, and advancing through all
  * legal extensions of each state already in the trellis. You should not have to modify this
  * code (or even read it, really).
  */
 private Trellis<State> buildTrellis(List<String> sentence) {
   Trellis<State> trellis = new Trellis<State>();
   trellis.setStartState(State.getStartState());
   State stopState = State.getStopState(sentence.size() + 2);
   trellis.setStopState(stopState);
   Set<State> states = Collections.singleton(State.getStartState());
   for (int position = 0; position <= sentence.size() + 1; position++) {
     Set<State> nextStates = new HashSet<State>();
     for (State state : states) {
       if (state.equals(stopState)) continue;
       LocalTrigramContext localTrigramContext =
           new LocalTrigramContext(
               sentence, position, state.getPreviousPreviousTag(), state.getPreviousTag());
       Counter<String> tagScores = localTrigramScorer.getLogScoreCounter(localTrigramContext);
       for (String tag : tagScores.keySet()) {
         double score = tagScores.getCount(tag);
         State nextState = state.getNextState(tag);
         trellis.setTransitionCount(state, nextState, score);
         nextStates.add(nextState);
       }
     }
     //        System.out.println("States: "+nextStates);
     states = nextStates;
   }
   return trellis;
 }
예제 #2
0
 private static Set<String> extractVocabulary(List<TaggedSentence> taggedSentences) {
   Set<String> vocabulary = new HashSet<String>();
   for (TaggedSentence taggedSentence : taggedSentences) {
     List<String> words = taggedSentence.getWords();
     vocabulary.addAll(words);
   }
   return vocabulary;
 }
예제 #3
0
 private Set<String> allowedFollowingTags(
     Set<String> tags, String previousPreviousTag, String previousTag) {
   Set<String> allowedTags = new HashSet<String>();
   for (String tag : tags) {
     String trigramString = makeTrigramString(previousPreviousTag, previousTag, tag);
     if (seenTagTrigrams.contains((trigramString))) {
       allowedTags.add(tag);
     }
   }
   return allowedTags;
 }
예제 #4
0
 public Counter<String> getLogScoreCounter(LocalTrigramContext localTrigramContext) {
   int position = localTrigramContext.getPosition();
   String word = localTrigramContext.getWords().get(position);
   Counter<String> tagCounter = unknownWordTags;
   if (wordsToTags.keySet().contains(word)) {
     tagCounter = wordsToTags.getCounter(word);
   }
   Set<String> allowedFollowingTags =
       allowedFollowingTags(
           tagCounter.keySet(),
           localTrigramContext.getPreviousPreviousTag(),
           localTrigramContext.getPreviousTag());
   Counter<String> logScoreCounter = new Counter<String>();
   for (String tag : tagCounter.keySet()) {
     double logScore = Math.log(tagCounter.getCount(tag));
     if (!restrictTrigrams
         || allowedFollowingTags.isEmpty()
         || allowedFollowingTags.contains(tag)) logScoreCounter.setCount(tag, logScore);
   }
   return logScoreCounter;
 }
예제 #5
0
 private static void evaluateTagger(
     POSTagger posTagger,
     List<TaggedSentence> taggedSentences,
     Set<String> trainingVocabulary,
     boolean verbose) {
   double numTags = 0.0;
   double numTagsCorrect = 0.0;
   double numUnknownWords = 0.0;
   double numUnknownWordsCorrect = 0.0;
   int numDecodingInversions = 0;
   for (TaggedSentence taggedSentence : taggedSentences) {
     List<String> words = taggedSentence.getWords();
     List<String> goldTags = taggedSentence.getTags();
     List<String> guessedTags = posTagger.tag(words);
     for (int position = 0; position < words.size() - 1; position++) {
       String word = words.get(position);
       String goldTag = goldTags.get(position);
       String guessedTag = guessedTags.get(position);
       if (guessedTag.equals(goldTag)) numTagsCorrect += 1.0;
       numTags += 1.0;
       if (!trainingVocabulary.contains(word)) {
         if (guessedTag.equals(goldTag)) numUnknownWordsCorrect += 1.0;
         numUnknownWords += 1.0;
       }
     }
     double scoreOfGoldTagging = posTagger.scoreTagging(taggedSentence);
     double scoreOfGuessedTagging = posTagger.scoreTagging(new TaggedSentence(words, guessedTags));
     if (scoreOfGoldTagging > scoreOfGuessedTagging) {
       numDecodingInversions++;
       if (verbose)
         System.out.println(
             "WARNING: Decoder suboptimality detected.  Gold tagging has higher score than guessed tagging.");
     }
     if (verbose) System.out.println(alignedTaggings(words, goldTags, guessedTags, true) + "\n");
   }
   System.out.println(
       "Tag Accuracy: "
           + (numTagsCorrect / numTags)
           + " (Unknown Accuracy: "
           + (numUnknownWordsCorrect / numUnknownWords)
           + ")  Decoder Suboptimalities Detected: "
           + numDecodingInversions);
 }
예제 #6
0
 public void train(List<LabeledLocalTrigramContext> labeledLocalTrigramContexts) {
   // collect word-tag counts
   for (LabeledLocalTrigramContext labeledLocalTrigramContext : labeledLocalTrigramContexts) {
     String word = labeledLocalTrigramContext.getCurrentWord();
     String tag = labeledLocalTrigramContext.getCurrentTag();
     if (!wordsToTags.keySet().contains(word)) {
       // word is currently unknown, so tally its tag in the unknown tag counter
       unknownWordTags.incrementCount(tag, 1.0);
     }
     wordsToTags.incrementCount(word, tag, 1.0);
     seenTagTrigrams.add(
         makeTrigramString(
             labeledLocalTrigramContext.getPreviousPreviousTag(),
             labeledLocalTrigramContext.getPreviousTag(),
             labeledLocalTrigramContext.getCurrentTag()));
   }
   wordsToTags = Counters.conditionalNormalize(wordsToTags);
   unknownWordTags = Counters.normalize(unknownWordTags);
 }
	  private CounterMap<String,String> trainEM(int maxIterations) {
		  Set<String> englishVocab = new HashSet<String>();
		  Set<String> frenchVocab = new HashSet<String>();
		  
		  CounterMap<String,String> translations = new CounterMap<String,String>();
		  englishVocab.add(NULL);
		  int iteration = 0;
		  final double thresholdProb = 0.0001;
		  
		  for (SentencePair sentencePair : trainingSentencePairs) {
			  List<String> frenchWords = sentencePair.getFrenchWords();
			  List<String> englishWords = sentencePair.getEnglishWords();
			  // add words from list to vocabulary sets
			  englishVocab.addAll(englishWords);
			  frenchVocab.addAll(frenchWords);
		  }
		  System.out.println("Ready");
		  
		  // We need to initialize translations.getCount(f,e) uniformly
		  // t(f|e) summed over all e in {E + NULL} = 1
		  final double initialCount = 1.0 / englishVocab.size();
		  
		  while(iteration < maxIterations) {
			  CounterMap<String,String> counts = new CounterMap<String,String>(); // set count(f|e) to 0 for all e,f
			  Counter<String> totalEnglish = new Counter<String>(); // set total(e) to 0 for all e
			  
			  // E-step: loop over all sentences and update counts
			  for (SentencePair sentencePair : trainingSentencePairs) {
				  List<String> frenchWords = sentencePair.getFrenchWords();
				  List<String> englishWords = sentencePair.getEnglishWords();
				  
			      int numFrenchWords = frenchWords.size();
			      int numEnglishWords = englishWords.size();
			      Counter<String> sTotalF = new Counter<String>(); 
			      
			      // compute normalization constant sTotalF
			      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
			    	  String f = frenchWords.get(frenchPosition);
			    	  // initialize and compute for English = NULL
			    	  if (!translations.containsKey(f) && initialize)
			    		  translations.setCount(f, NULL, initialCount);
			    	  else if (!translations.containsKey(f))
			    		  translations.setCount(f, NULL, thresholdProb);
			    	  sTotalF.incrementCount(f, translations.getCount(f, NULL)); 
			    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
			    		  String e = englishWords.get(englishPosition);
			    		  if (!(translations.getCounter(f)).containsKey(e) && initialize)
			    			  translations.setCount(f, e, initialCount);
			    		  else if (!(translations.getCounter(f)).containsKey(e))
			    			  translations.setCount(f, e, thresholdProb);
			    		  sTotalF.incrementCount(f, translations.getCount(f, e));
			    	  }
			      }
			      
			      // collect counts in counts and totalEnglish
			      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
			    	  String f = frenchWords.get(frenchPosition);
			    	  
			    	  // collect counts for English = NULL
			    	  double count = translations.getCount(f, NULL) / sTotalF.getCount(f);
			    	  counts.incrementCount(NULL, f, count);
			    	  totalEnglish.incrementCount(NULL, count);
			    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
			    		  String e = englishWords.get(englishPosition);
			    		  count = translations.getCount(f, e) / sTotalF.getCount(f);
			    		  counts.incrementCount(e, f, count);
			    		  totalEnglish.incrementCount(e, count);
			    	  }
			      }
			  } // end of E-step
			  System.out.println("Completed E-step");
			  
			  // M-step: update probabilities with counts from E-step and check for convergence
			  iteration++;
			  for (String e : counts.keySet()) {//englishVocab) {
				  double normalizer = totalEnglish.getCount(e);
				  for (String f : (counts.getCounter(e)).keySet()) {//frenchVocab) {
					  
					  // To speed implementation, we want to update translations only when count / normalizer > threshold
					  double prob = counts.getCount(e, f) / normalizer;
					  if (!initialize) {					  
						  if (prob > thresholdProb)
							  translations.setCount(f, e, prob);
						  else
							  (translations.getCounter(f)).removeKey(e);
					  }
					  else {
						  translations.setCount(f, e, prob);
					  }
				  }
			  }
			  System.out.println("Completed iteration " + iteration);
		  } // end of M-step
		  
		  System.out.println("Trained!");
		  return translations;
	  }