示例#1
0
 /**
  * Builds a Trellis over a sentence, by starting at the state State, and advancing through all
  * legal extensions of each state already in the trellis. You should not have to modify this
  * code (or even read it, really).
  */
 private Trellis<State> buildTrellis(List<String> sentence) {
   Trellis<State> trellis = new Trellis<State>();
   trellis.setStartState(State.getStartState());
   State stopState = State.getStopState(sentence.size() + 2);
   trellis.setStopState(stopState);
   Set<State> states = Collections.singleton(State.getStartState());
   for (int position = 0; position <= sentence.size() + 1; position++) {
     Set<State> nextStates = new HashSet<State>();
     for (State state : states) {
       if (state.equals(stopState)) continue;
       LocalTrigramContext localTrigramContext =
           new LocalTrigramContext(
               sentence, position, state.getPreviousPreviousTag(), state.getPreviousTag());
       Counter<String> tagScores = localTrigramScorer.getLogScoreCounter(localTrigramContext);
       for (String tag : tagScores.keySet()) {
         double score = tagScores.getCount(tag);
         State nextState = state.getNextState(tag);
         trellis.setTransitionCount(state, nextState, score);
         nextStates.add(nextState);
       }
     }
     //        System.out.println("States: "+nextStates);
     states = nextStates;
   }
   return trellis;
 }
示例#2
0
 private List<LabeledLocalTrigramContext> extractLabeledLocalTrigramContexts(
     List<TaggedSentence> taggedSentences) {
   List<LabeledLocalTrigramContext> localTrigramContexts =
       new ArrayList<LabeledLocalTrigramContext>();
   for (TaggedSentence taggedSentence : taggedSentences) {
     localTrigramContexts.addAll(extractLabeledLocalTrigramContexts(taggedSentence));
   }
   return localTrigramContexts;
 }
示例#3
0
 public static List<String> toTagList(List<State> states) {
   List<String> tags = new ArrayList<String>();
   if (states.size() > 0) {
     tags.add(states.get(0).getPreviousPreviousTag());
     for (State state : states) {
       tags.add(state.getPreviousTag());
     }
   }
   return tags;
 }
 private static List<SentencePair> readSentencePairs(String path, int maxSentencePairs) {
   List<SentencePair> sentencePairs = new ArrayList<SentencePair>();
   List<String> baseFileNames = getBaseFileNames(path);
   for (String baseFileName : baseFileNames) {
     if (sentencePairs.size() >= maxSentencePairs)
       continue;
     sentencePairs.addAll(readSentencePairs(baseFileName));
   }
   return sentencePairs;
 }
示例#5
0
 public String toString() {
   StringBuilder sb = new StringBuilder();
   for (int position = 0; position < words.size(); position++) {
     String word = words.get(position);
     String tag = tags.get(position);
     sb.append(word);
     sb.append("_");
     sb.append(tag);
   }
   return sb.toString();
 }
示例#6
0
 public List<S> getBestPath(Trellis<S> trellis) {
   List<S> states = new ArrayList<S>();
   S currentState = trellis.getStartState();
   states.add(currentState);
   while (!currentState.equals(trellis.getEndState())) {
     Counter<S> transitions = trellis.getForwardTransitions(currentState);
     S nextState = transitions.argMax();
     states.add(nextState);
     currentState = nextState;
   }
   return states;
 }
示例#7
0
    public boolean equals(Object o) {
      if (this == o) return true;
      if (!(o instanceof TaggedSentence)) return false;

      final TaggedSentence taggedSentence = (TaggedSentence) o;

      if (tags != null ? !tags.equals(taggedSentence.tags) : taggedSentence.tags != null)
        return false;
      if (words != null ? !words.equals(taggedSentence.words) : taggedSentence.words != null)
        return false;

      return true;
    }
示例#8
0
 private static void labelTestSet(
     POSTagger posTagger, List<TaggedSentence> testSentences, String path) throws Exception {
   BufferedWriter writer = new BufferedWriter(new FileWriter(path));
   for (TaggedSentence sentence : testSentences) {
     List<String> words = sentence.getWords();
     List<String> guessedTags = posTagger.tag(words);
     for (int i = 0; i < words.size(); i++) {
       writer.write(words.get(i) + "\t" + guessedTags.get(i) + "\n");
     }
     writer.write("\n");
   }
   writer.close();
 }
 private static List<String> getBaseFileNames(String path) {
   List<File> englishFiles = IOUtils.getFilesUnder(path, new FileFilter() {
     public boolean accept(File pathname) {
       if (pathname.isDirectory())
         return true;
       String name = pathname.getName();
       return name.endsWith(ENGLISH_EXTENSION);
     }
   });
   List<String> baseFileNames = new ArrayList<String>();
   for (File englishFile : englishFiles) {
     String baseFileName = chop(englishFile.getAbsolutePath(), "."+ENGLISH_EXTENSION);
     baseFileNames.add(baseFileName);
   }
   return baseFileNames;
 }
 private static Pair<Integer, List<String>> readSentence(String line) {
   int id = -1;
   List<String> words = new ArrayList<String>();
   String[] tokens = line.split("\\s+");
   for (int i = 0; i < tokens.length; i++) {
     String token = tokens[i];
     if (token.equals("<s")) continue;
     if (token.equals("</s>")) continue;
     if (token.startsWith("snum=")) {
       String idString = token.substring(5,token.length()-1);
       id = Integer.parseInt(idString);
       continue;
     }
     words.add(token.intern());
   }
   return new Pair<Integer, List<String>>(id, words);
 }
示例#11
0
 private List<LabeledLocalTrigramContext> extractLabeledLocalTrigramContexts(
     TaggedSentence taggedSentence) {
   List<LabeledLocalTrigramContext> labeledLocalTrigramContexts =
       new ArrayList<LabeledLocalTrigramContext>();
   List<String> words =
       new BoundedList<String>(taggedSentence.getWords(), START_WORD, STOP_WORD);
   List<String> tags = new BoundedList<String>(taggedSentence.getTags(), START_TAG, STOP_TAG);
   for (int position = 0; position <= taggedSentence.size() + 1; position++) {
     labeledLocalTrigramContexts.add(
         new LabeledLocalTrigramContext(
             words,
             position,
             tags.get(position - 2),
             tags.get(position - 1),
             tags.get(position)));
   }
   return labeledLocalTrigramContexts;
 }
 public String toString() {
   StringBuilder sb = new StringBuilder();
   for (int englishPosition = 0; englishPosition < englishWords.size(); englishPosition++) {
     String englishWord = englishWords.get(englishPosition);
     sb.append(englishPosition);
     sb.append(":");
     sb.append(englishWord);
     sb.append(" ");
   }
   sb.append("\n");
   for (int frenchPosition = 0; frenchPosition < frenchWords.size(); frenchPosition++) {
     String frenchWord = frenchWords.get(frenchPosition);
     sb.append(frenchPosition);
     sb.append(":");
     sb.append(frenchWord);
     sb.append(" ");
   }
   sb.append("\n");
   return sb.toString();
 }
	  public Alignment alignSentencePair(SentencePair sentencePair) {
		  Alignment alignment = new Alignment();
	      List<String> frenchWords = sentencePair.getFrenchWords();
	      List<String> englishWords = sentencePair.getEnglishWords();
	      int numFrenchWords = frenchWords.size();
	      int numEnglishWords = englishWords.size();
	      
	      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
	    	  String f = frenchWords.get(frenchPosition);
	    	  int englishMaxPosition = frenchPosition;
	    	  if (englishMaxPosition >= numEnglishWords)
	    		  englishMaxPosition = -1; // map French word to BASELINE if c(f,e) = 0 for all English words
	    	  double maxDice = 0;
	    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
	    		  String e = englishWords.get(englishPosition);
	    		  double dice = getDiceCoefficient(f,e);
	    		  if (dice > maxDice) {
	    			  maxDice = dice;
	    			  englishMaxPosition = englishPosition;
	    		  }
	    	  }	
	    	  alignment.addAlignment(englishMaxPosition, frenchPosition, true);
	      }
		  return alignment;
	  }
	  public Alignment alignSentencePair(SentencePair sentencePair) {
		  Alignment alignment = new Alignment();
		  List<String> frenchWords = sentencePair.getFrenchWords();
	      List<String> englishWords = sentencePair.getEnglishWords();     
	      int numFrenchWords = frenchWords.size();
	      int numEnglishWords = englishWords.size();
	      
		  // Model 1 assumes all alignments are equally likely
	      // So we can just take the argMax of t(f|e) to get the englishMaxPosition
	      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
	    	  String f = frenchWords.get(frenchPosition);
	    	  int englishMaxPosition = -1;
	    	  double maxTranslationProb = translationProbs.getCount(f, NULL);
	    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
	    		  String e = englishWords.get(englishPosition);
	    		  double translationProb = translationProbs.getCount(f, e);
	    		  if (translationProb > maxTranslationProb) {
	    			  maxTranslationProb = translationProb;
	    			  englishMaxPosition = englishPosition;
	    		  }
	    	  }
	    	  alignment.addAlignment(englishMaxPosition, frenchPosition, true);
	      }
		  return alignment;
	  }
	  public Alignment alignSentencePair(SentencePair sentencePair) {
		  Alignment alignment = new Alignment();
	      List<String> frenchWords = sentencePair.getFrenchWords();
	      List<String> englishWords = sentencePair.getEnglishWords();     
	      int numFrenchWords = frenchWords.size();
	      int numEnglishWords = englishWords.size();
	      
	      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
	    	  String f = frenchWords.get(frenchPosition);
	    	  int englishMaxPosition = frenchPosition;
	    	  if (englishMaxPosition >= numEnglishWords)
	    		  englishMaxPosition = -1; // map French word to BASELINE if c(f,e) = 0 for all English words
	    	  double maxConditionalProb = 0;
	    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
	    		  String e = englishWords.get(englishPosition);
	    		  double conditionalGivenEnglish = collocationCounts.getCount(f, e) / (eCounts.getCount(e));
	    		  if (conditionalGivenEnglish > maxConditionalProb) {
	    			  maxConditionalProb = conditionalGivenEnglish;
	    			  englishMaxPosition = englishPosition;
	    		  }
	    	  }	
	    	  alignment.addAlignment(englishMaxPosition, frenchPosition, true);
	      }
		  return alignment;
	  }
示例#16
0
 // pretty-print a pair of taggings for a sentence, possibly suppressing the tags which correctly
 // match
 private static String alignedTaggings(
     List<String> words,
     List<String> goldTags,
     List<String> guessedTags,
     boolean suppressCorrectTags) {
   StringBuilder goldSB = new StringBuilder("Gold Tags: ");
   StringBuilder guessedSB = new StringBuilder("Guessed Tags: ");
   StringBuilder wordSB = new StringBuilder("Words: ");
   for (int position = 0; position < words.size(); position++) {
     equalizeLengths(wordSB, goldSB, guessedSB);
     String word = words.get(position);
     String gold = goldTags.get(position);
     String guessed = guessedTags.get(position);
     wordSB.append(word);
     if (position < words.size() - 1) wordSB.append(' ');
     boolean correct = (gold.equals(guessed));
     if (correct && suppressCorrectTags) continue;
     guessedSB.append(guessed);
     goldSB.append(gold);
   }
   return goldSB + "\n" + guessedSB + "\n" + wordSB;
 }
 private static List<SentencePair> readSentencePairs(String baseFileName) {
   List<SentencePair> sentencePairs = new ArrayList<SentencePair>();
   String englishFileName = baseFileName + "." + ENGLISH_EXTENSION;
   String frenchFileName = baseFileName + "." + FRENCH_EXTENSION;
   try {
     BufferedReader englishIn = new BufferedReader(new FileReader(englishFileName));
     //BufferedReader frenchIn = new BufferedReader(new FileReader(frenchFileName));
     BufferedReader frenchIn = new BufferedReader(new InputStreamReader(
   		  new FileInputStream(frenchFileName), StandardCharsets.ISO_8859_1));
     while (englishIn.ready() && frenchIn.ready()) {
       String englishLine = englishIn.readLine();
       String frenchLine = frenchIn.readLine();
       Pair<Integer,List<String>> englishSentenceAndID = readSentence(englishLine);
       Pair<Integer,List<String>> frenchSentenceAndID = readSentence(frenchLine);
       if (! englishSentenceAndID.getFirst().equals(frenchSentenceAndID.getFirst()))
         throw new RuntimeException("Sentence ID confusion in file "+baseFileName+", lines were:\n\t"+englishLine+"\n\t"+frenchLine);
       sentencePairs.add(new SentencePair(englishSentenceAndID.getFirst(), baseFileName, englishSentenceAndID.getSecond(), frenchSentenceAndID.getSecond()));
     }
   } catch (IOException e) {
     throw new RuntimeException(e);
   }
   return sentencePairs;
 }
示例#18
0
 private static void evaluateTagger(
     POSTagger posTagger,
     List<TaggedSentence> taggedSentences,
     Set<String> trainingVocabulary,
     boolean verbose) {
   double numTags = 0.0;
   double numTagsCorrect = 0.0;
   double numUnknownWords = 0.0;
   double numUnknownWordsCorrect = 0.0;
   int numDecodingInversions = 0;
   for (TaggedSentence taggedSentence : taggedSentences) {
     List<String> words = taggedSentence.getWords();
     List<String> goldTags = taggedSentence.getTags();
     List<String> guessedTags = posTagger.tag(words);
     for (int position = 0; position < words.size() - 1; position++) {
       String word = words.get(position);
       String goldTag = goldTags.get(position);
       String guessedTag = guessedTags.get(position);
       if (guessedTag.equals(goldTag)) numTagsCorrect += 1.0;
       numTags += 1.0;
       if (!trainingVocabulary.contains(word)) {
         if (guessedTag.equals(goldTag)) numUnknownWordsCorrect += 1.0;
         numUnknownWords += 1.0;
       }
     }
     double scoreOfGoldTagging = posTagger.scoreTagging(taggedSentence);
     double scoreOfGuessedTagging = posTagger.scoreTagging(new TaggedSentence(words, guessedTags));
     if (scoreOfGoldTagging > scoreOfGuessedTagging) {
       numDecodingInversions++;
       if (verbose)
         System.out.println(
             "WARNING: Decoder suboptimality detected.  Gold tagging has higher score than guessed tagging.");
     }
     if (verbose) System.out.println(alignedTaggings(words, goldTags, guessedTags, true) + "\n");
   }
   System.out.println(
       "Tag Accuracy: "
           + (numTagsCorrect / numTags)
           + " (Unknown Accuracy: "
           + (numUnknownWordsCorrect / numUnknownWords)
           + ")  Decoder Suboptimalities Detected: "
           + numDecodingInversions);
 }
示例#19
0
 private static List<TaggedSentence> readTaggedSentences(String path, boolean hasTags)
     throws Exception {
   List<TaggedSentence> taggedSentences = new ArrayList<TaggedSentence>();
   BufferedReader reader = new BufferedReader(new FileReader(path));
   String line = "";
   List<String> words = new LinkedList<String>();
   List<String> tags = new LinkedList<String>();
   while ((line = reader.readLine()) != null) {
     if (line.equals("")) {
       taggedSentences.add(
           new TaggedSentence(
               new BoundedList<String>(words, START_WORD, STOP_WORD),
               new BoundedList<String>(tags, START_WORD, STOP_WORD)));
       words = new LinkedList<String>();
       tags = new LinkedList<String>();
     } else {
       String[] fields = line.split("\\s+");
       words.add(fields[0]);
       tags.add(hasTags ? fields[1] : "");
     }
   }
   System.out.println("Read " + taggedSentences.size() + " sentences.");
   return taggedSentences;
 }
示例#20
0
 public int hashCode() {
   int result;
   result = (words != null ? words.hashCode() : 0);
   result = 29 * result + (tags != null ? tags.hashCode() : 0);
   return result;
 }
	  private CounterMap<String,String> trainEM(int maxIterations) {
		  Set<String> englishVocab = new HashSet<String>();
		  Set<String> frenchVocab = new HashSet<String>();
		  
		  CounterMap<String,String> translations = new CounterMap<String,String>();
		  englishVocab.add(NULL);
		  int iteration = 0;
		  final double thresholdProb = 0.0001;
		  
		  for (SentencePair sentencePair : trainingSentencePairs) {
			  List<String> frenchWords = sentencePair.getFrenchWords();
			  List<String> englishWords = sentencePair.getEnglishWords();
			  // add words from list to vocabulary sets
			  englishVocab.addAll(englishWords);
			  frenchVocab.addAll(frenchWords);
		  }
		  System.out.println("Ready");
		  
		  // We need to initialize translations.getCount(f,e) uniformly
		  // t(f|e) summed over all e in {E + NULL} = 1
		  final double initialCount = 1.0 / englishVocab.size();
		  
		  while(iteration < maxIterations) {
			  CounterMap<String,String> counts = new CounterMap<String,String>(); // set count(f|e) to 0 for all e,f
			  Counter<String> totalEnglish = new Counter<String>(); // set total(e) to 0 for all e
			  
			  // E-step: loop over all sentences and update counts
			  for (SentencePair sentencePair : trainingSentencePairs) {
				  List<String> frenchWords = sentencePair.getFrenchWords();
				  List<String> englishWords = sentencePair.getEnglishWords();
				  
			      int numFrenchWords = frenchWords.size();
			      int numEnglishWords = englishWords.size();
			      Counter<String> sTotalF = new Counter<String>(); 
			      
			      // compute normalization constant sTotalF
			      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
			    	  String f = frenchWords.get(frenchPosition);
			    	  // initialize and compute for English = NULL
			    	  if (!translations.containsKey(f) && initialize)
			    		  translations.setCount(f, NULL, initialCount);
			    	  else if (!translations.containsKey(f))
			    		  translations.setCount(f, NULL, thresholdProb);
			    	  sTotalF.incrementCount(f, translations.getCount(f, NULL)); 
			    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
			    		  String e = englishWords.get(englishPosition);
			    		  if (!(translations.getCounter(f)).containsKey(e) && initialize)
			    			  translations.setCount(f, e, initialCount);
			    		  else if (!(translations.getCounter(f)).containsKey(e))
			    			  translations.setCount(f, e, thresholdProb);
			    		  sTotalF.incrementCount(f, translations.getCount(f, e));
			    	  }
			      }
			      
			      // collect counts in counts and totalEnglish
			      for (int frenchPosition = 0; frenchPosition < numFrenchWords; frenchPosition++) {
			    	  String f = frenchWords.get(frenchPosition);
			    	  
			    	  // collect counts for English = NULL
			    	  double count = translations.getCount(f, NULL) / sTotalF.getCount(f);
			    	  counts.incrementCount(NULL, f, count);
			    	  totalEnglish.incrementCount(NULL, count);
			    	  for (int englishPosition = 0; englishPosition < numEnglishWords; englishPosition++) {
			    		  String e = englishWords.get(englishPosition);
			    		  count = translations.getCount(f, e) / sTotalF.getCount(f);
			    		  counts.incrementCount(e, f, count);
			    		  totalEnglish.incrementCount(e, count);
			    	  }
			      }
			  } // end of E-step
			  System.out.println("Completed E-step");
			  
			  // M-step: update probabilities with counts from E-step and check for convergence
			  iteration++;
			  for (String e : counts.keySet()) {//englishVocab) {
				  double normalizer = totalEnglish.getCount(e);
				  for (String f : (counts.getCounter(e)).keySet()) {//frenchVocab) {
					  
					  // To speed implementation, we want to update translations only when count / normalizer > threshold
					  double prob = counts.getCount(e, f) / normalizer;
					  if (!initialize) {					  
						  if (prob > thresholdProb)
							  translations.setCount(f, e, prob);
						  else
							  (translations.getCounter(f)).removeKey(e);
					  }
					  else {
						  translations.setCount(f, e, prob);
					  }
				  }
			  }
			  System.out.println("Completed iteration " + iteration);
		  } // end of M-step
		  
		  System.out.println("Trained!");
		  return translations;
	  }
  public static void main(String[] args) throws IOException {
    // Parse command line flags and arguments
    Map<String,String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Set up default parameters and settings
    String basePath = ".";
    int maxTrainingSentences = 0;
    int maxIterations = 20;
    boolean verbose = false;
    boolean initialize = false;
    String dataset = "mini";
    String model = "baseline";

    // Update defaults using command line specifications
    if (argMap.containsKey("-path")) {
      basePath = argMap.get("-path");
      System.out.println("Using base path: "+basePath);
    }
    if (argMap.containsKey("-sentences")) {
      maxTrainingSentences = Integer.parseInt(argMap.get("-sentences"));
      System.out.println("Using an additional "+maxTrainingSentences+" training sentences.");
    }
    if (argMap.containsKey("-data")) {
      dataset = argMap.get("-data");
      System.out.println("Running with data: "+dataset);
    } else {
      System.out.println("No data set specified.  Use -data [miniTest, validate].");
    }
    if (argMap.containsKey("-model")) {
      model = argMap.get("-model");
      System.out.println("Running with model: "+model);
    } else {
      System.out.println("No model specified.  Use -model modelname.");
    }
    if (argMap.containsKey("-verbose")) {
      verbose = true;
    }
    if (argMap.containsKey("-iterations")) {
    	maxIterations = Integer.parseInt(argMap.get("-iterations"));
    }
    if (argMap.containsKey("-initialize")) {
    	initialize = true;
    }

    // Read appropriate training and testing sets.
    List<SentencePair> trainingSentencePairs = new ArrayList<SentencePair>();
    if (! (dataset.equals("miniTest") || dataset.equals("mini")) && maxTrainingSentences > 0)
      trainingSentencePairs = readSentencePairs(basePath+"/training", maxTrainingSentences);
    List<SentencePair> testSentencePairs = new ArrayList<SentencePair>();
    Map<Integer,Alignment> testAlignments = new HashMap<Integer, Alignment>();
    if (dataset.equalsIgnoreCase("validate")) {
      testSentencePairs = readSentencePairs(basePath+"/trial", Integer.MAX_VALUE);
      testAlignments = readAlignments(basePath+"/trial/trial.wa");
    } else if (dataset.equals("miniTest") || dataset.equals("mini")) {
      testSentencePairs = readSentencePairs(basePath+"/mini", Integer.MAX_VALUE);
      testAlignments = readAlignments(basePath+"/mini/mini.wa");
    } else {
      throw new RuntimeException("Bad data set mode: "+ dataset+", use validate or miniTest.");
    }
    trainingSentencePairs.addAll(testSentencePairs);

    // Build model
    WordAligner wordAligner = null;
    if (model.equalsIgnoreCase("baseline")) {
      wordAligner = new BaselineWordAligner();
    }
    // TODO : build other alignment models
    else if (model.equalsIgnoreCase("heuristic")) {
    	wordAligner = new HeuristicWordAligner(trainingSentencePairs);
    }
    else if (model.equalsIgnoreCase("dice")) {
    	wordAligner = new DiceWordAligner(trainingSentencePairs);
    }
    else if (model.equalsIgnoreCase("ibm1") || model.equalsIgnoreCase("ibmModel1")) {
    	wordAligner = new IBMmodel1WordAligner(trainingSentencePairs, maxIterations, initialize);
    }
    else if (model.equalsIgnoreCase("ibm2") || model.equalsIgnoreCase("ibmModel2")) {
    	wordAligner = new IBMmodel2WordAligner(trainingSentencePairs, maxIterations, initialize);
    }

    // Test model
    test(wordAligner, testSentencePairs, testAlignments, verbose);
    
    // Generate file for submission //can comment out if not ready for submission
    testSentencePairs = readSentencePairs(basePath+"/test", Integer.MAX_VALUE);
    predict(wordAligner, testSentencePairs, basePath+"/"+model+".out");
  }
示例#23
0
 public String getCurrentWord() {
   return words.get(position);
 }
示例#24
0
 private List<String> stripBoundaryTags(List<String> tags) {
   return tags.subList(2, tags.size() - 2);
 }
示例#25
0
 public int size() {
   return words.size();
 }