private String getBestTag(String word) { double bestScore = Double.NEGATIVE_INFINITY; String bestTag = null; for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (bestTag == null || score > bestScore) { bestScore = score; bestTag = tag; } } return bestTag; }
void leksēmaSelektēta() { Lexeme leksēma = null; int i = leksēmuTabula.getSelectedRow(); Object o = null; if (i >= 0) o = leksēmuTabula.getValueAt(i, 0); if (o != null) leksēma = lexicon.lexemeByID(Integer.parseInt(o.toString())); leksĪpMod.setAttributes(leksēma); }
void galotneSelektēta() { Ending ending = null; int i = galotņuTabula.getSelectedRow(); Object o = null; if (i >= 0) o = galotņuTabula.getValueAt(i, 0); if (o != null) ending = lexicon.endingByID(Integer.parseInt(o.toString())); galĪpMod.setAttributes(ending); }
/** * @param previousGrammar * @param previousLexicon * @param grammar * @param lexicon * @param trainStateSetTrees * @return */ public static double doOneEStep( Grammar previousGrammar, Lexicon previousLexicon, Grammar grammar, Lexicon lexicon, StateSetTreeList trainStateSetTrees, boolean updateOnlyLexicon, int unkThreshold) { boolean secondHalf = false; ArrayParser parser = new ArrayParser(previousGrammar, previousLexicon); double trainingLikelihood = 0; int n = 0; int nTrees = trainStateSetTrees.size(); for (Tree<StateSet> stateSetTree : trainStateSetTrees) { secondHalf = (n++ > nTrees / 2.0); boolean noSmoothing = true, debugOutput = false; parser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput); // E Step double ll = stateSetTree.getLabel().getIScore(0); ll = Math.log(ll) + (100 * stateSetTree.getLabel().getIScale()); // System.out.println(stateSetTree); if ((Double.isInfinite(ll) || Double.isNaN(ll))) { if (VERBOSE) { System.out.println("Training sentence " + n + " is given " + ll + " log likelihood!"); System.out.println( "Root iScore " + stateSetTree.getLabel().getIScore(0) + " scale " + stateSetTree.getLabel().getIScale()); } } else { lexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf, noSmoothing, unkThreshold); if (!updateOnlyLexicon) grammar.tallyStateSetTree(stateSetTree, previousGrammar); // E Step trainingLikelihood += ll; // there are for some reason some sentences that are unparsable } } lexicon.tieRareWordStats(unkThreshold); // SSIE ((SophisticatedLexicon) lexicon).overwriteWithMaxent(); return trainingLikelihood; }
void vārdgrupaSelektēta() { int i = vārdgrupas.getSelectedRow(); Paradigm vārdgrupa = null; if (i > -1) { String s = vārdgrupas.getValueAt(i, 0).toString(); vārdgrupa = lexicon.paradigmByID(Integer.parseInt(s)); } vgrĪpašībasAtjaunot(vārdgrupa); }
private void readTree(String tree, boolean removeDigits) { parsed = Utils.getCatInventory( removeDigits ? treeRemoveDigits(tree.trim()) : tree.trim(), opts.combineNNVBcats); IdGenerator idgen = new IdGenerator(); StringTree gsTree = Lexicon.convertToTree(new ElementaryStringTree("1\t" + parsed, opts.useSemantics), idgen); gsTree.removeUnaryNodes(gsTree.getRoot()); // makeLexiconEntry(); goldStandardNoTraces = gsTree.printNoTraces(); // .printSimpleCat(); }
/** * Converts a {@link Grammar} and {@link Lexicon} to BUBS sparse-matrix format, specifically * {@link LeftCscSparseMatrixGrammar}. Prunes rules below the minimum rule probability threshold * specified when the {@link DiscriminativeMergeObjectiveFunction} was initialized with {@link * #init(CompleteClosureModel, List, List, float)}. * * @param grammar * @param lexicon * @return {@link LeftCscSparseMatrixGrammar} */ protected LeftCscSparseMatrixGrammar convertGrammarToSparseMatrix( final Grammar grammar, final Lexicon lexicon) { try { final Writer w = new StringWriter(150 * 1024 * 1024); // Note We could use a PipedOutputStream / PipedInputStream combination (with 2 threads) to // write and read // at the same time, and avoid using enough memory to serialize the entire grammar. But memory // isn't a huge // constraint during training, and the threading would add complexity, so we'll skip that for // now. w.write(grammar.toString(lexicon.totalRules(minRuleProbability), minRuleProbability, 0, 0)); w.write("===== LEXICON =====\n"); w.write(lexicon.toString(minRuleProbability)); return new LeftCscSparseMatrixGrammar( new StringReader(w.toString()), new DecisionTreeTokenClassifier()); } catch (final IOException e) { // StringWriter and StringReader should never IOException throw new AssertionError(e); } }
public void newGame(View v) { game.flagEN = dictionaryEnglish; if (dictionaryEnglish) { // reset and load the english lexicon if (lexicon == null) { // only create an english lexicon if it doesn't exist already lexicon = new Lexicon(getApplicationContext(), "english.txt"); } else { lexicon.reset(); } game.lexicon = lexicon; } else { // reload the alternative (dutch) lexicon if (alt_lexicon == null) { // only create an alternative lexicon if it doesn't exist already alt_lexicon = new Lexicon(getApplicationContext(), "dutch.txt"); } else { alt_lexicon.reset(); } game.lexicon = alt_lexicon; } game.resetGame(); updateView(); // if button was fired from settings menu, toggle settings fragment if (v == null || v.getId() == R.id.newGame_settings) toggleSettings(); // save new game new AsyncSaveGame(this).execute(game); }
/** * Outputs a textual representation of this <code>SparseWeightVector</code> to a stream just like * {@link #write(PrintStream)}, but without the <code>"Begin"</code> and <code>"End"</code> * annotations. With a <code>Lexicon</code> passed as a parameter, the feature is printed along * with each weight. * * @param out The stream to write to. * @param min Sets the minimum width for the textual representation of all features. * @param lex The feature lexicon. */ public void toStringJustWeights(PrintStream out, int min, Lexicon lex) { Map map = lex.getMap(); Map.Entry[] entries = (Map.Entry[]) map.entrySet().toArray(new Map.Entry[map.size()]); Arrays.sort( entries, new Comparator() { public int compare(Object o1, Object o2) { Map.Entry e1 = (Map.Entry) o1; Map.Entry e2 = (Map.Entry) o2; int i1 = ((Integer) e1.getValue()).intValue(); int i2 = ((Integer) e2.getValue()).intValue(); if ((i1 < weights.size()) != (i2 < weights.size())) return i1 - i2; return ((Feature) e1.getKey()).compareTo(e2.getKey()); } }); int i, biggest = min; for (i = 0; i < entries.length; ++i) { // for (i = 0; i < weights.size(); ++i) String key = entries[i].getKey().toString() + (((Integer) entries[i].getValue()).intValue() < weights.size() ? "" : " (pruned)"); biggest = Math.max(biggest, key.length()); } if (biggest % 2 == 0) biggest += 2; else ++biggest; for (i = 0; i < entries.length; ++i) { // for (i = 0; i < weights.size(); ++i) String key = entries[i].getKey().toString() + (((Integer) entries[i].getValue()).intValue() < weights.size() ? "" : " (pruned)"); out.print(key); for (int j = 0; key.length() + j < biggest; ++j) out.print(" "); int index = ((Integer) entries[i].getValue()).intValue(); out.println(weights.get(index)); } }
public static void main(String[] args) { Options op = new Options(new EnglishTreebankParserParams()); // op.tlpParams may be changed to something else later, so don't use it till // after options are parsed. System.out.println(StringUtils.toInvocationString("FactoredParser", args)); String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj"; int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219; String serializeFile = null; int i = 0; while (i < args.length && args[i].startsWith("-")) { if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) { path = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) { trainLow = Integer.parseInt(args[i + 1]); trainHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) { testLow = Integer.parseInt(args[i + 1]); testHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) { serializeFile = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) { try { op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance(); } catch (ClassNotFoundException e) { System.err.println("Class not found: " + args[i + 1]); throw new RuntimeException(e); } catch (InstantiationException e) { System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString()); throw new RuntimeException(e); } catch (IllegalAccessException e) { System.err.println("illegal access" + e); throw new RuntimeException(e); } i += 2; } else if (args[i].equals("-encoding")) { // sets encoding for TreebankLangParserParams op.tlpParams.setInputEncoding(args[i + 1]); op.tlpParams.setOutputEncoding(args[i + 1]); i += 2; } else { i = op.setOptionOrWarn(args, i); } } // System.out.println(tlpParams.getClass()); TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack(); op.trainOptions.sisterSplitters = new HashSet<String>(Arrays.asList(op.tlpParams.sisterSplitters())); // BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams); PrintWriter pw = op.tlpParams.pw(); op.testOptions.display(); op.trainOptions.display(); op.display(); op.tlpParams.display(); // setup tree transforms Treebank trainTreebank = op.tlpParams.memoryTreebank(); MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank(); // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank(); // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/"; // blippTreebank.loadPath(blippPath, "", true); Timing.startTime(); System.err.print("Reading trees..."); testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true)); if (op.testOptions.increasingLength) { Collections.sort(testTreebank, new TreeLengthComparator()); } trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true)); Timing.tick("done."); System.err.print("Binarizing trees..."); TreeAnnotatorAndBinarizer binarizer; if (!op.trainOptions.leftToRight) { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op); } else { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op); } CollinsPuncTransformer collinsPuncTransformer = null; if (op.trainOptions.collinsPunc) { collinsPuncTransformer = new CollinsPuncTransformer(tlp); } TreeTransformer debinarizer = new Debinarizer(op.forceCNF); List<Tree> binaryTrainTrees = new ArrayList<Tree>(); if (op.trainOptions.selectiveSplit) { op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories( trainTreebank, op.trainOptions.tagSelectiveSplit, 0, op.trainOptions.selectiveSplitCutOff, op.trainOptions.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack()); if (op.trainOptions.deleteSplitters != null) { List<String> deleted = new ArrayList<String>(); for (String del : op.trainOptions.deleteSplitters) { String baseDel = tlp.basicCategory(del); boolean checkBasic = del.equals(baseDel); for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) { String elem = it.next(); String baseElem = tlp.basicCategory(elem); boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del); if (delStr) { it.remove(); deleted.add(elem); } } } System.err.println("Removed from vertical splitters: " + deleted); } } if (op.trainOptions.selectivePostSplit) { TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams, op); Treebank annotatedTB = trainTreebank.transform(myTransformer); op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories( annotatedTB, true, 0, op.trainOptions.selectivePostSplitCutOff, op.trainOptions.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack()); } if (op.trainOptions.hSelSplit) { binarizer.setDoSelectiveSplit(false); for (Tree tree : trainTreebank) { if (op.trainOptions.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } // tree.pennPrint(tlpParams.pw()); tree = binarizer.transformTree(tree); // binaryTrainTrees.add(tree); } binarizer.setDoSelectiveSplit(true); } for (Tree tree : trainTreebank) { if (op.trainOptions.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTrainTrees.add(tree); } if (op.testOptions.verbose) { binarizer.dumpStats(); } List<Tree> binaryTestTrees = new ArrayList<Tree>(); for (Tree tree : testTreebank) { if (op.trainOptions.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTestTrees.add(tree); } Timing.tick("done."); // binarization BinaryGrammar bg = null; UnaryGrammar ug = null; DependencyGrammar dg = null; // DependencyGrammar dgBLIPP = null; Lexicon lex = null; Index<String> stateIndex = new HashIndex<String>(); // extract grammars Extractor<Pair<UnaryGrammar, BinaryGrammar>> bgExtractor = new BinaryGrammarExtractor(op, stateIndex); // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor(); // Extractor lexExtractor = new LexiconExtractor(); // Extractor dgExtractor = new DependencyMemGrammarExtractor(); if (op.doPCFG) { System.err.print("Extracting PCFG..."); Pair<UnaryGrammar, BinaryGrammar> bgug = null; if (op.trainOptions.cheatPCFG) { List<Tree> allTrees = new ArrayList<Tree>(binaryTrainTrees); allTrees.addAll(binaryTestTrees); bgug = bgExtractor.extract(allTrees); } else { bgug = bgExtractor.extract(binaryTrainTrees); } bg = bgug.second; bg.splitRules(); ug = bgug.first; ug.purgeRules(); Timing.tick("done."); } System.err.print("Extracting Lexicon..."); Index<String> wordIndex = new HashIndex<String>(); Index<String> tagIndex = new HashIndex<String>(); lex = op.tlpParams.lex(op, wordIndex, tagIndex); lex.train(binaryTrainTrees); Timing.tick("done."); if (op.doDep) { System.err.print("Extracting Dependencies..."); binaryTrainTrees.clear(); Extractor<DependencyGrammar> dgExtractor = new MLEDependencyGrammarExtractor(op, wordIndex, tagIndex); // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams,true)); // DependencyGrammar dg1 = dgExtractor.extract(trainTreebank.iterator(), new // TransformTreeDependency(op.tlpParams, true)); // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new // TransformTreeDependency(tlpParams)); // dg = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams)); // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2); dg = dgExtractor.extract( binaryTrainTrees); // uses information whether the words are known or not, discards // unknown words Timing.tick("done."); // System.out.print("Extracting Unknown Word Model..."); // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees); // Timing.tick("done."); System.out.print("Tuning Dependency Model..."); dg.tune(binaryTestTrees); // System.out.println("TUNE DEPS: "+tuneDeps); Timing.tick("done."); } BinaryGrammar boundBG = bg; UnaryGrammar boundUG = ug; GrammarProjection gp = new NullGrammarProjection(bg, ug); // serialization if (serializeFile != null) { System.err.print("Serializing parser..."); LexicalizedParser.saveParserDataToSerialized( new ParserData(lex, bg, ug, dg, stateIndex, wordIndex, tagIndex, op), serializeFile); Timing.tick("done."); } // test: pcfg-parse and output ExhaustivePCFGParser parser = null; if (op.doPCFG) { parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op, stateIndex, wordIndex, tagIndex); } ExhaustiveDependencyParser dparser = ((op.doDep && !op.testOptions.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op, wordIndex, tagIndex) : null); Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp, op), dparser) : null); // Scorer scorer = parser; BiLexPCFGParser bparser = null; if (op.doPCFG && op.doDep) { bparser = (op.testOptions.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex) : new BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex); } Evalb pcfgPE = new Evalb("pcfg PE", true); Evalb comboPE = new Evalb("combo PE", true); AbstractEval pcfgCB = new Evalb.CBEval("pcfg CB", true); AbstractEval pcfgTE = new TaggingEval("pcfg TE"); AbstractEval comboTE = new TaggingEval("combo TE"); AbstractEval pcfgTEnoPunct = new TaggingEval("pcfg nopunct TE"); AbstractEval comboTEnoPunct = new TaggingEval("combo nopunct TE"); AbstractEval depTE = new TaggingEval("depnd TE"); AbstractEval depDE = new UnlabeledAttachmentEval("depnd DE", true, null, tlp.punctuationWordRejectFilter()); AbstractEval comboDE = new UnlabeledAttachmentEval("combo DE", true, null, tlp.punctuationWordRejectFilter()); if (op.testOptions.evalb) { EvalbFormatWriter.initEVALBfiles(op.tlpParams); } // int[] countByLength = new int[op.testOptions.maxLength+1]; // Use a reflection ruse, so one can run this without needing the // tagger. Using a function rather than a MaxentTagger means we // can distribute a version of the parser that doesn't include the // entire tagger. Function<List<? extends HasWord>, ArrayList<TaggedWord>> tagger = null; if (op.testOptions.preTag) { try { Class[] argsClass = {String.class}; Object[] arguments = new Object[] {op.testOptions.taggerSerializedFile}; tagger = (Function<List<? extends HasWord>, ArrayList<TaggedWord>>) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger") .getConstructor(argsClass) .newInstance(arguments); } catch (Exception e) { System.err.println(e); System.err.println("Warning: No pretagging of sentences will be done."); } } for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) { Tree tree = testTreebank.get(tNum); int testTreeLen = tree.yield().size(); if (testTreeLen > op.testOptions.maxLength) { continue; } Tree binaryTree = binaryTestTrees.get(tNum); // countByLength[testTreeLen]++; System.out.println("-------------------------------------"); System.out.println("Number: " + (tNum + 1)); System.out.println("Length: " + testTreeLen); // tree.pennPrint(pw); // System.out.println("XXXX The binary tree is"); // binaryTree.pennPrint(pw); // System.out.println("Here are the tags in the lexicon:"); // System.out.println(lex.showTags()); // System.out.println("Here's the tagnumberer:"); // System.out.println(Numberer.getGlobalNumberer("tags").toString()); long timeMil1 = System.currentTimeMillis(); Timing.tick("Starting parse."); if (op.doPCFG) { // System.err.println(op.testOptions.forceTags); if (op.testOptions.forceTags) { if (tagger != null) { // System.out.println("Using a tagger to set tags"); // System.out.println("Tagged sentence as: " + // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false)); parser.parse(addLast(tagger.apply(cutLast(wordify(binaryTree.yield()))))); } else { // System.out.println("Forcing tags to match input."); parser.parse(cleanTags(binaryTree.taggedYield(), tlp)); } } else { // System.out.println("XXXX Parsing " + binaryTree.yield()); parser.parse(binaryTree.yieldHasWord()); } // Timing.tick("Done with pcfg phase."); } if (op.doDep) { dparser.parse(binaryTree.yieldHasWord()); // Timing.tick("Done with dependency phase."); } boolean bothPassed = false; if (op.doPCFG && op.doDep) { bothPassed = bparser.parse(binaryTree.yieldHasWord()); // Timing.tick("Done with combination phase."); } long timeMil2 = System.currentTimeMillis(); long elapsed = timeMil2 - timeMil1; System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec."); // System.out.println("PCFG Best Parse:"); Tree tree2b = null; Tree tree2 = null; // System.out.println("Got full best parse..."); if (op.doPCFG) { tree2b = parser.getBestParse(); tree2 = debinarizer.transformTree(tree2b); } // System.out.println("Debinarized parse..."); // tree2.pennPrint(); // System.out.println("DepG Best Parse:"); Tree tree3 = null; Tree tree3db = null; if (op.doDep) { tree3 = dparser.getBestParse(); // was: but wrong Tree tree3db = debinarizer.transformTree(tree2); tree3db = debinarizer.transformTree(tree3); tree3.pennPrint(pw); } // tree.pennPrint(); // ((Tree)binaryTrainTrees.get(tNum)).pennPrint(); // System.out.println("Combo Best Parse:"); Tree tree4 = null; if (op.doPCFG && op.doDep) { try { tree4 = bparser.getBestParse(); if (tree4 == null) { tree4 = tree2b; } } catch (NullPointerException e) { System.err.println("Blocked, using PCFG parse!"); tree4 = tree2b; } } if (op.doPCFG && !bothPassed) { tree4 = tree2b; } // tree4.pennPrint(); if (op.doDep) { depDE.evaluate(tree3, binaryTree, pw); depTE.evaluate(tree3db, tree, pw); } TreeTransformer tc = op.tlpParams.collinizer(); TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb(); if (op.doPCFG) { // System.out.println("XXXX Best PCFG was: "); // tree2.pennPrint(); // System.out.println("XXXX Transformed best PCFG is: "); // tc.transformTree(tree2).pennPrint(); // System.out.println("True Best Parse:"); // tree.pennPrint(); // tc.transformTree(tree).pennPrint(); pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); Tree tree4b = null; if (op.doDep) { comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw); tree4b = tree4; tree4 = debinarizer.transformTree(tree4); if (op.nodePrune) { NodePruner np = new NodePruner(parser, debinarizer); tree4 = np.prune(tree4); } // tree4.pennPrint(); comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } // pcfgTE.evaluate(tree2, tree); pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw); pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw); comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0)); // tc.transformTree(tree2).pennPrint(); tree2.pennPrint(pw); if (op.doDep) { System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0)); // tc.transformTree(tree4).pennPrint(pw); tree4.pennPrint(pw); } System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0)); /* if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) { System.out.println("SCORE INVERSION"); parser.validateBinarizedTree(binaryTree,0); } */ tree.pennPrint(pw); } // end if doPCFG if (op.testOptions.evalb) { if (op.doPCFG && op.doDep) { EvalbFormatWriter.writeEVALBline( tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4)); } else if (op.doPCFG) { EvalbFormatWriter.writeEVALBline( tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2)); } else if (op.doDep) { EvalbFormatWriter.writeEVALBline( tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db)); } } } // end for each tree in test treebank if (op.testOptions.evalb) { EvalbFormatWriter.closeEVALBfiles(); } // op.testOptions.display(); if (op.doPCFG) { pcfgPE.display(false, pw); System.out.println("Grammar size: " + stateIndex.size()); pcfgCB.display(false, pw); if (op.doDep) { comboPE.display(false, pw); } pcfgTE.display(false, pw); pcfgTEnoPunct.display(false, pw); if (op.doDep) { comboTE.display(false, pw); comboTEnoPunct.display(false, pw); } } if (op.doDep) { depTE.display(false, pw); depDE.display(false, pw); } if (op.doPCFG && op.doDep) { comboDE.display(false, pw); } // pcfgPE.printGoodBad(); }
public static void main(String[] args) { OptionParser optParser = new OptionParser(Options.class); Options opts = (Options) optParser.parse(args, true); // provide feedback on command-line arguments System.out.println("Calling with " + optParser.getPassedInOptions()); String path = opts.path; // int lang = opts.lang; System.out.println("Loading trees from " + path + " and using language " + opts.treebank); double trainingFractionToKeep = opts.trainingFractionToKeep; int maxSentenceLength = opts.maxSentenceLength; System.out.println("Will remove sentences with more than " + maxSentenceLength + " words."); HORIZONTAL_MARKOVIZATION = opts.horizontalMarkovization; VERTICAL_MARKOVIZATION = opts.verticalMarkovization; System.out.println( "Using horizontal=" + HORIZONTAL_MARKOVIZATION + " and vertical=" + VERTICAL_MARKOVIZATION + " markovization."); Binarization binarization = opts.binarization; System.out.println( "Using " + binarization.name() + " binarization."); // and "+annotateString+"."); double randomness = opts.randomization; System.out.println("Using a randomness value of " + randomness); String outFileName = opts.outFileName; if (outFileName == null) { System.out.println("Output File name is required."); System.exit(-1); } else System.out.println("Using grammar output file " + outFileName + "."); VERBOSE = opts.verbose; RANDOM = new Random(opts.randSeed); System.out.println("Random number generator seeded at " + opts.randSeed + "."); boolean manualAnnotation = false; boolean baseline = opts.baseline; boolean noSplit = opts.noSplit; int numSplitTimes = opts.numSplits; if (baseline) numSplitTimes = 0; String splitGrammarFile = opts.inFile; int allowedDroppingIters = opts.di; int maxIterations = opts.splitMaxIterations; int minIterations = opts.splitMinIterations; if (minIterations > 0) System.out.println("I will do at least " + minIterations + " iterations."); double[] smoothParams = {opts.smoothingParameter1, opts.smoothingParameter2}; System.out.println("Using smoothing parameters " + smoothParams[0] + " and " + smoothParams[1]); boolean allowMoreSubstatesThanCounts = false; boolean findClosedUnaryPaths = opts.findClosedUnaryPaths; Corpus corpus = new Corpus( path, opts.treebank, trainingFractionToKeep, false, opts.skipSection, opts.skipBilingual); List<Tree<String>> trainTrees = Corpus.binarizeAndFilterTrees( corpus.getTrainTrees(), VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, maxSentenceLength, binarization, manualAnnotation, VERBOSE); List<Tree<String>> validationTrees = Corpus.binarizeAndFilterTrees( corpus.getValidationTrees(), VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, maxSentenceLength, binarization, manualAnnotation, VERBOSE); Numberer tagNumberer = Numberer.getGlobalNumberer("tags"); // for (Tree<String> t : trainTrees){ // System.out.println(t); // } if (opts.trainOnDevSet) { System.out.println("Adding devSet to training data."); trainTrees.addAll(validationTrees); } if (opts.lowercase) { System.out.println("Lowercasing the treebank."); Corpus.lowercaseWords(trainTrees); Corpus.lowercaseWords(validationTrees); } int nTrees = trainTrees.size(); System.out.println("There are " + nTrees + " trees in the training set."); double filter = opts.filter; if (filter > 0) System.out.println( "Will remove rules with prob under " + filter + ".\nEven though only unlikely rules are pruned the training LL is not guaranteed to increase in every round anymore " + "(especially when we are close to converging)." + "\nFurthermore it increases the variance because 'good' rules can be pruned away in early stages."); short nSubstates = opts.nSubStates; short[] numSubStatesArray = initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates); if (baseline) { short one = 1; Arrays.fill(numSubStatesArray, one); System.out.println("Training just the baseline grammar (1 substate for all states)"); randomness = 0.0f; } if (VERBOSE) { for (int i = 0; i < numSubStatesArray.length; i++) { System.out.println("Tag " + (String) tagNumberer.object(i) + " " + i); } } System.out.println("There are " + numSubStatesArray.length + " observed categories."); // initialize lexicon and grammar Lexicon lexicon = null, maxLexicon = null, previousLexicon = null; Grammar grammar = null, maxGrammar = null, previousGrammar = null; double maxLikelihood = Double.NEGATIVE_INFINITY; // String smootherStr = opts.smooth; // Smoother lexiconSmoother = null; // Smoother grammarSmoother = null; // if (splitGrammarFile!=null){ // lexiconSmoother = maxLexicon.smoother; // grammarSmoother = maxGrammar.smoother; // System.out.println("Using smoother from input grammar."); // } // else if (smootherStr.equals("NoSmoothing")) // lexiconSmoother = grammarSmoother = new NoSmoothing(); // else if (smootherStr.equals("SmoothAcrossParentBits")) { // lexiconSmoother = grammarSmoother = new SmoothAcrossParentBits(grammarSmoothing, // maxGrammar.splitTrees); // } // else // throw new Error("I didn't understand the type of smoother '"+smootherStr+"'"); // System.out.println("Using smoother "+smootherStr); // EM: iterate until the validation likelihood drops for four consecutive // iterations int iter = 0; int droppingIter = 0; // If we are splitting, we load the old grammar and start off by splitting. int startSplit = 0; if (splitGrammarFile != null) { System.out.println("Loading old grammar from " + splitGrammarFile); startSplit = 1; // we've already trained the grammar ParserData pData = ParserData.Load(splitGrammarFile); maxGrammar = pData.gr; maxLexicon = pData.lex; numSubStatesArray = maxGrammar.numSubStates; previousGrammar = grammar = maxGrammar; previousLexicon = lexicon = maxLexicon; Numberer.setNumberers(pData.getNumbs()); tagNumberer = Numberer.getGlobalNumberer("tags"); System.out.println("Loading old grammar complete."); if (noSplit) { System.out.println("Will NOT split the loaded grammar."); startSplit = 0; } } double mergingPercentage = opts.mergingPercentage; boolean separateMergingThreshold = opts.separateMergingThreshold; if (mergingPercentage > 0) { System.out.println( "Will merge " + (int) (mergingPercentage * 100) + "% of the splits in each round."); System.out.println( "The threshold for merging lexical and phrasal categories will be set separately: " + separateMergingThreshold); } StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer); StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer); // deletePC); // get rid of the old trees trainTrees = null; validationTrees = null; corpus = null; System.gc(); if (opts.simpleLexicon) { System.out.println( "Replacing words which have been seen less than 5 times with their signature."); Corpus.replaceRareWords( trainStateSetTrees, new SimpleLexicon(numSubStatesArray, -1), opts.rare); } // If we're training without loading a split grammar, then we run once without splitting. if (splitGrammarFile == null) { grammar = new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter); Lexicon tmp_lexicon = (opts.simpleLexicon) ? new SimpleLexicon( numSubStatesArray, -1, smoothParams, new NoSmoothing(), filter, trainStateSetTrees) : new SophisticatedLexicon( numSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, smoothParams, new NoSmoothing(), filter); int n = 0; boolean secondHalf = false; for (Tree<StateSet> stateSetTree : trainStateSetTrees) { secondHalf = (n++ > nTrees / 2.0); tmp_lexicon.trainTree(stateSetTree, randomness, null, secondHalf, false, opts.rare); } lexicon = (opts.simpleLexicon) ? new SimpleLexicon( numSubStatesArray, -1, smoothParams, new NoSmoothing(), filter, trainStateSetTrees) : new SophisticatedLexicon( numSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, smoothParams, new NoSmoothing(), filter); for (Tree<StateSet> stateSetTree : trainStateSetTrees) { secondHalf = (n++ > nTrees / 2.0); lexicon.trainTree(stateSetTree, randomness, tmp_lexicon, secondHalf, false, opts.rare); grammar.tallyUninitializedStateSetTree(stateSetTree); } lexicon.tieRareWordStats(opts.rare); lexicon.optimize(); // SSIE ((SophisticatedLexicon) lexicon).overwriteWithMaxent(); grammar.optimize(randomness); // System.out.println(grammar); previousGrammar = maxGrammar = grammar; // needed for baseline - when there is no EM loop previousLexicon = maxLexicon = lexicon; } // the main loop: split and train the grammar for (int splitIndex = startSplit; splitIndex < numSplitTimes * 3; splitIndex++) { // now do either a merge or a split and the end a smooth // on odd iterations merge, on even iterations split String opString = ""; if (splitIndex % 3 == 2) { // (splitIndex==numSplitTimes*2){ if (opts.smooth.equals("NoSmoothing")) continue; System.out.println("Setting smoother for grammar and lexicon."); Smoother grSmoother = new SmoothAcrossParentBits(0.01, maxGrammar.splitTrees); Smoother lexSmoother = new SmoothAcrossParentBits(0.1, maxGrammar.splitTrees); // Smoother grSmoother = new SmoothAcrossParentSubstate(0.01); // Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1); maxGrammar.setSmoother(grSmoother); maxLexicon.setSmoother(lexSmoother); minIterations = maxIterations = opts.smoothMaxIterations; opString = "smoothing"; } else if (splitIndex % 3 == 0) { // the case where we split if (opts.noSplit) continue; System.out.println( "Before splitting, we have a total of " + maxGrammar.totalSubStates() + " substates."); CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer, trainStateSetTrees); int[] counts = corpusStatistics.getSymbolCounts(); maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts, 0); maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts, 0); Smoother grSmoother = new NoSmoothing(); Smoother lexSmoother = new NoSmoothing(); maxGrammar.setSmoother(grSmoother); maxLexicon.setSmoother(lexSmoother); System.out.println( "After splitting, we have a total of " + maxGrammar.totalSubStates() + " substates."); System.out.println( "Rule probabilities are NOT normalized in the split, therefore the training LL is not guaranteed to improve between iteration 0 and 1!"); opString = "splitting"; maxIterations = opts.splitMaxIterations; minIterations = opts.splitMinIterations; } else { if (mergingPercentage == 0) continue; // the case where we merge double[][] mergeWeights = GrammarMerger.computeMergeWeights(maxGrammar, maxLexicon, trainStateSetTrees); double[][][] deltas = GrammarMerger.computeDeltas(maxGrammar, maxLexicon, mergeWeights, trainStateSetTrees); boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs( deltas, separateMergingThreshold, mergingPercentage, maxGrammar); grammar = GrammarMerger.doTheMerges(maxGrammar, maxLexicon, mergeThesePairs, mergeWeights); short[] newNumSubStatesArray = grammar.numSubStates; trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, newNumSubStatesArray, false); validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, newNumSubStatesArray, false); // retrain lexicon to finish the lexicon merge (updates the unknown words model)... lexicon = (opts.simpleLexicon) ? new SimpleLexicon( newNumSubStatesArray, -1, smoothParams, maxLexicon.getSmoother(), filter, trainStateSetTrees) : new SophisticatedLexicon( newNumSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, maxLexicon.getSmoothingParams(), maxLexicon.getSmoother(), maxLexicon.getPruningThreshold()); boolean updateOnlyLexicon = true; double trainingLikelihood = GrammarTrainer.doOneEStep( grammar, maxLexicon, null, lexicon, trainStateSetTrees, updateOnlyLexicon, opts.rare); // System.out.println("The training LL is "+trainingLikelihood); lexicon .optimize(); // Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION); // // M Step GrammarMerger.printMergingStatistics(maxGrammar, grammar); opString = "merging"; maxGrammar = grammar; maxLexicon = lexicon; maxIterations = opts.mergeMaxIterations; minIterations = opts.mergeMinIterations; } // update the substate dependent objects previousGrammar = grammar = maxGrammar; previousLexicon = lexicon = maxLexicon; droppingIter = 0; numSubStatesArray = grammar.numSubStates; trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, numSubStatesArray, false); validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, numSubStatesArray, false); maxLikelihood = calculateLogLikelihood(maxGrammar, maxLexicon, validationStateSetTrees); System.out.println( "After " + opString + " in the " + (splitIndex / 3 + 1) + "th round, we get a validation likelihood of " + maxLikelihood); iter = 0; // the inner loop: train the grammar via EM until validation likelihood reliably drops do { iter += 1; System.out.println("Beginning iteration " + (iter - 1) + ":"); // 1) Compute the validation likelihood of the previous iteration System.out.print("Calculating validation likelihood..."); double validationLikelihood = calculateLogLikelihood( previousGrammar, previousLexicon, validationStateSetTrees); // The validation LL of previousGrammar/previousLexicon System.out.println("done: " + validationLikelihood); // 2) Perform the E step while computing the training likelihood of the previous iteration System.out.print("Calculating training likelihood..."); grammar = new Grammar( grammar.numSubStates, grammar.findClosedPaths, grammar.smoother, grammar, grammar.threshold); lexicon = (opts.simpleLexicon) ? new SimpleLexicon( grammar.numSubStates, -1, smoothParams, lexicon.getSmoother(), filter, trainStateSetTrees) : new SophisticatedLexicon( grammar.numSubStates, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, lexicon.getSmoothingParams(), lexicon.getSmoother(), lexicon.getPruningThreshold()); boolean updateOnlyLexicon = false; double trainingLikelihood = doOneEStep( previousGrammar, previousLexicon, grammar, lexicon, trainStateSetTrees, updateOnlyLexicon, opts.rare); // The training LL of previousGrammar/previousLexicon System.out.println("done: " + trainingLikelihood); // 3) Perform the M-Step lexicon.optimize(); // M Step grammar.optimize(0); // M Step // 4) Check whether previousGrammar/previousLexicon was in fact better than the best if (iter < minIterations || validationLikelihood >= maxLikelihood) { maxLikelihood = validationLikelihood; maxGrammar = previousGrammar; maxLexicon = previousLexicon; droppingIter = 0; } else { droppingIter++; } // 5) advance the 'pointers' previousGrammar = grammar; previousLexicon = lexicon; } while ((droppingIter < allowedDroppingIters) && (!baseline) && (iter < maxIterations)); // Dump a grammar file to disk from time to time ParserData pData = new ParserData( maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, binarization); String outTmpName = outFileName + "_" + (splitIndex / 3 + 1) + "_" + opString + ".gr"; System.out.println("Saving grammar to " + outTmpName + "."); if (pData.Save(outTmpName)) System.out.println("Saving successful."); else System.out.println("Saving failed!"); pData = null; } // The last grammar/lexicon has not yet been evaluated. Even though the validation likelihood // has been dropping in the past few iteration, there is still a chance that the last one was in // fact the best so just in case we evaluate it. System.out.print("Calculating last validation likelihood..."); double validationLikelihood = calculateLogLikelihood(grammar, lexicon, validationStateSetTrees); System.out.println( "done.\n Iteration " + iter + " (final) gives validation likelihood " + validationLikelihood); if (validationLikelihood > maxLikelihood) { maxLikelihood = validationLikelihood; maxGrammar = previousGrammar; maxLexicon = previousLexicon; } ParserData pData = new ParserData( maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, VERTICAL_MARKOVIZATION, HORIZONTAL_MARKOVIZATION, binarization); System.out.println("Saving grammar to " + outFileName + "."); System.out.println("It gives a validation data log likelihood of: " + maxLikelihood); if (pData.Save(outFileName)) System.out.println("Saving successful."); else System.out.println("Saving failed!"); System.exit(0); }
public Tree<String> getBestParse(List<String> sentence) { // TODO: implement this method int n = sentence.size(); // System.out.println("getBestParse: n=" + n); List<List<Map<Object, Double>>> scores = new ArrayList<List<Map<Object, Double>>>(n + 1); for (int i = 0; i < n + 1; i++) { List<Map<Object, Double>> row = new ArrayList<Map<Object, Double>>(n + 1); for (int j = 0; j < n + 1; j++) { row.add(new HashMap<Object, Double>()); } scores.add(row); } List<List<Map<Object, Triplet<Integer, Object, Object>>>> backs = new ArrayList<List<Map<Object, Triplet<Integer, Object, Object>>>>(n + 1); for (int i = 0; i < n + 1; i++) { List<Map<Object, Triplet<Integer, Object, Object>>> row = new ArrayList<Map<Object, Triplet<Integer, Object, Object>>>(n + 1); for (int j = 0; j < n + 1; j++) { row.add(new HashMap<Object, Triplet<Integer, Object, Object>>()); } backs.add(row); } /* System.out.println("scores=" + scores.size() + "x" + scores.get(0).size()); System.out.println("backs=" + backs.size() + "x" + backs.get(0).size()); printChart(scores, backs, "scores"); */ // First the Lexicon for (int i = 0; i < n; i++) { String word = sentence.get(i); for (String tag : lexicon.getAllTags()) { UnaryRule A = new UnaryRule(tag, word); A.setScore(Math.log(lexicon.scoreTagging(word, tag))); scores.get(i).get(i + 1).put(A, A.getScore()); backs.get(i).get(i + 1).put(A, null); } // System.out.println("Starting unaries: i=" + i + ",n=" + n ); // Handle unaries boolean added = true; while (added) { added = false; Map<Object, Double> A_scores = scores.get(i).get(i + 1); // Don't modify the dict we are iterating List<Object> A_keys = copyKeys(A_scores); // for (int j = 0; j < 5 && j < A_keys.size(); j++) { // System.out.print("," + j + "=" + A_scores.get(A_keys.get(j))); // } for (Object oB : A_keys) { UnaryRule B = (UnaryRule) oB; for (UnaryRule A : grammar.getUnaryRulesByChild(B.getParent())) { double prob = Math.log(A.getScore()) + A_scores.get(B); if (prob > -1000.0) { if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { // System.out.print(" *A=" + A + ", B=" + B); // System.out.print(", prob=" + prob); // System.out.println(", A_scores.get(A)=" + A_scores.get(A)); A_scores.put(A, prob); backs.get(i).get(i + 1).put(A, new Triplet<Integer, Object, Object>(-1, B, null)); added = true; } // System.out.println(", added=" + added); } } } // System.out.println(", A_scores=" + A_scores.size() + ", added=" + added); } } // printChart(scores, backs, "scores with Lexicon"); // Do higher layers // Naming is based on rules: A -> B,C long startTime = new Date().getTime(); for (int span = 2; span < n + 1; span++) { for (int begin = 0; begin < n + 1 - span; begin++) { int end = begin + span; Map<Object, Double> A_scores = scores.get(begin).get(end); Map<Object, Triplet<Integer, Object, Object>> A_backs = backs.get(begin).get(end); for (int split = begin + 1; split < end; split++) { Map<Object, Double> B_scores = scores.get(begin).get(split); Map<Object, Double> C_scores = scores.get(split).get(end); List<Object> B_list = new ArrayList<Object>(B_scores.keySet()); List<Object> C_list = new ArrayList<Object>(C_scores.keySet()); // This is a key optimization. !@#$ // It avoids a B_list.size() x C_list.size() search in the for (Object B : B_list) loop Map<String, List<Object>> C_map = new HashMap<String, List<Object>>(); for (Object C : C_list) { String parent = getParent(C); if (!C_map.containsKey(parent)) { C_map.put(parent, new ArrayList<Object>()); } C_map.get(parent).add(C); } for (Object B : B_list) { for (BinaryRule A : grammar.getBinaryRulesByLeftChild(getParent(B))) { if (C_map.containsKey(A.getRightChild())) { for (Object C : C_map.get(A.getRightChild())) { // We now have A which has B as left child and C as right child double prob = Math.log(A.getScore()) + B_scores.get(B) + C_scores.get(C); if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { A_scores.put(A, prob); A_backs.put(A, new Triplet<Integer, Object, Object>(split, B, C)); } } } } } } // Handle unaries: A -> B boolean added = true; while (added) { added = false; // Don't modify the dict we are iterating List<Object> A_keys = copyKeys(A_scores); for (Object oB : A_keys) { for (UnaryRule A : grammar.getUnaryRulesByChild(getParent(oB))) { double prob = Math.log(A.getScore()) + A_scores.get(oB); if (!A_scores.containsKey(A) || prob > A_scores.get(A)) { A_scores.put(A, prob); A_backs.put(A, new Triplet<Integer, Object, Object>(-1, oB, null)); added = true; } } } } } } // printChart(scores, backs, "scores with Lexicon and Grammar"); Map<Object, Double> topOfChart = scores.get(0).get(n); System.out.println("topOfChart: " + topOfChart.size()); /* for (Object o: topOfChart.keySet()) { System.out.println("o=" + o + ", score=" + topOfChart.getCount(o)); } */ // All parses have "ROOT" at top of tree Object bestKey = null; Object secondBestKey = null; double bestScore = Double.NEGATIVE_INFINITY; double secondBestScore = Double.NEGATIVE_INFINITY; for (Object key : topOfChart.keySet()) { double score = topOfChart.get(key); if (score >= secondBestScore || secondBestKey == null) { secondBestKey = key; secondBestScore = score; } if ("ROOT".equals(getParent(key)) && (score >= bestScore || bestKey == null)) { bestKey = key; bestScore = score; } } if (bestKey == null) { bestKey = secondBestKey; System.out.println("secondBestKey=" + secondBestKey); } if (bestKey == null) { for (Object key : topOfChart.keySet()) { System.out.println("val=" + topOfChart.get(key) + ", key=" + key); } } System.out.println("bestKey=" + bestKey + ", log(prob)=" + topOfChart.get(bestKey)); Tree<String> result = makeTree(backs, 0, n, bestKey); if (!"ROOT".equals(result.getLabel())) { List<Tree<String>> children = new ArrayList<Tree<String>>(); children.add(result); result = new Tree<String>("ROOT", children); // !@#$ } /* System.out.println("=================================================="); System.out.println(result); System.out.println("====================^^^^^^========================"); */ return TreeAnnotations.unAnnotateTree(result); }
@TargetApi(Build.VERSION_CODES.KITKAT) @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); setContentView(R.layout.activity_game); myDB = new DBHelper(this); Bundle extras = getIntent().getExtras(); if (savedInstanceState != null) { // Restore game state if there is any game = (Game) savedInstanceState.getSerializable("game"); } // here a game is recreated (from recent games) else if (extras.getSerializable("game") != null) { game = (Game) extras.getSerializable("game"); if (extras.containsKey("flag_EN")) { game.flagEN = extras.getBoolean("flag_EN"); } if (game != null) { // update players in this game from DB because score may have changed from other games player1 = myDB.getPlayer(game.p1.getId()); player2 = myDB.getPlayer(game.p2.getId()); game.setPlayers(player1, player2); if (game.flagEN) { // determine which language and rebuild lexicon lexicon = new Lexicon(this, ("english.txt")); lexicon.filter(game.guessedLetters); game.lexicon = lexicon; } else { alt_lexicon = new Lexicon(this, ("dutch.txt")); alt_lexicon.filter(game.guessedLetters); game.lexicon = alt_lexicon; } dictionaryEnglish = game.flagEN; } } // this is only true if a user switches locale (system) languages else if (extras.containsKey("gameID")) { dictionaryEnglish = extras.getBoolean("flag_EN"); // sync the game with game from DB game = myDB.getGame(extras.getString("gameID")); player1 = (Player) extras.getSerializable("p1"); player2 = (Player) extras.getSerializable("p2"); } // if all these conditions are false we are creating a new game from scratch else { // load up the two players from intent if (player1 == null) player1 = (Player) extras.getSerializable("p1"); if (player2 == null) player2 = (Player) extras.getSerializable("p2"); // get language from intent dictionaryEnglish = extras.getBoolean("flag_EN"); // create new lexicon only once here for selected language if (dictionaryEnglish) { lexicon = new Lexicon(this, "english.txt"); game = new Game(lexicon); } else { alt_lexicon = new Lexicon(this, "dutch.txt"); game = new Game(alt_lexicon); } game.setPlayers(player1, player2); game.flagEN = dictionaryEnglish; myDB.insertGame(game); } // set the two avatars if (player1 != null && player2 != null) { ImageView p1image = (ImageView) findViewById(R.id.p1avatar); ImageView p2image = (ImageView) findViewById(R.id.p2avatar); p1image.setImageResource(player1.avatarId); p2image.setImageResource(player2.avatarId); } createSettingsFragment(); updateView(); }
public void stocGradTrain(Parser parser, boolean testEachRound) { int numUpdates = 0; List<LexEntry> fixedEntries = new LinkedList<LexEntry>(); fixedEntries.addAll(parser.returnLex().getLexicon()); // add all sentential lexical entries. for (int l = 0; l < trainData.size(); l++) { parser.addLexEntries(trainData.getDataSet(l).makeSentEntries()); } parser.setGlobals(); DataSet data = null; // for each pass over the data for (int j = 0; j < EPOCHS; j++) { System.out.println("Training, iteration " + j); int total = 0, correct = 0, wrong = 0, looCorrect = 0, looWrong = 0; for (int l = 0; l < trainData.size(); l++) { // the variables to hold the current training example String words = null; Exp sem = null; data = trainData.getDataSet(l); if (verbose) System.out.println("---------------------"); String filename = trainData.getFilename(l); if (verbose) System.out.println("DataSet: " + filename); if (verbose) System.out.println("---------------------"); // loop through the training examples // try to create lexical entries for each training example for (int i = 0; i < data.size(); i++) { // print running stats if (verbose) { if (total != 0) { double r = (double) correct / total; double p = (double) correct / (correct + wrong); System.out.print(i + ": =========== r:" + r + " p:" + p); System.out.println(" (epoch:" + j + " file:" + l + " " + filename + ")"); } else System.out.println(i + ": ==========="); } // get the training example words = data.sent(i); sem = data.sem(i); if (verbose) { System.out.println(words); System.out.println(sem); } List<String> tokens = Parser.tokenize(words); if (tokens.size() > maxSentLen) continue; total++; String mes = null; boolean hasCorrect = false; // first, get all possible lexical entries from // a manipulation of the best parse. List<LexEntry> lex = makeLexEntriesChart(words, sem, parser); if (verbose) { System.out.println("Adding:"); for (LexEntry le : lex) { System.out.println(le + " : " + LexiconFeatSet.initialWeight(le)); } } parser.addLexEntries(lex); if (verbose) System.out.println("Lex Size: " + parser.returnLex().size()); // first parse to see if we are currently correct if (verbose) mes = "First"; parser.parseTimed(words, null, mes); Chart firstChart = parser.getChart(); Exp best = parser.bestSem(); // this just collates and outputs the training // accuracy. if (sem.equals(best)) { // System.out.println(parser.bestParses().get(0)); if (verbose) { System.out.println("CORRECT:" + best); lex = parser.getMaxLexEntriesFor(sem); System.out.println("Using:"); printLex(lex); if (lex.size() == 0) { System.out.println("ERROR: empty lex"); } } correct++; } else { if (verbose) { System.out.println("WRONG: " + best); lex = parser.getMaxLexEntriesFor(best); System.out.println("Using:"); printLex(lex); if (best != null && lex.size() == 0) { System.out.println("ERROR: empty lex"); } } wrong++; } // compute first half of parameter update: // subtract the expectation of parameters // under the distribution that is conditioned // on the sentence alone. double norm = firstChart.computeNorm(); HashVector update = new HashVector(); HashVector firstfeats = null, secondfeats = null; if (norm != 0.0) { firstfeats = firstChart.computeExpFeatVals(); firstfeats.divideBy(norm); firstfeats.dropSmallEntries(); firstfeats.addTimesInto(-1.0, update); } else continue; firstChart = null; if (verbose) mes = "Second"; parser.parseTimed(words, sem, mes); hasCorrect = parser.hasParseFor(sem); // compute second half of parameter update: // add the expectation of parameters // under the distribution that is conditioned // on the sentence and correct logical form. if (!hasCorrect) continue; Chart secondChart = parser.getChart(); double secnorm = secondChart.computeNorm(sem); if (norm != 0.0) { secondfeats = secondChart.computeExpFeatVals(sem); secondfeats.divideBy(secnorm); secondfeats.dropSmallEntries(); secondfeats.addTimesInto(1.0, update); lex = parser.getMaxLexEntriesFor(sem); data.setBestLex(i, lex); if (verbose) { System.out.println("Best LexEntries:"); printLex(lex); if (lex.size() == 0) { System.out.println("ERROR: empty lex"); } } } else continue; // now do the update double scale = alpha_0 / (1.0 + c * numUpdates); if (verbose) System.out.println("Scale: " + scale); update.multiplyBy(scale); update.dropSmallEntries(); numUpdates++; if (verbose) { System.out.println("Update:"); System.out.println(update); } if (!update.isBad()) { if (!update.valuesInRange(-100, 100)) { System.out.println("WARNING: large update"); System.out.println("first feats: " + firstfeats); System.out.println("second feats: " + secondfeats); } parser.updateParams(update); } else { System.out.println( "ERROR: Bad Update: " + update + " -- norm: " + norm + " -- feats: "); parser.getParams().printValues(update); System.out.println(); } } // end for each training example } // end for each data set double r = (double) correct / total; // we can prune the lexical items that were not used // in a max scoring parse. if (pruneLex) { Lexicon cur = new Lexicon(); cur.addLexEntries(fixedEntries); cur.addLexEntries(data.getBestLex()); parser.setLexicon(cur); } if (testEachRound) { System.out.println("Testing"); test(parser, false); } } // end epochs loop }
public Tree<String> getBestParseOld(List<String> sentence) { // TODO: This implements the CKY algorithm CounterMap<String, String> parseScores = new CounterMap<String, String>(); System.out.println(sentence.toString()); // First deal with the lexicons int index = 0; int span = 1; // All spans are 1 at the lexicon level for (String word : sentence) { for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (score >= 0.0) { // This lexicon may generate this word // We use a counter map in order to store the scores for this sentence parse. parseScores.setCount(index + " " + (index + span), tag, score); } } index = index + 1; } // handle unary rules now HashMap<String, Triplet<Integer, String, String>> backHash = new HashMap< String, Triplet<Integer, String, String>>(); // hashmap to store back propation // System.out.println("Lexicons found"); Boolean added = true; while (added) { added = false; for (index = 0; index < sentence.size(); index++) { // For each index+ span pair, get the counter. Counter<String> count = parseScores.getCounter(index + " " + (index + span)); PriorityQueue<String> countAsPQ = count.asPriorityQueue(); while (countAsPQ.hasNext()) { String entry = countAsPQ.next(); // System.out.println("I am fine here!!"); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { // These are the unary rules which might give rise to the above preterminal double prob = rule.getScore() * parseScores.getCount(index + " " + (index + span), entry); if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) { parseScores.setCount(index + " " + (index + span), rule.parent, prob); backHash.put( index + " " + (index + span) + " " + rule.parent, new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } } // System.out.println("Lexicon unaries dealt with"); // Now work with the grammar to produce higher level probabilities for (span = 2; span <= sentence.size(); span++) { for (int begin = 0; begin <= (sentence.size() - span); begin++) { int end = begin + span; for (int split = begin + 1; split <= end - 1; split++) { Counter<String> countLeft = parseScores.getCounter(begin + " " + split); Counter<String> countRight = parseScores.getCounter(split + " " + end); // List<BinaryRule> leftRules= new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>(); // List<BinaryRule> rightRules=new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>(); for (String entry : countLeft.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) { if (!leftMap.containsKey(rule.hashCode())) { leftMap.put(rule.hashCode(), rule); } } } for (String entry : countRight.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) { if (!rightMap.containsKey(rule.hashCode())) { rightMap.put(rule.hashCode(), rule); } } } // System.out.println("About to enter the rules loops"); for (Integer ruleHash : leftMap.keySet()) { if (rightMap.containsKey(ruleHash)) { BinaryRule ruleRight = rightMap.get(ruleHash); double prob = ruleRight.getScore() * parseScores.getCount(begin + " " + split, ruleRight.leftChild) * parseScores.getCount(split + " " + end, ruleRight.rightChild); // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) { // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); // System.out.println("parentrule :"+ ruleRight.getParent()); parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob); backHash.put( begin + " " + end + " " + ruleRight.parent, new Triplet<Integer, String, String>( split, ruleRight.leftChild, ruleRight.rightChild)); } } } // System.out.println("Exited rules loop"); } // System.out.println("Grammar found for " + begin + " "+ end); // Now handle unary rules added = true; while (added) { added = false; Counter<String> count = parseScores.getCounter(begin + " " + end); PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue(); while (countAsPriorityQueue.hasNext()) { String entry = countAsPriorityQueue.next(); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry); if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) { parseScores.setCount(begin + " " + (end), rule.parent, prob); backHash.put( begin + " " + (end) + " " + rule.parent, new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } // System.out.println("Unaries dealt for " + begin + " "+ end); } } // Create and return the parse tree Tree<String> parseTree = new Tree<String>("null"); // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString()); String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax(); if (parent == null) { System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString()); System.out.println("THIS IS WEIRD"); } parent = "ROOT"; parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent); // System.out.println("PARSE SCORES"); // System.out.println(parseScores.toString()); // System.out.println("BACK HASH"); // System.out.println(backHash.toString()); // parseTree = addRoot(parseTree); // System.out.println(parseTree.toString()); // return parseTree; return TreeAnnotations.unAnnotateTree(parseTree); }
public Tree<String> getBestParse(List<String> sentence) { // This implements the CKY algorithm int nEntries = sentence.size(); // hashmap to store back rules HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash = new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>(); // more efficient access with arrays, but must cast each time :( @SuppressWarnings("unchecked") Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]); for (int i = 0; i < nEntries; i++) { for (int j = 0; j < nEntries; j++) { parseScores[i][j] = new Counter<String>(); } } System.out.println(sentence.toString()); // First deal with the lexicons int index = 0; int span = 1; // All spans are 1 at the lexicon level for (String word : sentence) { for (String tag : lexicon.getAllTags()) { double score = lexicon.scoreTagging(word, tag); if (score >= 0.0) { // This lexicon may generate this word // We use a counter map in order to store the scores for this sentence parse. parseScores[index][index + span - 1].setCount(tag, score); } } index = index + 1; } // handle unary rules now // System.out.println("Lexicons found"); boolean added = true; while (added) { added = false; for (index = 0; index < sentence.size(); index++) { // For each index+ span pair, get the counter. Counter<String> count = parseScores[index][index + span - 1]; PriorityQueue<String> countAsPQ = count.asPriorityQueue(); while (countAsPQ.hasNext()) { String entry = countAsPQ.next(); // System.out.println("I am fine here!!"); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { // These are the unary rules which might give rise to the above preterminal double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry); if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) { parseScores[index][index + span - 1].setCount(rule.parent, prob); backHash.put( new Triplet<Integer, Integer, String>(index, index + span, rule.parent), new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } } // System.out.println("Lexicon unaries dealt with"); // Now work with the grammar to produce higher level probabilities for (span = 2; span <= sentence.size(); span++) { for (int begin = 0; begin <= (sentence.size() - span); begin++) { int end = begin + span; for (int split = begin + 1; split <= end - 1; split++) { Counter<String> countLeft = parseScores[begin][split - 1]; Counter<String> countRight = parseScores[split][end - 1]; // List<BinaryRule> leftRules= new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>(); // List<BinaryRule> rightRules=new ArrayList<BinaryRule>(); HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>(); for (String entry : countLeft.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) { if (!leftMap.containsKey(rule.hashCode())) { leftMap.put(rule.hashCode(), rule); } } } for (String entry : countRight.keySet()) { for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) { if (!rightMap.containsKey(rule.hashCode())) { rightMap.put(rule.hashCode(), rule); } } } // System.out.println("About to enter the rules loops"); for (Integer ruleHash : leftMap.keySet()) { if (rightMap.containsKey(ruleHash)) { BinaryRule ruleRight = rightMap.get(ruleHash); double prob = ruleRight.getScore() * parseScores[begin][split - 1].getCount(ruleRight.leftChild) * parseScores[split][end - 1].getCount(ruleRight.rightChild); // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) { // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob); // System.out.println("parentrule :"+ ruleRight.getParent()); parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob); backHash.put( new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent), new Triplet<Integer, String, String>( split, ruleRight.leftChild, ruleRight.rightChild)); } } } // System.out.println("Exited rules loop"); } // System.out.println("Grammar found for " + begin + " "+ end); // Now handle unary rules added = true; while (added) { added = false; Counter<String> count = parseScores[begin][end - 1]; PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue(); while (countAsPriorityQueue.hasNext()) { String entry = countAsPriorityQueue.next(); List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry); for (UnaryRule rule : unaryRules) { double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry); if (prob > parseScores[begin][end - 1].getCount(rule.parent)) { parseScores[begin][end - 1].setCount(rule.parent, prob); backHash.put( new Triplet<Integer, Integer, String>(begin, end, rule.parent), new Triplet<Integer, String, String>(-1, entry, null)); added = true; } } } } // System.out.println("Unaries dealt for " + begin + " "+ end); } } // Create and return the parse tree Tree<String> parseTree = new Tree<String>("null"); // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString()); // Pick the argmax String parent = parseScores[0][nEntries - 1].argMax(); // Or pick root. This second one is preferred since sentences are meant to have ROOT as their // root node. parent = "ROOT"; parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent); // System.out.println("PARSE SCORES"); // System.out.println(parseScores.toString()); // System.out.println("BACK HASH"); // System.out.println(backHash.toString()); // parseTree = addRoot(parseTree); // System.out.println(parseTree.toString()); // return parseTree; return TreeAnnotations.unAnnotateTree(parseTree); }