private FSArray addTreebankNodeChildrenToIndexes( TreebankNode parent, JCas jCas, List<CoreLabel> tokenAnns, Tree tree) { Tree[] childTrees = tree.children(); // collect all children (except leaves, which are just the words - POS tags are pre-terminals in // a Stanford tree) List<TreebankNode> childNodes = new ArrayList<TreebankNode>(); for (Tree child : childTrees) { if (!child.isLeaf()) { // set node attributes and add children (mutual recursion) TreebankNode node = new TreebankNode(jCas); node.setParent(parent); this.addTreebankNodeToIndexes(node, jCas, child, tokenAnns); childNodes.add(node); } } // convert the child list into an FSArray FSArray childNodeArray = new FSArray(jCas, childNodes.size()); for (int i = 0; i < childNodes.size(); ++i) { childNodeArray.set(i, childNodes.get(i)); } return childNodeArray; }
public Tree transformTree(Tree tree) { Label lab = tree.label(); if (tree.isLeaf()) { Tree leaf = tf.newLeaf(lab); leaf.setScore(tree.score()); return leaf; } String s = lab.value(); s = treebankLanguagePack().basicCategory(s); int numKids = tree.numChildren(); List<Tree> children = new ArrayList<Tree>(numKids); for (int cNum = 0; cNum < numKids; cNum++) { Tree child = tree.getChild(cNum); Tree newChild = transformTree(child); // cdm 2007: for just subcategory stripping, null shouldn't happen // if (newChild != null) { children.add(newChild); // } } // if (children.isEmpty()) { // return null; // } CategoryWordTag newLabel = new CategoryWordTag(lab); newLabel.setCategory(s); if (lab instanceof HasTag) { String tag = ((HasTag) lab).tag(); tag = treebankLanguagePack().basicCategory(tag); newLabel.setTag(tag); } Tree node = tf.newTreeNode(newLabel, children); node.setScore(tree.score()); return node; }
public Tree transformTree(Tree tree) { Label lab = tree.label(); if (tree.isLeaf()) { Tree leaf = tf.newLeaf(lab); leaf.setScore(tree.score()); return leaf; } String s = lab.value(); s = treebankLanguagePack().basicCategory(s); s = treebankLanguagePack().stripGF(s); int numKids = tree.numChildren(); List<Tree> children = new ArrayList<Tree>(numKids); for (int cNum = 0; cNum < numKids; cNum++) { Tree child = tree.getChild(cNum); Tree newChild = transformTree(child); children.add(newChild); } CategoryWordTag newLabel = new CategoryWordTag(lab); newLabel.setCategory(s); if (lab instanceof HasTag) { String tag = ((HasTag) lab).tag(); tag = treebankLanguagePack().basicCategory(tag); tag = treebankLanguagePack().stripGF(tag); newLabel.setTag(tag); } Tree node = tf.newTreeNode(newLabel, children); node.setScore(tree.score()); return node; }
/** * Build the set of dependencies for evaluation. This set excludes all dependencies for which the * argument is a punctuation tag. */ @Override protected Set<?> makeObjects(Tree tree) { Set<Dependency<Label, Label, Object>> deps = new HashSet<Dependency<Label, Label, Object>>(); for (Tree node : tree.subTreeList()) { if (DEBUG) EncodingPrintWriter.err.println("Considering " + node.label()); // every child with a different head is an argument, as are ones with // the same head after the first one found if (node.isLeaf() || node.children().length < 2) { continue; } // System.err.println("XXX node is " + node + "; label type is " + // node.label().getClass().getName()); String head = ((HasWord) node.label()).word(); boolean seenHead = false; for (int cNum = 0; cNum < node.children().length; cNum++) { Tree child = node.children()[cNum]; String arg = ((HasWord) child.label()).word(); if (DEBUG) EncodingPrintWriter.err.println("Considering " + head + " --> " + arg); if (head.equals(arg) && !seenHead) { seenHead = true; if (DEBUG) EncodingPrintWriter.err.println(" ... is head"); } else if (!punctFilter.accept(arg)) { deps.add(new UnnamedDependency(head, arg)); if (DEBUG) EncodingPrintWriter.err.println(" ... added"); } else if (DEBUG) { if (DEBUG) EncodingPrintWriter.err.println(" ... is punct dep"); } } } if (DEBUG) { EncodingPrintWriter.err.println("Deps: " + deps); } return deps; }
Tree prune(Tree tree, int start) { if (tree.isLeaf() || tree.isPreTerminal()) { return tree; } // check each node's children for deletion List<Tree> children = helper(tree.getChildrenAsList(), start); children = prune(children, tree.label(), start, start + tree.yield().size()); return tree.treeFactory().newTreeNode(tree.label(), children); }
List<Tree> prune(List<Tree> treeList, Label label, int start, int end) { // get reference tree if (treeList.size() == 1) { return treeList; } Tree testTree = treeList.get(0).treeFactory().newTreeNode(label, treeList); int goal = Numberer.getGlobalNumberer("states").number(label.value()); Tree tempTree = parser.extractBestParse(goal, start, end); // parser.restoreUnaries(tempTree); Tree pcfgTree = debinarizer.transformTree(tempTree); Set<Constituent> pcfgConstituents = pcfgTree.constituents(new LabeledScoredConstituentFactory()); // delete child labels that are not in reference but do not cross reference List<Tree> prunedChildren = new ArrayList<Tree>(); int childStart = 0; for (int c = 0, numCh = testTree.numChildren(); c < numCh; c++) { Tree child = testTree.getChild(c); boolean isExtra = true; int childEnd = childStart + child.yield().size(); Constituent childConstituent = new LabeledScoredConstituent(childStart, childEnd, child.label(), 0); if (pcfgConstituents.contains(childConstituent)) { isExtra = false; } if (childConstituent.crosses(pcfgConstituents)) { isExtra = false; } if (child.isLeaf() || child.isPreTerminal()) { isExtra = false; } if (pcfgTree.yield().size() != testTree.yield().size()) { isExtra = false; } if (!label.value().startsWith("NP^NP")) { isExtra = false; } if (isExtra) { System.err.println( "Pruning: " + child.label() + " from " + (childStart + start) + " to " + (childEnd + start)); System.err.println("Was: " + testTree + " vs " + pcfgTree); prunedChildren.addAll(child.getChildrenAsList()); } else { prunedChildren.add(child); } childStart = childEnd; } return prunedChildren; }
private static int reIndexLeaves(Tree t, int startIndex) { if (t.isLeaf()) { CoreLabel afl = (CoreLabel) t.label(); afl.setIndex(startIndex); startIndex++; } else { for (Tree child : t.children()) { startIndex = reIndexLeaves(child, startIndex); } } return startIndex; }
/** * Sets the labels on the tree to be the indices of the nodes. Starts counting at the root and * does a postorder traversal. */ static int setIndexLabels(Tree tree, int index) { if (tree.isLeaf()) { return index; } tree.label().setValue(Integer.toString(index)); index++; for (Tree child : tree.children()) { index = setIndexLabels(child, index); } return index; }
private static <E> void dependencyObjectifyHelper( Tree t, Tree root, HeadFinder hf, Collection<E> c, DependencyTyper<E> typer) { if (t.isLeaf() || t.isPreTerminal()) { return; } Tree headDtr = hf.determineHead(t); for (Tree child : t.children()) { dependencyObjectifyHelper(child, root, hf, c, typer); if (child != headDtr) { c.add(typer.makeDependency(headDtr, child, root)); } } }
protected static String localize(Tree tree) { if (tree.isLeaf()) { return ""; } StringBuilder sb = new StringBuilder(); sb.append(tree.label()); sb.append(" ->"); for (int i = 0; i < tree.children().length; i++) { sb.append(' '); sb.append(tree.children()[i].label()); } return sb.toString(); }
/** * Sets the labels on the tree (except the leaves) to be the integer value of the sentiment * prediction. Makes it easy to print out with Tree.toString() */ static void setSentimentLabels(Tree tree) { if (tree.isLeaf()) { return; } for (Tree child : tree.children()) { setSentimentLabels(child); } Label label = tree.label(); if (!(label instanceof CoreLabel)) { throw new IllegalArgumentException("Required a tree with CoreLabels"); } CoreLabel cl = (CoreLabel) label; cl.setValue(Integer.toString(RNNCoreAnnotations.getPredictedClass(tree))); }
/** Outputs the scores from the tree. Counts the tree nodes the same as setIndexLabels. */ static int outputTreeScores(PrintStream out, Tree tree, int index) { if (tree.isLeaf()) { return index; } out.print(" " + index + ":"); SimpleMatrix vector = RNNCoreAnnotations.getPredictions(tree); for (int i = 0; i < vector.getNumElements(); ++i) { out.print(" " + NF.format(vector.get(i))); } out.println(); index++; for (Tree child : tree.children()) { index = outputTreeScores(out, child, index); } return index; }
/** * Takes a Tree and a collinizer and returns a Collection of {@link Constituent}s for PARSEVAL * evaluation. Some notes on this particular parseval: * * <ul> * <li>It is character-based, which allows it to be used on segmentation/parsing combination * evaluation. * <li>whether it gives you labeled or unlabeled bracketings depends on the value of the <code> * labelConstituents</code> parameter * </ul> * * (Note that I haven't checked this rigorously yet with the PARSEVAL definition -- Roger.) */ public static Collection<Constituent> parsevalObjectify( Tree t, TreeTransformer collinizer, boolean labelConstituents) { Collection<Constituent> spans = new ArrayList<Constituent>(); Tree t1 = collinizer.transformTree(t); if (t1 == null) { return spans; } for (Tree node : t1) { if (node.isLeaf() || node.isPreTerminal() || (node != t1 && node.parent(t1) == null)) { continue; } int leftEdge = t1.leftCharEdge(node); int rightEdge = t1.rightCharEdge(node); if (labelConstituents) spans.add(new LabeledConstituent(leftEdge, rightEdge, node.label())); else spans.add(new SimpleConstituent(leftEdge, rightEdge)); } return spans; }
protected static void updateTreeLabels( Tree root, Tree tree, MutableInteger offset, MutableInteger leafIndex) { if (tree.isLeaf()) { leafIndex.value++; return; } String labelValue = tree.label().value().toUpperCase(); int begin = root.leftCharEdge(tree); int end = root.rightCharEdge(tree); // System.out.println(labelValue+"("+begin+","+end+")"); int length = end - begin; // apply offset to begin extent begin += offset.value; // calculate offset delta based on label if (double_quote_lable_pattern.matcher(labelValue).matches() && length > 1) { offset.value--; log.debug("Quotes label pattern fired: " + offset); } else if (bracket_label_pattern.matcher(labelValue).matches()) { offset.value -= 4; log.debug("Bracket label pattern fired: " + offset); } else if (tree.isPreTerminal()) { Tree leaf = tree.firstChild(); String text = leaf.label().value(); Matcher matcher = escaped_char_pattern.matcher(text); while (matcher.find()) { offset.value--; } } for (Tree child : tree.children()) updateTreeLabels(root, child, offset, leafIndex); // apply offset to end extent end += offset.value; // set begin and end offsets on node MapLabel label = new MapLabel(tree.label()); label.put(BEGIN_KEY, begin); label.put(END_KEY, end); label.put(MapLabel.INDEX_KEY, leafIndex.value); tree.setLabel(label); }
protected void tallyTree(Tree t, LinkedList<String> parents) { // traverse tree, building parent list String str = t.label().value(); boolean strIsPassive = (str.indexOf('@') == -1); if (strIsPassive) { parents.addFirst(str); } if (!t.isLeaf()) { if (!t.children()[0].isLeaf()) { tallyInternalNode(t, parents); for (int c = 0; c < t.children().length; c++) { Tree child = t.children()[c]; tallyTree(child, parents); } } else { tagNumberer.number(t.label().value()); } } if (strIsPassive) { parents.removeFirst(); } }
/** @param args */ public static void main(String[] args) { if (args.length != 3) { System.err.printf( "Usage: java %s language filename features%n", TreebankFactoredLexiconStats.class.getName()); System.exit(-1); } Language language = Language.valueOf(args[0]); TreebankLangParserParams tlpp = language.params; if (language.equals(Language.Arabic)) { String[] options = {"-arabicFactored"}; tlpp.setOptionFlag(options, 0); } else { String[] options = {"-frenchFactored"}; tlpp.setOptionFlag(options, 0); } Treebank tb = tlpp.diskTreebank(); tb.loadPath(args[1]); MorphoFeatureSpecification morphoSpec = language.equals(Language.Arabic) ? new ArabicMorphoFeatureSpecification() : new FrenchMorphoFeatureSpecification(); String[] features = args[2].trim().split(","); for (String feature : features) { morphoSpec.activate(MorphoFeatureType.valueOf(feature)); } // Counters Counter<String> wordTagCounter = new ClassicCounter<>(30000); Counter<String> morphTagCounter = new ClassicCounter<>(500); // Counter<String> signatureTagCounter = new ClassicCounter<String>(); Counter<String> morphCounter = new ClassicCounter<>(500); Counter<String> wordCounter = new ClassicCounter<>(30000); Counter<String> tagCounter = new ClassicCounter<>(300); Counter<String> lemmaCounter = new ClassicCounter<>(25000); Counter<String> lemmaTagCounter = new ClassicCounter<>(25000); Counter<String> richTagCounter = new ClassicCounter<>(1000); Counter<String> reducedTagCounter = new ClassicCounter<>(500); Counter<String> reducedTagLemmaCounter = new ClassicCounter<>(500); Map<String, Set<String>> wordLemmaMap = Generics.newHashMap(); TwoDimensionalIntCounter<String, String> lemmaReducedTagCounter = new TwoDimensionalIntCounter<>(30000); TwoDimensionalIntCounter<String, String> reducedTagTagCounter = new TwoDimensionalIntCounter<>(500); TwoDimensionalIntCounter<String, String> tagReducedTagCounter = new TwoDimensionalIntCounter<>(300); int numTrees = 0; for (Tree tree : tb) { for (Tree subTree : tree) { if (!subTree.isLeaf()) { tlpp.transformTree(subTree, tree); } } List<Label> pretermList = tree.preTerminalYield(); List<Label> yield = tree.yield(); assert yield.size() == pretermList.size(); int yieldLen = yield.size(); for (int i = 0; i < yieldLen; ++i) { String tag = pretermList.get(i).value(); String word = yield.get(i).value(); String morph = ((CoreLabel) yield.get(i)).originalText(); // Note: if there is no lemma, then we use the surface form. Pair<String, String> lemmaTag = MorphoFeatureSpecification.splitMorphString(word, morph); String lemma = lemmaTag.first(); String richTag = lemmaTag.second(); // WSGDEBUG if (tag.contains("MW")) lemma += "-MWE"; lemmaCounter.incrementCount(lemma); lemmaTagCounter.incrementCount(lemma + tag); richTagCounter.incrementCount(richTag); String reducedTag = morphoSpec.strToFeatures(richTag).toString(); reducedTagCounter.incrementCount(reducedTag); reducedTagLemmaCounter.incrementCount(reducedTag + lemma); wordTagCounter.incrementCount(word + tag); morphTagCounter.incrementCount(morph + tag); morphCounter.incrementCount(morph); wordCounter.incrementCount(word); tagCounter.incrementCount(tag); reducedTag = reducedTag.equals("") ? "NONE" : reducedTag; if (wordLemmaMap.containsKey(word)) { wordLemmaMap.get(word).add(lemma); } else { Set<String> lemmas = Generics.newHashSet(1); wordLemmaMap.put(word, lemmas); } lemmaReducedTagCounter.incrementCount(lemma, reducedTag); reducedTagTagCounter.incrementCount(lemma + reducedTag, tag); tagReducedTagCounter.incrementCount(tag, reducedTag); } ++numTrees; } // Barf... System.out.println("Language: " + language.toString()); System.out.printf("#trees:\t%d%n", numTrees); System.out.printf("#tokens:\t%d%n", (int) wordCounter.totalCount()); System.out.printf("#words:\t%d%n", wordCounter.keySet().size()); System.out.printf("#tags:\t%d%n", tagCounter.keySet().size()); System.out.printf("#wordTagPairs:\t%d%n", wordTagCounter.keySet().size()); System.out.printf("#lemmas:\t%d%n", lemmaCounter.keySet().size()); System.out.printf("#lemmaTagPairs:\t%d%n", lemmaTagCounter.keySet().size()); System.out.printf("#feattags:\t%d%n", reducedTagCounter.keySet().size()); System.out.printf("#feattag+lemmas:\t%d%n", reducedTagLemmaCounter.keySet().size()); System.out.printf("#richtags:\t%d%n", richTagCounter.keySet().size()); System.out.printf("#richtag+lemma:\t%d%n", morphCounter.keySet().size()); System.out.printf("#richtag+lemmaTagPairs:\t%d%n", morphTagCounter.keySet().size()); // Extra System.out.println("=================="); StringBuilder sbNoLemma = new StringBuilder(); StringBuilder sbMultLemmas = new StringBuilder(); for (Map.Entry<String, Set<String>> wordLemmas : wordLemmaMap.entrySet()) { String word = wordLemmas.getKey(); Set<String> lemmas = wordLemmas.getValue(); if (lemmas.size() == 0) { sbNoLemma.append("NO LEMMAS FOR WORD: " + word + "\n"); continue; } if (lemmas.size() > 1) { sbMultLemmas.append("MULTIPLE LEMMAS: " + word + " " + setToString(lemmas) + "\n"); continue; } String lemma = lemmas.iterator().next(); Set<String> reducedTags = lemmaReducedTagCounter.getCounter(lemma).keySet(); if (reducedTags.size() > 1) { System.out.printf("%s --> %s%n", word, lemma); for (String reducedTag : reducedTags) { int count = lemmaReducedTagCounter.getCount(lemma, reducedTag); String posTags = setToString(reducedTagTagCounter.getCounter(lemma + reducedTag).keySet()); System.out.printf("\t%s\t%d\t%s%n", reducedTag, count, posTags); } System.out.println(); } } System.out.println("=================="); System.out.println(sbNoLemma.toString()); System.out.println(sbMultLemmas.toString()); System.out.println("=================="); List<String> tags = new ArrayList<>(tagReducedTagCounter.firstKeySet()); Collections.sort(tags); for (String tag : tags) { System.out.println(tag); Set<String> reducedTags = tagReducedTagCounter.getCounter(tag).keySet(); for (String reducedTag : reducedTags) { int count = tagReducedTagCounter.getCount(tag, reducedTag); // reducedTag = reducedTag.equals("") ? "NONE" : reducedTag; System.out.printf("\t%s\t%d%n", reducedTag, count); } System.out.println(); } System.out.println("=================="); }
/** * Maps Tree node offsets using provided mapping. * * @param tree the Tree whose begin and end extents should be mapped. * @param mapping the list of RangeMap objects which defines the mapping. */ protected static void mapOffsets(Tree tree, List<RangeMap> mapping) { // if mapping is empty, then assume 1-to-1 mapping. if (mapping == null || mapping.size() == 0) return; int begin_map_index = 0; RangeMap begin_rmap = mapping.get(begin_map_index); TREE: for (Tree t : tree) { if (t.isLeaf()) continue; MapLabel label = (MapLabel) t.label(); int begin = (Integer) label.get(BEGIN_KEY); // "end" must be index of last char in range int end = (Integer) label.get(END_KEY) - 1; // find the first rangemap whose end is greater than the // beginning of current annotation. // log.debug("Finding RangeMap whose extents include // annotation.begin"); while (begin_rmap.end <= begin) { begin_map_index++; if (begin_map_index >= mapping.size()) break TREE; begin_rmap = mapping.get(begin_map_index); } // if beginning of current rangemap is greater than end of // current annotation, then skip this annotation (default // mapping is 1-to-1). if (begin_rmap.begin > end) { // log.debug("Skipping annotation (assuming 1-to-1 offset // mapping)"); continue; } // if beginning of current annotation falls within current range // map, then map it back to source space. int new_begin = begin; if (begin_rmap.begin <= new_begin) { // log.debug("Applying RangeMap to begin offset"); new_begin = begin_rmap.map(new_begin); } // find the first rangemap whose end is greater than the end of // current annotation. // log.debug("Finding RangeMap whose extents include // annotation.end"); int end_map_index = begin_map_index; RangeMap end_rmap = begin_rmap; END_OFFSET: while (end_rmap.end <= end) { end_map_index++; if (end_map_index >= mapping.size()) break END_OFFSET; end_rmap = mapping.get(end_map_index); } // if end of current annotation falls within "end" range map, // then map it back to source space. int new_end = end; if (end_rmap.begin <= end) { // log.debug("Applying RangeMap to end offset"); new_end = end_rmap.map(end); } label.put(BEGIN_KEY, new_begin); label.put(END_KEY, new_end + 1); } }
/** * transformTree does all language-specific tree transformations. Any parameterizations should be * inside the specific TreebankLangParserParams class. */ @Override public Tree transformTree(Tree t, Tree root) { if (t == null || t.isLeaf()) { return t; } String parentStr; String grandParentStr; Tree parent; Tree grandParent; if (root == null || t.equals(root)) { parent = null; parentStr = ""; } else { parent = t.parent(root); parentStr = parent.label().value(); } if (parent == null || parent.equals(root)) { grandParent = null; grandParentStr = ""; } else { grandParent = parent.parent(root); grandParentStr = grandParent.label().value(); } String baseParentStr = ctlp.basicCategory(parentStr); String baseGrandParentStr = ctlp.basicCategory(grandParentStr); CoreLabel lab = (CoreLabel) t.label(); String word = lab.word(); String tag = lab.tag(); String baseTag = ctlp.basicCategory(tag); String category = lab.value(); String baseCategory = ctlp.basicCategory(category); if (t.isPreTerminal()) { // it's a POS tag List<String> leftAunts = listBasicCategories(SisterAnnotationStats.leftSisterLabels(parent, grandParent)); List<String> rightAunts = listBasicCategories(SisterAnnotationStats.rightSisterLabels(parent, grandParent)); // Chinese-specific punctuation splits if (chineseSplitPunct && baseTag.equals("PU")) { if (ChineseTreebankLanguagePack.chineseDouHaoAcceptFilter().accept(word)) { tag = tag + "-DOU"; // System.out.println("Punct: Split dou hao"); // debugging } else if (ChineseTreebankLanguagePack.chineseCommaAcceptFilter().accept(word)) { tag = tag + "-COMMA"; // System.out.println("Punct: Split comma"); // debugging } else if (ChineseTreebankLanguagePack.chineseColonAcceptFilter().accept(word)) { tag = tag + "-COLON"; // System.out.println("Punct: Split colon"); // debugging } else if (ChineseTreebankLanguagePack.chineseQuoteMarkAcceptFilter().accept(word)) { if (chineseSplitPunctLR) { if (ChineseTreebankLanguagePack.chineseLeftQuoteMarkAcceptFilter().accept(word)) { tag += "-LQUOTE"; } else { tag += "-RQUOTE"; } } else { tag = tag + "-QUOTE"; } // System.out.println("Punct: Split quote"); // debugging } else if (ChineseTreebankLanguagePack.chineseEndSentenceAcceptFilter().accept(word)) { tag = tag + "-ENDSENT"; // System.out.println("Punct: Split end sent"); // debugging } else if (ChineseTreebankLanguagePack.chineseParenthesisAcceptFilter().accept(word)) { if (chineseSplitPunctLR) { if (ChineseTreebankLanguagePack.chineseLeftParenthesisAcceptFilter().accept(word)) { tag += "-LPAREN"; } else { tag += "-RPAREN"; } } else { tag += "-PAREN"; // printlnErr("Just used -PAREN annotation"); // printlnErr(word); // throw new RuntimeException(); } // System.out.println("Punct: Split paren"); // debugging } else if (ChineseTreebankLanguagePack.chineseDashAcceptFilter().accept(word)) { tag = tag + "-DASH"; // System.out.println("Punct: Split dash"); // debugging } else if (ChineseTreebankLanguagePack.chineseOtherAcceptFilter().accept(word)) { tag = tag + "-OTHER"; } else { printlnErr("Unknown punct (you should add it to CTLP): " + tag + " |" + word + "|"); } } else if (chineseSplitDouHao) { // only split DouHao if (ChineseTreebankLanguagePack.chineseDouHaoAcceptFilter().accept(word) && baseTag.equals("PU")) { tag = tag + "-DOU"; } } // Chinese-specific POS tag splits (non-punctuation) if (tagWordSize) { int l = word.length(); tag += "-" + l + "CHARS"; } if (mergeNNVV && baseTag.equals("NN")) { tag = "VV"; } if ((chineseSelectiveTagPA || chineseVerySelectiveTagPA) && (baseTag.equals("CC") || baseTag.equals("P"))) { tag += "-" + baseParentStr; } if (chineseSelectiveTagPA && (baseTag.equals("VV"))) { tag += "-" + baseParentStr; } if (markMultiNtag && tag.startsWith("N")) { for (int i = 0; i < parent.numChildren(); i++) { if (parent.children()[i].label().value().startsWith("N") && parent.children()[i] != t) { tag += "=N"; // System.out.println("Found multi=N rewrite"); } } } if (markVVsisterIP && baseTag.equals("VV")) { boolean seenIP = false; for (int i = 0; i < parent.numChildren(); i++) { if (parent.children()[i].label().value().startsWith("IP")) { seenIP = true; } } if (seenIP) { tag += "-IP"; // System.out.println("Found VV with IP sister"); // testing } } if (markPsisterIP && baseTag.equals("P")) { boolean seenIP = false; for (int i = 0; i < parent.numChildren(); i++) { if (parent.children()[i].label().value().startsWith("IP")) { seenIP = true; } } if (seenIP) { tag += "-IP"; } } if (markADgrandchildOfIP && baseTag.equals("AD") && baseGrandParentStr.equals("IP")) { tag += "~IP"; // System.out.println("Found AD with IP grandparent"); // testing } if (gpaAD && baseTag.equals("AD")) { tag += "~" + baseGrandParentStr; // System.out.println("Found AD with grandparent " + grandParentStr); // testing } if (markPostverbalP && leftAunts.contains("VV") && baseTag.equals("P")) { // System.out.println("Found post-verbal P"); tag += "^=lVV"; } // end Chinese-specific tag splits Label label = new CategoryWordTag(tag, word, tag); t.setLabel(label); } else { // it's a phrasal category Tree[] kids = t.children(); // Chinese-specific category splits List<String> leftSis = listBasicCategories(SisterAnnotationStats.leftSisterLabels(t, parent)); List<String> rightSis = listBasicCategories(SisterAnnotationStats.rightSisterLabels(t, parent)); if (paRootDtr && baseParentStr.equals("ROOT")) { category += "^ROOT"; } if (markIPsisterBA && baseCategory.equals("IP")) { if (leftSis.contains("BA")) { category += "=BA"; // System.out.println("Found IP sister of BA"); } } if (dominatesV && hasV(t.preTerminalYield())) { // mark categories containing a verb category += "-v"; } if (markIPsisterVVorP && baseCategory.equals("IP")) { // todo: cdm: is just looking for "P" here selective enough?? if (leftSis.contains("VV") || leftSis.contains("P")) { category += "=VVP"; } } if (markIPsisDEC && baseCategory.equals("IP")) { if (rightSis.contains("DEC")) { category += "=DEC"; // System.out.println("Found prenominal IP"); } } if (baseCategory.equals("VP")) { // cdm 2008: this used to just check that it startsWith("VP"), but // I think that was bad because it also matched VPT verb compounds if (chineseSplitVP == 3) { boolean hasCC = false; boolean hasPU = false; boolean hasLexV = false; for (Tree kid : kids) { if (kid.label().value().startsWith("CC")) { hasCC = true; } else if (kid.label().value().startsWith("PU")) { hasPU = true; } else if (StringUtils.lookingAt( kid.label().value(), "(V[ACEV]|VCD|VCP|VNV|VPT|VRD|VSB)")) { hasLexV = true; } } if (hasCC || (hasPU && !hasLexV)) { category += "-CRD"; // System.out.println("Found coordinate VP"); // testing } else if (hasLexV) { category += "-COMP"; // System.out.println("Found complementing VP"); // testing } else { category += "-ADJT"; // System.out.println("Found adjoining VP"); // testing } } else if (chineseSplitVP >= 1) { boolean hasBA = false; for (Tree kid : kids) { if (kid.label().value().startsWith("BA")) { hasBA = true; } else if (chineseSplitVP == 2 && tlp.basicCategory(kid.label().value()).equals("VP")) { for (Tree kidkid : kid.children()) { if (kidkid.label().value().startsWith("BA")) { hasBA = true; } } } } if (hasBA) { category += "-BA"; } } } if (markVPadjunct && baseParentStr.equals("VP")) { // cdm 2008: This used to use startsWith("VP") but changed to baseCat Tree[] sisters = parent.children(); boolean hasVPsister = false; boolean hasCC = false; boolean hasPU = false; boolean hasLexV = false; for (Tree sister : sisters) { if (tlp.basicCategory(sister.label().value()).equals("VP")) { hasVPsister = true; } if (sister.label().value().startsWith("CC")) { hasCC = true; } if (sister.label().value().startsWith("PU")) { hasPU = true; } if (StringUtils.lookingAt(sister.label().value(), "(V[ACEV]|VCD|VCP|VNV|VPT|VRD|VSB)")) { hasLexV = true; } } if (hasVPsister && !(hasCC || hasPU || hasLexV)) { category += "-VPADJ"; // System.out.println("Found adjunct of VP"); // testing } } if (markNPmodNP && baseCategory.equals("NP") && baseParentStr.equals("NP")) { if (rightSis.contains("NP")) { category += "=MODIFIERNP"; // System.out.println("Found NP modifier of NP"); // testing } } if (markModifiedNP && baseCategory.equals("NP") && baseParentStr.equals("NP")) { if (rightSis.isEmpty() && (leftSis.contains("ADJP") || leftSis.contains("NP") || leftSis.contains("DNP") || leftSis.contains("QP") || leftSis.contains("CP") || leftSis.contains("PP"))) { category += "=MODIFIEDNP"; // System.out.println("Found modified NP"); // testing } } if (markNPconj && baseCategory.equals("NP") && baseParentStr.equals("NP")) { if (rightSis.contains("CC") || rightSis.contains("PU") || leftSis.contains("CC") || leftSis.contains("PU")) { category += "=CONJ"; // System.out.println("Found NP conjunct"); // testing } } if (markIPconj && baseCategory.equals("IP") && baseParentStr.equals("IP")) { Tree[] sisters = parent.children(); boolean hasCommaSis = false; boolean hasIPSis = false; for (Tree sister : sisters) { if (ctlp.basicCategory(sister.label().value()).equals("PU") && ChineseTreebankLanguagePack.chineseCommaAcceptFilter() .accept(sister.children()[0].label().toString())) { hasCommaSis = true; // System.out.println("Found CommaSis"); // testing } if (ctlp.basicCategory(sister.label().value()).equals("IP") && sister != t) { hasIPSis = true; } } if (hasCommaSis && hasIPSis) { category += "-CONJ"; // System.out.println("Found IP conjunct"); // testing } } if (unaryIP && baseCategory.equals("IP") && t.numChildren() == 1) { category += "-U"; // System.out.println("Found unary IP"); //testing } if (unaryCP && baseCategory.equals("CP") && t.numChildren() == 1) { category += "-U"; // System.out.println("Found unary CP"); //testing } if (splitBaseNP && baseCategory.equals("NP")) { if (t.isPrePreTerminal()) { category = category + "-B"; } } // if (Test.verbose) printlnErr(baseCategory + " " + leftSis.toString()); //debugging if (markPostverbalPP && leftSis.contains("VV") && baseCategory.equals("PP")) { // System.out.println("Found post-verbal PP"); category += "=lVV"; } if ((markADgrandchildOfIP || gpaAD) && listBasicCategories(SisterAnnotationStats.kidLabels(t)).contains("AD")) { category += "^ADVP"; } if (markCC) { // was: for (int i = 0; i < kids.length; i++) { // This second version takes an idea from Collins: don't count // marginal conjunctions which don't conjoin 2 things. for (int i = 1; i < kids.length - 1; i++) { String cat2 = kids[i].label().value(); if (cat2.startsWith("CC")) { category += "-CC"; } } } Label label = new CategoryWordTag(category, word, tag); t.setLabel(label); } return t; }
private void forwardPropagateTree( Tree tree, List<String> words, IdentityHashMap<Tree, SimpleMatrix> nodeVectors, IdentityHashMap<Tree, Double> scores) { if (tree.isLeaf()) { return; } if (tree.isPreTerminal()) { Tree wordNode = tree.children()[0]; String word = wordNode.label().value(); SimpleMatrix wordVector = dvModel.getWordVector(word); wordVector = NeuralUtils.elementwiseApplyTanh(wordVector); nodeVectors.put(tree, wordVector); return; } for (Tree child : tree.children()) { forwardPropagateTree(child, words, nodeVectors, scores); } // at this point, nodeVectors contains the vectors for all of // the children of tree SimpleMatrix childVec; if (tree.children().length == 2) { childVec = NeuralUtils.concatenateWithBias( nodeVectors.get(tree.children()[0]), nodeVectors.get(tree.children()[1])); } else { childVec = NeuralUtils.concatenateWithBias(nodeVectors.get(tree.children()[0])); } if (op.trainOptions.useContextWords) { childVec = concatenateContextWords(childVec, tree.getSpan(), words); } SimpleMatrix W = dvModel.getWForNode(tree); if (W == null) { String error = "Could not find W for tree " + tree; if (op.testOptions.verbose) { System.err.println(error); } throw new NoSuchParseException(error); } SimpleMatrix currentVector = W.mult(childVec); currentVector = NeuralUtils.elementwiseApplyTanh(currentVector); nodeVectors.put(tree, currentVector); SimpleMatrix scoreW = dvModel.getScoreWForNode(tree); if (scoreW == null) { String error = "Could not find scoreW for tree " + tree; if (op.testOptions.verbose) { System.err.println(error); } throw new NoSuchParseException(error); } double score = scoreW.dot(currentVector); // score = NeuralUtils.sigmoid(score); scores.put(tree, score); // System.err.print(Double.toString(score)+" "); }
private PropertyList addConstituentFeatures( PropertyList pl, Document doc, Pair<Integer, Integer> candidate, int arg2Line, int arg2HeadPos, int connStart, int connEnd) { Sentence arg2Sentence = doc.getSentence(arg2Line); String conn = arg2Sentence.toString(connStart, connEnd); int connHeadPos = connAnalyzer.getHeadWord(arg2Sentence.getParseTree(), connStart, connEnd); int arg1Line = candidate.first(); Tree arg1Tree = doc.getTree(arg1Line); int arg1HeadPos = candidate.second(); List<String> path = new ArrayList<String>(); List<String> pathWithoutPOS = new ArrayList<String>(); if (arg1Line == arg2Line) { Tree root = arg1Tree; List<Tree> leaves = root.getLeaves(); List<Tree> treePath = root.pathNodeToNode(leaves.get(connHeadPos), leaves.get(arg1HeadPos)); if (treePath != null) { for (Tree t : treePath) { if (!t.isLeaf()) { path.add(t.value()); if (!t.isPreTerminal()) { pathWithoutPOS.add(t.value()); } } } } } else { Tree arg2Root = arg2Sentence.getParseTree(); Tree mainHead = headAnalyzer.getCollinsHead(arg2Root.getChild(0)); List<Tree> leaves = arg2Root.getLeaves(); int mainHeadPos = treeAnalyzer.getLeafPosition(arg2Root, mainHead); if (mainHeadPos != -1) { List<Tree> treePath = arg2Root.pathNodeToNode(leaves.get(connHeadPos), leaves.get(mainHeadPos)); if (treePath != null) { for (Tree t : treePath) { if (!t.isLeaf()) { path.add(t.value()); if (!t.isPreTerminal()) { pathWithoutPOS.add(t.value()); } } } } } for (int i = 0; i < Math.abs(arg1Line - arg2Line); i++) { path.add("SENT"); pathWithoutPOS.add("SENT"); } Tree arg1Root = arg1Tree; mainHead = headAnalyzer.getCollinsHead(arg1Root.getChild(0)); leaves = arg1Root.getLeaves(); mainHeadPos = treeAnalyzer.getLeafPosition(arg1Root, mainHead); if (mainHeadPos != -1) { List<Tree> treePath = arg1Root.pathNodeToNode(leaves.get(mainHeadPos), leaves.get(arg1HeadPos)); if (treePath != null) { for (Tree t : treePath) { if (!t.isLeaf()) { path.add(t.value()); if (!t.isPreTerminal()) { pathWithoutPOS.add(t.value()); } } } } } } // H-full path // L-C&H StringBuilder fullPath = new StringBuilder(); for (String node : path) { fullPath.append(node).append(":"); } pl = PropertyList.add("H=" + fullPath.toString(), 1.0, pl); pl = PropertyList.add("L=CONN-" + conn + "&" + "H-" + fullPath.toString(), 1.0, pl); // I-length of path pl = PropertyList.add("I=" + path.size(), 1.0, pl); // J-collapsed path without part of speech // K-collapsed path without repititions fullPath = new StringBuilder(); StringBuilder collapsedPath = new StringBuilder(); String prev = ""; for (String node : pathWithoutPOS) { fullPath.append(node).append(":"); if (!node.equals(prev)) { collapsedPath.append(node).append(":"); } prev = node; } pl = PropertyList.add("J=" + fullPath.toString(), 1.0, pl); pl = PropertyList.add("K=" + collapsedPath.toString(), 1.0, pl); return pl; }
public void backpropDerivative( Tree tree, List<String> words, IdentityHashMap<Tree, SimpleMatrix> nodeVectors, TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfs, Map<String, SimpleMatrix> unaryW_dfs, TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivatives, Map<String, SimpleMatrix> unaryScoreDerivatives, Map<String, SimpleMatrix> wordVectorDerivatives, SimpleMatrix deltaUp) { if (tree.isLeaf()) { return; } if (tree.isPreTerminal()) { if (op.trainOptions.trainWordVectors) { String word = tree.children()[0].label().value(); word = dvModel.getVocabWord(word); // SimpleMatrix currentVector = nodeVectors.get(tree); // SimpleMatrix currentVectorDerivative = // nonlinearityVectorToDerivative(currentVector); // SimpleMatrix derivative = deltaUp.elementMult(currentVectorDerivative); SimpleMatrix derivative = deltaUp; wordVectorDerivatives.put(word, wordVectorDerivatives.get(word).plus(derivative)); } return; } SimpleMatrix currentVector = nodeVectors.get(tree); SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector); SimpleMatrix scoreW = dvModel.getScoreWForNode(tree); currentVectorDerivative = currentVectorDerivative.elementMult(scoreW.transpose()); // the delta that is used at the current nodes SimpleMatrix deltaCurrent = deltaUp.plus(currentVectorDerivative); SimpleMatrix W = dvModel.getWForNode(tree); SimpleMatrix WTdelta = W.transpose().mult(deltaCurrent); if (tree.children().length == 2) { // TODO: RS: Change to the nice "getWForNode" setup? String leftLabel = dvModel.basicCategory(tree.children()[0].label().value()); String rightLabel = dvModel.basicCategory(tree.children()[1].label().value()); binaryScoreDerivatives.put( leftLabel, rightLabel, binaryScoreDerivatives.get(leftLabel, rightLabel).plus(currentVector.transpose())); SimpleMatrix leftVector = nodeVectors.get(tree.children()[0]); SimpleMatrix rightVector = nodeVectors.get(tree.children()[1]); SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector); if (op.trainOptions.useContextWords) { childrenVector = concatenateContextWords(childrenVector, tree.getSpan(), words); } SimpleMatrix W_df = deltaCurrent.mult(childrenVector.transpose()); binaryW_dfs.put(leftLabel, rightLabel, binaryW_dfs.get(leftLabel, rightLabel).plus(W_df)); // and then recurse SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector); SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector); SimpleMatrix leftWTDelta = WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1); SimpleMatrix rightWTDelta = WTdelta.extractMatrix(deltaCurrent.numRows(), deltaCurrent.numRows() * 2, 0, 1); backpropDerivative( tree.children()[0], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, leftDerivative.elementMult(leftWTDelta)); backpropDerivative( tree.children()[1], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, rightDerivative.elementMult(rightWTDelta)); } else if (tree.children().length == 1) { String childLabel = dvModel.basicCategory(tree.children()[0].label().value()); unaryScoreDerivatives.put( childLabel, unaryScoreDerivatives.get(childLabel).plus(currentVector.transpose())); SimpleMatrix childVector = nodeVectors.get(tree.children()[0]); SimpleMatrix childVectorWithBias = NeuralUtils.concatenateWithBias(childVector); if (op.trainOptions.useContextWords) { childVectorWithBias = concatenateContextWords(childVectorWithBias, tree.getSpan(), words); } SimpleMatrix W_df = deltaCurrent.mult(childVectorWithBias.transpose()); // System.out.println("unary backprop derivative for " + childLabel); // System.out.println("Old transform:"); // System.out.println(unaryW_dfs.get(childLabel)); // System.out.println(" Delta:"); // System.out.println(W_df.scale(scale)); unaryW_dfs.put(childLabel, unaryW_dfs.get(childLabel).plus(W_df)); // and then recurse SimpleMatrix childDerivative = NeuralUtils.elementwiseApplyTanhDerivative(childVector); // SimpleMatrix childDerivative = childVector; SimpleMatrix childWTDelta = WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1); backpropDerivative( tree.children()[0], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, childDerivative.elementMult(childWTDelta)); } }
/** * Adds dependencies to list depList. These are in terms of the original tag set not the reduced * (projected) tag set. */ protected static EndHead treeToDependencyHelper( Tree tree, List<IntDependency> depList, int loc, Index<String> wordIndex, Index<String> tagIndex) { // try { // PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out,"GB18030"),true); // tree.pennPrint(pw); // } // catch (UnsupportedEncodingException e) {} if (tree.isLeaf() || tree.isPreTerminal()) { EndHead tempEndHead = new EndHead(); tempEndHead.head = loc; tempEndHead.end = loc + 1; return tempEndHead; } Tree[] kids = tree.children(); if (kids.length == 1) { return treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex); } EndHead tempEndHead = treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex); int lHead = tempEndHead.head; int split = tempEndHead.end; tempEndHead = treeToDependencyHelper(kids[1], depList, tempEndHead.end, wordIndex, tagIndex); int end = tempEndHead.end; int rHead = tempEndHead.head; String hTag = ((HasTag) tree.label()).tag(); String lTag = ((HasTag) kids[0].label()).tag(); String rTag = ((HasTag) kids[1].label()).tag(); String hWord = ((HasWord) tree.label()).word(); String lWord = ((HasWord) kids[0].label()).word(); String rWord = ((HasWord) kids[1].label()).word(); boolean leftHeaded = hWord.equals(lWord); String aTag = (leftHeaded ? rTag : lTag); String aWord = (leftHeaded ? rWord : lWord); int hT = tagIndex.indexOf(hTag); int aT = tagIndex.indexOf(aTag); int hW = (wordIndex.contains(hWord) ? wordIndex.indexOf(hWord) : wordIndex.indexOf(Lexicon.UNKNOWN_WORD)); int aW = (wordIndex.contains(aWord) ? wordIndex.indexOf(aWord) : wordIndex.indexOf(Lexicon.UNKNOWN_WORD)); int head = (leftHeaded ? lHead : rHead); int arg = (leftHeaded ? rHead : lHead); IntDependency dependency = new IntDependency( hW, hT, aW, aT, leftHeaded, (leftHeaded ? split - head - 1 : head - split)); depList.add(dependency); IntDependency stopL = new IntDependency( aW, aT, STOP_WORD_INT, STOP_TAG_INT, false, (leftHeaded ? arg - split : arg - loc)); depList.add(stopL); IntDependency stopR = new IntDependency( aW, aT, STOP_WORD_INT, STOP_TAG_INT, true, (leftHeaded ? end - arg - 1 : split - arg - 1)); depList.add(stopR); // System.out.println("Adding: "+dependency+" at "+tree.label()); tempEndHead.head = head; return tempEndHead; }