Ejemplo n.º 1
0
 private void tallyTree(
     Tree<String> tree,
     Counter<String> symbolCounter,
     Counter<UnaryRule> unaryRuleCounter,
     Counter<BinaryRule> binaryRuleCounter) {
   if (tree.isLeaf()) return;
   if (tree.isPreTerminal()) return;
   if (tree.getChildren().size() == 1) {
     UnaryRule unaryRule = makeUnaryRule(tree);
     symbolCounter.incrementCount(tree.getLabel(), 1.0);
     unaryRuleCounter.incrementCount(unaryRule, 1.0);
   }
   if (tree.getChildren().size() == 2) {
     BinaryRule binaryRule = makeBinaryRule(tree);
     symbolCounter.incrementCount(tree.getLabel(), 1.0);
     binaryRuleCounter.incrementCount(binaryRule, 1.0);
   }
   if (tree.getChildren().size() < 1 || tree.getChildren().size() > 2) {
     throw new RuntimeException(
         "Attempted to construct a Grammar with an illegal tree: " + tree);
   }
   for (Tree<String> child : tree.getChildren()) {
     tallyTree(child, symbolCounter, unaryRuleCounter, binaryRuleCounter);
   }
 }
Ejemplo n.º 2
0
 private void tallyTagging(String word, String tag) {
   if (!isKnown(word)) {
     totalWordTypes += 1.0;
     typeTagCounter.incrementCount(tag, 1.0);
   }
   totalTokens += 1.0;
   tagCounter.incrementCount(tag, 1.0);
   wordCounter.incrementCount(word, 1.0);
   wordToTagCounters.incrementCount(word, tag, 1.0);
 }
Ejemplo n.º 3
0
 /* Returns a smoothed estimate of P(word|tag) */
 public double scoreTagging(String word, String tag) {
   double p_tag = tagCounter.getCount(tag) / totalTokens;
   double c_word = wordCounter.getCount(word);
   double c_tag_and_word = wordToTagCounters.getCount(word, tag);
   if (c_word < 10) { // rare or unknown
     c_word += 1.0;
     c_tag_and_word += typeTagCounter.getCount(tag) / totalWordTypes;
   }
   double p_word = (1.0 + c_word) / (totalTokens + totalWordTypes);
   double p_tag_given_word = c_tag_and_word / c_word;
   return p_tag_given_word / p_tag * p_word;
 }
Ejemplo n.º 4
0
 /* A builds PCFG using the observed counts of binary and unary
  * productions in the training trees to estimate the probabilities
  * for those rules.  */
 public Grammar(List<Tree<String>> trainTrees) {
   Counter<UnaryRule> unaryRuleCounter = new Counter<UnaryRule>();
   Counter<BinaryRule> binaryRuleCounter = new Counter<BinaryRule>();
   Counter<String> symbolCounter = new Counter<String>();
   for (Tree<String> trainTree : trainTrees) {
     tallyTree(trainTree, symbolCounter, unaryRuleCounter, binaryRuleCounter);
   }
   for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
     double unaryProbability =
         unaryRuleCounter.getCount(unaryRule) / symbolCounter.getCount(unaryRule.getParent());
     unaryRule.setScore(unaryProbability);
     addUnary(unaryRule);
   }
   for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
     double binaryProbability =
         binaryRuleCounter.getCount(binaryRule) / symbolCounter.getCount(binaryRule.getParent());
     binaryRule.setScore(binaryProbability);
     addBinary(binaryRule);
   }
 }
Ejemplo n.º 5
0
 public boolean isKnown(String word) {
   return wordCounter.keySet().contains(word);
 }
Ejemplo n.º 6
0
 public Set<String> getAllTags() {
   return tagCounter.keySet();
 }
Ejemplo n.º 7
0
    public Tree<String> getBestParseOld(List<String> sentence) {
      // TODO: This implements the CKY algorithm

      CounterMap<String, String> parseScores = new CounterMap<String, String>();

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores.setCount(index + " " + (index + span), tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now
      HashMap<String, Triplet<Integer, String, String>> backHash =
          new HashMap<
              String, Triplet<Integer, String, String>>(); // hashmap to store back propation

      // System.out.println("Lexicons found");
      Boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores.getCounter(index + " " + (index + span));
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob =
                  rule.getScore() * parseScores.getCount(index + " " + (index + span), entry);
              if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) {
                parseScores.setCount(index + " " + (index + span), rule.parent, prob);
                backHash.put(
                    index + " " + (index + span) + " " + rule.parent,
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores.getCounter(begin + " " + split);
            Counter<String> countRight = parseScores.getCounter(split + " " + end);
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores.getCount(begin + " " + split, ruleRight.leftChild)
                        * parseScores.getCount(split + " " + end, ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob);
                  backHash.put(
                      begin + " " + end + " " + ruleRight.parent,
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores.getCounter(begin + " " + end);
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry);
                if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) {
                  parseScores.setCount(begin + " " + (end), rule.parent, prob);

                  backHash.put(
                      begin + " " + (end) + " " + rule.parent,
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());
      String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax();
      if (parent == null) {
        System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString());
        System.out.println("THIS IS WEIRD");
      }
      parent = "ROOT";
      parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
Ejemplo n.º 8
0
    public Tree<String> getBestParse(List<String> sentence) {
      // This implements the CKY algorithm
      int nEntries = sentence.size();

      // hashmap to store back rules
      HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash =
          new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>();

      // more efficient access with arrays, but must cast each time :(
      @SuppressWarnings("unchecked")
      Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]);

      for (int i = 0; i < nEntries; i++) {
        for (int j = 0; j < nEntries; j++) {
          parseScores[i][j] = new Counter<String>();
        }
      }

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores[index][index + span - 1].setCount(tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now

      // System.out.println("Lexicons found");
      boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores[index][index + span - 1];
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry);
              if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) {
                parseScores[index][index + span - 1].setCount(rule.parent, prob);
                backHash.put(
                    new Triplet<Integer, Integer, String>(index, index + span, rule.parent),
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores[begin][split - 1];
            Counter<String> countRight = parseScores[split][end - 1];
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores[begin][split - 1].getCount(ruleRight.leftChild)
                        * parseScores[split][end - 1].getCount(ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob);
                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent),
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores[begin][end - 1];
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry);
                if (prob > parseScores[begin][end - 1].getCount(rule.parent)) {
                  parseScores[begin][end - 1].setCount(rule.parent, prob);

                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, rule.parent),
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());

      // Pick the argmax
      String parent = parseScores[0][nEntries - 1].argMax();

      // Or pick root. This second one is preferred since sentences are meant to have ROOT as their
      // root node.
      parent = "ROOT";
      parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
  public HashMap<String, Entity> getAllEntities(String handle) {
    HashMap<String, Entity> allEntities = new HashMap<String, Entity>();
    try {
      BufferedReader br = new BufferedReader(new FileReader("data/" + handle + ".txt"));
      BufferedWriter bw = new BufferedWriter(new FileWriter("data/" + handle + "_entities.txt"));
      BufferedWriter bw1 = new BufferedWriter(new FileWriter("data/" + handle + "_statistics.txt"));

      String line = "";
      Counter<String> nPhraseCounter = new Counter<String>();
      Counter<String> capitalsCounter = new Counter<String>();
      while ((line = br.readLine()) != null) {
        line = line.replaceAll("RT", "");
        TwitterTokenizer tweetTokenizer = new TwitterTokenizer();
        for (String token : tweetTokenizer.tokenize(line)) {
          token = token.trim();
          token =
              token.replaceAll(
                  "( [^a-zA-Z0-9\\.]) | ( [^a-zA-Z0-9\\.] ) | ([^a-zA-Z0-9\\.] )", " ");
          ArrayList<String> nPhrases = new ArrayList<String>();
          HashSet<String> capitalWords = new HashSet<String>();
          try {
            Pattern p = Pattern.compile("^[A-Z]+.*");
            String[] split = token.split("\\s+");
            for (String s : split) {
              if (p.matcher(s).matches() && !stopWords.contains(s.toLowerCase())) {
                capitalWords.add(s.toLowerCase());
                capitalsCounter.incrementCount(s.toLowerCase(), 1.0);
                if (allEntities.containsKey(s.trim())) {
                  Entity e = allEntities.get(s.trim());
                  if (!e.tweets.contains(line)) {
                    e.tweets.add(line);
                    allEntities.put(s.trim(), e);
                  }
                } else {
                  Entity e = new Entity(s.trim());
                  e.tweets.add(line);
                  allEntities.put(s.trim(), e);
                }
              }
            }
          } catch (Exception e) {
            e.printStackTrace();
          }

          bw.write("===============================================\n");
          bw.write(token + "\n");
          System.out.println("token: " + token);
          for (String np : npe.extract(token)) {
            if (!stopWords.contains(np.trim().toLowerCase())) {
              nPhrases.add(np.trim());
              nPhraseCounter.incrementCount(np.trim(), 1.0);
              if (allEntities.containsKey(np.trim())) {
                Entity e = allEntities.get(np.trim());
                if (!e.tweets.contains(line)) {
                  e.tweets.add(line);
                  allEntities.put(np.trim(), e);
                }
              } else {
                Entity e = new Entity(np.trim());
                e.tweets.add(line);
                allEntities.put(np.trim(), e);
              }
            }
          }
          bw.write("===============================================\n");
          bw.write("Noun-Phrases: " + nPhrases.toString() + "\n");
          // HashSet<String> capitalWords =
          // getCapitalizedWords(token);
          if (capitalWords == null) {
            bw.write("No capitals\n\n");
          } else {
            bw.write("Capitals: " + capitalWords.toString() + "\n\n");
          }
        }

        bw.flush();
        if (true) continue;
      }

      PriorityQueue<String> nPhraseQueue = nPhraseCounter.asPriorityQueue();
      PriorityQueue<String> capitalQueue = capitalsCounter.asPriorityQueue();
      while (nPhraseQueue.hasNext()) {
        String np = nPhraseQueue.next();
        bw1.write(np + " " + nPhraseCounter.getCount(np) + "\n");
      }
      bw1.write("=========================================================\n");
      while (capitalQueue.hasNext()) {
        String cap = capitalQueue.next();
        bw1.write(cap + " " + capitalsCounter.getCount(cap) + "\n");
      }
      bw1.flush();
    } catch (Exception e) {
      e.printStackTrace();
    }
    return allEntities;
  }
  public void train(List<SentencePair> trainingPairs) {
    sourceTargetCounts = new CounterMap<String, String>();
    sourceTargetDistortions = new CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>>();
    for (SentencePair pair : trainingPairs) {
      List<String> sourceSentence = pair.getSourceWords();
      List<String> targetSentence = pair.getTargetWords();
      targetSentence.add(WordAligner.NULL_WORD);
      int m = sourceSentence.size();
      int l = targetSentence.size();
      for (int i = 0; i < m; i++) {
        String sourceWord = sourceSentence.get(i);
        for (int j = 0; j < l; j++) {
          String targetWord = targetSentence.get(j);
          sourceTargetCounts.setCount(sourceWord, targetWord, 1.0);
          Pair<Integer, Integer> lmPair = new Pair<Integer, Integer>(l, m);
          Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i);
          sourceTargetDistortions.setCount(jiPair, lmPair, 1.0);
        }
      }
    }

    // Use Model 1 to train params
    double delta = Double.POSITIVE_INFINITY;
    for (int i = 0; i < MAX_ITERS && delta > CONVERGENCE; i++) {
      CounterMap<String, String> tempSourceTargetCounts = new CounterMap<String, String>();
      Counter<String> targetCounts = new Counter<String>();
      delta = 0.0;
      for (SentencePair pair : trainingPairs) {
        List<String> sourceSentence = pair.getSourceWords();
        List<String> targetSentence = pair.getTargetWords();
        Counter<String> sourceTotals = new Counter<String>();

        for (String sourceWord : sourceSentence) {
          for (String targetWord : targetSentence) {
            sourceTotals.incrementCount(
                sourceWord, sourceTargetCounts.getCount(sourceWord, targetWord));
          }
        }
        for (String sourceWord : sourceSentence) {
          for (String targetWord : targetSentence) {
            double transProb = sourceTargetCounts.getCount(sourceWord, targetWord);
            double sourceTotal = sourceTotals.getCount(sourceWord);
            tempSourceTargetCounts.incrementCount(sourceWord, targetWord, transProb / sourceTotal);
            targetCounts.incrementCount(targetWord, transProb / sourceTotal);
          }
        }
      }

      // update t(s|t) values
      for (String sourceWord : tempSourceTargetCounts.keySet()) {
        for (String targetWord : tempSourceTargetCounts.getCounter(sourceWord).keySet()) {
          double oldProb = sourceTargetCounts.getCount(sourceWord, targetWord);
          double newProb =
              tempSourceTargetCounts.getCount(sourceWord, targetWord)
                  / targetCounts.getCount(targetWord);
          sourceTargetCounts.setCount(sourceWord, targetWord, newProb);
          delta += Math.pow(oldProb - newProb, 2.0);
        }
      }
      delta /= sourceTargetCounts.totalSize();
    }

    // Maximizing for ibm model 2

    delta = Double.POSITIVE_INFINITY;
    for (int iter = 0; iter < MAX_ITERS && delta > CONVERGENCE; iter++) {
      CounterMap<String, String> tempSourceTargetCounts = new CounterMap<String, String>();
      CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>> tempSourceTargetDistortions =
          new CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>>();
      Counter<String> targetCounts = new Counter<String>();
      CounterMap<Pair<Integer, Integer>, Integer> targetDistorts =
          new CounterMap<Pair<Integer, Integer>, Integer>();
      delta = 0.0;
      for (SentencePair pair : trainingPairs) {
        List<String> sourceSentence = pair.getSourceWords();
        List<String> targetSentence = pair.getTargetWords();
        CounterMap<Pair<Integer, Integer>, Integer> distortSourceTotals =
            new CounterMap<Pair<Integer, Integer>, Integer>();
        Pair<Integer, Integer> lmPair =
            new Pair<Integer, Integer>(targetSentence.size(), sourceSentence.size());
        for (int i = 0; i < sourceSentence.size(); i++) {
          String sourceWord = sourceSentence.get(i);
          for (int j = 0; j < targetSentence.size(); j++) {
            String targetWord = targetSentence.get(j);
            Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i);
            double currTransProb = sourceTargetCounts.getCount(sourceWord, targetWord);
            double currAlignProb = sourceTargetDistortions.getCount(jiPair, lmPair);
            distortSourceTotals.incrementCount(lmPair, i, currTransProb * currAlignProb);
          }
        }
        for (int i = 0; i < sourceSentence.size(); i++) {
          String sourceWord = sourceSentence.get(i);
          double distortTransSourceTotal = distortSourceTotals.getCount(lmPair, i);
          for (int j = 0; j < targetSentence.size(); j++) {
            String targetWord = targetSentence.get(j);
            Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i);
            double transProb = sourceTargetCounts.getCount(sourceWord, targetWord);
            double distortProb = sourceTargetDistortions.getCount(jiPair, lmPair);
            double update =
                (transProb * distortProb) / (distortTransSourceTotal); // q(j|ilm)t(f|e)/totals
            tempSourceTargetCounts.incrementCount(sourceWord, targetWord, update);
            tempSourceTargetDistortions.incrementCount(jiPair, lmPair, update);
            targetCounts.incrementCount(targetWord, update);
            targetDistorts.incrementCount(lmPair, i, update);
          }
        }
      }
      // update t(s|t) values
      double delta_trans = 0.0;
      for (String sourceWord : tempSourceTargetCounts.keySet()) {
        for (String targetWord : tempSourceTargetCounts.getCounter(sourceWord).keySet()) {
          double oldProb = sourceTargetCounts.getCount(sourceWord, targetWord);
          double newProb =
              tempSourceTargetCounts.getCount(sourceWord, targetWord)
                  / targetCounts.getCount(targetWord);
          sourceTargetCounts.setCount(sourceWord, targetWord, newProb);
          delta += Math.pow(oldProb - newProb, 2.0);
        }
      }
      // update q(j|ilm) values
      double delta_dist = 0.0;
      for (Pair<Integer, Integer> jiPair : tempSourceTargetDistortions.keySet()) {
        for (Pair<Integer, Integer> lmPair :
            tempSourceTargetDistortions.getCounter(jiPair).keySet()) {
          double oldProb = sourceTargetDistortions.getCount(jiPair, lmPair);
          double tempAlignProb = tempSourceTargetDistortions.getCount(jiPair, lmPair);
          double tempTargetDist = targetDistorts.getCount(lmPair, jiPair.getSecond());
          double newProb = tempAlignProb / tempTargetDist;
          sourceTargetDistortions.setCount(jiPair, lmPair, newProb);
          delta_dist += Math.pow(oldProb - newProb, 2.0);
        }
      }
      delta =
          (delta_trans / sourceTargetCounts.totalSize()
                  + delta_dist / sourceTargetDistortions.totalSize())
              / 2.0;
    }
  }