private String getBestTag(String word) { double bestScore = Double.NEGATIVE_INFINITY; String bestTag = null; for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (bestTag == null || score > bestScore) { bestScore = score; bestTag = tag; } } return bestTag; }
/** * @param previousGrammar * @param previousLexicon * @param grammar * @param lexicon * @param trainStateSetTrees * @return */ public static double doOneEStep( Grammar previousGrammar, Lexicon previousLexicon, Grammar grammar, Lexicon lexicon, StateSetTreeList trainStateSetTrees, boolean updateOnlyLexicon, int unkThreshold) { boolean secondHalf = false; ArrayParser parser = new ArrayParser(previousGrammar, previousLexicon); double trainingLikelihood = 0; int n = 0; int nTrees = trainStateSetTrees.size(); for (Tree<StateSet> stateSetTree : trainStateSetTrees) { secondHalf = (n++ > nTrees / 2.0); boolean noSmoothing = true, debugOutput = false; parser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput); // E Step double ll = stateSetTree.getLabel().getIScore(0); ll = Math.log(ll) + (100 * stateSetTree.getLabel().getIScale()); // System.out.println(stateSetTree); if ((Double.isInfinite(ll) || Double.isNaN(ll))) { if (VERBOSE) { System.out.println("Training sentence " + n + " is given " + ll + " log likelihood!"); System.out.println( "Root iScore " + stateSetTree.getLabel().getIScore(0) + " scale " + stateSetTree.getLabel().getIScale()); } } else { lexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf, noSmoothing, unkThreshold); if (!updateOnlyLexicon) grammar.tallyStateSetTree(stateSetTree, previousGrammar); // E Step trainingLikelihood += ll; // there are for some reason some sentences that are unparsable } } lexicon.tieRareWordStats(unkThreshold); // SSIE ((SophisticatedLexicon) lexicon).overwriteWithMaxent(); return trainingLikelihood; }
public static void main(String[] args) { OptionParser optParser = new OptionParser(Options.class); Options opts = (Options) optParser.parse(args, true); // provide feedback on command-line arguments System.out.println("Calling with " + optParser.getPassedInOptions()); String path = opts.path; // int lang = opts.lang; System.out.println("Loading trees from " + path + " and using language " + opts.treebank); double trainingFractionToKeep = opts.trainingFractionToKeep; int maxSentenceLength = opts.maxSentenceLength; System.out.println("Will remove sentences with more than " + maxSentenceLength + " words."); HORIZONTAL_MARKOVIZATION = opts.horizontalMarkovization; VERTICAL_MARKOVIZATION = opts.verticalMarkovization; System.out.println( "Using horizontal=" + HORIZONTAL_MARKOVIZATION + " and vertical=" + VERTICAL_MARKOVIZATION + " markovization."); Binarization binarization = opts.binarization; System.out.println( "Using " + binarization.name() + " binarization."); // and "+annotateString+"."); double randomness = opts.randomization; System.out.println("Using a randomness value of " + randomness); String outFileName = opts.outFileName; if (outFileName == null) { System.out.println("Output File name is required."); System.exit(-1); } else System.out.println("Using grammar output file " + outFileName + "."); VERBOSE = opts.verbose; RANDOM = new Random(opts.randSeed); System.out.println("Random number generator seeded at " + opts.randSeed + "."); boolean manualAnnotation = false; boolean baseline = opts.baseline; boolean noSplit = opts.noSplit; int numSplitTimes = opts.numSplits; if (baseline) numSplitTimes = 0; String splitGrammarFile = opts.inFile; int allowedDroppingIters = opts.di; int maxIterations = opts.splitMaxIterations; int minIterations = opts.splitMinIterations; if (minIterations > 0) System.out.println("I will do at least " + minIterations + " iterations."); double[] smoothParams = {opts.smoothingParameter1, opts.smoothingParameter2}; System.out.println("Using smoothing parameters " + smoothParams[0] + " and " + smoothParams[1]); boolean allowMoreSubstatesThanCounts = false; boolean findClosedUnaryPaths = opts.findClosedUnaryPaths; Corpus corpus = new Corpus( path, opts.treebank, trainingFractionToKeep, false, opts.skipSection, opts.skipBilingual); List<Tree<String>> trainTrees = Corpus.binarizeAndFilterTrees( corpus.getTrainTrees(), VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, maxSentenceLength, binarization, manualAnnotation, VERBOSE); List<Tree<String>> validationTrees = Corpus.binarizeAndFilterTrees( corpus.getValidationTrees(), VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, maxSentenceLength, binarization, manualAnnotation, VERBOSE); Numberer tagNumberer = Numberer.getGlobalNumberer("tags"); // for (Tree<String> t : trainTrees){ // System.out.println(t); // } if (opts.trainOnDevSet) { System.out.println("Adding devSet to training data."); trainTrees.addAll(validationTrees); } if (opts.lowercase) { System.out.println("Lowercasing the treebank."); Corpus.lowercaseWords(trainTrees); Corpus.lowercaseWords(validationTrees); } int nTrees = trainTrees.size(); System.out.println("There are " + nTrees + " trees in the training set."); double filter = opts.filter; if (filter > 0) System.out.println( "Will remove rules with prob under " + filter + ".\nEven though only unlikely rules are pruned the training LL is not guaranteed to increase in every round anymore " + "(especially when we are close to converging)." + "\nFurthermore it increases the variance because 'good' rules can be pruned away in early stages."); short nSubstates = opts.nSubStates; short[] numSubStatesArray = initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates); if (baseline) { short one = 1; Arrays.fill(numSubStatesArray, one); System.out.println("Training just the baseline grammar (1 substate for all states)"); randomness = 0.0f; } if (VERBOSE) { for (int i = 0; i < numSubStatesArray.length; i++) { System.out.println("Tag " + (String) tagNumberer.object(i) + " " + i); } } System.out.println("There are " + numSubStatesArray.length + " observed categories."); // initialize lexicon and grammar Lexicon lexicon = null, maxLexicon = null, previousLexicon = null; Grammar grammar = null, maxGrammar = null, previousGrammar = null; double maxLikelihood = Double.NEGATIVE_INFINITY; // String smootherStr = opts.smooth; // Smoother lexiconSmoother = null; // Smoother grammarSmoother = null; // if (splitGrammarFile!=null){ // lexiconSmoother = maxLexicon.smoother; // grammarSmoother = maxGrammar.smoother; // System.out.println("Using smoother from input grammar."); // } // else if (smootherStr.equals("NoSmoothing")) // lexiconSmoother = grammarSmoother = new NoSmoothing(); // else if (smootherStr.equals("SmoothAcrossParentBits")) { // lexiconSmoother = grammarSmoother = new SmoothAcrossParentBits(grammarSmoothing, // maxGrammar.splitTrees); // } // else // throw new Error("I didn't understand the type of smoother '"+smootherStr+"'"); // System.out.println("Using smoother "+smootherStr); // EM: iterate until the validation likelihood drops for four consecutive // iterations int iter = 0; int droppingIter = 0; // If we are splitting, we load the old grammar and start off by splitting. int startSplit = 0; if (splitGrammarFile != null) { System.out.println("Loading old grammar from " + splitGrammarFile); startSplit = 1; // we've already trained the grammar ParserData pData = ParserData.Load(splitGrammarFile); maxGrammar = pData.gr; maxLexicon = pData.lex; numSubStatesArray = maxGrammar.numSubStates; previousGrammar = grammar = maxGrammar; previousLexicon = lexicon = maxLexicon; Numberer.setNumberers(pData.getNumbs()); tagNumberer = Numberer.getGlobalNumberer("tags"); System.out.println("Loading old grammar complete."); if (noSplit) { System.out.println("Will NOT split the loaded grammar."); startSplit = 0; } } double mergingPercentage = opts.mergingPercentage; boolean separateMergingThreshold = opts.separateMergingThreshold; if (mergingPercentage > 0) { System.out.println( "Will merge " + (int) (mergingPercentage * 100) + "% of the splits in each round."); System.out.println( "The threshold for merging lexical and phrasal categories will be set separately: " + separateMergingThreshold); } StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer); StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer); // deletePC); // get rid of the old trees trainTrees = null; validationTrees = null; corpus = null; System.gc(); if (opts.simpleLexicon) { System.out.println( "Replacing words which have been seen less than 5 times with their signature."); Corpus.replaceRareWords( trainStateSetTrees, new SimpleLexicon(numSubStatesArray, -1), opts.rare); } // If we're training without loading a split grammar, then we run once without splitting. if (splitGrammarFile == null) { grammar = new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter); Lexicon tmp_lexicon = (opts.simpleLexicon) ? new SimpleLexicon( numSubStatesArray, -1, smoothParams, new NoSmoothing(), filter, trainStateSetTrees) : new SophisticatedLexicon( numSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, smoothParams, new NoSmoothing(), filter); int n = 0; boolean secondHalf = false; for (Tree<StateSet> stateSetTree : trainStateSetTrees) { secondHalf = (n++ > nTrees / 2.0); tmp_lexicon.trainTree(stateSetTree, randomness, null, secondHalf, false, opts.rare); } lexicon = (opts.simpleLexicon) ? new SimpleLexicon( numSubStatesArray, -1, smoothParams, new NoSmoothing(), filter, trainStateSetTrees) : new SophisticatedLexicon( numSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, smoothParams, new NoSmoothing(), filter); for (Tree<StateSet> stateSetTree : trainStateSetTrees) { secondHalf = (n++ > nTrees / 2.0); lexicon.trainTree(stateSetTree, randomness, tmp_lexicon, secondHalf, false, opts.rare); grammar.tallyUninitializedStateSetTree(stateSetTree); } lexicon.tieRareWordStats(opts.rare); lexicon.optimize(); // SSIE ((SophisticatedLexicon) lexicon).overwriteWithMaxent(); grammar.optimize(randomness); // System.out.println(grammar); previousGrammar = maxGrammar = grammar; // needed for baseline - when there is no EM loop previousLexicon = maxLexicon = lexicon; } // the main loop: split and train the grammar for (int splitIndex = startSplit; splitIndex < numSplitTimes * 3; splitIndex++) { // now do either a merge or a split and the end a smooth // on odd iterations merge, on even iterations split String opString = ""; if (splitIndex % 3 == 2) { // (splitIndex==numSplitTimes*2){ if (opts.smooth.equals("NoSmoothing")) continue; System.out.println("Setting smoother for grammar and lexicon."); Smoother grSmoother = new SmoothAcrossParentBits(0.01, maxGrammar.splitTrees); Smoother lexSmoother = new SmoothAcrossParentBits(0.1, maxGrammar.splitTrees); // Smoother grSmoother = new SmoothAcrossParentSubstate(0.01); // Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1); maxGrammar.setSmoother(grSmoother); maxLexicon.setSmoother(lexSmoother); minIterations = maxIterations = opts.smoothMaxIterations; opString = "smoothing"; } else if (splitIndex % 3 == 0) { // the case where we split if (opts.noSplit) continue; System.out.println( "Before splitting, we have a total of " + maxGrammar.totalSubStates() + " substates."); CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer, trainStateSetTrees); int[] counts = corpusStatistics.getSymbolCounts(); maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts, 0); maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts, 0); Smoother grSmoother = new NoSmoothing(); Smoother lexSmoother = new NoSmoothing(); maxGrammar.setSmoother(grSmoother); maxLexicon.setSmoother(lexSmoother); System.out.println( "After splitting, we have a total of " + maxGrammar.totalSubStates() + " substates."); System.out.println( "Rule probabilities are NOT normalized in the split, therefore the training LL is not guaranteed to improve between iteration 0 and 1!"); opString = "splitting"; maxIterations = opts.splitMaxIterations; minIterations = opts.splitMinIterations; } else { if (mergingPercentage == 0) continue; // the case where we merge double[][] mergeWeights = GrammarMerger.computeMergeWeights(maxGrammar, maxLexicon, trainStateSetTrees); double[][][] deltas = GrammarMerger.computeDeltas(maxGrammar, maxLexicon, mergeWeights, trainStateSetTrees); boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs( deltas, separateMergingThreshold, mergingPercentage, maxGrammar); grammar = GrammarMerger.doTheMerges(maxGrammar, maxLexicon, mergeThesePairs, mergeWeights); short[] newNumSubStatesArray = grammar.numSubStates; trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, newNumSubStatesArray, false); validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, newNumSubStatesArray, false); // retrain lexicon to finish the lexicon merge (updates the unknown words model)... lexicon = (opts.simpleLexicon) ? new SimpleLexicon( newNumSubStatesArray, -1, smoothParams, maxLexicon.getSmoother(), filter, trainStateSetTrees) : new SophisticatedLexicon( newNumSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, maxLexicon.getSmoothingParams(), maxLexicon.getSmoother(), maxLexicon.getPruningThreshold()); boolean updateOnlyLexicon = true; double trainingLikelihood = GrammarTrainer.doOneEStep( grammar, maxLexicon, null, lexicon, trainStateSetTrees, updateOnlyLexicon, opts.rare); // System.out.println("The training LL is "+trainingLikelihood); lexicon .optimize(); // Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION); // // M Step GrammarMerger.printMergingStatistics(maxGrammar, grammar); opString = "merging"; maxGrammar = grammar; maxLexicon = lexicon; maxIterations = opts.mergeMaxIterations; minIterations = opts.mergeMinIterations; } // update the substate dependent objects previousGrammar = grammar = maxGrammar; previousLexicon = lexicon = maxLexicon; droppingIter = 0; numSubStatesArray = grammar.numSubStates; trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, numSubStatesArray, false); validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, numSubStatesArray, false); maxLikelihood = calculateLogLikelihood(maxGrammar, maxLexicon, validationStateSetTrees); System.out.println( "After " + opString + " in the " + (splitIndex / 3 + 1) + "th round, we get a validation likelihood of " + maxLikelihood); iter = 0; // the inner loop: train the grammar via EM until validation likelihood reliably drops do { iter += 1; System.out.println("Beginning iteration " + (iter - 1) + ":"); // 1) Compute the validation likelihood of the previous iteration System.out.print("Calculating validation likelihood..."); double validationLikelihood = calculateLogLikelihood( previousGrammar, previousLexicon, validationStateSetTrees); // The validation LL of previousGrammar/previousLexicon System.out.println("done: " + validationLikelihood); // 2) Perform the E step while computing the training likelihood of the previous iteration System.out.print("Calculating training likelihood..."); grammar = new Grammar( grammar.numSubStates, grammar.findClosedPaths, grammar.smoother, grammar, grammar.threshold); lexicon = (opts.simpleLexicon) ? new SimpleLexicon( grammar.numSubStates, -1, smoothParams, lexicon.getSmoother(), filter, trainStateSetTrees) : new SophisticatedLexicon( grammar.numSubStates, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, lexicon.getSmoothingParams(), lexicon.getSmoother(), lexicon.getPruningThreshold()); boolean updateOnlyLexicon = false; double trainingLikelihood = doOneEStep( previousGrammar, previousLexicon, grammar, lexicon, trainStateSetTrees, updateOnlyLexicon, opts.rare); // The training LL of previousGrammar/previousLexicon System.out.println("done: " + trainingLikelihood); // 3) Perform the M-Step lexicon.optimize(); // M Step grammar.optimize(0); // M Step // 4) Check whether previousGrammar/previousLexicon was in fact better than the best if (iter < minIterations || validationLikelihood >= maxLikelihood) { maxLikelihood = validationLikelihood; maxGrammar = previousGrammar; maxLexicon = previousLexicon; droppingIter = 0; } else { droppingIter++; } // 5) advance the 'pointers' previousGrammar = grammar; previousLexicon = lexicon; } while ((droppingIter < allowedDroppingIters) && (!baseline) && (iter < maxIterations)); // Dump a grammar file to disk from time to time ParserData pData = new ParserData( maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, binarization); String outTmpName = outFileName + "_" + (splitIndex / 3 + 1) + "_" + opString + ".gr"; System.out.println("Saving grammar to " + outTmpName + "."); if (pData.Save(outTmpName)) System.out.println("Saving successful."); else System.out.println("Saving failed!"); pData = null; } // The last grammar/lexicon has not yet been evaluated. Even though the validation likelihood // has been dropping in the past few iteration, there is still a chance that the last one was in // fact the best so just in case we evaluate it. System.out.print("Calculating last validation likelihood..."); double validationLikelihood = calculateLogLikelihood(grammar, lexicon, validationStateSetTrees); System.out.println( "done.\n Iteration " + iter + " (final) gives validation likelihood " + validationLikelihood); if (validationLikelihood > maxLikelihood) { maxLikelihood = validationLikelihood; maxGrammar = previousGrammar; maxLexicon = previousLexicon; } ParserData pData = new ParserData( maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, binarization); System.out.println("Saving grammar to " + outFileName + "."); System.out.println("It gives a validation data log likelihood of: " + maxLikelihood); if (pData.Save(outFileName)) System.out.println("Saving successful."); else System.out.println("Saving failed!"); System.exit(0); }
public Tree<String> getBestParse(List<String> sentence) { // TODO: implement this method int n = sentence.size(); // System.out.println("getBestParse: n=" + n); List<List<Map<Object, Double>>> scores = new ArrayList<List<Map<Object, Double>>>(n + 1); for (int i = 0; i < n + 1; i++) { List<Map<Object, Double>> row = new ArrayList<Map<Object, Double>>(n + 1); for (int j = 0; j < n + 1; j++) { row.add(new HashMap<Object, Double>()); } scores.add(row); } List<List<Map<Object, Triplet<Integer, Object, Object>>>> backs = new ArrayList<List<Map<Object, Triplet<Integer, Object, Object>>>>(n + 1); for (int i = 0; i < n + 1; i++) { List<Map<Object, Triplet<Integer, Object, Object>>> row = new ArrayList<Map<Object, Triplet<Integer, Object, Object>>>(n + 1); for (int j = 0; j < n + 1; j++) { row.add(new HashMap<Object, Triplet<Integer, Object, Object>>()); } backs.add(row); } /* System.out.println("scores=" + scores.size() + "x" + scores.get(0).size()); System.out.println("backs=" + backs.size() + "x" + backs.get(0).size()); printChart(scores, backs, "scores"); */ // First the Lexicon for (int i = 0; i < n; i++) { String word = sentence.get(i); for (String tag : lexicon.getAllTags()) { UnaryRule A = new UnaryRule(tag, word); A.setScore(Math.log(lexicon.scoreTagging(word, tag))); scores.get(i).get(i + 1).put(A, A.getScore()); backs.get(i).get(i + 1).put(A, null); } // System.out.println("Starting unaries: i=" + i + ",n=" + n ); // Handle unaries boolean added = true; while (added) { added = false; Map<Object, Double> A_scores = scores.get(i).get(i + 1); // Don't modify the dict we are iterating List<Object> A_keys = copyKeys(A_scores); // for (int j = 0; j < 5 && j < A_keys.size(); j++) { // System.out.print("," + j + "=" + A_scores.get(A_keys.get(j))); // } for (Object oB : A_keys) { UnaryRule B = (UnaryRule) oB; for (UnaryRule A : grammar.getUnaryRulesByChild(B.getParent())) { double prob = Math.log(A.getScore()) + A_scores.get(B); if (prob > -1000.0) { if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { // System.out.print(" *A=" + A + ", B=" + B); // System.out.print(", prob=" + prob); // System.out.println(", A_scores.get(A)=" + A_scores.get(A)); A_scores.put(A, prob); backs.get(i).get(i + 1).put(A, new Triplet<Integer, Object, Object>(-1, B, null)); added = true; } // System.out.println(", added=" + added); } } } // System.out.println(", A_scores=" + A_scores.size() + ", added=" + added); } } // printChart(scores, backs, "scores with Lexicon"); // Do higher layers // Naming is based on rules: A -> B,C long startTime = new Date().getTime(); for (int span = 2; span < n + 1; span++) { for (int begin = 0; begin < n + 1 - span; begin++) { int end = begin + span; Map<Object, Double> A_scores = scores.get(begin).get(end); Map<Object, Triplet<Integer, Object, Object>> A_backs = backs.get(begin).get(end); for (int split = begin + 1; split < end; split++) { Map<Object, Double> B_scores = scores.get(begin).get(split); Map<Object, Double> C_scores = scores.get(split).get(end); List<Object> B_list = new ArrayList<Object>(B_scores.keySet()); List<Object> C_list = new ArrayList<Object>(C_scores.keySet()); // This is a key optimization. !@#$ // It avoids a B_list.size() x C_list.size() search in the for (Object B : B_list) loop Map<String, List<Object>> C_map = new HashMap<String, List<Object>>(); for (Object C : C_list) { String parent = getParent(C); if (!C_map.containsKey(parent)) { C_map.put(parent, new ArrayList<Object>()); } C_map.get(parent).add(C); } for (Object B : B_list) { for (BinaryRule A : grammar.getBinaryRulesByLeftChild(getParent(B))) { if (C_map.containsKey(A.getRightChild())) { for (Object C : C_map.get(A.getRightChild())) { // We now have A which has B as left child and C as right child double prob = Math.log(A.getScore()) + B_scores.get(B) + C_scores.get(C); if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { A_scores.put(A, prob); A_backs.put(A, new Triplet<Integer, Object, Object>(split, B, C)); } } } } } } // Handle unaries: A -> B boolean added = true; while (added) { added = false; // Don't modify the dict we are iterating List<Object> A_keys = copyKeys(A_scores); for (Object oB : A_keys) { for (UnaryRule A : grammar.getUnaryRulesByChild(getParent(oB))) { double prob = Math.log(A.getScore()) + A_scores.get(oB); if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { A_scores.put(A, prob); A_backs.put(A, new Triplet<Integer, Object, Object>(-1, oB, null)); added = true; } } } } } } // printChart(scores, backs, "scores with Lexicon and Grammar"); Map<Object, Double> topOfChart = scores.get(0).get(n); System.out.println("topOfChart: " + topOfChart.size()); /* for (Object o: topOfChart.keySet()) { System.out.println("o=" + o + ", score=" + topOfChart.getCount(o)); } */ // All parses have "ROOT" at top of tree Object bestKey = null; Object secondBestKey = null; double bestScore = Double.NEGATIVE_INFINITY; double secondBestScore = Double.NEGATIVE_INFINITY; for (Object key : topOfChart.keySet()) { double score = topOfChart.get(key); if (score >= secondBestScore || secondBestKey == null) { secondBestKey = key; secondBestScore = score; } if ("ROOT".equals(getParent(key)) && (score >= bestScore || bestKey == null)) { bestKey = key; bestScore = score; } } if (bestKey == null) { bestKey = secondBestKey; System.out.println("secondBestKey=" + secondBestKey); } if (bestKey == null) { for (Object key : topOfChart.keySet()) { System.out.println("val=" + topOfChart.get(key) + ", key=" + key); } } System.out.println("bestKey=" + bestKey + ", log(prob)=" + topOfChart.get(bestKey)); Tree<String> result = makeTree(backs, 0, n, bestKey); if (!"ROOT".equals(result.getLabel())) { List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(result); result = new Tree<String>("ROOT", children); // !@#$ } /* System.out.println("=================================================="); System.out.println(result); System.out.println("====================^^^^^^========================"); */ return TreeAnnotations.unAnnotateTree(result); }
public static void main(String[] args) { Options op = new Options(new EnglishTreebankParserParams()); // op.tlpParams may be changed to something else later, so don't use it till // after options are parsed. System.out.println("Currently " + new Date()); System.out.print("Invoked with arguments:"); for (String arg : args) { System.out.print(" " + arg); } System.out.println(); String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj"; int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219; String serializeFile = null; int i = 0; while (i < args.length && args[i].startsWith("-")) { if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) { path = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) { trainLow = Integer.parseInt(args[i + 1]); trainHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) { testLow = Integer.parseInt(args[i + 1]); testHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) { serializeFile = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) { try { op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance(); } catch (ClassNotFoundException e) { System.err.println("Class not found: " + args[i + 1]); } catch (InstantiationException e) { System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString()); } catch (IllegalAccessException e) { System.err.println("illegal access" + e); } i += 2; } else if (args[i].equals("-encoding")) { // sets encoding for TreebankLangParserParams op.tlpParams.setInputEncoding(args[i + 1]); op.tlpParams.setOutputEncoding(args[i + 1]); i += 2; } else { i = op.setOptionOrWarn(args, i); } } // System.out.println(tlpParams.getClass()); TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack(); Train.sisterSplitters = new HashSet(Arrays.asList(op.tlpParams.sisterSplitters())); // BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams); PrintWriter pw = op.tlpParams.pw(); Test.display(); Train.display(); op.display(); op.tlpParams.display(); // setup tree transforms Treebank trainTreebank = op.tlpParams.memoryTreebank(); MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank(); // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank(); // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/"; // blippTreebank.loadPath(blippPath, "", true); Timing.startTime(); System.err.print("Reading trees..."); testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true)); if (Test.increasingLength) { Collections.sort(testTreebank, new TreeLengthComparator()); } trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true)); Timing.tick("done."); System.err.print("Binarizing trees..."); TreeAnnotatorAndBinarizer binarizer = null; if (!Train.leftToRight) { binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams, op.forceCNF, !Train.outsideFactor(), true); } else { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !Train.outsideFactor(), true); } CollinsPuncTransformer collinsPuncTransformer = null; if (Train.collinsPunc) { collinsPuncTransformer = new CollinsPuncTransformer(tlp); } TreeTransformer debinarizer = new Debinarizer(op.forceCNF); List<Tree> binaryTrainTrees = new ArrayList<Tree>(); if (Train.selectiveSplit) { Train.splitters = ParentAnnotationStats.getSplitCategories( trainTreebank, Train.tagSelectiveSplit, 0, Train.selectiveSplitCutOff, Train.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack()); if (Train.deleteSplitters != null) { List<String> deleted = new ArrayList<String>(); for (String del : Train.deleteSplitters) { String baseDel = tlp.basicCategory(del); boolean checkBasic = del.equals(baseDel); for (Iterator<String> it = Train.splitters.iterator(); it.hasNext(); ) { String elem = it.next(); String baseElem = tlp.basicCategory(elem); boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del); if (delStr) { it.remove(); deleted.add(elem); } } } System.err.println("Removed from vertical splitters: " + deleted); } } if (Train.selectivePostSplit) { TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams); Treebank annotatedTB = trainTreebank.transform(myTransformer); Train.postSplitters = ParentAnnotationStats.getSplitCategories( annotatedTB, true, 0, Train.selectivePostSplitCutOff, Train.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack()); } if (Train.hSelSplit) { binarizer.setDoSelectiveSplit(false); for (Tree tree : trainTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } // tree.pennPrint(tlpParams.pw()); tree = binarizer.transformTree(tree); // binaryTrainTrees.add(tree); } binarizer.setDoSelectiveSplit(true); } for (Tree tree : trainTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTrainTrees.add(tree); } if (Test.verbose) { binarizer.dumpStats(); } List<Tree> binaryTestTrees = new ArrayList<Tree>(); for (Tree tree : testTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTestTrees.add(tree); } Timing.tick("done."); // binarization BinaryGrammar bg = null; UnaryGrammar ug = null; DependencyGrammar dg = null; // DependencyGrammar dgBLIPP = null; Lexicon lex = null; // extract grammars Extractor bgExtractor = new BinaryGrammarExtractor(); // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor(); // Extractor lexExtractor = new LexiconExtractor(); // Extractor dgExtractor = new DependencyMemGrammarExtractor(); Extractor dgExtractor = new MLEDependencyGrammarExtractor(op); if (op.doPCFG) { System.err.print("Extracting PCFG..."); Pair bgug = null; if (Train.cheatPCFG) { List allTrees = new ArrayList(binaryTrainTrees); allTrees.addAll(binaryTestTrees); bgug = (Pair) bgExtractor.extract(allTrees); } else { bgug = (Pair) bgExtractor.extract(binaryTrainTrees); } bg = (BinaryGrammar) bgug.second; bg.splitRules(); ug = (UnaryGrammar) bgug.first; ug.purgeRules(); Timing.tick("done."); } System.err.print("Extracting Lexicon..."); lex = op.tlpParams.lex(op.lexOptions); lex.train(binaryTrainTrees); Timing.tick("done."); if (op.doDep) { System.err.print("Extracting Dependencies..."); binaryTrainTrees.clear(); // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams,true)); DependencyGrammar dg1 = (DependencyGrammar) dgExtractor.extract( trainTreebank.iterator(), new TransformTreeDependency(op.tlpParams, true)); // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new // TransformTreeDependency(tlpParams)); // dg = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams)); // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2); // dg = (DependencyGrammar) dgExtractor.extract(binaryTrainTrees); //uses information whether // the words are known or not, discards unknown words Timing.tick("done."); // System.out.print("Extracting Unknown Word Model..."); // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees); // Timing.tick("done."); System.out.print("Tuning Dependency Model..."); dg.tune(binaryTestTrees); // System.out.println("TUNE DEPS: "+tuneDeps); Timing.tick("done."); } BinaryGrammar boundBG = bg; UnaryGrammar boundUG = ug; GrammarProjection gp = new NullGrammarProjection(bg, ug); // serialization if (serializeFile != null) { System.err.print("Serializing parser..."); LexicalizedParser.saveParserDataToSerialized( new ParserData(lex, bg, ug, dg, Numberer.getNumberers(), op), serializeFile); Timing.tick("done."); } // test: pcfg-parse and output ExhaustivePCFGParser parser = null; if (op.doPCFG) { parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op); } ExhaustiveDependencyParser dparser = ((op.doDep && !Test.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op) : null); Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp), dparser) : null); // Scorer scorer = parser; BiLexPCFGParser bparser = null; if (op.doPCFG && op.doDep) { bparser = (Test.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp) : new BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp); } LabeledConstituentEval pcfgPE = new LabeledConstituentEval("pcfg PE", true, tlp); LabeledConstituentEval comboPE = new LabeledConstituentEval("combo PE", true, tlp); AbstractEval pcfgCB = new LabeledConstituentEval.CBEval("pcfg CB", true, tlp); AbstractEval pcfgTE = new AbstractEval.TaggingEval("pcfg TE"); AbstractEval comboTE = new AbstractEval.TaggingEval("combo TE"); AbstractEval pcfgTEnoPunct = new AbstractEval.TaggingEval("pcfg nopunct TE"); AbstractEval comboTEnoPunct = new AbstractEval.TaggingEval("combo nopunct TE"); AbstractEval depTE = new AbstractEval.TaggingEval("depnd TE"); AbstractEval depDE = new AbstractEval.DependencyEval("depnd DE", true, tlp.punctuationWordAcceptFilter()); AbstractEval comboDE = new AbstractEval.DependencyEval("combo DE", true, tlp.punctuationWordAcceptFilter()); if (Test.evalb) { EvalB.initEVALBfiles(op.tlpParams); } // int[] countByLength = new int[Test.maxLength+1]; // use a reflection ruse, so one can run this without needing the tagger // edu.stanford.nlp.process.SentenceTagger tagger = (Test.preTag ? new // edu.stanford.nlp.process.SentenceTagger("/u/nlp/data/tagger.params/wsj0-21.holder") : null); SentenceProcessor tagger = null; if (Test.preTag) { try { Class[] argsClass = new Class[] {String.class}; Object[] arguments = new Object[] {"/u/nlp/data/pos-tagger/wsj3t0-18-bidirectional/train-wsj-0-18.holder"}; tagger = (SentenceProcessor) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger") .getConstructor(argsClass) .newInstance(arguments); } catch (Exception e) { System.err.println(e); System.err.println("Warning: No pretagging of sentences will be done."); } } for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) { Tree tree = testTreebank.get(tNum); int testTreeLen = tree.yield().size(); if (testTreeLen > Test.maxLength) { continue; } Tree binaryTree = binaryTestTrees.get(tNum); // countByLength[testTreeLen]++; System.out.println("-------------------------------------"); System.out.println("Number: " + (tNum + 1)); System.out.println("Length: " + testTreeLen); // tree.pennPrint(pw); // System.out.println("XXXX The binary tree is"); // binaryTree.pennPrint(pw); // System.out.println("Here are the tags in the lexicon:"); // System.out.println(lex.showTags()); // System.out.println("Here's the tagnumberer:"); // System.out.println(Numberer.getGlobalNumberer("tags").toString()); long timeMil1 = System.currentTimeMillis(); Timing.tick("Starting parse."); if (op.doPCFG) { // System.err.println(Test.forceTags); if (Test.forceTags) { if (tagger != null) { // System.out.println("Using a tagger to set tags"); // System.out.println("Tagged sentence as: " + // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false)); parser.parse(addLast(tagger.processSentence(cutLast(wordify(binaryTree.yield()))))); } else { // System.out.println("Forcing tags to match input."); parser.parse(cleanTags(binaryTree.taggedYield(), tlp)); } } else { // System.out.println("XXXX Parsing " + binaryTree.yield()); parser.parse(binaryTree.yield()); } // Timing.tick("Done with pcfg phase."); } if (op.doDep) { dparser.parse(binaryTree.yield()); // Timing.tick("Done with dependency phase."); } boolean bothPassed = false; if (op.doPCFG && op.doDep) { bothPassed = bparser.parse(binaryTree.yield()); // Timing.tick("Done with combination phase."); } long timeMil2 = System.currentTimeMillis(); long elapsed = timeMil2 - timeMil1; System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec."); // System.out.println("PCFG Best Parse:"); Tree tree2b = null; Tree tree2 = null; // System.out.println("Got full best parse..."); if (op.doPCFG) { tree2b = parser.getBestParse(); tree2 = debinarizer.transformTree(tree2b); } // System.out.println("Debinarized parse..."); // tree2.pennPrint(); // System.out.println("DepG Best Parse:"); Tree tree3 = null; Tree tree3db = null; if (op.doDep) { tree3 = dparser.getBestParse(); // was: but wrong Tree tree3db = debinarizer.transformTree(tree2); tree3db = debinarizer.transformTree(tree3); tree3.pennPrint(pw); } // tree.pennPrint(); // ((Tree)binaryTrainTrees.get(tNum)).pennPrint(); // System.out.println("Combo Best Parse:"); Tree tree4 = null; if (op.doPCFG && op.doDep) { try { tree4 = bparser.getBestParse(); if (tree4 == null) { tree4 = tree2b; } } catch (NullPointerException e) { System.err.println("Blocked, using PCFG parse!"); tree4 = tree2b; } } if (op.doPCFG && !bothPassed) { tree4 = tree2b; } // tree4.pennPrint(); if (op.doDep) { depDE.evaluate(tree3, binaryTree, pw); depTE.evaluate(tree3db, tree, pw); } TreeTransformer tc = op.tlpParams.collinizer(); TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb(); Tree tree4b = null; if (op.doPCFG) { // System.out.println("XXXX Best PCFG was: "); // tree2.pennPrint(); // System.out.println("XXXX Transformed best PCFG is: "); // tc.transformTree(tree2).pennPrint(); // System.out.println("True Best Parse:"); // tree.pennPrint(); // tc.transformTree(tree).pennPrint(); pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw); tree4b = tree4; tree4 = debinarizer.transformTree(tree4); if (op.nodePrune) { NodePruner np = new NodePruner(parser, debinarizer); tree4 = np.prune(tree4); } // tree4.pennPrint(); comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } // pcfgTE.evaluate(tree2, tree); pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw); pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw); comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0)); // tc.transformTree(tree2).pennPrint(); tree2.pennPrint(pw); if (op.doDep) { System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0)); // tc.transformTree(tree4).pennPrint(pw); tree4.pennPrint(pw); } System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0)); /* if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) { System.out.println("SCORE INVERSION"); parser.validateBinarizedTree(binaryTree,0); } */ tree.pennPrint(pw); } // end if doPCFG if (Test.evalb) { if (op.doPCFG && op.doDep) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4)); } else if (op.doPCFG) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2)); } else if (op.doDep) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db)); } } } // end for each tree in test treebank if (Test.evalb) { EvalB.closeEVALBfiles(); } // Test.display(); if (op.doPCFG) { pcfgPE.display(false, pw); System.out.println("Grammar size: " + Numberer.getGlobalNumberer("states").total()); pcfgCB.display(false, pw); if (op.doDep) { comboPE.display(false, pw); } pcfgTE.display(false, pw); pcfgTEnoPunct.display(false, pw); if (op.doDep) { comboTE.display(false, pw); comboTEnoPunct.display(false, pw); } } if (op.doDep) { depTE.display(false, pw); depDE.display(false, pw); } if (op.doPCFG && op.doDep) { comboDE.display(false, pw); } // pcfgPE.printGoodBad(); }
public void stocGradTrain(Parser parser, boolean testEachRound) { int numUpdates = 0; List<LexEntry> fixedEntries = new LinkedList<LexEntry>(); fixedEntries.addAll(parser.returnLex().getLexicon()); // add all sentential lexical entries. for (int l = 0; l < trainData.size(); l++) { parser.addLexEntries(trainData.getDataSet(l).makeSentEntries()); } parser.setGlobals(); DataSet data = null; // for each pass over the data for (int j = 0; j < EPOCHS; j++) { System.out.println("Training, iteration " + j); int total = 0, correct = 0, wrong = 0, looCorrect = 0, looWrong = 0; for (int l = 0; l < trainData.size(); l++) { // the variables to hold the current training example String words = null; Exp sem = null; data = trainData.getDataSet(l); if (verbose) System.out.println("---------------------"); String filename = trainData.getFilename(l); if (verbose) System.out.println("DataSet: " + filename); if (verbose) System.out.println("---------------------"); // loop through the training examples // try to create lexical entries for each training example for (int i = 0; i < data.size(); i++) { // print running stats if (verbose) { if (total != 0) { double r = (double) correct / total; double p = (double) correct / (correct + wrong); System.out.print(i + ": =========== r:" + r + " p:" + p); System.out.println(" (epoch:" + j + " file:" + l + " " + filename + ")"); } else System.out.println(i + ": ==========="); } // get the training example words = data.sent(i); sem = data.sem(i); if (verbose) { System.out.println(words); System.out.println(sem); } List<String> tokens = Parser.tokenize(words); if (tokens.size() > maxSentLen) continue; total++; String mes = null; boolean hasCorrect = false; // first, get all possible lexical entries from // a manipulation of the best parse. List<LexEntry> lex = makeLexEntriesChart(words, sem, parser); if (verbose) { System.out.println("Adding:"); for (LexEntry le : lex) { System.out.println(le + " : " + LexiconFeatSet.initialWeight(le)); } } parser.addLexEntries(lex); if (verbose) System.out.println("Lex Size: " + parser.returnLex().size()); // first parse to see if we are currently correct if (verbose) mes = "First"; parser.parseTimed(words, null, mes); Chart firstChart = parser.getChart(); Exp best = parser.bestSem(); // this just collates and outputs the training // accuracy. if (sem.equals(best)) { // System.out.println(parser.bestParses().get(0)); if (verbose) { System.out.println("CORRECT:" + best); lex = parser.getMaxLexEntriesFor(sem); System.out.println("Using:"); printLex(lex); if (lex.size() == 0) { System.out.println("ERROR: empty lex"); } } correct++; } else { if (verbose) { System.out.println("WRONG: " + best); lex = parser.getMaxLexEntriesFor(best); System.out.println("Using:"); printLex(lex); if (best != null && lex.size() == 0) { System.out.println("ERROR: empty lex"); } } wrong++; } // compute first half of parameter update: // subtract the expectation of parameters // under the distribution that is conditioned // on the sentence alone. double norm = firstChart.computeNorm(); HashVector update = new HashVector(); HashVector firstfeats = null, secondfeats = null; if (norm != 0.0) { firstfeats = firstChart.computeExpFeatVals(); firstfeats.divideBy(norm); firstfeats.dropSmallEntries(); firstfeats.addTimesInto(-1.0, update); } else continue; firstChart = null; if (verbose) mes = "Second"; parser.parseTimed(words, sem, mes); hasCorrect = parser.hasParseFor(sem); // compute second half of parameter update: // add the expectation of parameters // under the distribution that is conditioned // on the sentence and correct logical form. if (!hasCorrect) continue; Chart secondChart = parser.getChart(); double secnorm = secondChart.computeNorm(sem); if (norm != 0.0) { secondfeats = secondChart.computeExpFeatVals(sem); secondfeats.divideBy(secnorm); secondfeats.dropSmallEntries(); secondfeats.addTimesInto(1.0, update); lex = parser.getMaxLexEntriesFor(sem); data.setBestLex(i, lex); if (verbose) { System.out.println("Best LexEntries:"); printLex(lex); if (lex.size() == 0) { System.out.println("ERROR: empty lex"); } } } else continue; // now do the update double scale = alpha_0 / (1.0 + c * numUpdates); if (verbose) System.out.println("Scale: " + scale); update.multiplyBy(scale); update.dropSmallEntries(); numUpdates++; if (verbose) { System.out.println("Update:"); System.out.println(update); } if (!update.isBad()) { if (!update.valuesInRange(-100, 100)) { System.out.println("WARNING: large update"); System.out.println("first feats: " + firstfeats); System.out.println("second feats: " + secondfeats); } parser.updateParams(update); } else { System.out.println( "ERROR: Bad Update: " + update + " -- norm: " + norm + " -- feats: "); parser.getParams().printValues(update); System.out.println(); } } // end for each training example } // end for each data set double r = (double) correct / total; // we can prune the lexical items that were not used // in a max scoring parse. if (pruneLex) { Lexicon cur = new Lexicon(); cur.addLexEntries(fixedEntries); cur.addLexEntries(data.getBestLex()); parser.setLexicon(cur); } if (testEachRound) { System.out.println("Testing"); test(parser, false); } } // end epochs loop }
public Tree<String> getBestParseOld(List<String> sentence) { // TODO: This implements the CKY algorithm CounterMap<String, String> parseScores = new CounterMap<String, String>(); System.out.println(sentence.toString()); // First deal with the lexicons int index = 0; int span = 1; // All spans are 1 at the lexicon level for (String word : sentence) { for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (score >= 0.0) { // This lexicon may generate this word // We use a counter map in order to store the scores for this sentence parse. parseScores.setCount(index + " " + (index + span), tag, score); } } index = index + 1; } // handle unary rules now HashMap<String, Triplet<Integer, String, String>> backHash = new HashMap< String, Triplet<Integer, String, String>>(); // hashmap to store back propation // System.out.println("Lexicons found"); Boolean added = true; while (added) { added = false; for (index = 0; index < sentence.size(); index++) { // For each index+ span pair, get the counter. Counter<String> count = parseScores.getCounter(index + " " + (index + span)); PriorityQueue<String> countAsPQ = count.asPriorityQueue(); while (countAsPQ.hasNext()) { String entry = countAsPQ.next(); // System.out.println("I am fine here!!"); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { // These are the unary rules which might give rise to the above preterminal double prob = rule.getScore() * parseScores.getCount(index + " " + (index + span), entry); if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) { parseScores.setCount(index + " " + (index + span), rule.parent, prob); backHash.put( index + " " + (index + span) + " " + rule.parent, new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } } // System.out.println("Lexicon unaries dealt with"); // Now work with the grammar to produce higher level probabilities for (span = 2; span <= sentence.size(); span++) { for (int begin = 0; begin <= (sentence.size() - span); begin++) { int end = begin + span; for (int split = begin + 1; split <= end - 1; split++) { Counter<String> countLeft = parseScores.getCounter(begin + " " + split); Counter<String> countRight = parseScores.getCounter(split + " " + end); // List<BinaryRule> leftRules= new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>(); // List<BinaryRule> rightRules=new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>(); for (String entry : countLeft.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) { if (!leftMap.containsKey(rule.hashCode())) { leftMap.put(rule.hashCode(), rule); } } } for (String entry : countRight.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) { if (!rightMap.containsKey(rule.hashCode())) { rightMap.put(rule.hashCode(), rule); } } } // System.out.println("About to enter the rules loops"); for (Integer ruleHash : leftMap.keySet()) { if (rightMap.containsKey(ruleHash)) { BinaryRule ruleRight = rightMap.get(ruleHash); double prob = ruleRight.getScore() * parseScores.getCount(begin + " " + split, ruleRight.leftChild) * parseScores.getCount(split + " " + end, ruleRight.rightChild); // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) { // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); // System.out.println("parentrule :"+ ruleRight.getParent()); parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob); backHash.put( begin + " " + end + " " + ruleRight.parent, new Triplet<Integer, String, String>( split, ruleRight.leftChild, ruleRight.rightChild)); } } } // System.out.println("Exited rules loop"); } // System.out.println("Grammar found for " + begin + " "+ end); // Now handle unary rules added = true; while (added) { added = false; Counter<String> count = parseScores.getCounter(begin + " " + end); PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue(); while (countAsPriorityQueue.hasNext()) { String entry = countAsPriorityQueue.next(); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry); if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) { parseScores.setCount(begin + " " + (end), rule.parent, prob); backHash.put( begin + " " + (end) + " " + rule.parent, new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } // System.out.println("Unaries dealt for " + begin + " "+ end); } } // Create and return the parse tree Tree<String> parseTree = new Tree<String>("null"); // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString()); String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax(); if (parent == null) { System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString()); System.out.println("THIS IS WEIRD"); } parent = "ROOT"; parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent); // System.out.println("PARSE SCORES"); // System.out.println(parseScores.toString()); // System.out.println("BACK HASH"); // System.out.println(backHash.toString()); // parseTree = addRoot(parseTree); // System.out.println(parseTree.toString()); // return parseTree; return TreeAnnotations.unAnnotateTree(parseTree); }
public Tree<String> getBestParse(List<String> sentence) { // This implements the CKY algorithm int nEntries = sentence.size(); // hashmap to store back rules HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash = new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>(); // more efficient access with arrays, but must cast each time :( @SuppressWarnings("unchecked") Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]); for (int i = 0; i < nEntries; i++) { for (int j = 0; j < nEntries; j++) { parseScores[i][j] = new Counter<String>(); } } System.out.println(sentence.toString()); // First deal with the lexicons int index = 0; int span = 1; // All spans are 1 at the lexicon level for (String word : sentence) { for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (score >= 0.0) { // This lexicon may generate this word // We use a counter map in order to store the scores for this sentence parse. parseScores[index][index + span - 1].setCount(tag, score); } } index = index + 1; } // handle unary rules now // System.out.println("Lexicons found"); boolean added = true; while (added) { added = false; for (index = 0; index < sentence.size(); index++) { // For each index+ span pair, get the counter. Counter<String> count = parseScores[index][index + span - 1]; PriorityQueue<String> countAsPQ = count.asPriorityQueue(); while (countAsPQ.hasNext()) { String entry = countAsPQ.next(); // System.out.println("I am fine here!!"); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { // These are the unary rules which might give rise to the above preterminal double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry); if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) { parseScores[index][index + span - 1].setCount(rule.parent, prob); backHash.put( new Triplet<Integer, Integer, String>(index, index + span, rule.parent), new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } } // System.out.println("Lexicon unaries dealt with"); // Now work with the grammar to produce higher level probabilities for (span = 2; span <= sentence.size(); span++) { for (int begin = 0; begin <= (sentence.size() - span); begin++) { int end = begin + span; for (int split = begin + 1; split <= end - 1; split++) { Counter<String> countLeft = parseScores[begin][split - 1]; Counter<String> countRight = parseScores[split][end - 1]; // List<BinaryRule> leftRules= new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>(); // List<BinaryRule> rightRules=new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>(); for (String entry : countLeft.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) { if (!leftMap.containsKey(rule.hashCode())) { leftMap.put(rule.hashCode(), rule); } } } for (String entry : countRight.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) { if (!rightMap.containsKey(rule.hashCode())) { rightMap.put(rule.hashCode(), rule); } } } // System.out.println("About to enter the rules loops"); for (Integer ruleHash : leftMap.keySet()) { if (rightMap.containsKey(ruleHash)) { BinaryRule ruleRight = rightMap.get(ruleHash); double prob = ruleRight.getScore() * parseScores[begin][split - 1].getCount(ruleRight.leftChild) * parseScores[split][end - 1].getCount(ruleRight.rightChild); // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) { // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); // System.out.println("parentrule :"+ ruleRight.getParent()); parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob); backHash.put( new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent), new Triplet<Integer, String, String>( split, ruleRight.leftChild, ruleRight.rightChild)); } } } // System.out.println("Exited rules loop"); } // System.out.println("Grammar found for " + begin + " "+ end); // Now handle unary rules added = true; while (added) { added = false; Counter<String> count = parseScores[begin][end - 1]; PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue(); while (countAsPriorityQueue.hasNext()) { String entry = countAsPriorityQueue.next(); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry); if (prob > parseScores[begin][end - 1].getCount(rule.parent)) { parseScores[begin][end - 1].setCount(rule.parent, prob); backHash.put( new Triplet<Integer, Integer, String>(begin, end, rule.parent), new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } // System.out.println("Unaries dealt for " + begin + " "+ end); } } // Create and return the parse tree Tree<String> parseTree = new Tree<String>("null"); // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString()); // Pick the argmax String parent = parseScores[0][nEntries - 1].argMax(); // Or pick root. This second one is preferred since sentences are meant to have ROOT as their // root node. parent = "ROOT"; parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent); // System.out.println("PARSE SCORES"); // System.out.println(parseScores.toString()); // System.out.println("BACK HASH"); // System.out.println(backHash.toString()); // parseTree = addRoot(parseTree); // System.out.println(parseTree.toString()); // return parseTree; return TreeAnnotations.unAnnotateTree(parseTree); }