/** * Create a searcher manually, suppling a dependency tree, an optional classifier for when to * split clauses, and a featurizer for that classifier. You almost certainly want to use {@link * ClauseSplitter#load(String)} instead of this constructor. * * @param tree The dependency tree to search over. * @param assumedTruth The assumed truth of the tree (relevant for natural logic inference). If in * doubt, pass in true. * @param isClauseClassifier The classifier for whether a given dependency arc should be a new * clause. If this is not given, all arcs are treated as clause separators. * @param featurizer The featurizer for the classifier. If no featurizer is given, one should be * given in {@link ClauseSplitterSearchProblem#search(java.util.function.Predicate, * Classifier, Map, java.util.function.Function, int)}, or else the classifier will be * useless. * @see ClauseSplitter#load(String) */ protected ClauseSplitterSearchProblem( SemanticGraph tree, boolean assumedTruth, Optional<Classifier<ClauseSplitter.ClauseClassifierLabel, String>> isClauseClassifier, Optional< Function< Triple< ClauseSplitterSearchProblem.State, ClauseSplitterSearchProblem.Action, ClauseSplitterSearchProblem.State>, Counter<String>>> featurizer) { this.tree = new SemanticGraph(tree); this.assumedTruth = assumedTruth; this.isClauseClassifier = isClauseClassifier; this.featurizer = featurizer; // Index edges this.tree.edgeIterable().forEach(edgeToIndex::addToIndex); // Get length List<IndexedWord> sortedVertices = tree.vertexListSorted(); sentenceLength = sortedVertices.get(sortedVertices.size() - 1).index(); // Register extra edges for (IndexedWord vertex : sortedVertices) { extraEdgesByGovernor.put(vertex, new ArrayList<>()); extraEdgesByDependent.put(vertex, new ArrayList<>()); } List<SemanticGraphEdge> extraEdges = Util.cleanTree(this.tree); assert Util.isTree(this.tree); for (SemanticGraphEdge edge : extraEdges) { extraEdgesByGovernor.get(edge.getGovernor()).add(edge); extraEdgesByDependent.get(edge.getDependent()).add(edge); } }
/** * The basic method for splitting off a clause of a tree. This modifies the tree in place. * * @param tree The tree to split a clause from. * @param toKeep The edge representing the clause to keep. */ static void splitToChildOfEdge(SemanticGraph tree, SemanticGraphEdge toKeep) { Queue<IndexedWord> fringe = new LinkedList<>(); List<IndexedWord> nodesToRemove = new ArrayList<>(); // Find nodes to remove // (from the root) for (IndexedWord root : tree.getRoots()) { nodesToRemove.add(root); for (SemanticGraphEdge out : tree.outgoingEdgeIterable(root)) { if (!out.equals(toKeep)) { fringe.add(out.getDependent()); } } } // (recursively) while (!fringe.isEmpty()) { IndexedWord node = fringe.poll(); nodesToRemove.add(node); for (SemanticGraphEdge out : tree.outgoingEdgeIterable(node)) { if (!out.equals(toKeep)) { fringe.add(out.getDependent()); } } } // Remove nodes nodesToRemove.forEach(tree::removeVertex); // Set new root tree.setRoot(toKeep.getDependent()); }
/** * 29% in FactorTable.getValue() 28% in CRFCliqueTree.getCalibratedCliqueTree() 12.6% waiting for * threads * * <p>Single threaded: 15000 ms - 26000 ms Multi threaded: 4500 ms - 7000 ms * * <p>with 8 cpus, 3.3x - 3.7x speedup, around 800% utilization */ public static void benchmarkCRF() { Properties props = new Properties(); props.setProperty("macro", "true"); // use a generic CRF configuration props.setProperty("useIfInteger", "true"); props.setProperty("featureFactory", "edu.stanford.nlp.benchmarks.BenchmarkFeatureFactory"); props.setProperty("saveFeatureIndexToDisk", "false"); CRFClassifier<CoreLabel> crf = new CRFClassifier<CoreLabel>(props); Random r = new Random(42); List<List<CoreLabel>> data = new ArrayList<>(); for (int i = 0; i < 100; i++) { List<CoreLabel> sentence = new ArrayList<>(); for (int j = 0; j < 20; j++) { CoreLabel l = new CoreLabel(); l.setWord("j:" + j); boolean tag = j % 2 == 0 ^ (r.nextDouble() > 0.7); l.set(CoreAnnotations.AnswerAnnotation.class, "target:" + tag); sentence.add(l); } data.add(sentence); } long msStart = System.currentTimeMillis(); crf.train(data); long delay = System.currentTimeMillis() - msStart; System.out.println("Training took " + delay + " ms"); }
/** * returns the syntactic category of the tree as a list of the syntactic categories of the mother * and the daughters */ public static List<String> localTreeAsCatList(Tree t) { List<String> l = new ArrayList<String>(t.children().length + 1); l.add(t.label().value()); for (int i = 0; i < t.children().length; i++) { l.add(t.children()[i].label().value()); } return l; }
private List<Tree> helper(List<Tree> treeList, int start) { List<Tree> newTreeList = new ArrayList<Tree>(treeList.size()); for (Tree tree : treeList) { int end = start + tree.yield().size(); newTreeList.add(prune(tree, start)); start = end; } return newTreeList; }
private static List<TaggedWord> cleanTags(List twList, TreebankLanguagePack tlp) { int sz = twList.size(); List<TaggedWord> l = new ArrayList<TaggedWord>(sz); for (int i = 0; i < sz; i++) { TaggedWord tw = (TaggedWord) twList.get(i); TaggedWord tw2 = new TaggedWord(tw.word(), tlp.basicCategory(tw.tag())); l.add(tw2); } return l; }
/** * returns list of tree nodes to root from t. Includes root and t. Returns null if tree not found * dominated by root */ public static List<Tree> pathFromRoot(Tree t, Tree root) { if (t == root) { // if (t.equals(root)) { List<Tree> l = new ArrayList<Tree>(1); l.add(t); return l; } else if (root == null) { return null; } return root.dominationPath(t); }
List<Tree> prune(List<Tree> treeList, Label label, int start, int end) { // get reference tree if (treeList.size() == 1) { return treeList; } Tree testTree = treeList.get(0).treeFactory().newTreeNode(label, treeList); int goal = Numberer.getGlobalNumberer("states").number(label.value()); Tree tempTree = parser.extractBestParse(goal, start, end); // parser.restoreUnaries(tempTree); Tree pcfgTree = debinarizer.transformTree(tempTree); Set<Constituent> pcfgConstituents = pcfgTree.constituents(new LabeledScoredConstituentFactory()); // delete child labels that are not in reference but do not cross reference List<Tree> prunedChildren = new ArrayList<Tree>(); int childStart = 0; for (int c = 0, numCh = testTree.numChildren(); c < numCh; c++) { Tree child = testTree.getChild(c); boolean isExtra = true; int childEnd = childStart + child.yield().size(); Constituent childConstituent = new LabeledScoredConstituent(childStart, childEnd, child.label(), 0); if (pcfgConstituents.contains(childConstituent)) { isExtra = false; } if (childConstituent.crosses(pcfgConstituents)) { isExtra = false; } if (child.isLeaf() || child.isPreTerminal()) { isExtra = false; } if (pcfgTree.yield().size() != testTree.yield().size()) { isExtra = false; } if (!label.value().startsWith("NP^NP")) { isExtra = false; } if (isExtra) { System.err.println( "Pruning: " + child.label() + " from " + (childStart + start) + " to " + (childEnd + start)); System.err.println("Was: " + testTree + " vs " + pcfgTree); prunedChildren.addAll(child.getChildrenAsList()); } else { prunedChildren.add(child); } childStart = childEnd; } return prunedChildren; }
protected String historyToString(List history) { String str = (String) historyToString.get(history); if (str == null) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < history.size(); i++) { sb.append('^'); sb.append(history.get(i)); } str = sb.toString(); historyToString.put(history, str); } return str; }
public static final String doCorefResolution(Annotation annotation) { Map<Integer, CorefChain> corefs = annotation.get(CorefChainAnnotation.class); List<CoreMap> sentences = annotation.get(CoreAnnotations.SentencesAnnotation.class); List<String> resolved = new ArrayList<String>(); for (CoreMap sentence : sentences) { List<CoreLabel> tokens = sentence.get(CoreAnnotations.TokensAnnotation.class); for (CoreLabel token : tokens) { Integer corefClustId = token.get(CorefCoreAnnotations.CorefClusterIdAnnotation.class); CorefChain chain = corefs.get(corefClustId); if (chain == null) resolved.add(token.word()); else { int sentINdx = chain.getRepresentativeMention().sentNum - 1; CoreMap corefSentence = sentences.get(sentINdx); List<CoreLabel> corefSentenceTokens = corefSentence.get(TokensAnnotation.class); CorefMention reprMent = chain.getRepresentativeMention(); if (token.index() < reprMent.startIndex || token.index() > reprMent.endIndex) { for (int i = reprMent.startIndex; i < reprMent.endIndex; i++) { CoreLabel matchedLabel = corefSentenceTokens.get(i - 1); resolved.add(matchedLabel.word()); } } else resolved.add(token.word()); } } } String resolvedStr = ""; System.out.println(); for (String str : resolved) { resolvedStr += str + " "; } System.out.println(resolvedStr); return resolvedStr; }
/** replaces all instances (by ==) of node with node1. Doesn't affect the node t itself */ public static void replaceNode(Tree node, Tree node1, Tree t) { if (t.isLeaf()) return; Tree[] kids = t.children(); List<Tree> newKids = new ArrayList<Tree>(kids.length); for (int i = 0, n = kids.length; i < n; i++) { if (kids[i] != node) { newKids.add(kids[i]); replaceNode(node, node1, kids[i]); } else { newKids.add(node1); } } t.setChildren(newKids); }
public Object formResult() { Set brs = new HashSet(); Set urs = new HashSet(); // scan each rule / history pair int ruleCount = 0; for (Iterator pairI = rulePairs.keySet().iterator(); pairI.hasNext(); ) { if (ruleCount % 100 == 0) { System.err.println("Rules multiplied: " + ruleCount); } ruleCount++; Pair rulePair = (Pair) pairI.next(); Rule baseRule = (Rule) rulePair.first; String baseLabel = (String) ruleToLabel.get(baseRule); List history = (List) rulePair.second; double totalProb = 0; for (int depth = 1; depth <= HISTORY_DEPTH() && depth <= history.size(); depth++) { List subHistory = history.subList(0, depth); double c_label = labelPairs.getCount(new Pair(baseLabel, subHistory)); double c_rule = rulePairs.getCount(new Pair(baseRule, subHistory)); // System.out.println("Multiplying out "+baseRule+" with history "+subHistory); // System.out.println("Count of "+baseLabel+" with "+subHistory+" is "+c_label); // System.out.println("Count of "+baseRule+" with "+subHistory+" is "+c_rule ); double prob = (1.0 / HISTORY_DEPTH()) * (c_rule) / (c_label); totalProb += prob; for (int childDepth = 0; childDepth <= Math.min(HISTORY_DEPTH() - 1, depth); childDepth++) { Rule rule = specifyRule(baseRule, subHistory, childDepth); rule.score = (float) Math.log(totalProb); // System.out.println("Created "+rule+" with score "+rule.score); if (rule instanceof UnaryRule) { urs.add(rule); } else { brs.add(rule); } } } } System.out.println("Total states: " + stateNumberer.total()); BinaryGrammar bg = new BinaryGrammar(stateNumberer.total()); UnaryGrammar ug = new UnaryGrammar(stateNumberer.total()); for (Iterator brI = brs.iterator(); brI.hasNext(); ) { BinaryRule br = (BinaryRule) brI.next(); bg.addRule(br); } for (Iterator urI = urs.iterator(); urI.hasNext(); ) { UnaryRule ur = (UnaryRule) urI.next(); ug.addRule(ur); } return new Pair(ug, bg); }
/** * Stips aux and mark edges when we are splitting into a clause. * * @param toModify The tree we are stripping the edges from. */ private void stripAuxMark(SemanticGraph toModify) { List<SemanticGraphEdge> toClean = new ArrayList<>(); for (SemanticGraphEdge edge : toModify.outgoingEdgeIterable(toModify.getFirstRoot())) { String rel = edge.getRelation().toString(); if (("aux".equals(rel) || "mark".equals(rel)) && !toModify.outgoingEdgeIterator(edge.getDependent()).hasNext()) { toClean.add(edge); } } for (SemanticGraphEdge edge : toClean) { toModify.removeEdge(edge); toModify.removeVertex(edge.getDependent()); } }
protected void tallyInternalNode(Tree lt, List parents) { // form base rule String label = lt.label().value(); Rule baseR = ltToRule(lt); ruleToLabel.put(baseR, label); // act on each history depth for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size()); depth <= maxDepth; depth++) { List history = new ArrayList(parents.subList(0, depth)); // tally each history level / rewrite pair rulePairs.incrementCount(new Pair(baseR, history), 1); labelPairs.incrementCount(new Pair(label, history), 1); } }
/** Re-order the action space based on the specified order of names. */ private Collection<Action> orderActions(Collection<Action> actionSpace, List<String> order) { List<Action> tmp = new ArrayList<>(actionSpace); List<Action> out = new ArrayList<>(); for (String key : order) { Iterator<Action> iter = tmp.iterator(); while (iter.hasNext()) { Action a = iter.next(); if (a.signature().equals(key)) { out.add(a); iter.remove(); } } } out.addAll(tmp); return out; }
/** * Turns a sentence into a flat phrasal tree. The structure is S -> tag*. And then each tag goes * to a word. The tag is either found from the label or made "WD". The tag and phrasal node have a * StringLabel. * * @param s The Sentence to make the Tree from * @param lf The LabelFactory with which to create the new Tree labels * @return The one phrasal level Tree */ public static Tree toFlatTree(Sentence<?> s, LabelFactory lf) { List<Tree> daughters = new ArrayList<Tree>(s.length()); for (HasWord word : s) { Tree wordNode = new LabeledScoredTreeLeaf(lf.newLabel(word.word())); if (word instanceof TaggedWord) { TaggedWord taggedWord = (TaggedWord) word; wordNode = new LabeledScoredTreeNode( new StringLabel(taggedWord.tag()), Collections.singletonList(wordNode)); } else { wordNode = new LabeledScoredTreeNode(lf.newLabel("WD"), Collections.singletonList(wordNode)); } daughters.add(wordNode); } return new LabeledScoredTreeNode(new StringLabel("S"), daughters); }
public static ArrayList<ArrayList<TaggedWord>> getPhrases(Tree parse, int phraseSizeLimit) { ArrayList<ArrayList<TaggedWord>> newList = new ArrayList<ArrayList<TaggedWord>>(); List<Tree> leaves = parse.getLeaves(); if (leaves.size() <= phraseSizeLimit) { // ArrayList<TaggedWord> phraseElements = PreprocessPhrase(parse.taggedYield()); ArrayList<TaggedWord> phraseElements = Preprocess(parse.taggedYield()); if (phraseElements.size() > 0) newList.add(phraseElements); } else { Tree[] childrenNodes = parse.children(); for (int i = 0; i < childrenNodes.length; i++) { Tree currentParse = childrenNodes[i]; newList.addAll(getPhrases(currentParse, phraseSizeLimit)); } } return newList; }
private Distribution<Integer> getSegmentedWordLengthDistribution(Treebank tb) { // CharacterLevelTagExtender ext = new CharacterLevelTagExtender(); ClassicCounter<Integer> c = new ClassicCounter<Integer>(); for (Iterator iterator = tb.iterator(); iterator.hasNext(); ) { Tree gold = (Tree) iterator.next(); StringBuilder goldChars = new StringBuilder(); ArrayList goldYield = gold.yield(); for (Iterator wordIter = goldYield.iterator(); wordIter.hasNext(); ) { Word word = (Word) wordIter.next(); goldChars.append(word); } List<HasWord> ourWords = segment(goldChars.toString()); for (int i = 0; i < ourWords.size(); i++) { c.incrementCount(Integer.valueOf(ourWords.get(i).word().length())); } } return Distribution.getDistribution(c); }
private static void leafLabels(Tree t, List<Label> l) { if (t.isLeaf()) { l.add(t.label()); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { leafLabels(kids[j], l); } } }
private static void preTerminals(Tree t, List<Tree> l) { if (t.isPreTerminal()) { l.add(t); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { preTerminals(kids[j], l); } } }
/** * Get the top few clauses from this searcher, cutting off at the given minimum probability. * * @param thresholdProbability The threshold under which to stop returning clauses. This should be * between 0 and 1. * @return The resulting {@link edu.stanford.nlp.naturalli.SentenceFragment} objects, representing * the top clauses of the sentence. */ public List<SentenceFragment> topClauses(double thresholdProbability) { List<SentenceFragment> results = new ArrayList<>(); search( triple -> { assert triple.first <= 0.0; double prob = Math.exp(triple.first); assert prob <= 1.0; assert prob >= 0.0; assert !Double.isNaN(prob); if (prob >= thresholdProbability) { SentenceFragment fragment = triple.third.get(); fragment.score = prob; results.add(fragment); return true; } else { return false; } }); return results; }
protected Rule specifyRule(Rule rule, List history, int childDepth) { Rule r; String topHistoryStr = historyToString(history.subList(1, history.size())); String bottomHistoryStr = historyToString(history.subList(0, childDepth)); if (rule instanceof UnaryRule) { UnaryRule ur = new UnaryRule(); UnaryRule urule = (UnaryRule) rule; ur.parent = stateNumberer.number(stateNumberer.object(urule.parent) + topHistoryStr); if (isSynthetic(urule.child)) { ur.child = stateNumberer.number(stateNumberer.object(urule.child) + topHistoryStr); } else if (isTag(urule.child)) { ur.child = urule.child; } else { ur.child = stateNumberer.number(stateNumberer.object(urule.child) + bottomHistoryStr); } r = ur; } else { BinaryRule br = new BinaryRule(); BinaryRule brule = (BinaryRule) rule; br.parent = stateNumberer.number(stateNumberer.object(brule.parent) + topHistoryStr); if (isSynthetic(brule.leftChild)) { br.leftChild = stateNumberer.number(stateNumberer.object(brule.leftChild) + topHistoryStr); } else if (isTag(brule.leftChild)) { br.leftChild = brule.leftChild; } else { br.leftChild = stateNumberer.number(stateNumberer.object(brule.leftChild) + bottomHistoryStr); } if (isSynthetic(brule.rightChild)) { br.rightChild = stateNumberer.number(stateNumberer.object(brule.rightChild) + topHistoryStr); } else if (isTag(brule.rightChild)) { br.rightChild = brule.rightChild; } else { br.rightChild = stateNumberer.number(stateNumberer.object(brule.rightChild) + bottomHistoryStr); } r = br; } return r; }
private static void taggedLeafLabels(Tree t, List<CoreLabel> l) { if (t.isPreTerminal()) { CoreLabel fl = (CoreLabel) t.getChild(0).label(); fl.set(TagLabelAnnotation.class, t.label()); l.add(fl); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { taggedLeafLabels(kids[j], l); } } }
/** * returns the node of a tree which represents the lowest common ancestor of nodes t1 and t2 * dominated by root. If either t1 or or t2 is not dominated by root, returns null. */ public static Tree getLowestCommonAncestor(Tree t1, Tree t2, Tree root) { List<Tree> t1Path = pathFromRoot(t1, root); List<Tree> t2Path = pathFromRoot(t2, root); if (t1Path == null || t2Path == null) return null; int min = Math.min(t1Path.size(), t2Path.size()); Tree commonAncestor = null; for (int i = 0; i < min && t1Path.get(i).equals(t2Path.get(i)); ++i) { commonAncestor = t1Path.get(i); } return commonAncestor; }
/** Get lowest common ancestor of all the nodes in the list with the tree rooted at root */ public static Tree getLowestCommonAncestor(List<Tree> nodes, Tree root) { List<List<Tree>> paths = new ArrayList<List<Tree>>(); int min = Integer.MAX_VALUE; for (Tree t : nodes) { List<Tree> path = pathFromRoot(t, root); if (path == null) return null; min = Math.min(min, path.size()); paths.add(path); } Tree commonAncestor = null; for (int i = 0; i < min; ++i) { Tree ancestor = paths.get(0).get(i); boolean quit = false; for (List<Tree> path : paths) { if (!path.get(i).equals(ancestor)) { quit = true; break; } } if (quit) break; commonAncestor = ancestor; } return commonAncestor; }
/** * Prints out all matches of a semgrex pattern on a file of dependencies. <br> * Usage:<br> * java edu.stanford.nlp.semgraph.semgrex.SemgrexPattern [args] <br> * See the help() function for a list of possible arguments to provide. */ public static void main(String[] args) throws IOException { Map<String, Integer> flagMap = Generics.newHashMap(); flagMap.put(PATTERN, 1); flagMap.put(TREE_FILE, 1); flagMap.put(MODE, 1); flagMap.put(EXTRAS, 1); flagMap.put(CONLLU_FILE, 1); flagMap.put(OUTPUT_FORMAT_OPTION, 1); Map<String, String[]> argsMap = StringUtils.argsToMap(args, flagMap); args = argsMap.get(null); // TODO: allow patterns to be extracted from a file if (!(argsMap.containsKey(PATTERN)) || argsMap.get(PATTERN).length == 0) { help(); System.exit(2); } SemgrexPattern semgrex = SemgrexPattern.compile(argsMap.get(PATTERN)[0]); String modeString = DEFAULT_MODE; if (argsMap.containsKey(MODE) && argsMap.get(MODE).length > 0) { modeString = argsMap.get(MODE)[0].toUpperCase(); } SemanticGraphFactory.Mode mode = SemanticGraphFactory.Mode.valueOf(modeString); String outputFormatString = DEFAULT_OUTPUT_FORMAT; if (argsMap.containsKey(OUTPUT_FORMAT_OPTION) && argsMap.get(OUTPUT_FORMAT_OPTION).length > 0) { outputFormatString = argsMap.get(OUTPUT_FORMAT_OPTION)[0].toUpperCase(); } OutputFormat outputFormat = OutputFormat.valueOf(outputFormatString); boolean useExtras = true; if (argsMap.containsKey(EXTRAS) && argsMap.get(EXTRAS).length > 0) { useExtras = Boolean.valueOf(argsMap.get(EXTRAS)[0]); } List<SemanticGraph> graphs = Generics.newArrayList(); // TODO: allow other sources of graphs, such as dependency files if (argsMap.containsKey(TREE_FILE) && argsMap.get(TREE_FILE).length > 0) { for (String treeFile : argsMap.get(TREE_FILE)) { System.err.println("Loading file " + treeFile); MemoryTreebank treebank = new MemoryTreebank(new TreeNormalizer()); treebank.loadPath(treeFile); for (Tree tree : treebank) { // TODO: allow other languages... this defaults to English SemanticGraph graph = SemanticGraphFactory.makeFromTree( tree, mode, useExtras ? GrammaticalStructure.Extras.MAXIMAL : GrammaticalStructure.Extras.NONE, true); graphs.add(graph); } } } if (argsMap.containsKey(CONLLU_FILE) && argsMap.get(CONLLU_FILE).length > 0) { CoNLLUDocumentReader reader = new CoNLLUDocumentReader(); for (String conlluFile : argsMap.get(CONLLU_FILE)) { System.err.println("Loading file " + conlluFile); Iterator<SemanticGraph> it = reader.getIterator(IOUtils.readerFromString(conlluFile)); while (it.hasNext()) { SemanticGraph graph = it.next(); graphs.add(graph); } } } for (SemanticGraph graph : graphs) { SemgrexMatcher matcher = semgrex.matcher(graph); if (!(matcher.find())) { continue; } if (outputFormat == OutputFormat.LIST) { System.err.println("Matched graph:"); System.err.println(graph.toString(SemanticGraph.OutputFormat.LIST)); boolean found = true; while (found) { System.err.println( "Matches at: " + matcher.getMatch().value() + "-" + matcher.getMatch().index()); List<String> nodeNames = Generics.newArrayList(); nodeNames.addAll(matcher.getNodeNames()); Collections.sort(nodeNames); for (String name : nodeNames) { System.err.println( " " + name + ": " + matcher.getNode(name).value() + "-" + matcher.getNode(name).index()); } System.err.println(); found = matcher.find(); } } else if (outputFormat == OutputFormat.OFFSET) { if (graph.vertexListSorted().isEmpty()) { continue; } System.out.printf( "+%d %s%n", graph.vertexListSorted().get(0).get(CoreAnnotations.LineNumberAnnotation.class), argsMap.get(CONLLU_FILE)[0]); } } }
public static void main(String[] args) { Options op = new Options(new EnglishTreebankParserParams()); // op.tlpParams may be changed to something else later, so don't use it till // after options are parsed. System.out.println("Currently " + new Date()); System.out.print("Invoked with arguments:"); for (String arg : args) { System.out.print(" " + arg); } System.out.println(); String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj"; int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219; String serializeFile = null; int i = 0; while (i < args.length && args[i].startsWith("-")) { if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) { path = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) { trainLow = Integer.parseInt(args[i + 1]); trainHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) { testLow = Integer.parseInt(args[i + 1]); testHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) { serializeFile = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) { try { op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance(); } catch (ClassNotFoundException e) { System.err.println("Class not found: " + args[i + 1]); } catch (InstantiationException e) { System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString()); } catch (IllegalAccessException e) { System.err.println("illegal access" + e); } i += 2; } else if (args[i].equals("-encoding")) { // sets encoding for TreebankLangParserParams op.tlpParams.setInputEncoding(args[i + 1]); op.tlpParams.setOutputEncoding(args[i + 1]); i += 2; } else { i = op.setOptionOrWarn(args, i); } } // System.out.println(tlpParams.getClass()); TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack(); Train.sisterSplitters = new HashSet(Arrays.asList(op.tlpParams.sisterSplitters())); // BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams); PrintWriter pw = op.tlpParams.pw(); Test.display(); Train.display(); op.display(); op.tlpParams.display(); // setup tree transforms Treebank trainTreebank = op.tlpParams.memoryTreebank(); MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank(); // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank(); // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/"; // blippTreebank.loadPath(blippPath, "", true); Timing.startTime(); System.err.print("Reading trees..."); testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true)); if (Test.increasingLength) { Collections.sort(testTreebank, new TreeLengthComparator()); } trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true)); Timing.tick("done."); System.err.print("Binarizing trees..."); TreeAnnotatorAndBinarizer binarizer = null; if (!Train.leftToRight) { binarizer = new TreeAnnotatorAndBinarizer(op.tlpParams, op.forceCNF, !Train.outsideFactor(), true); } else { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !Train.outsideFactor(), true); } CollinsPuncTransformer collinsPuncTransformer = null; if (Train.collinsPunc) { collinsPuncTransformer = new CollinsPuncTransformer(tlp); } TreeTransformer debinarizer = new Debinarizer(op.forceCNF); List<Tree> binaryTrainTrees = new ArrayList<Tree>(); if (Train.selectiveSplit) { Train.splitters = ParentAnnotationStats.getSplitCategories( trainTreebank, Train.tagSelectiveSplit, 0, Train.selectiveSplitCutOff, Train.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack()); if (Train.deleteSplitters != null) { List<String> deleted = new ArrayList<String>(); for (String del : Train.deleteSplitters) { String baseDel = tlp.basicCategory(del); boolean checkBasic = del.equals(baseDel); for (Iterator<String> it = Train.splitters.iterator(); it.hasNext(); ) { String elem = it.next(); String baseElem = tlp.basicCategory(elem); boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del); if (delStr) { it.remove(); deleted.add(elem); } } } System.err.println("Removed from vertical splitters: " + deleted); } } if (Train.selectivePostSplit) { TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams); Treebank annotatedTB = trainTreebank.transform(myTransformer); Train.postSplitters = ParentAnnotationStats.getSplitCategories( annotatedTB, true, 0, Train.selectivePostSplitCutOff, Train.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack()); } if (Train.hSelSplit) { binarizer.setDoSelectiveSplit(false); for (Tree tree : trainTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } // tree.pennPrint(tlpParams.pw()); tree = binarizer.transformTree(tree); // binaryTrainTrees.add(tree); } binarizer.setDoSelectiveSplit(true); } for (Tree tree : trainTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTrainTrees.add(tree); } if (Test.verbose) { binarizer.dumpStats(); } List<Tree> binaryTestTrees = new ArrayList<Tree>(); for (Tree tree : testTreebank) { if (Train.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTestTrees.add(tree); } Timing.tick("done."); // binarization BinaryGrammar bg = null; UnaryGrammar ug = null; DependencyGrammar dg = null; // DependencyGrammar dgBLIPP = null; Lexicon lex = null; // extract grammars Extractor bgExtractor = new BinaryGrammarExtractor(); // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor(); // Extractor lexExtractor = new LexiconExtractor(); // Extractor dgExtractor = new DependencyMemGrammarExtractor(); Extractor dgExtractor = new MLEDependencyGrammarExtractor(op); if (op.doPCFG) { System.err.print("Extracting PCFG..."); Pair bgug = null; if (Train.cheatPCFG) { List allTrees = new ArrayList(binaryTrainTrees); allTrees.addAll(binaryTestTrees); bgug = (Pair) bgExtractor.extract(allTrees); } else { bgug = (Pair) bgExtractor.extract(binaryTrainTrees); } bg = (BinaryGrammar) bgug.second; bg.splitRules(); ug = (UnaryGrammar) bgug.first; ug.purgeRules(); Timing.tick("done."); } System.err.print("Extracting Lexicon..."); lex = op.tlpParams.lex(op.lexOptions); lex.train(binaryTrainTrees); Timing.tick("done."); if (op.doDep) { System.err.print("Extracting Dependencies..."); binaryTrainTrees.clear(); // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams,true)); DependencyGrammar dg1 = (DependencyGrammar) dgExtractor.extract( trainTreebank.iterator(), new TransformTreeDependency(op.tlpParams, true)); // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new // TransformTreeDependency(tlpParams)); // dg = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams)); // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2); // dg = (DependencyGrammar) dgExtractor.extract(binaryTrainTrees); //uses information whether // the words are known or not, discards unknown words Timing.tick("done."); // System.out.print("Extracting Unknown Word Model..."); // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees); // Timing.tick("done."); System.out.print("Tuning Dependency Model..."); dg.tune(binaryTestTrees); // System.out.println("TUNE DEPS: "+tuneDeps); Timing.tick("done."); } BinaryGrammar boundBG = bg; UnaryGrammar boundUG = ug; GrammarProjection gp = new NullGrammarProjection(bg, ug); // serialization if (serializeFile != null) { System.err.print("Serializing parser..."); LexicalizedParser.saveParserDataToSerialized( new ParserData(lex, bg, ug, dg, Numberer.getNumberers(), op), serializeFile); Timing.tick("done."); } // test: pcfg-parse and output ExhaustivePCFGParser parser = null; if (op.doPCFG) { parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op); } ExhaustiveDependencyParser dparser = ((op.doDep && !Test.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op) : null); Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp), dparser) : null); // Scorer scorer = parser; BiLexPCFGParser bparser = null; if (op.doPCFG && op.doDep) { bparser = (Test.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp) : new BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp); } LabeledConstituentEval pcfgPE = new LabeledConstituentEval("pcfg PE", true, tlp); LabeledConstituentEval comboPE = new LabeledConstituentEval("combo PE", true, tlp); AbstractEval pcfgCB = new LabeledConstituentEval.CBEval("pcfg CB", true, tlp); AbstractEval pcfgTE = new AbstractEval.TaggingEval("pcfg TE"); AbstractEval comboTE = new AbstractEval.TaggingEval("combo TE"); AbstractEval pcfgTEnoPunct = new AbstractEval.TaggingEval("pcfg nopunct TE"); AbstractEval comboTEnoPunct = new AbstractEval.TaggingEval("combo nopunct TE"); AbstractEval depTE = new AbstractEval.TaggingEval("depnd TE"); AbstractEval depDE = new AbstractEval.DependencyEval("depnd DE", true, tlp.punctuationWordAcceptFilter()); AbstractEval comboDE = new AbstractEval.DependencyEval("combo DE", true, tlp.punctuationWordAcceptFilter()); if (Test.evalb) { EvalB.initEVALBfiles(op.tlpParams); } // int[] countByLength = new int[Test.maxLength+1]; // use a reflection ruse, so one can run this without needing the tagger // edu.stanford.nlp.process.SentenceTagger tagger = (Test.preTag ? new // edu.stanford.nlp.process.SentenceTagger("/u/nlp/data/tagger.params/wsj0-21.holder") : null); SentenceProcessor tagger = null; if (Test.preTag) { try { Class[] argsClass = new Class[] {String.class}; Object[] arguments = new Object[] {"/u/nlp/data/pos-tagger/wsj3t0-18-bidirectional/train-wsj-0-18.holder"}; tagger = (SentenceProcessor) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger") .getConstructor(argsClass) .newInstance(arguments); } catch (Exception e) { System.err.println(e); System.err.println("Warning: No pretagging of sentences will be done."); } } for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) { Tree tree = testTreebank.get(tNum); int testTreeLen = tree.yield().size(); if (testTreeLen > Test.maxLength) { continue; } Tree binaryTree = binaryTestTrees.get(tNum); // countByLength[testTreeLen]++; System.out.println("-------------------------------------"); System.out.println("Number: " + (tNum + 1)); System.out.println("Length: " + testTreeLen); // tree.pennPrint(pw); // System.out.println("XXXX The binary tree is"); // binaryTree.pennPrint(pw); // System.out.println("Here are the tags in the lexicon:"); // System.out.println(lex.showTags()); // System.out.println("Here's the tagnumberer:"); // System.out.println(Numberer.getGlobalNumberer("tags").toString()); long timeMil1 = System.currentTimeMillis(); Timing.tick("Starting parse."); if (op.doPCFG) { // System.err.println(Test.forceTags); if (Test.forceTags) { if (tagger != null) { // System.out.println("Using a tagger to set tags"); // System.out.println("Tagged sentence as: " + // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false)); parser.parse(addLast(tagger.processSentence(cutLast(wordify(binaryTree.yield()))))); } else { // System.out.println("Forcing tags to match input."); parser.parse(cleanTags(binaryTree.taggedYield(), tlp)); } } else { // System.out.println("XXXX Parsing " + binaryTree.yield()); parser.parse(binaryTree.yield()); } // Timing.tick("Done with pcfg phase."); } if (op.doDep) { dparser.parse(binaryTree.yield()); // Timing.tick("Done with dependency phase."); } boolean bothPassed = false; if (op.doPCFG && op.doDep) { bothPassed = bparser.parse(binaryTree.yield()); // Timing.tick("Done with combination phase."); } long timeMil2 = System.currentTimeMillis(); long elapsed = timeMil2 - timeMil1; System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec."); // System.out.println("PCFG Best Parse:"); Tree tree2b = null; Tree tree2 = null; // System.out.println("Got full best parse..."); if (op.doPCFG) { tree2b = parser.getBestParse(); tree2 = debinarizer.transformTree(tree2b); } // System.out.println("Debinarized parse..."); // tree2.pennPrint(); // System.out.println("DepG Best Parse:"); Tree tree3 = null; Tree tree3db = null; if (op.doDep) { tree3 = dparser.getBestParse(); // was: but wrong Tree tree3db = debinarizer.transformTree(tree2); tree3db = debinarizer.transformTree(tree3); tree3.pennPrint(pw); } // tree.pennPrint(); // ((Tree)binaryTrainTrees.get(tNum)).pennPrint(); // System.out.println("Combo Best Parse:"); Tree tree4 = null; if (op.doPCFG && op.doDep) { try { tree4 = bparser.getBestParse(); if (tree4 == null) { tree4 = tree2b; } } catch (NullPointerException e) { System.err.println("Blocked, using PCFG parse!"); tree4 = tree2b; } } if (op.doPCFG && !bothPassed) { tree4 = tree2b; } // tree4.pennPrint(); if (op.doDep) { depDE.evaluate(tree3, binaryTree, pw); depTE.evaluate(tree3db, tree, pw); } TreeTransformer tc = op.tlpParams.collinizer(); TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb(); Tree tree4b = null; if (op.doPCFG) { // System.out.println("XXXX Best PCFG was: "); // tree2.pennPrint(); // System.out.println("XXXX Transformed best PCFG is: "); // tc.transformTree(tree2).pennPrint(); // System.out.println("True Best Parse:"); // tree.pennPrint(); // tc.transformTree(tree).pennPrint(); pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw); tree4b = tree4; tree4 = debinarizer.transformTree(tree4); if (op.nodePrune) { NodePruner np = new NodePruner(parser, debinarizer); tree4 = np.prune(tree4); } // tree4.pennPrint(); comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } // pcfgTE.evaluate(tree2, tree); pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw); pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw); comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0)); // tc.transformTree(tree2).pennPrint(); tree2.pennPrint(pw); if (op.doDep) { System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0)); // tc.transformTree(tree4).pennPrint(pw); tree4.pennPrint(pw); } System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0)); /* if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) { System.out.println("SCORE INVERSION"); parser.validateBinarizedTree(binaryTree,0); } */ tree.pennPrint(pw); } // end if doPCFG if (Test.evalb) { if (op.doPCFG && op.doDep) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4)); } else if (op.doPCFG) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2)); } else if (op.doDep) { EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db)); } } } // end for each tree in test treebank if (Test.evalb) { EvalB.closeEVALBfiles(); } // Test.display(); if (op.doPCFG) { pcfgPE.display(false, pw); System.out.println("Grammar size: " + Numberer.getGlobalNumberer("states").total()); pcfgCB.display(false, pw); if (op.doDep) { comboPE.display(false, pw); } pcfgTE.display(false, pw); pcfgTEnoPunct.display(false, pw); if (op.doDep) { comboTE.display(false, pw); comboTEnoPunct.display(false, pw); } } if (op.doDep) { depTE.display(false, pw); depDE.display(false, pw); } if (op.doPCFG && op.doDep) { comboDE.display(false, pw); } // pcfgPE.printGoodBad(); }
@Override protected void appendNewOrder(Matcher matches, List<TaggedWord> words, List<TaggedWord> newList) { newList.add(new TaggedWord("on", "IN")); newList.addAll(words.subList(matches.start(2), matches.end(3))); }
/** * returns a list of categories that is the path from Tree from to Tree to within Tree root. If * either from or to is not in root, returns null. Otherwise includes both from and to in the * list. */ public static List<String> pathNodeToNode(Tree from, Tree to, Tree root) { List<Tree> fromPath = pathFromRoot(from, root); // System.out.println(treeListToCatList(fromPath)); if (fromPath == null) return null; List<Tree> toPath = pathFromRoot(to, root); // System.out.println(treeListToCatList(toPath)); if (toPath == null) return null; // System.out.println(treeListToCatList(fromPath)); // System.out.println(treeListToCatList(toPath)); int last = 0; int min = fromPath.size() <= toPath.size() ? fromPath.size() : toPath.size(); Tree lastNode = null; // while((! (fromPath.isEmpty() || toPath.isEmpty())) && // fromPath.get(0).equals(toPath.get(0))) { // lastNode = (Tree) fromPath.remove(0); // toPath.remove(0); // } while (last < min && fromPath.get(last).equals(toPath.get(last))) { lastNode = fromPath.get(last); last++; } // System.out.println(treeListToCatList(fromPath)); // System.out.println(treeListToCatList(toPath)); List<String> totalPath = new ArrayList<String>(); for (int i = fromPath.size() - 1; i >= last; i--) { Tree t = fromPath.get(i); totalPath.add("up-" + t.label().value()); } if (lastNode != null) totalPath.add("up-" + lastNode.label().value()); for (Tree t : toPath) totalPath.add("down-" + t.label().value()); // for(ListIterator i = fromPath.listIterator(fromPath.size()); i.hasPrevious(); ){ // Tree t = (Tree) i.previous(); // totalPath.add("up-" + t.label().value()); // } // if(lastNode != null) // totalPath.add("up-" + lastNode.label().value()); // for(ListIterator j = toPath.listIterator(); j.hasNext(); ){ // Tree t = (Tree) j.next(); // totalPath.add("down-" + t.label().value()); // } return totalPath; }
@Override protected boolean IsMatchValid(Matcher matches, List<TaggedWord> words) { return Pattern.matches("(hour|day|week|month)", words.get(matches.start(3)).word()); }