private void tallyTree( Tree<String> tree, Counter<String> symbolCounter, Counter<UnaryRule> unaryRuleCounter, Counter<BinaryRule> binaryRuleCounter) { if (tree.isLeaf()) return; if (tree.isPreTerminal()) return; if (tree.getChildren().size() == 1) { UnaryRule unaryRule = makeUnaryRule(tree); symbolCounter.incrementCount(tree.getLabel(), 1.0); unaryRuleCounter.incrementCount(unaryRule, 1.0); } if (tree.getChildren().size() == 2) { BinaryRule binaryRule = makeBinaryRule(tree); symbolCounter.incrementCount(tree.getLabel(), 1.0); binaryRuleCounter.incrementCount(binaryRule, 1.0); } if (tree.getChildren().size() < 1 || tree.getChildren().size() > 2) { throw new RuntimeException( "Attempted to construct a Grammar with an illegal tree: " + tree); } for (Tree<String> child : tree.getChildren()) { tallyTree(child, symbolCounter, unaryRuleCounter, binaryRuleCounter); } }
private void tallyTagging(String word, String tag) { if (!isKnown(word)) { totalWordTypes += 1.0; typeTagCounter.incrementCount(tag, 1.0); } totalTokens += 1.0; tagCounter.incrementCount(tag, 1.0); wordCounter.incrementCount(word, 1.0); wordToTagCounters.incrementCount(word, tag, 1.0); }
/* Returns a smoothed estimate of P(word|tag) */ public double scoreTagging(String word, String tag) { double p_tag = tagCounter.getCount(tag) / totalTokens; double c_word = wordCounter.getCount(word); double c_tag_and_word = wordToTagCounters.getCount(word, tag); if (c_word < 10) { // rare or unknown c_word += 1.0; c_tag_and_word += typeTagCounter.getCount(tag) / totalWordTypes; } double p_word = (1.0 + c_word) / (totalTokens + totalWordTypes); double p_tag_given_word = c_tag_and_word / c_word; return p_tag_given_word / p_tag * p_word; }
/* A builds PCFG using the observed counts of binary and unary * productions in the training trees to estimate the probabilities * for those rules. */ public Grammar(List<Tree<String>> trainTrees) { Counter<UnaryRule> unaryRuleCounter = new Counter<UnaryRule>(); Counter<BinaryRule> binaryRuleCounter = new Counter<BinaryRule>(); Counter<String> symbolCounter = new Counter<String>(); for (Tree<String> trainTree : trainTrees) { tallyTree(trainTree, symbolCounter, unaryRuleCounter, binaryRuleCounter); } for (UnaryRule unaryRule : unaryRuleCounter.keySet()) { double unaryProbability = unaryRuleCounter.getCount(unaryRule) / symbolCounter.getCount(unaryRule.getParent()); unaryRule.setScore(unaryProbability); addUnary(unaryRule); } for (BinaryRule binaryRule : binaryRuleCounter.keySet()) { double binaryProbability = binaryRuleCounter.getCount(binaryRule) / symbolCounter.getCount(binaryRule.getParent()); binaryRule.setScore(binaryProbability); addBinary(binaryRule); } }
public boolean isKnown(String word) { return wordCounter.keySet().contains(word); }
public Set<String> getAllTags() { return tagCounter.keySet(); }
public Tree<String> getBestParseOld(List<String> sentence) { // TODO: This implements the CKY algorithm CounterMap<String, String> parseScores = new CounterMap<String, String>(); System.out.println(sentence.toString()); // First deal with the lexicons int index = 0; int span = 1; // All spans are 1 at the lexicon level for (String word : sentence) { for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (score >= 0.0) { // This lexicon may generate this word // We use a counter map in order to store the scores for this sentence parse. parseScores.setCount(index + " " + (index + span), tag, score); } } index = index + 1; } // handle unary rules now HashMap<String, Triplet<Integer, String, String>> backHash = new HashMap< String, Triplet<Integer, String, String>>(); // hashmap to store back propation // System.out.println("Lexicons found"); Boolean added = true; while (added) { added = false; for (index = 0; index < sentence.size(); index++) { // For each index+ span pair, get the counter. Counter<String> count = parseScores.getCounter(index + " " + (index + span)); PriorityQueue<String> countAsPQ = count.asPriorityQueue(); while (countAsPQ.hasNext()) { String entry = countAsPQ.next(); // System.out.println("I am fine here!!"); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { // These are the unary rules which might give rise to the above preterminal double prob = rule.getScore() * parseScores.getCount(index + " " + (index + span), entry); if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) { parseScores.setCount(index + " " + (index + span), rule.parent, prob); backHash.put( index + " " + (index + span) + " " + rule.parent, new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } } // System.out.println("Lexicon unaries dealt with"); // Now work with the grammar to produce higher level probabilities for (span = 2; span <= sentence.size(); span++) { for (int begin = 0; begin <= (sentence.size() - span); begin++) { int end = begin + span; for (int split = begin + 1; split <= end - 1; split++) { Counter<String> countLeft = parseScores.getCounter(begin + " " + split); Counter<String> countRight = parseScores.getCounter(split + " " + end); // List<BinaryRule> leftRules= new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>(); // List<BinaryRule> rightRules=new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>(); for (String entry : countLeft.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) { if (!leftMap.containsKey(rule.hashCode())) { leftMap.put(rule.hashCode(), rule); } } } for (String entry : countRight.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) { if (!rightMap.containsKey(rule.hashCode())) { rightMap.put(rule.hashCode(), rule); } } } // System.out.println("About to enter the rules loops"); for (Integer ruleHash : leftMap.keySet()) { if (rightMap.containsKey(ruleHash)) { BinaryRule ruleRight = rightMap.get(ruleHash); double prob = ruleRight.getScore() * parseScores.getCount(begin + " " + split, ruleRight.leftChild) * parseScores.getCount(split + " " + end, ruleRight.rightChild); // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) { // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); // System.out.println("parentrule :"+ ruleRight.getParent()); parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob); backHash.put( begin + " " + end + " " + ruleRight.parent, new Triplet<Integer, String, String>( split, ruleRight.leftChild, ruleRight.rightChild)); } } } // System.out.println("Exited rules loop"); } // System.out.println("Grammar found for " + begin + " "+ end); // Now handle unary rules added = true; while (added) { added = false; Counter<String> count = parseScores.getCounter(begin + " " + end); PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue(); while (countAsPriorityQueue.hasNext()) { String entry = countAsPriorityQueue.next(); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry); if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) { parseScores.setCount(begin + " " + (end), rule.parent, prob); backHash.put( begin + " " + (end) + " " + rule.parent, new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } // System.out.println("Unaries dealt for " + begin + " "+ end); } } // Create and return the parse tree Tree<String> parseTree = new Tree<String>("null"); // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString()); String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax(); if (parent == null) { System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString()); System.out.println("THIS IS WEIRD"); } parent = "ROOT"; parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent); // System.out.println("PARSE SCORES"); // System.out.println(parseScores.toString()); // System.out.println("BACK HASH"); // System.out.println(backHash.toString()); // parseTree = addRoot(parseTree); // System.out.println(parseTree.toString()); // return parseTree; return TreeAnnotations.unAnnotateTree(parseTree); }
public Tree<String> getBestParse(List<String> sentence) { // This implements the CKY algorithm int nEntries = sentence.size(); // hashmap to store back rules HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash = new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>(); // more efficient access with arrays, but must cast each time :( @SuppressWarnings("unchecked") Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]); for (int i = 0; i < nEntries; i++) { for (int j = 0; j < nEntries; j++) { parseScores[i][j] = new Counter<String>(); } } System.out.println(sentence.toString()); // First deal with the lexicons int index = 0; int span = 1; // All spans are 1 at the lexicon level for (String word : sentence) { for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (score >= 0.0) { // This lexicon may generate this word // We use a counter map in order to store the scores for this sentence parse. parseScores[index][index + span - 1].setCount(tag, score); } } index = index + 1; } // handle unary rules now // System.out.println("Lexicons found"); boolean added = true; while (added) { added = false; for (index = 0; index < sentence.size(); index++) { // For each index+ span pair, get the counter. Counter<String> count = parseScores[index][index + span - 1]; PriorityQueue<String> countAsPQ = count.asPriorityQueue(); while (countAsPQ.hasNext()) { String entry = countAsPQ.next(); // System.out.println("I am fine here!!"); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { // These are the unary rules which might give rise to the above preterminal double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry); if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) { parseScores[index][index + span - 1].setCount(rule.parent, prob); backHash.put( new Triplet<Integer, Integer, String>(index, index + span, rule.parent), new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } } // System.out.println("Lexicon unaries dealt with"); // Now work with the grammar to produce higher level probabilities for (span = 2; span <= sentence.size(); span++) { for (int begin = 0; begin <= (sentence.size() - span); begin++) { int end = begin + span; for (int split = begin + 1; split <= end - 1; split++) { Counter<String> countLeft = parseScores[begin][split - 1]; Counter<String> countRight = parseScores[split][end - 1]; // List<BinaryRule> leftRules= new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>(); // List<BinaryRule> rightRules=new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>(); for (String entry : countLeft.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) { if (!leftMap.containsKey(rule.hashCode())) { leftMap.put(rule.hashCode(), rule); } } } for (String entry : countRight.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) { if (!rightMap.containsKey(rule.hashCode())) { rightMap.put(rule.hashCode(), rule); } } } // System.out.println("About to enter the rules loops"); for (Integer ruleHash : leftMap.keySet()) { if (rightMap.containsKey(ruleHash)) { BinaryRule ruleRight = rightMap.get(ruleHash); double prob = ruleRight.getScore() * parseScores[begin][split - 1].getCount(ruleRight.leftChild) * parseScores[split][end - 1].getCount(ruleRight.rightChild); // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) { // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); // System.out.println("parentrule :"+ ruleRight.getParent()); parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob); backHash.put( new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent), new Triplet<Integer, String, String>( split, ruleRight.leftChild, ruleRight.rightChild)); } } } // System.out.println("Exited rules loop"); } // System.out.println("Grammar found for " + begin + " "+ end); // Now handle unary rules added = true; while (added) { added = false; Counter<String> count = parseScores[begin][end - 1]; PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue(); while (countAsPriorityQueue.hasNext()) { String entry = countAsPriorityQueue.next(); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry); if (prob > parseScores[begin][end - 1].getCount(rule.parent)) { parseScores[begin][end - 1].setCount(rule.parent, prob); backHash.put( new Triplet<Integer, Integer, String>(begin, end, rule.parent), new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } // System.out.println("Unaries dealt for " + begin + " "+ end); } } // Create and return the parse tree Tree<String> parseTree = new Tree<String>("null"); // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString()); // Pick the argmax String parent = parseScores[0][nEntries - 1].argMax(); // Or pick root. This second one is preferred since sentences are meant to have ROOT as their // root node. parent = "ROOT"; parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent); // System.out.println("PARSE SCORES"); // System.out.println(parseScores.toString()); // System.out.println("BACK HASH"); // System.out.println(backHash.toString()); // parseTree = addRoot(parseTree); // System.out.println(parseTree.toString()); // return parseTree; return TreeAnnotations.unAnnotateTree(parseTree); }
public HashMap<String, Entity> getAllEntities(String handle) { HashMap<String, Entity> allEntities = new HashMap<String, Entity>(); try { BufferedReader br = new BufferedReader(new FileReader("data/" + handle + ".txt")); BufferedWriter bw = new BufferedWriter(new FileWriter("data/" + handle + "_entities.txt")); BufferedWriter bw1 = new BufferedWriter(new FileWriter("data/" + handle + "_statistics.txt")); String line = ""; Counter<String> nPhraseCounter = new Counter<String>(); Counter<String> capitalsCounter = new Counter<String>(); while ((line = br.readLine()) != null) { line = line.replaceAll("RT", ""); TwitterTokenizer tweetTokenizer = new TwitterTokenizer(); for (String token : tweetTokenizer.tokenize(line)) { token = token.trim(); token = token.replaceAll( "( [^a-zA-Z0-9\\.]) | ( [^a-zA-Z0-9\\.] ) | ([^a-zA-Z0-9\\.] )", " "); ArrayList<String> nPhrases = new ArrayList<String>(); HashSet<String> capitalWords = new HashSet<String>(); try { Pattern p = Pattern.compile("^[A-Z]+.*"); String[] split = token.split("\\s+"); for (String s : split) { if (p.matcher(s).matches() && !stopWords.contains(s.toLowerCase())) { capitalWords.add(s.toLowerCase()); capitalsCounter.incrementCount(s.toLowerCase(), 1.0); if (allEntities.containsKey(s.trim())) { Entity e = allEntities.get(s.trim()); if (!e.tweets.contains(line)) { e.tweets.add(line); allEntities.put(s.trim(), e); } } else { Entity e = new Entity(s.trim()); e.tweets.add(line); allEntities.put(s.trim(), e); } } } } catch (Exception e) { e.printStackTrace(); } bw.write("===============================================\n"); bw.write(token + "\n"); System.out.println("token: " + token); for (String np : npe.extract(token)) { if (!stopWords.contains(np.trim().toLowerCase())) { nPhrases.add(np.trim()); nPhraseCounter.incrementCount(np.trim(), 1.0); if (allEntities.containsKey(np.trim())) { Entity e = allEntities.get(np.trim()); if (!e.tweets.contains(line)) { e.tweets.add(line); allEntities.put(np.trim(), e); } } else { Entity e = new Entity(np.trim()); e.tweets.add(line); allEntities.put(np.trim(), e); } } } bw.write("===============================================\n"); bw.write("Noun-Phrases: " + nPhrases.toString() + "\n"); // HashSet<String> capitalWords = // getCapitalizedWords(token); if (capitalWords == null) { bw.write("No capitals\n\n"); } else { bw.write("Capitals: " + capitalWords.toString() + "\n\n"); } } bw.flush(); if (true) continue; } PriorityQueue<String> nPhraseQueue = nPhraseCounter.asPriorityQueue(); PriorityQueue<String> capitalQueue = capitalsCounter.asPriorityQueue(); while (nPhraseQueue.hasNext()) { String np = nPhraseQueue.next(); bw1.write(np + " " + nPhraseCounter.getCount(np) + "\n"); } bw1.write("=========================================================\n"); while (capitalQueue.hasNext()) { String cap = capitalQueue.next(); bw1.write(cap + " " + capitalsCounter.getCount(cap) + "\n"); } bw1.flush(); } catch (Exception e) { e.printStackTrace(); } return allEntities; }
public void train(List<SentencePair> trainingPairs) { sourceTargetCounts = new CounterMap<String, String>(); sourceTargetDistortions = new CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>>(); for (SentencePair pair : trainingPairs) { List<String> sourceSentence = pair.getSourceWords(); List<String> targetSentence = pair.getTargetWords(); targetSentence.add(WordAligner.NULL_WORD); int m = sourceSentence.size(); int l = targetSentence.size(); for (int i = 0; i < m; i++) { String sourceWord = sourceSentence.get(i); for (int j = 0; j < l; j++) { String targetWord = targetSentence.get(j); sourceTargetCounts.setCount(sourceWord, targetWord, 1.0); Pair<Integer, Integer> lmPair = new Pair<Integer, Integer>(l, m); Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i); sourceTargetDistortions.setCount(jiPair, lmPair, 1.0); } } } // Use Model 1 to train params double delta = Double.POSITIVE_INFINITY; for (int i = 0; i < MAX_ITERS && delta > CONVERGENCE; i++) { CounterMap<String, String> tempSourceTargetCounts = new CounterMap<String, String>(); Counter<String> targetCounts = new Counter<String>(); delta = 0.0; for (SentencePair pair : trainingPairs) { List<String> sourceSentence = pair.getSourceWords(); List<String> targetSentence = pair.getTargetWords(); Counter<String> sourceTotals = new Counter<String>(); for (String sourceWord : sourceSentence) { for (String targetWord : targetSentence) { sourceTotals.incrementCount( sourceWord, sourceTargetCounts.getCount(sourceWord, targetWord)); } } for (String sourceWord : sourceSentence) { for (String targetWord : targetSentence) { double transProb = sourceTargetCounts.getCount(sourceWord, targetWord); double sourceTotal = sourceTotals.getCount(sourceWord); tempSourceTargetCounts.incrementCount(sourceWord, targetWord, transProb / sourceTotal); targetCounts.incrementCount(targetWord, transProb / sourceTotal); } } } // update t(s|t) values for (String sourceWord : tempSourceTargetCounts.keySet()) { for (String targetWord : tempSourceTargetCounts.getCounter(sourceWord).keySet()) { double oldProb = sourceTargetCounts.getCount(sourceWord, targetWord); double newProb = tempSourceTargetCounts.getCount(sourceWord, targetWord) / targetCounts.getCount(targetWord); sourceTargetCounts.setCount(sourceWord, targetWord, newProb); delta += Math.pow(oldProb - newProb, 2.0); } } delta /= sourceTargetCounts.totalSize(); } // Maximizing for ibm model 2 delta = Double.POSITIVE_INFINITY; for (int iter = 0; iter < MAX_ITERS && delta > CONVERGENCE; iter++) { CounterMap<String, String> tempSourceTargetCounts = new CounterMap<String, String>(); CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>> tempSourceTargetDistortions = new CounterMap<Pair<Integer, Integer>, Pair<Integer, Integer>>(); Counter<String> targetCounts = new Counter<String>(); CounterMap<Pair<Integer, Integer>, Integer> targetDistorts = new CounterMap<Pair<Integer, Integer>, Integer>(); delta = 0.0; for (SentencePair pair : trainingPairs) { List<String> sourceSentence = pair.getSourceWords(); List<String> targetSentence = pair.getTargetWords(); CounterMap<Pair<Integer, Integer>, Integer> distortSourceTotals = new CounterMap<Pair<Integer, Integer>, Integer>(); Pair<Integer, Integer> lmPair = new Pair<Integer, Integer>(targetSentence.size(), sourceSentence.size()); for (int i = 0; i < sourceSentence.size(); i++) { String sourceWord = sourceSentence.get(i); for (int j = 0; j < targetSentence.size(); j++) { String targetWord = targetSentence.get(j); Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i); double currTransProb = sourceTargetCounts.getCount(sourceWord, targetWord); double currAlignProb = sourceTargetDistortions.getCount(jiPair, lmPair); distortSourceTotals.incrementCount(lmPair, i, currTransProb * currAlignProb); } } for (int i = 0; i < sourceSentence.size(); i++) { String sourceWord = sourceSentence.get(i); double distortTransSourceTotal = distortSourceTotals.getCount(lmPair, i); for (int j = 0; j < targetSentence.size(); j++) { String targetWord = targetSentence.get(j); Pair<Integer, Integer> jiPair = new Pair<Integer, Integer>(j, i); double transProb = sourceTargetCounts.getCount(sourceWord, targetWord); double distortProb = sourceTargetDistortions.getCount(jiPair, lmPair); double update = (transProb * distortProb) / (distortTransSourceTotal); // q(j|ilm)t(f|e)/totals tempSourceTargetCounts.incrementCount(sourceWord, targetWord, update); tempSourceTargetDistortions.incrementCount(jiPair, lmPair, update); targetCounts.incrementCount(targetWord, update); targetDistorts.incrementCount(lmPair, i, update); } } } // update t(s|t) values double delta_trans = 0.0; for (String sourceWord : tempSourceTargetCounts.keySet()) { for (String targetWord : tempSourceTargetCounts.getCounter(sourceWord).keySet()) { double oldProb = sourceTargetCounts.getCount(sourceWord, targetWord); double newProb = tempSourceTargetCounts.getCount(sourceWord, targetWord) / targetCounts.getCount(targetWord); sourceTargetCounts.setCount(sourceWord, targetWord, newProb); delta += Math.pow(oldProb - newProb, 2.0); } } // update q(j|ilm) values double delta_dist = 0.0; for (Pair<Integer, Integer> jiPair : tempSourceTargetDistortions.keySet()) { for (Pair<Integer, Integer> lmPair : tempSourceTargetDistortions.getCounter(jiPair).keySet()) { double oldProb = sourceTargetDistortions.getCount(jiPair, lmPair); double tempAlignProb = tempSourceTargetDistortions.getCount(jiPair, lmPair); double tempTargetDist = targetDistorts.getCount(lmPair, jiPair.getSecond()); double newProb = tempAlignProb / tempTargetDist; sourceTargetDistortions.setCount(jiPair, lmPair, newProb); delta_dist += Math.pow(oldProb - newProb, 2.0); } } delta = (delta_trans / sourceTargetCounts.totalSize() + delta_dist / sourceTargetDistortions.totalSize()) / 2.0; } }