/* Returns a smoothed estimate of P(word|tag) */
 public double scoreTagging(String word, String tag) {
   double p_tag = tagCounter.getCount(tag) / totalTokens;
   double c_word = wordCounter.getCount(word);
   double c_tag_and_word = wordToTagCounters.getCount(word, tag);
   if (c_word < 10) { // rare or unknown
     c_word += 1.0;
     c_tag_and_word += typeTagCounter.getCount(tag) / totalWordTypes;
   }
   double p_word = (1.0 + c_word) / (totalTokens + totalWordTypes);
   double p_tag_given_word = c_tag_and_word / c_word;
   return p_tag_given_word / p_tag * p_word;
 }
 /* A builds PCFG using the observed counts of binary and unary
  * productions in the training trees to estimate the probabilities
  * for those rules.  */
 public Grammar(List<Tree<String>> trainTrees) {
   Counter<UnaryRule> unaryRuleCounter = new Counter<UnaryRule>();
   Counter<BinaryRule> binaryRuleCounter = new Counter<BinaryRule>();
   Counter<String> symbolCounter = new Counter<String>();
   for (Tree<String> trainTree : trainTrees) {
     tallyTree(trainTree, symbolCounter, unaryRuleCounter, binaryRuleCounter);
   }
   for (UnaryRule unaryRule : unaryRuleCounter.keySet()) {
     double unaryProbability =
         unaryRuleCounter.getCount(unaryRule) / symbolCounter.getCount(unaryRule.getParent());
     unaryRule.setScore(unaryProbability);
     addUnary(unaryRule);
   }
   for (BinaryRule binaryRule : binaryRuleCounter.keySet()) {
     double binaryProbability =
         binaryRuleCounter.getCount(binaryRule) / symbolCounter.getCount(binaryRule.getParent());
     binaryRule.setScore(binaryProbability);
     addBinary(binaryRule);
   }
 }
  public HashMap<String, Entity> getAllEntities(String handle) {
    HashMap<String, Entity> allEntities = new HashMap<String, Entity>();
    try {
      BufferedReader br = new BufferedReader(new FileReader("data/" + handle + ".txt"));
      BufferedWriter bw = new BufferedWriter(new FileWriter("data/" + handle + "_entities.txt"));
      BufferedWriter bw1 = new BufferedWriter(new FileWriter("data/" + handle + "_statistics.txt"));

      String line = "";
      Counter<String> nPhraseCounter = new Counter<String>();
      Counter<String> capitalsCounter = new Counter<String>();
      while ((line = br.readLine()) != null) {
        line = line.replaceAll("RT", "");
        TwitterTokenizer tweetTokenizer = new TwitterTokenizer();
        for (String token : tweetTokenizer.tokenize(line)) {
          token = token.trim();
          token =
              token.replaceAll(
                  "( [^a-zA-Z0-9\\.]) | ( [^a-zA-Z0-9\\.] ) | ([^a-zA-Z0-9\\.] )", " ");
          ArrayList<String> nPhrases = new ArrayList<String>();
          HashSet<String> capitalWords = new HashSet<String>();
          try {
            Pattern p = Pattern.compile("^[A-Z]+.*");
            String[] split = token.split("\\s+");
            for (String s : split) {
              if (p.matcher(s).matches() && !stopWords.contains(s.toLowerCase())) {
                capitalWords.add(s.toLowerCase());
                capitalsCounter.incrementCount(s.toLowerCase(), 1.0);
                if (allEntities.containsKey(s.trim())) {
                  Entity e = allEntities.get(s.trim());
                  if (!e.tweets.contains(line)) {
                    e.tweets.add(line);
                    allEntities.put(s.trim(), e);
                  }
                } else {
                  Entity e = new Entity(s.trim());
                  e.tweets.add(line);
                  allEntities.put(s.trim(), e);
                }
              }
            }
          } catch (Exception e) {
            e.printStackTrace();
          }

          bw.write("===============================================\n");
          bw.write(token + "\n");
          System.out.println("token: " + token);
          for (String np : npe.extract(token)) {
            if (!stopWords.contains(np.trim().toLowerCase())) {
              nPhrases.add(np.trim());
              nPhraseCounter.incrementCount(np.trim(), 1.0);
              if (allEntities.containsKey(np.trim())) {
                Entity e = allEntities.get(np.trim());
                if (!e.tweets.contains(line)) {
                  e.tweets.add(line);
                  allEntities.put(np.trim(), e);
                }
              } else {
                Entity e = new Entity(np.trim());
                e.tweets.add(line);
                allEntities.put(np.trim(), e);
              }
            }
          }
          bw.write("===============================================\n");
          bw.write("Noun-Phrases: " + nPhrases.toString() + "\n");
          // HashSet<String> capitalWords =
          // getCapitalizedWords(token);
          if (capitalWords == null) {
            bw.write("No capitals\n\n");
          } else {
            bw.write("Capitals: " + capitalWords.toString() + "\n\n");
          }
        }

        bw.flush();
        if (true) continue;
      }

      PriorityQueue<String> nPhraseQueue = nPhraseCounter.asPriorityQueue();
      PriorityQueue<String> capitalQueue = capitalsCounter.asPriorityQueue();
      while (nPhraseQueue.hasNext()) {
        String np = nPhraseQueue.next();
        bw1.write(np + " " + nPhraseCounter.getCount(np) + "\n");
      }
      bw1.write("=========================================================\n");
      while (capitalQueue.hasNext()) {
        String cap = capitalQueue.next();
        bw1.write(cap + " " + capitalsCounter.getCount(cap) + "\n");
      }
      bw1.flush();
    } catch (Exception e) {
      e.printStackTrace();
    }
    return allEntities;
  }
  public void train(List<SentencePair> trainingPairs) {
    sourceTargetCounts = new CounterMap<String, String>();
    sourceTargetDistortions = new CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>>();
    for (SentencePair pair : trainingPairs) {
      List<String> sourceSentence = pair.getSourceWords();
      List<String> targetSentence = pair.getTargetWords();
      targetSentence.add(WordAligner.NULL_WORD);
      int m = sourceSentence.size();
      int l = targetSentence.size();
      for (int i = 0; i < m; i++) {
        String sourceWord = sourceSentence.get(i);
        for (int j = 0; j < l; j++) {
          String targetWord = targetSentence.get(j);
          sourceTargetCounts.setCount(sourceWord, targetWord, 1.0);
          Pair<Integer, Integer> lmPair = new Pair<Integer, Integer>(l, m);
          Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i);
          sourceTargetDistortions.setCount(jiPair, lmPair, 1.0);
        }
      }
    }

    // Use Model 1 to train params
    double delta = Double.POSITIVE_INFINITY;
    for (int i = 0; i < MAX_ITERS && delta > CONVERGENCE; i++) {
      CounterMap<String, String> tempSourceTargetCounts = new CounterMap<String, String>();
      Counter<String> targetCounts = new Counter<String>();
      delta = 0.0;
      for (SentencePair pair : trainingPairs) {
        List<String> sourceSentence = pair.getSourceWords();
        List<String> targetSentence = pair.getTargetWords();
        Counter<String> sourceTotals = new Counter<String>();

        for (String sourceWord : sourceSentence) {
          for (String targetWord : targetSentence) {
            sourceTotals.incrementCount(
                sourceWord, sourceTargetCounts.getCount(sourceWord, targetWord));
          }
        }
        for (String sourceWord : sourceSentence) {
          for (String targetWord : targetSentence) {
            double transProb = sourceTargetCounts.getCount(sourceWord, targetWord);
            double sourceTotal = sourceTotals.getCount(sourceWord);
            tempSourceTargetCounts.incrementCount(sourceWord, targetWord, transProb / sourceTotal);
            targetCounts.incrementCount(targetWord, transProb / sourceTotal);
          }
        }
      }

      // update t(s|t) values
      for (String sourceWord : tempSourceTargetCounts.keySet()) {
        for (String targetWord : tempSourceTargetCounts.getCounter(sourceWord).keySet()) {
          double oldProb = sourceTargetCounts.getCount(sourceWord, targetWord);
          double newProb =
              tempSourceTargetCounts.getCount(sourceWord, targetWord)
                  / targetCounts.getCount(targetWord);
          sourceTargetCounts.setCount(sourceWord, targetWord, newProb);
          delta += Math.pow(oldProb - newProb, 2.0);
        }
      }
      delta /= sourceTargetCounts.totalSize();
    }

    // Maximizing for ibm model 2

    delta = Double.POSITIVE_INFINITY;
    for (int iter = 0; iter < MAX_ITERS && delta > CONVERGENCE; iter++) {
      CounterMap<String, String> tempSourceTargetCounts = new CounterMap<String, String>();
      CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>> tempSourceTargetDistortions =
          new CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>>();
      Counter<String> targetCounts = new Counter<String>();
      CounterMap<Pair<Integer, Integer>, Integer> targetDistorts =
          new CounterMap<Pair<Integer, Integer>, Integer>();
      delta = 0.0;
      for (SentencePair pair : trainingPairs) {
        List<String> sourceSentence = pair.getSourceWords();
        List<String> targetSentence = pair.getTargetWords();
        CounterMap<Pair<Integer, Integer>, Integer> distortSourceTotals =
            new CounterMap<Pair<Integer, Integer>, Integer>();
        Pair<Integer, Integer> lmPair =
            new Pair<Integer, Integer>(targetSentence.size(), sourceSentence.size());
        for (int i = 0; i < sourceSentence.size(); i++) {
          String sourceWord = sourceSentence.get(i);
          for (int j = 0; j < targetSentence.size(); j++) {
            String targetWord = targetSentence.get(j);
            Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i);
            double currTransProb = sourceTargetCounts.getCount(sourceWord, targetWord);
            double currAlignProb = sourceTargetDistortions.getCount(jiPair, lmPair);
            distortSourceTotals.incrementCount(lmPair, i, currTransProb * currAlignProb);
          }
        }
        for (int i = 0; i < sourceSentence.size(); i++) {
          String sourceWord = sourceSentence.get(i);
          double distortTransSourceTotal = distortSourceTotals.getCount(lmPair, i);
          for (int j = 0; j < targetSentence.size(); j++) {
            String targetWord = targetSentence.get(j);
            Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i);
            double transProb = sourceTargetCounts.getCount(sourceWord, targetWord);
            double distortProb = sourceTargetDistortions.getCount(jiPair, lmPair);
            double update =
                (transProb * distortProb) / (distortTransSourceTotal); // q(j|ilm)t(f|e)/totals
            tempSourceTargetCounts.incrementCount(sourceWord, targetWord, update);
            tempSourceTargetDistortions.incrementCount(jiPair, lmPair, update);
            targetCounts.incrementCount(targetWord, update);
            targetDistorts.incrementCount(lmPair, i, update);
          }
        }
      }
      // update t(s|t) values
      double delta_trans = 0.0;
      for (String sourceWord : tempSourceTargetCounts.keySet()) {
        for (String targetWord : tempSourceTargetCounts.getCounter(sourceWord).keySet()) {
          double oldProb = sourceTargetCounts.getCount(sourceWord, targetWord);
          double newProb =
              tempSourceTargetCounts.getCount(sourceWord, targetWord)
                  / targetCounts.getCount(targetWord);
          sourceTargetCounts.setCount(sourceWord, targetWord, newProb);
          delta += Math.pow(oldProb - newProb, 2.0);
        }
      }
      // update q(j|ilm) values
      double delta_dist = 0.0;
      for (Pair<Integer, Integer> jiPair : tempSourceTargetDistortions.keySet()) {
        for (Pair<Integer, Integer> lmPair :
            tempSourceTargetDistortions.getCounter(jiPair).keySet()) {
          double oldProb = sourceTargetDistortions.getCount(jiPair, lmPair);
          double tempAlignProb = tempSourceTargetDistortions.getCount(jiPair, lmPair);
          double tempTargetDist = targetDistorts.getCount(lmPair, jiPair.getSecond());
          double newProb = tempAlignProb / tempTargetDist;
          sourceTargetDistortions.setCount(jiPair, lmPair, newProb);
          delta_dist += Math.pow(oldProb - newProb, 2.0);
        }
      }
      delta =
          (delta_trans / sourceTargetCounts.totalSize()
                  + delta_dist / sourceTargetDistortions.totalSize())
              / 2.0;
    }
  }