/** * Turns a sentence into a flat phrasal tree. The structure is S -> tag*. And then each tag goes * to a word. The tag is either found from the label or made "WD". The tag and phrasal node have a * StringLabel. * * @param s The Sentence to make the Tree from * @param lf The LabelFactory with which to create the new Tree labels * @return The one phrasal level Tree */ public static Tree toFlatTree(Sentence<?> s, LabelFactory lf) { List<Tree> daughters = new ArrayList<Tree>(s.length()); for (HasWord word : s) { Tree wordNode = new LabeledScoredTreeLeaf(lf.newLabel(word.word())); if (word instanceof TaggedWord) { TaggedWord taggedWord = (TaggedWord) word; wordNode = new LabeledScoredTreeNode( new StringLabel(taggedWord.tag()), Collections.singletonList(wordNode)); } else { wordNode = new LabeledScoredTreeNode(lf.newLabel("WD"), Collections.singletonList(wordNode)); } daughters.add(wordNode); } return new LabeledScoredTreeNode(new StringLabel("S"), daughters); }
/** * Prints out all matches of a semgrex pattern on a file of dependencies. <br> * Usage:<br> * java edu.stanford.nlp.semgraph.semgrex.SemgrexPattern [args] <br> * See the help() function for a list of possible arguments to provide. */ public static void main(String[] args) throws IOException { Map<String, Integer> flagMap = Generics.newHashMap(); flagMap.put(PATTERN, 1); flagMap.put(TREE_FILE, 1); flagMap.put(MODE, 1); flagMap.put(EXTRAS, 1); flagMap.put(CONLLU_FILE, 1); flagMap.put(OUTPUT_FORMAT_OPTION, 1); Map<String, String[]> argsMap = StringUtils.argsToMap(args, flagMap); args = argsMap.get(null); // TODO: allow patterns to be extracted from a file if (!(argsMap.containsKey(PATTERN)) || argsMap.get(PATTERN).length == 0) { help(); System.exit(2); } SemgrexPattern semgrex = SemgrexPattern.compile(argsMap.get(PATTERN)[0]); String modeString = DEFAULT_MODE; if (argsMap.containsKey(MODE) && argsMap.get(MODE).length > 0) { modeString = argsMap.get(MODE)[0].toUpperCase(); } SemanticGraphFactory.Mode mode = SemanticGraphFactory.Mode.valueOf(modeString); String outputFormatString = DEFAULT_OUTPUT_FORMAT; if (argsMap.containsKey(OUTPUT_FORMAT_OPTION) && argsMap.get(OUTPUT_FORMAT_OPTION).length > 0) { outputFormatString = argsMap.get(OUTPUT_FORMAT_OPTION)[0].toUpperCase(); } OutputFormat outputFormat = OutputFormat.valueOf(outputFormatString); boolean useExtras = true; if (argsMap.containsKey(EXTRAS) && argsMap.get(EXTRAS).length > 0) { useExtras = Boolean.valueOf(argsMap.get(EXTRAS)[0]); } List<SemanticGraph> graphs = Generics.newArrayList(); // TODO: allow other sources of graphs, such as dependency files if (argsMap.containsKey(TREE_FILE) && argsMap.get(TREE_FILE).length > 0) { for (String treeFile : argsMap.get(TREE_FILE)) { System.err.println("Loading file " + treeFile); MemoryTreebank treebank = new MemoryTreebank(new TreeNormalizer()); treebank.loadPath(treeFile); for (Tree tree : treebank) { // TODO: allow other languages... this defaults to English SemanticGraph graph = SemanticGraphFactory.makeFromTree( tree, mode, useExtras ? GrammaticalStructure.Extras.MAXIMAL : GrammaticalStructure.Extras.NONE, true); graphs.add(graph); } } } if (argsMap.containsKey(CONLLU_FILE) && argsMap.get(CONLLU_FILE).length > 0) { CoNLLUDocumentReader reader = new CoNLLUDocumentReader(); for (String conlluFile : argsMap.get(CONLLU_FILE)) { System.err.println("Loading file " + conlluFile); Iterator<SemanticGraph> it = reader.getIterator(IOUtils.readerFromString(conlluFile)); while (it.hasNext()) { SemanticGraph graph = it.next(); graphs.add(graph); } } } for (SemanticGraph graph : graphs) { SemgrexMatcher matcher = semgrex.matcher(graph); if (!(matcher.find())) { continue; } if (outputFormat == OutputFormat.LIST) { System.err.println("Matched graph:"); System.err.println(graph.toString(SemanticGraph.OutputFormat.LIST)); boolean found = true; while (found) { System.err.println( "Matches at: " + matcher.getMatch().value() + "-" + matcher.getMatch().index()); List<String> nodeNames = Generics.newArrayList(); nodeNames.addAll(matcher.getNodeNames()); Collections.sort(nodeNames); for (String name : nodeNames) { System.err.println( " " + name + ": " + matcher.getNode(name).value() + "-" + matcher.getNode(name).index()); } System.err.println(); found = matcher.find(); } } else if (outputFormat == OutputFormat.OFFSET) { if (graph.vertexListSorted().isEmpty()) { continue; } System.out.printf( "+%d %s%n", graph.vertexListSorted().get(0).get(CoreAnnotations.LineNumberAnnotation.class), argsMap.get(CONLLU_FILE)[0]); } } }
public static void main(String[] args) { Options op = new Options(new EnglishTreebankParserParams()); // op.tlpParams may be changed to something else later, so don't use it till // after options are parsed. System.out.println("Currently " + new Date()); System.out.print("Invoked with arguments:"); for (String arg : args) { System.out.print(" " + arg); } System.out.println(); String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj"; int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219; String serializeFile = null; int i = 0; while (i < args.length && args[i].startsWith("-")) { if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) { path = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) { trainLow = Integer.parseInt(args[i + 1]); trainHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) { testLow = Integer.parseInt(args[i + 1]); testHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) { serializeFile = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) { try { op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance(); } catch (ClassNotFoundException e) { System.err.println("Class not found: " + args[i + 1]); } catch (InstantiationException e) { System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString()); } catch (IllegalAccessException e) { System.err.println("illegal access" + e); } i += 2; } else if (args[i].equals("-encoding")) { // sets encoding for TreebankLangParserParams op.tlpParams.setInputEncoding(args[i + 1]); op.tlpParams.setOutputEncoding(args[i + 1]); i += 2; } else { i = op.setOptionOrWarn(args, i); } } // System.out.println(tlpParams.getClass()); TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack(); Train.sisterSplitters = new HashSet(Arrays.asList(op.tlpParams.sisterSplitters())); // BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams); PrintWriter pw = op.tlpParams.pw(); Test.display(); Train.display(); op.display(); op.tlpParams.display(); // setup tree transforms Treebank trainTreebank = op.tlpParams.memoryTreebank(); MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank(); // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank(); // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/"; // blippTreebank.loadPath(blippPath, "", true); Timing.startTime(); System.err.print("Reading trees..."); testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true)); if (Test.increasingLength) { Collections.sort(testTreebank, new TreeLengthComparator()); } trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true)); Timing.tick("done."); System.err.print("Binarizing trees..."); TreeAnnotatorAndBinarizer binarizer = null; if (!Train.leftToRight) { binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams, op.forceCNF, !Train.outsideFactor(), true); } else { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !Train.outsideFactor(), true); } CollinsPuncTransformer collinsPuncTransformer = null; if (Train.collinsPunc) { collinsPuncTransformer = new CollinsPuncTransformer(tlp); } TreeTransformer debinarizer = new Debinarizer(op.forceCNF); List<Tree> binaryTrainTrees = new ArrayList<Tree>(); if (Train.selectiveSplit) { Train.splitters = ParentAnnotationStats.getSplitCategories( trainTreebank, Train.tagSelectiveSplit, 0, Train.selectiveSplitCutOff, Train.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack()); if (Train.deleteSplitters != null) { List<String> deleted = new ArrayList<String>(); for (String del : Train.deleteSplitters) { String baseDel = tlp.basicCategory(del); boolean checkBasic = del.equals(baseDel); for (Iterator<String> it = Train.splitters.iterator(); it.hasNext(); ) { String elem = it.next(); String baseElem = tlp.basicCategory(elem); boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del); if (delStr) { it.remove(); deleted.add(elem); } } } System.err.println("Removed from vertical splitters: " + deleted); } } if (Train.selectivePostSplit) { TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams); Treebank annotatedTB = trainTreebank.transform(myTransformer); Train.postSplitters = ParentAnnotationStats.getSplitCategories( annotatedTB, true, 0, Train.selectivePostSplitCutOff, Train.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack()); } if (Train.hSelSplit) { binarizer.setDoSelectiveSplit(false); for (Tree tree : trainTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } // tree.pennPrint(tlpParams.pw()); tree = binarizer.transformTree(tree); // binaryTrainTrees.add(tree); } binarizer.setDoSelectiveSplit(true); } for (Tree tree : trainTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTrainTrees.add(tree); } if (Test.verbose) { binarizer.dumpStats(); } List<Tree> binaryTestTrees = new ArrayList<Tree>(); for (Tree tree : testTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTestTrees.add(tree); } Timing.tick("done."); // binarization BinaryGrammar bg = null; UnaryGrammar ug = null; DependencyGrammar dg = null; // DependencyGrammar dgBLIPP = null; Lexicon lex = null; // extract grammars Extractor bgExtractor = new BinaryGrammarExtractor(); // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor(); // Extractor lexExtractor = new LexiconExtractor(); // Extractor dgExtractor = new DependencyMemGrammarExtractor(); Extractor dgExtractor = new MLEDependencyGrammarExtractor(op); if (op.doPCFG) { System.err.print("Extracting PCFG..."); Pair bgug = null; if (Train.cheatPCFG) { List allTrees = new ArrayList(binaryTrainTrees); allTrees.addAll(binaryTestTrees); bgug = (Pair) bgExtractor.extract(allTrees); } else { bgug = (Pair) bgExtractor.extract(binaryTrainTrees); } bg = (BinaryGrammar) bgug.second; bg.splitRules(); ug = (UnaryGrammar) bgug.first; ug.purgeRules(); Timing.tick("done."); } System.err.print("Extracting Lexicon..."); lex = op.tlpParams.lex(op.lexOptions); lex.train(binaryTrainTrees); Timing.tick("done."); if (op.doDep) { System.err.print("Extracting Dependencies..."); binaryTrainTrees.clear(); // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams,true)); DependencyGrammar dg1 = (DependencyGrammar) dgExtractor.extract( trainTreebank.iterator(), new TransformTreeDependency(op.tlpParams, true)); // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new // TransformTreeDependency(tlpParams)); // dg = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams)); // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2); // dg = (DependencyGrammar) dgExtractor.extract(binaryTrainTrees); //uses information whether // the words are known or not, discards unknown words Timing.tick("done."); // System.out.print("Extracting Unknown Word Model..."); // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees); // Timing.tick("done."); System.out.print("Tuning Dependency Model..."); dg.tune(binaryTestTrees); // System.out.println("TUNE DEPS: "+tuneDeps); Timing.tick("done."); } BinaryGrammar boundBG = bg; UnaryGrammar boundUG = ug; GrammarProjection gp = new NullGrammarProjection(bg, ug); // serialization if (serializeFile != null) { System.err.print("Serializing parser..."); LexicalizedParser.saveParserDataToSerialized( new ParserData(lex, bg, ug, dg, Numberer.getNumberers(), op), serializeFile); Timing.tick("done."); } // test: pcfg-parse and output ExhaustivePCFGParser parser = null; if (op.doPCFG) { parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op); } ExhaustiveDependencyParser dparser = ((op.doDep && !Test.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op) : null); Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp), dparser) : null); // Scorer scorer = parser; BiLexPCFGParser bparser = null; if (op.doPCFG && op.doDep) { bparser = (Test.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp) : new BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp); } LabeledConstituentEval pcfgPE = new LabeledConstituentEval("pcfg PE", true, tlp); LabeledConstituentEval comboPE = new LabeledConstituentEval("combo PE", true, tlp); AbstractEval pcfgCB = new LabeledConstituentEval.CBEval("pcfg CB", true, tlp); AbstractEval pcfgTE = new AbstractEval.TaggingEval("pcfg TE"); AbstractEval comboTE = new AbstractEval.TaggingEval("combo TE"); AbstractEval pcfgTEnoPunct = new AbstractEval.TaggingEval("pcfg nopunct TE"); AbstractEval comboTEnoPunct = new AbstractEval.TaggingEval("combo nopunct TE"); AbstractEval depTE = new AbstractEval.TaggingEval("depnd TE"); AbstractEval depDE = new AbstractEval.DependencyEval("depnd DE", true, tlp.punctuationWordAcceptFilter()); AbstractEval comboDE = new AbstractEval.DependencyEval("combo DE", true, tlp.punctuationWordAcceptFilter()); if (Test.evalb) { EvalB.initEVALBfiles(op.tlpParams); } // int[] countByLength = new int[Test.maxLength+1]; // use a reflection ruse, so one can run this without needing the tagger // edu.stanford.nlp.process.SentenceTagger tagger = (Test.preTag ? new // edu.stanford.nlp.process.SentenceTagger("/u/nlp/data/tagger.params/wsj0-21.holder") : null); SentenceProcessor tagger = null; if (Test.preTag) { try { Class[] argsClass = new Class[] {String.class}; Object[] arguments = new Object[] {"/u/nlp/data/pos-tagger/wsj3t0-18-bidirectional/train-wsj-0-18.holder"}; tagger = (SentenceProcessor) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger") .getConstructor(argsClass) .newInstance(arguments); } catch (Exception e) { System.err.println(e); System.err.println("Warning: No pretagging of sentences will be done."); } } for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) { Tree tree = testTreebank.get(tNum); int testTreeLen = tree.yield().size(); if (testTreeLen > Test.maxLength) { continue; } Tree binaryTree = binaryTestTrees.get(tNum); // countByLength[testTreeLen]++; System.out.println("-------------------------------------"); System.out.println("Number: " + (tNum + 1)); System.out.println("Length: " + testTreeLen); // tree.pennPrint(pw); // System.out.println("XXXX The binary tree is"); // binaryTree.pennPrint(pw); // System.out.println("Here are the tags in the lexicon:"); // System.out.println(lex.showTags()); // System.out.println("Here's the tagnumberer:"); // System.out.println(Numberer.getGlobalNumberer("tags").toString()); long timeMil1 = System.currentTimeMillis(); Timing.tick("Starting parse."); if (op.doPCFG) { // System.err.println(Test.forceTags); if (Test.forceTags) { if (tagger != null) { // System.out.println("Using a tagger to set tags"); // System.out.println("Tagged sentence as: " + // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false)); parser.parse(addLast(tagger.processSentence(cutLast(wordify(binaryTree.yield()))))); } else { // System.out.println("Forcing tags to match input."); parser.parse(cleanTags(binaryTree.taggedYield(), tlp)); } } else { // System.out.println("XXXX Parsing " + binaryTree.yield()); parser.parse(binaryTree.yield()); } // Timing.tick("Done with pcfg phase."); } if (op.doDep) { dparser.parse(binaryTree.yield()); // Timing.tick("Done with dependency phase."); } boolean bothPassed = false; if (op.doPCFG && op.doDep) { bothPassed = bparser.parse(binaryTree.yield()); // Timing.tick("Done with combination phase."); } long timeMil2 = System.currentTimeMillis(); long elapsed = timeMil2 - timeMil1; System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec."); // System.out.println("PCFG Best Parse:"); Tree tree2b = null; Tree tree2 = null; // System.out.println("Got full best parse..."); if (op.doPCFG) { tree2b = parser.getBestParse(); tree2 = debinarizer.transformTree(tree2b); } // System.out.println("Debinarized parse..."); // tree2.pennPrint(); // System.out.println("DepG Best Parse:"); Tree tree3 = null; Tree tree3db = null; if (op.doDep) { tree3 = dparser.getBestParse(); // was: but wrong Tree tree3db = debinarizer.transformTree(tree2); tree3db = debinarizer.transformTree(tree3); tree3.pennPrint(pw); } // tree.pennPrint(); // ((Tree)binaryTrainTrees.get(tNum)).pennPrint(); // System.out.println("Combo Best Parse:"); Tree tree4 = null; if (op.doPCFG && op.doDep) { try { tree4 = bparser.getBestParse(); if (tree4 == null) { tree4 = tree2b; } } catch (NullPointerException e) { System.err.println("Blocked, using PCFG parse!"); tree4 = tree2b; } } if (op.doPCFG && !bothPassed) { tree4 = tree2b; } // tree4.pennPrint(); if (op.doDep) { depDE.evaluate(tree3, binaryTree, pw); depTE.evaluate(tree3db, tree, pw); } TreeTransformer tc = op.tlpParams.collinizer(); TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb(); Tree tree4b = null; if (op.doPCFG) { // System.out.println("XXXX Best PCFG was: "); // tree2.pennPrint(); // System.out.println("XXXX Transformed best PCFG is: "); // tc.transformTree(tree2).pennPrint(); // System.out.println("True Best Parse:"); // tree.pennPrint(); // tc.transformTree(tree).pennPrint(); pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw); tree4b = tree4; tree4 = debinarizer.transformTree(tree4); if (op.nodePrune) { NodePruner np = new NodePruner(parser, debinarizer); tree4 = np.prune(tree4); } // tree4.pennPrint(); comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } // pcfgTE.evaluate(tree2, tree); pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw); pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw); comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0)); // tc.transformTree(tree2).pennPrint(); tree2.pennPrint(pw); if (op.doDep) { System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0)); // tc.transformTree(tree4).pennPrint(pw); tree4.pennPrint(pw); } System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0)); /* if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) { System.out.println("SCORE INVERSION"); parser.validateBinarizedTree(binaryTree,0); } */ tree.pennPrint(pw); } // end if doPCFG if (Test.evalb) { if (op.doPCFG && op.doDep) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4)); } else if (op.doPCFG) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2)); } else if (op.doDep) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db)); } } } // end for each tree in test treebank if (Test.evalb) { EvalB.closeEVALBfiles(); } // Test.display(); if (op.doPCFG) { pcfgPE.display(false, pw); System.out.println("Grammar size: " + Numberer.getGlobalNumberer("states").total()); pcfgCB.display(false, pw); if (op.doDep) { comboPE.display(false, pw); } pcfgTE.display(false, pw); pcfgTEnoPunct.display(false, pw); if (op.doDep) { comboTE.display(false, pw); comboTEnoPunct.display(false, pw); } } if (op.doDep) { depTE.display(false, pw); depDE.display(false, pw); } if (op.doPCFG && op.doDep) { comboDE.display(false, pw); } // pcfgPE.printGoodBad(); }
/** * A search problem for finding clauses in a sentence. * * <p>For usage at test time, load a model from {@link ClauseSplitter#load(String)}, and then take * the top clauses of a given tree with {@link ClauseSplitterSearchProblem#topClauses(double)}, * yielding a list of {@link edu.stanford.nlp.naturalli.SentenceFragment}s. * * <pre>{@code * ClauseSearcher searcher = ClauseSearcher.factory("/model/path/"); * List<SentenceFragment> sentences = searcher.topClauses(threshold); * * }</pre> * * <p>For training, see {@link ClauseSplitter#train(Stream, File, File)}. * * @author Gabor Angeli */ public class ClauseSplitterSearchProblem { /** * A specification for clause splits we _always_ want to do. The format is a map from the edge * label we are splitting, to the preference for the type of split we should do. The most * preferred is at the front of the list, and then it backs off to the less and less preferred * split types. */ protected static final Map<String, List<String>> HARD_SPLITS = Collections.unmodifiableMap( new HashMap<String, List<String>>() { { put( "comp", new ArrayList<String>() { { add("simple"); } }); put( "ccomp", new ArrayList<String>() { { add("simple"); } }); put( "xcomp", new ArrayList<String>() { { add("clone_dobj"); add("clone_nsubj"); add("simple"); } }); put( "vmod", new ArrayList<String>() { { add("clone_nsubj"); add("simple"); } }); put( "csubj", new ArrayList<String>() { { add("clone_dobj"); add("simple"); } }); put( "advcl", new ArrayList<String>() { { add("clone_nsubj"); add("simple"); } }); put( "conj:*", new ArrayList<String>() { { add("clone_nsubj"); add("clone_dobj"); add("simple"); } }); put( "acl:relcl", new ArrayList<String>() { { // no doubt (-> that cats have tails <-) add("simple"); } }); } }); /** * A set of words which indicate that the complement clause is not factual, or at least not * necessarily factual. */ protected static final Set<String> INDIRECT_SPEECH_LEMMAS = Collections.unmodifiableSet( new HashSet<String>() { { add("report"); add("say"); add("told"); add("claim"); add("assert"); add("think"); add("believe"); add("suppose"); } }); /** The tree to search over. */ public final SemanticGraph tree; /** The assumed truth of the original clause. */ public final boolean assumedTruth; /** The length of the sentence, as determined from the tree. */ public final int sentenceLength; /** A mapping from a word to the extra edges that come out of it. */ private final Map<IndexedWord, Collection<SemanticGraphEdge>> extraEdgesByGovernor = new HashMap<>(); /** A mapping from a word to the extra edges that to into it. */ private final Map<IndexedWord, Collection<SemanticGraphEdge>> extraEdgesByDependent = new HashMap<>(); /** The classifier for whether a particular dependency edge defines a clause boundary. */ private final Optional<Classifier<ClauseSplitter.ClauseClassifierLabel, String>> isClauseClassifier; /** * An optional featurizer to use with the clause classifier ({@link * ClauseSplitterSearchProblem#isClauseClassifier}). If that classifier is defined, this should be * as well. */ private final Optional< Function< Triple< ClauseSplitterSearchProblem.State, ClauseSplitterSearchProblem.Action, ClauseSplitterSearchProblem.State>, Counter<String>>> featurizer; /** A mapping from edges in the tree, to an index. */ @SuppressWarnings("Convert2Diamond") // It's lying -- type inference times out with a diamond private final Index<SemanticGraphEdge> edgeToIndex = new HashIndex<SemanticGraphEdge>(ArrayList::new, IdentityHashMap::new); /** A search state. */ public class State { public final SemanticGraphEdge edge; public final int edgeIndex; public final SemanticGraphEdge subjectOrNull; public final int distanceFromSubj; public final SemanticGraphEdge objectOrNull; public final Consumer<SemanticGraph> thunk; public boolean isDone; public State( SemanticGraphEdge edge, SemanticGraphEdge subjectOrNull, int distanceFromSubj, SemanticGraphEdge objectOrNull, Consumer<SemanticGraph> thunk, boolean isDone) { this.edge = edge; this.edgeIndex = edgeToIndex.indexOf(edge); this.subjectOrNull = subjectOrNull; this.distanceFromSubj = distanceFromSubj; this.objectOrNull = objectOrNull; this.thunk = thunk; this.isDone = isDone; } public State(State source, boolean isDone) { this.edge = source.edge; this.edgeIndex = edgeToIndex.indexOf(edge); this.subjectOrNull = source.subjectOrNull; this.distanceFromSubj = source.distanceFromSubj; this.objectOrNull = source.objectOrNull; this.thunk = source.thunk; this.isDone = isDone; } public SemanticGraph originalTree() { return ClauseSplitterSearchProblem.this.tree; } public State withIsDone(ClauseClassifierLabel argmax) { if (argmax == ClauseClassifierLabel.CLAUSE_SPLIT) { isDone = true; } else if (argmax == ClauseClassifierLabel.CLAUSE_INTERM) { isDone = false; } else { throw new IllegalStateException("Invalid classifier label for isDone: " + argmax); } return this; } } /** An action being taken; that is, the type of clause splitting going on. */ public interface Action { /** The name of this action. */ String signature(); /** * A check to make sure this is actually a valid action to take, in the context of the given * tree. * * @param originalTree The _original_ tree we are searching over. This is before any clauses are * split off. * @param edge The edge that we are traversing with this clause. * @return True if this is a valid action. */ @SuppressWarnings("UnusedParameters") default boolean prerequisitesMet(SemanticGraph originalTree, SemanticGraphEdge edge) { return true; } /** * Apply this action to the given state. * * @param tree The original tree we are applying the action to. * @param source The source state we are mutating from. * @param outgoingEdge The edge we are splitting off as a clause. * @param subjectOrNull The subject of the parent tree, if there is one. * @param ppOrNull The preposition attachment of the parent tree, if there is one. * @return A new state, or {@link Optional#empty()} if this action was not successful. */ Optional<State> applyTo( SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge ppOrNull); } /** The options used for training the clause searcher. */ public static class TrainingOptions { @ArgumentParser.Option( name = "negativeSubsampleRatio", gloss = "The percent of negative datums to take") public double negativeSubsampleRatio = 1.00; @ArgumentParser.Option( name = "positiveDatumWeight", gloss = "The weight to assign every positive datum.") public float positiveDatumWeight = 100.0f; @ArgumentParser.Option( name = "unknownDatumWeight", gloss = "The weight to assign every unknown datum (everything extracted with an unconfirmed relation).") public float unknownDatumWeight = 1.0f; @ArgumentParser.Option( name = "clauseSplitWeight", gloss = "The weight to assign for clause splitting datums. Higher values push towards higher recall.") public float clauseSplitWeight = 1.0f; @ArgumentParser.Option( name = "clauseIntermWeight", gloss = "The weight to assign for intermediate splits. Higher values push towards higher recall.") public float clauseIntermWeight = 2.0f; @ArgumentParser.Option(name = "seed", gloss = "The random seed to use") public int seed = 42; @SuppressWarnings("unchecked") @ArgumentParser.Option( name = "classifierFactory", gloss = "The class of the classifier factory to use for training the various classifiers") public Class< ? extends ClassifierFactory< ClauseSplitter.ClauseClassifierLabel, String, Classifier<ClauseSplitter.ClauseClassifierLabel, String>>> classifierFactory = (Class< ? extends ClassifierFactory< ClauseSplitter.ClauseClassifierLabel, String, Classifier<ClauseSplitter.ClauseClassifierLabel, String>>>) ((Object) LinearClassifierFactory.class); } /** Mostly just an alias, but make sure our featurizer is serializable! */ public interface Featurizer extends Function< Triple< ClauseSplitterSearchProblem.State, ClauseSplitterSearchProblem.Action, ClauseSplitterSearchProblem.State>, Counter<String>>, Serializable { boolean isSimpleSplit(Counter<String> feats); } /** * Create a searcher manually, suppling a dependency tree, an optional classifier for when to * split clauses, and a featurizer for that classifier. You almost certainly want to use {@link * ClauseSplitter#load(String)} instead of this constructor. * * @param tree The dependency tree to search over. * @param assumedTruth The assumed truth of the tree (relevant for natural logic inference). If in * doubt, pass in true. * @param isClauseClassifier The classifier for whether a given dependency arc should be a new * clause. If this is not given, all arcs are treated as clause separators. * @param featurizer The featurizer for the classifier. If no featurizer is given, one should be * given in {@link ClauseSplitterSearchProblem#search(java.util.function.Predicate, * Classifier, Map, java.util.function.Function, int)}, or else the classifier will be * useless. * @see ClauseSplitter#load(String) */ protected ClauseSplitterSearchProblem( SemanticGraph tree, boolean assumedTruth, Optional<Classifier<ClauseSplitter.ClauseClassifierLabel, String>> isClauseClassifier, Optional< Function< Triple< ClauseSplitterSearchProblem.State, ClauseSplitterSearchProblem.Action, ClauseSplitterSearchProblem.State>, Counter<String>>> featurizer) { this.tree = new SemanticGraph(tree); this.assumedTruth = assumedTruth; this.isClauseClassifier = isClauseClassifier; this.featurizer = featurizer; // Index edges this.tree.edgeIterable().forEach(edgeToIndex::addToIndex); // Get length List<IndexedWord> sortedVertices = tree.vertexListSorted(); sentenceLength = sortedVertices.get(sortedVertices.size() - 1).index(); // Register extra edges for (IndexedWord vertex : sortedVertices) { extraEdgesByGovernor.put(vertex, new ArrayList<>()); extraEdgesByDependent.put(vertex, new ArrayList<>()); } List<SemanticGraphEdge> extraEdges = Util.cleanTree(this.tree); assert Util.isTree(this.tree); for (SemanticGraphEdge edge : extraEdges) { extraEdgesByGovernor.get(edge.getGovernor()).add(edge); extraEdgesByDependent.get(edge.getDependent()).add(edge); } } /** * Create a clause searcher which searches naively through every possible subtree as a clause. For * an end-user, this is almost certainly not what you want. However, it is very useful for * training time. * * @param tree The dependency tree to search over. * @param assumedTruth The truth of the premise. Almost always True. */ public ClauseSplitterSearchProblem(SemanticGraph tree, boolean assumedTruth) { this(tree, assumedTruth, Optional.empty(), Optional.empty()); } /** * The basic method for splitting off a clause of a tree. This modifies the tree in place. * * @param tree The tree to split a clause from. * @param toKeep The edge representing the clause to keep. */ static void splitToChildOfEdge(SemanticGraph tree, SemanticGraphEdge toKeep) { Queue<IndexedWord> fringe = new LinkedList<>(); List<IndexedWord> nodesToRemove = new ArrayList<>(); // Find nodes to remove // (from the root) for (IndexedWord root : tree.getRoots()) { nodesToRemove.add(root); for (SemanticGraphEdge out : tree.outgoingEdgeIterable(root)) { if (!out.equals(toKeep)) { fringe.add(out.getDependent()); } } } // (recursively) while (!fringe.isEmpty()) { IndexedWord node = fringe.poll(); nodesToRemove.add(node); for (SemanticGraphEdge out : tree.outgoingEdgeIterable(node)) { if (!out.equals(toKeep)) { fringe.add(out.getDependent()); } } } // Remove nodes nodesToRemove.forEach(tree::removeVertex); // Set new root tree.setRoot(toKeep.getDependent()); } /** * The basic method for splitting off a clause of a tree. This modifies the tree in place. This * method addtionally follows ref edges. * * @param tree The tree to split a clause from. * @param toKeep The edge representing the clause to keep. */ @SuppressWarnings("unchecked") private void simpleClause(SemanticGraph tree, SemanticGraphEdge toKeep) { splitToChildOfEdge(tree, toKeep); // Follow 'ref' edges Map<IndexedWord, IndexedWord> refReplaceMap = new HashMap<>(); // (find replacements) for (IndexedWord vertex : tree.vertexSet()) { for (SemanticGraphEdge edge : extraEdgesByDependent.get(vertex)) { if ("ref".equals(edge.getRelation().toString()) && // it's a ref edge... !tree.containsVertex( edge.getGovernor())) { // ...that doesn't already exist in the tree. refReplaceMap.put(vertex, edge.getGovernor()); } } } // (do replacements) for (Map.Entry<IndexedWord, IndexedWord> entry : refReplaceMap.entrySet()) { Iterator<SemanticGraphEdge> iter = tree.incomingEdgeIterator(entry.getKey()); if (!iter.hasNext()) { continue; } SemanticGraphEdge incomingEdge = iter.next(); IndexedWord governor = incomingEdge.getGovernor(); tree.removeVertex(entry.getKey()); addSubtree( tree, governor, incomingEdge.getRelation().toString(), this.tree, entry.getValue(), this.tree.incomingEdgeList(tree.getFirstRoot())); } } /** * A helper to add a single word to a given dependency tree * * @param toModify The tree to add the word to. * @param root The root of the tree where we should be adding the word. * @param rel The relation to add the word with. * @param coreLabel The word to add. */ @SuppressWarnings("UnusedDeclaration") private static void addWord( SemanticGraph toModify, IndexedWord root, String rel, CoreLabel coreLabel) { IndexedWord dependent = new IndexedWord(coreLabel); toModify.addVertex(dependent); toModify.addEdge( root, dependent, GrammaticalRelation.valueOf(Language.English, rel), Double.NEGATIVE_INFINITY, false); } /** * A helper to add an entire subtree to a given dependency tree. * * @param toModify The tree to add the subtree to. * @param root The root of the tree where we should be adding the subtree. * @param rel The relation to add the subtree with. * @param originalTree The orignal tree (i.e., {@link ClauseSplitterSearchProblem#tree}). * @param subject The root of the clause to add. * @param ignoredEdges The edges to ignore adding when adding this subtree. */ private static void addSubtree( SemanticGraph toModify, IndexedWord root, String rel, SemanticGraph originalTree, IndexedWord subject, Collection<SemanticGraphEdge> ignoredEdges) { if (toModify.containsVertex(subject)) { return; // This subtree already exists. } Queue<IndexedWord> fringe = new LinkedList<>(); Collection<IndexedWord> wordsToAdd = new ArrayList<>(); Collection<SemanticGraphEdge> edgesToAdd = new ArrayList<>(); // Search for subtree to add for (SemanticGraphEdge edge : originalTree.outgoingEdgeIterable(subject)) { if (!ignoredEdges.contains(edge)) { if (toModify.containsVertex(edge.getDependent())) { // Case: we're adding a subtree that's not disjoint from toModify. This is bad news. return; } edgesToAdd.add(edge); fringe.add(edge.getDependent()); } } while (!fringe.isEmpty()) { IndexedWord node = fringe.poll(); wordsToAdd.add(node); for (SemanticGraphEdge edge : originalTree.outgoingEdgeIterable(node)) { if (!ignoredEdges.contains(edge)) { if (toModify.containsVertex(edge.getDependent())) { // Case: we're adding a subtree that's not disjoint from toModify. This is bad news. return; } edgesToAdd.add(edge); fringe.add(edge.getDependent()); } } } // Add subtree // (add subject) toModify.addVertex(subject); toModify.addEdge( root, subject, GrammaticalRelation.valueOf(Language.English, rel), Double.NEGATIVE_INFINITY, false); // (add nodes) wordsToAdd.forEach(toModify::addVertex); // (add edges) for (SemanticGraphEdge edge : edgesToAdd) { assert !toModify.incomingEdgeIterator(edge.getDependent()).hasNext(); toModify.addEdge( edge.getGovernor(), edge.getDependent(), edge.getRelation(), edge.getWeight(), edge.isExtra()); } } /** * Stips aux and mark edges when we are splitting into a clause. * * @param toModify The tree we are stripping the edges from. */ private void stripAuxMark(SemanticGraph toModify) { List<SemanticGraphEdge> toClean = new ArrayList<>(); for (SemanticGraphEdge edge : toModify.outgoingEdgeIterable(toModify.getFirstRoot())) { String rel = edge.getRelation().toString(); if (("aux".equals(rel) || "mark".equals(rel)) && !toModify.outgoingEdgeIterator(edge.getDependent()).hasNext()) { toClean.add(edge); } } for (SemanticGraphEdge edge : toClean) { toModify.removeEdge(edge); toModify.removeVertex(edge.getDependent()); } } /** * Create a mock node, to be added to the dependency tree but which is not part of the original * sentence. * * @param toCopy The CoreLabel to copy from initially. * @param word The new word to add. * @param POS The new part of speech to add. * @return A CoreLabel copying most fields from toCopy, but with a new word and POS tag (as well * as a new index). */ @SuppressWarnings("UnusedDeclaration") private CoreLabel mockNode(CoreLabel toCopy, String word, String POS) { CoreLabel mock = new CoreLabel(toCopy); mock.setWord(word); mock.setLemma(word); mock.setValue(word); mock.setNER("O"); mock.setTag(POS); mock.setIndex(sentenceLength + 5); return mock; } /** * Get the top few clauses from this searcher, cutting off at the given minimum probability. * * @param thresholdProbability The threshold under which to stop returning clauses. This should be * between 0 and 1. * @return The resulting {@link edu.stanford.nlp.naturalli.SentenceFragment} objects, representing * the top clauses of the sentence. */ public List<SentenceFragment> topClauses(double thresholdProbability) { List<SentenceFragment> results = new ArrayList<>(); search( triple -> { assert triple.first <= 0.0; double prob = Math.exp(triple.first); assert prob <= 1.0; assert prob >= 0.0; assert !Double.isNaN(prob); if (prob >= thresholdProbability) { SentenceFragment fragment = triple.third.get(); fragment.score = prob; results.add(fragment); return true; } else { return false; } }); return results; } /** * Search, using the default weights / featurizer. This is the most common entry method for the * raw search, though {@link ClauseSplitterSearchProblem#topClauses(double)} may be a more * convenient method for an end user. * * @param candidateFragments The callback function for results. The return value defines whether * to continue searching. */ public void search( final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments) { if (!isClauseClassifier.isPresent()) { search( candidateFragments, new LinearClassifier<>(new ClassicCounter<>()), HARD_SPLITS, this.featurizer.isPresent() ? this.featurizer.get() : DEFAULT_FEATURIZER, 1000); } else { if (!(isClauseClassifier.get() instanceof LinearClassifier)) { throw new IllegalArgumentException("For now, only linear classifiers are supported"); } search( candidateFragments, isClauseClassifier.get(), HARD_SPLITS, this.featurizer.get(), 1000); } } /** * Search from the root of the tree. This function also defines the default action space to use * during search. This is NOT recommended to be used at test time. * * @see edu.stanford.nlp.naturalli.ClauseSplitterSearchProblem#search(Predicate) * @param candidateFragments The callback function. * @param classifier The classifier for whether an arc should be on the path to a clause split, a * clause split itself, or neither. * @param featurizer The featurizer to use during search, to be dot producted with the weights. */ public void search( // The output specs final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, final Map<String, List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final int maxTicks) { Collection<Action> actionSpace = new ArrayList<>(); // SIMPLE SPLIT actionSpace.add( new Action() { @Override public String signature() { return "simple"; } @Override public boolean prerequisitesMet(SemanticGraph originalTree, SemanticGraphEdge edge) { char tag = edge.getDependent().tag().charAt(0); return !(tag != 'V' && tag != 'N' && tag != 'J' && tag != 'P' && tag != 'D'); } @Override public Optional<State> applyTo( SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge objectOrNull) { return Optional.of( new State( outgoingEdge, subjectOrNull == null ? source.subjectOrNull : subjectOrNull, subjectOrNull == null ? (source.distanceFromSubj + 1) : 0, objectOrNull == null ? source.objectOrNull : objectOrNull, source.thunk.andThen( toModify -> { assert Util.isTree(toModify); simpleClause(toModify, outgoingEdge); if (outgoingEdge.getRelation().toString().endsWith("comp")) { stripAuxMark(toModify); } assert Util.isTree(toModify); }), false)); } }); // CLONE ROOT actionSpace.add( new Action() { @Override public String signature() { return "clone_root_as_nsubjpass"; } @Override public boolean prerequisitesMet(SemanticGraph originalTree, SemanticGraphEdge edge) { // Only valid if there's a single nontrivial outgoing edge from a node. Otherwise it's a // whole can of worms. Iterator<SemanticGraphEdge> iter = originalTree.outgoingEdgeIterable(edge.getGovernor()).iterator(); if (!iter.hasNext()) { return false; // what? } boolean nontrivialEdge = false; while (iter.hasNext()) { SemanticGraphEdge outEdge = iter.next(); switch (outEdge.getRelation().toString()) { case "nn": case "amod": break; default: if (nontrivialEdge) { return false; } nontrivialEdge = true; } } return true; } @Override public Optional<State> applyTo( SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge objectOrNull) { return Optional.of( new State( outgoingEdge, subjectOrNull == null ? source.subjectOrNull : subjectOrNull, subjectOrNull == null ? (source.distanceFromSubj + 1) : 0, objectOrNull == null ? source.objectOrNull : objectOrNull, source.thunk.andThen( toModify -> { assert Util.isTree(toModify); simpleClause(toModify, outgoingEdge); addSubtree( toModify, outgoingEdge.getDependent(), "nsubjpass", tree, outgoingEdge.getGovernor(), Collections.singleton(outgoingEdge)); // addWord(toModify, outgoingEdge.getDependent(), "auxpass", // mockNode(outgoingEdge.getDependent().backingLabel(), "is", "VBZ")); assert Util.isTree(toModify); }), true)); } }); // COPY SUBJECT actionSpace.add( new Action() { @Override public String signature() { return "clone_nsubj"; } @Override public boolean prerequisitesMet(SemanticGraph originalTree, SemanticGraphEdge edge) { // Don't split into anything but verbs or nouns char tag = edge.getDependent().tag().charAt(0); if (tag != 'V' && tag != 'N') { return false; } for (SemanticGraphEdge grandchild : originalTree.outgoingEdgeIterable(edge.getDependent())) { if (grandchild.getRelation().toString().contains("subj")) { return false; } } return true; } @Override public Optional<State> applyTo( SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge objectOrNull) { if (subjectOrNull != null && !outgoingEdge.equals(subjectOrNull)) { return Optional.of( new State( outgoingEdge, subjectOrNull, 0, objectOrNull == null ? source.objectOrNull : objectOrNull, source.thunk.andThen( toModify -> { assert Util.isTree(toModify); simpleClause(toModify, outgoingEdge); addSubtree( toModify, outgoingEdge.getDependent(), "nsubj", tree, subjectOrNull.getDependent(), Collections.singleton(outgoingEdge)); assert Util.isTree(toModify); stripAuxMark(toModify); assert Util.isTree(toModify); }), false)); } else { return Optional.empty(); } } }); // COPY OBJECT actionSpace.add( new Action() { @Override public String signature() { return "clone_dobj"; } @Override public boolean prerequisitesMet(SemanticGraph originalTree, SemanticGraphEdge edge) { // Don't split into anything but verbs or nouns char tag = edge.getDependent().tag().charAt(0); if (tag != 'V' && tag != 'N') { return false; } for (SemanticGraphEdge grandchild : originalTree.outgoingEdgeIterable(edge.getDependent())) { if (grandchild.getRelation().toString().contains("subj")) { return false; } } return true; } @Override public Optional<State> applyTo( SemanticGraph tree, State source, SemanticGraphEdge outgoingEdge, SemanticGraphEdge subjectOrNull, SemanticGraphEdge objectOrNull) { if (objectOrNull != null && !outgoingEdge.equals(objectOrNull)) { return Optional.of( new State( outgoingEdge, subjectOrNull == null ? source.subjectOrNull : subjectOrNull, subjectOrNull == null ? (source.distanceFromSubj + 1) : 0, objectOrNull, source.thunk.andThen( toModify -> { assert Util.isTree(toModify); // Split the clause simpleClause(toModify, outgoingEdge); // Attach the new subject addSubtree( toModify, outgoingEdge.getDependent(), "nsubj", tree, objectOrNull.getDependent(), Collections.singleton(outgoingEdge)); // Strip bits we don't want assert Util.isTree(toModify); stripAuxMark(toModify); assert Util.isTree(toModify); }), false)); } else { return Optional.empty(); } } }); for (IndexedWord root : tree.getRoots()) { search( root, candidateFragments, classifier, hardCodedSplits, featurizer, actionSpace, maxTicks); } } /** Re-order the action space based on the specified order of names. */ private Collection<Action> orderActions(Collection<Action> actionSpace, List<String> order) { List<Action> tmp = new ArrayList<>(actionSpace); List<Action> out = new ArrayList<>(); for (String key : order) { Iterator<Action> iter = tmp.iterator(); while (iter.hasNext()) { Action a = iter.next(); if (a.signature().equals(key)) { out.add(a); iter.remove(); } } } out.addAll(tmp); return out; } /** * The core implementation of the search. * * @param root The root word to search from. Traditionally, this is the root of the sentence. * @param candidateFragments The callback for the resulting sentence fragments. This is a * predicate of a triple of values. The return value of the predicate determines whether we * should continue searching. The triple is a triple of * <ol> * <li>The log probability of the sentence fragment, according to the featurizer and the * weights * <li>The features along the path to this fragment. The last element of this is the * features from the most recent step. * <li>The sentence fragment. Because it is relatively expensive to compute the resulting * tree, this is returned as a lazy {@link Supplier}. * </ol> * * @param classifier The classifier for whether an arc should be on the path to a clause split, a * clause split itself, or neither. * @param featurizer The featurizer to use. Make sure this matches the weights! * @param actionSpace The action space we are allowed to take. Each action defines a means of * splitting a clause on a dependency boundary. */ protected void search( // The root to search from IndexedWord root, // The output specs final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, Map<String, ? extends List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final Collection<Action> actionSpace, final int maxTicks) { // (the fringe) PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>(); // (avoid duplicate work) Set<IndexedWord> seenWords = new HashSet<>(); State firstState = new State(null, null, -9000, null, x -> {}, true); // First state is implicitly "done" fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0); int ticks = 0; while (!fringe.isEmpty()) { if (++ticks > maxTicks) { // System.err.println("WARNING! Timed out on search with " + ticks + " ticks"); return; } // Useful variables double logProbSoFar = fringe.getPriority(); assert logProbSoFar <= 0.0; Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst(); State lastState = lastStatePair.first; List<Counter<String>> featuresSoFar = lastStatePair.second; IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent(); // Register thunk if (lastState.isDone) { if (!candidateFragments.test( Triple.makeTriple( logProbSoFar, featuresSoFar, () -> { SemanticGraph copy = new SemanticGraph(tree); lastState .thunk .andThen( x -> { // Add the extra edges back in, if they don't break the tree-ness of the // extraction for (IndexedWord newTreeRoot : x.getRoots()) { if (newTreeRoot != null) { // what a strange thing to have happen... for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) { assert Util.isTree(x); //noinspection unchecked addSubtree( x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot)); assert Util.isTree(x); } } } }) .accept(copy); return new SentenceFragment(copy, assumedTruth, false); }))) { break; } } // Find relevant auxilliary terms SemanticGraphEdge subjOrNull = null; SemanticGraphEdge objOrNull = null; for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) { String relString = auxEdge.getRelation().toString(); if (relString.contains("obj")) { objOrNull = auxEdge; } else if (relString.contains("subj")) { subjOrNull = auxEdge; } } // Iterate over children // For each outgoing edge... for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) { // Prohibit indirect speech verbs from splitting off clauses // (e.g., 'said', 'think') // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp if (outgoingEdge.getRelation().toString().equals("ccomp") && ((outgoingEdge.getGovernor().lemma() != null && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma())) || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) { continue; } // Get some variables String outgoingEdgeRelation = outgoingEdge.getRelation().toString(); List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation); if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) { forcedArcOrder = hardCodedSplits.get( outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*"); } boolean doneForcedArc = false; // For each action... for (Action action : (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) { // Check the prerequisite if (!action.prerequisitesMet(tree, outgoingEdge)) { continue; } if (forcedArcOrder != null && doneForcedArc) { break; } // 1. Compute the child state Optional<State> candidate = action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull); if (candidate.isPresent()) { double logProbability; ClauseClassifierLabel bestLabel; Counter<String> features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get())); if (forcedArcOrder != null && !doneForcedArc) { logProbability = 0.0; bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT; doneForcedArc = true; } else if (features.containsKey("__undocumented_junit_no_classifier")) { logProbability = Double.NEGATIVE_INFINITY; bestLabel = ClauseClassifierLabel.CLAUSE_INTERM; } else { Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features)); if (scores.size() > 0) { Counters.logNormalizeInPlace(scores); } String rel = outgoingEdge.getRelation().toString(); if ("nsubj".equals(rel) || "dobj".equals(rel)) { scores.remove( ClauseClassifierLabel.NOT_A_CLAUSE); // Always at least yield on nsubj and dobj } logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY); bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT); } if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) { Pair<State, List<Counter<String>>> childState = Pair.makePair( candidate.get().withIsDone(bestLabel), new ArrayList<Counter<String>>(featuresSoFar) { { add(features); } }); // 2. Register the child state if (!seenWords.contains(childState.first.edge.getDependent())) { // System.err.println(" pushing " + action.signature() + " with " + // argmax.first.edge); fringe.add(childState, logProbability); } } } } } seenWords.add(rootWord); } // System.err.println("Search finished in " + ticks + " ticks and " + classifierEvals + " // classifier evaluations."); } /** The default featurizer to use during training. */ public static final Featurizer DEFAULT_FEATURIZER = new Featurizer() { private static final long serialVersionUID = 4145523451314579506l; @Override public boolean isSimpleSplit(Counter<String> feats) { for (String key : feats.keySet()) { if (key.startsWith("simple&")) { return true; } } return false; } @Override public Counter<String> apply(Triple<State, Action, State> triple) { // Variables State from = triple.first; Action action = triple.second; State to = triple.third; String signature = action.signature(); String edgeRelTaken = to.edge == null ? "root" : to.edge.getRelation().toString(); String edgeRelShort = to.edge == null ? "root" : to.edge.getRelation().getShortName(); if (edgeRelShort.contains("_")) { edgeRelShort = edgeRelShort.substring(0, edgeRelShort.indexOf("_")); } // -- Featurize -- // Variables to aggregate boolean parentHasSubj = false; boolean parentHasObj = false; boolean childHasSubj = false; boolean childHasObj = false; Counter<String> feats = new ClassicCounter<>(); // 1. edge taken feats.incrementCount(signature + "&edge:" + edgeRelTaken); feats.incrementCount(signature + "&edge_type:" + edgeRelShort); // 2. last edge taken if (from.edge == null) { assert to.edge == null || to.originalTree().getRoots().contains(to.edge.getGovernor()); feats.incrementCount(signature + "&at_root"); feats.incrementCount( signature + "&at_root&root_pos:" + to.originalTree().getFirstRoot().tag()); } else { feats.incrementCount(signature + "¬_root"); String lastRelShort = from.edge.getRelation().getShortName(); if (lastRelShort.contains("_")) { lastRelShort = lastRelShort.substring(0, lastRelShort.indexOf("_")); } feats.incrementCount(signature + "&last_edge:" + lastRelShort); } if (to.edge != null) { // 3. other edges at parent for (SemanticGraphEdge parentNeighbor : from.originalTree().outgoingEdgeIterable(to.edge.getGovernor())) { if (parentNeighbor != to.edge) { String parentNeighborRel = parentNeighbor.getRelation().toString(); if (parentNeighborRel.contains("subj")) { parentHasSubj = true; } if (parentNeighborRel.contains("obj")) { parentHasObj = true; } // (add feature) feats.incrementCount(signature + "&parent_neighbor:" + parentNeighborRel); feats.incrementCount( signature + "&edge_type:" + edgeRelShort + "&parent_neighbor:" + parentNeighborRel); } } // 4. Other edges at child int childNeighborCount = 0; for (SemanticGraphEdge childNeighbor : from.originalTree().outgoingEdgeIterable(to.edge.getDependent())) { String childNeighborRel = childNeighbor.getRelation().toString(); if (childNeighborRel.contains("subj")) { childHasSubj = true; } if (childNeighborRel.contains("obj")) { childHasObj = true; } childNeighborCount += 1; // (add feature) feats.incrementCount(signature + "&child_neighbor:" + childNeighborRel); feats.incrementCount( signature + "&edge_type:" + edgeRelShort + "&child_neighbor:" + childNeighborRel); } // 4.1 Number of other edges at child feats.incrementCount( signature + "&child_neighbor_count:" + (childNeighborCount < 3 ? childNeighborCount : ">2")); feats.incrementCount( signature + "&edge_type:" + edgeRelShort + "&child_neighbor_count:" + (childNeighborCount < 3 ? childNeighborCount : ">2")); // 5. Subject/Object stats feats.incrementCount(signature + "&parent_neighbor_subj:" + parentHasSubj); feats.incrementCount(signature + "&parent_neighbor_obj:" + parentHasObj); feats.incrementCount(signature + "&child_neighbor_subj:" + childHasSubj); feats.incrementCount(signature + "&child_neighbor_obj:" + childHasObj); // 6. POS tag info feats.incrementCount(signature + "&parent_pos:" + to.edge.getGovernor().tag()); feats.incrementCount(signature + "&child_pos:" + to.edge.getDependent().tag()); feats.incrementCount( signature + "&pos_signature:" + to.edge.getGovernor().tag() + "_" + to.edge.getDependent().tag()); feats.incrementCount( signature + "&edge_type:" + edgeRelShort + "&pos_signature:" + to.edge.getGovernor().tag() + "_" + to.edge.getDependent().tag()); } return feats; } }; }