/** * Return information about the objects in this Tree. * * @param t The tree to examine. * @return A human-readable String */ public static String toDebugStructureString(Tree t) { StringBuilder sb = new StringBuilder(); String tCl = StringUtils.getShortClassName(t); String tfCl = StringUtils.getShortClassName(t.treeFactory()); String lCl = StringUtils.getShortClassName(t.label()); String lfCl = StringUtils.getShortClassName(t.label().labelFactory()); Set<String> otherClasses = new HashSet<String>(); for (Tree st : t) { String stCl = StringUtils.getShortClassName(st); String stfCl = StringUtils.getShortClassName(st.treeFactory()); String slCl = StringUtils.getShortClassName(st.label()); String slfCl = StringUtils.getShortClassName(st.label().labelFactory()); if (!tCl.equals(stCl)) { otherClasses.add(stCl); } if (!tfCl.equals(stfCl)) { otherClasses.add(stfCl); } if (!lCl.equals(slCl)) { otherClasses.add(slCl); } if (!lfCl.equals(slfCl)) { otherClasses.add(slfCl); } } sb.append("Tree with root of class ").append(tCl).append(" and factory ").append(tfCl); sb.append(" with label class ").append(lCl).append(" and factory ").append(lfCl); if (!otherClasses.isEmpty()) { sb.append(" with the following classes also found within the tree: ").append(otherClasses); } return sb.toString(); }
Tree prune(Tree tree, int start) { if (tree.isLeaf() || tree.isPreTerminal()) { return tree; } // check each node's children for deletion List<Tree> children = helper(tree.getChildrenAsList(), start); children = prune(children, tree.label(), start, start + tree.yield().size()); return tree.treeFactory().newTreeNode(tree.label(), children); }
/** * Called by determineHead and may be overridden in subclasses if special treatment is necessary * for particular categories. */ protected Tree determineNonTrivialHead(Tree t, Tree parent) { Tree theHead = null; String motherCat = tlp.basicCategory(t.label().value()); if (DEBUG) { System.err.println( "Looking for head of " + t.label() + "; value is |" + t.label().value() + "|, " + " baseCat is |" + motherCat + '|'); } // We know we have nonterminals underneath // (a bit of a Penn Treebank assumption, but). // Look at label. // a total special case.... // first look for POS tag at end // this appears to be redundant in the Collins case since the rule already would do that // Tree lastDtr = t.lastChild(); // if (tlp.basicCategory(lastDtr.label().value()).equals("POS")) { // theHead = lastDtr; // } else { String[][] how = nonTerminalInfo.get(motherCat); if (how == null) { if (DEBUG) { System.err.println( "Warning: No rule found for " + motherCat + " (first char: " + motherCat.charAt(0) + ')'); System.err.println("Known nonterms are: " + nonTerminalInfo.keySet()); } if (defaultRule != null) { if (DEBUG) { System.err.println(" Using defaultRule"); } return traverseLocate(t.children(), defaultRule, true); } else { return null; } } for (int i = 0; i < how.length; i++) { boolean deflt = (i == how.length - 1); theHead = traverseLocate(t.children(), how[i], deflt); if (theHead != null) { break; } } if (DEBUG) { System.err.println(" Chose " + theHead.label()); } return theHead; }
public static Tree normalizeTree(Tree tree, TreeNormalizer tn, TreeFactory tf) { for (Tree node : tree) { if (node.isLeaf()) { node.label().setValue(tn.normalizeTerminal(node.label().value())); } else { node.label().setValue(tn.normalizeNonterminal(node.label().value())); } } return tn.normalizeWholeTree(tree, tf); }
/** * returns a list of categories that is the path from Tree from to Tree to within Tree root. If * either from or to is not in root, returns null. Otherwise includes both from and to in the * list. */ public static List<String> pathNodeToNode(Tree from, Tree to, Tree root) { List<Tree> fromPath = pathFromRoot(from, root); // System.out.println(treeListToCatList(fromPath)); if (fromPath == null) return null; List<Tree> toPath = pathFromRoot(to, root); // System.out.println(treeListToCatList(toPath)); if (toPath == null) return null; // System.out.println(treeListToCatList(fromPath)); // System.out.println(treeListToCatList(toPath)); int last = 0; int min = fromPath.size() <= toPath.size() ? fromPath.size() : toPath.size(); Tree lastNode = null; // while((! (fromPath.isEmpty() || toPath.isEmpty())) && // fromPath.get(0).equals(toPath.get(0))) { // lastNode = (Tree) fromPath.remove(0); // toPath.remove(0); // } while (last < min && fromPath.get(last).equals(toPath.get(last))) { lastNode = fromPath.get(last); last++; } // System.out.println(treeListToCatList(fromPath)); // System.out.println(treeListToCatList(toPath)); List<String> totalPath = new ArrayList<String>(); for (int i = fromPath.size() - 1; i >= last; i--) { Tree t = fromPath.get(i); totalPath.add("up-" + t.label().value()); } if (lastNode != null) totalPath.add("up-" + lastNode.label().value()); for (Tree t : toPath) totalPath.add("down-" + t.label().value()); // for(ListIterator i = fromPath.listIterator(fromPath.size()); i.hasPrevious(); ){ // Tree t = (Tree) i.previous(); // totalPath.add("up-" + t.label().value()); // } // if(lastNode != null) // totalPath.add("up-" + lastNode.label().value()); // for(ListIterator j = toPath.listIterator(); j.hasNext(); ){ // Tree t = (Tree) j.next(); // totalPath.add("down-" + t.label().value()); // } return totalPath; }
List<Tree> prune(List<Tree> treeList, Label label, int start, int end) { // get reference tree if (treeList.size() == 1) { return treeList; } Tree testTree = treeList.get(0).treeFactory().newTreeNode(label, treeList); int goal = Numberer.getGlobalNumberer("states").number(label.value()); Tree tempTree = parser.extractBestParse(goal, start, end); // parser.restoreUnaries(tempTree); Tree pcfgTree = debinarizer.transformTree(tempTree); Set<Constituent> pcfgConstituents = pcfgTree.constituents(new LabeledScoredConstituentFactory()); // delete child labels that are not in reference but do not cross reference List<Tree> prunedChildren = new ArrayList<Tree>(); int childStart = 0; for (int c = 0, numCh = testTree.numChildren(); c < numCh; c++) { Tree child = testTree.getChild(c); boolean isExtra = true; int childEnd = childStart + child.yield().size(); Constituent childConstituent = new LabeledScoredConstituent(childStart, childEnd, child.label(), 0); if (pcfgConstituents.contains(childConstituent)) { isExtra = false; } if (childConstituent.crosses(pcfgConstituents)) { isExtra = false; } if (child.isLeaf() || child.isPreTerminal()) { isExtra = false; } if (pcfgTree.yield().size() != testTree.yield().size()) { isExtra = false; } if (!label.value().startsWith("NP^NP")) { isExtra = false; } if (isExtra) { System.err.println( "Pruning: " + child.label() + " from " + (childStart + start) + " to " + (childEnd + start)); System.err.println("Was: " + testTree + " vs " + pcfgTree); prunedChildren.addAll(child.getChildrenAsList()); } else { prunedChildren.add(child); } childStart = childEnd; } return prunedChildren; }
public Tree transformTree(Tree tree) { Tree treeCopy = tree.deepCopy(); for (Tree subtree : treeCopy) { if (subtree.depth() < 2) { continue; } String categoryLabel = subtree.label().toString(); Label label = subtree.label(); label.setFromString(categoryLabel + "-t3"); } return treeCopy; }
protected Rule ltToRule(Tree lt) { if (lt.children().length == 1) { UnaryRule ur = new UnaryRule(); ur.parent = stateNumberer.number(lt.label().value()); ur.child = stateNumberer.number(lt.children()[0].label().value()); return ur; } else { BinaryRule br = new BinaryRule(); br.parent = stateNumberer.number(lt.label().value()); br.leftChild = stateNumberer.number(lt.children()[0].label().value()); br.rightChild = stateNumberer.number(lt.children()[1].label().value()); return br; } }
public Tree transformTree(Tree tree) { Label lab = tree.label(); if (tree.isLeaf()) { Tree leaf = tf.newLeaf(lab); leaf.setScore(tree.score()); return leaf; } String s = lab.value(); s = treebankLanguagePack().basicCategory(s); s = treebankLanguagePack().stripGF(s); int numKids = tree.numChildren(); List<Tree> children = new ArrayList<Tree>(numKids); for (int cNum = 0; cNum < numKids; cNum++) { Tree child = tree.getChild(cNum); Tree newChild = transformTree(child); children.add(newChild); } CategoryWordTag newLabel = new CategoryWordTag(lab); newLabel.setCategory(s); if (lab instanceof HasTag) { String tag = ((HasTag) lab).tag(); tag = treebankLanguagePack().basicCategory(tag); tag = treebankLanguagePack().stripGF(tag); newLabel.setTag(tag); } Tree node = tf.newTreeNode(newLabel, children); node.setScore(tree.score()); return node; }
public Tree transformTree(Tree tree) { Label lab = tree.label(); if (tree.isLeaf()) { Tree leaf = tf.newLeaf(lab); leaf.setScore(tree.score()); return leaf; } String s = lab.value(); s = treebankLanguagePack().basicCategory(s); int numKids = tree.numChildren(); List<Tree> children = new ArrayList<Tree>(numKids); for (int cNum = 0; cNum < numKids; cNum++) { Tree child = tree.getChild(cNum); Tree newChild = transformTree(child); // cdm 2007: for just subcategory stripping, null shouldn't happen // if (newChild != null) { children.add(newChild); // } } // if (children.isEmpty()) { // return null; // } CategoryWordTag newLabel = new CategoryWordTag(lab); newLabel.setCategory(s); if (lab instanceof HasTag) { String tag = ((HasTag) lab).tag(); tag = treebankLanguagePack().basicCategory(tag); newLabel.setTag(tag); } Tree node = tf.newTreeNode(newLabel, children); node.setScore(tree.score()); return node; }
/** * returns the syntactic category of the tree as a list of the syntactic categories of the mother * and the daughters */ public static List<String> localTreeAsCatList(Tree t) { List<String> l = new ArrayList<String>(t.children().length + 1); l.add(t.label().value()); for (int i = 0; i < t.children().length; i++) { l.add(t.children()[i].label().value()); } return l; }
public static Tree copyHelper(Tree t, Map<Tree, Tree> newToOld, Map<Tree, Tree> oldToNew) { Tree[] kids = t.children(); Tree[] newKids = new Tree[kids.length]; for (int i = 0, n = kids.length; i < n; i++) { newKids[i] = copyHelper(kids[i], newToOld, oldToNew); } TreeFactory tf = t.treeFactory(); if (kids.length == 0) { Tree newLeaf = tf.newLeaf(t.label()); newToOld.put(newLeaf, t); oldToNew.put(newLeaf, t); return newLeaf; } Tree newNode = tf.newTreeNode(t.label(), Arrays.asList(newKids)); newToOld.put(newNode, t); oldToNew.put(t, newNode); return newNode; }
private static void leafLabels(Tree t, List<Label> l) { if (t.isLeaf()) { l.add(t.label()); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { leafLabels(kids[j], l); } } }
/** * This is the method to call for assigning labels and node vectors to the Tree. After calling * this, each of the non-leaf nodes will have the node vector and the predictions of their classes * assigned to that subtree's node. */ public void forwardPropagateTree(Tree tree) { FloatMatrix nodeVector; FloatMatrix classification; if (tree.isLeaf()) { // We do nothing for the leaves. The preterminals will // calculate the classification for this word/tag. In fact, the // recursion should not have gotten here (unless there are // degenerate trees of just one leaf) throw new AssertionError("We should not have reached leaves in forwardPropagate"); } else if (tree.isPreTerminal()) { classification = getUnaryClassification(tree.label()); String word = tree.children().get(0).value(); FloatMatrix wordVector = getFeatureVector(word); if (wordVector == null) { wordVector = featureVectors.get(UNKNOWN_FEATURE); } nodeVector = activationFunction.apply(wordVector); } else if (tree.children().size() == 1) { throw new AssertionError( "Non-preterminal nodes of size 1 should have already been collapsed"); } else if (tree.children().size() == 2) { Tree left = tree.firstChild(), right = tree.lastChild(); forwardPropagateTree(left); forwardPropagateTree(right); String leftCategory = tree.children().get(0).label(); String rightCategory = tree.children().get(1).label(); FloatMatrix W = getBinaryTransform(leftCategory, rightCategory); classification = getBinaryClassification(leftCategory, rightCategory); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); if (useFloatTensors) { FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory); FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector); FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn); nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut)); } else nodeVector = activationFunction.apply(W.mmul(childrenVector)); } else { throw new AssertionError("Tree not correctly binarized"); } FloatMatrix inputWithBias = appendBias(nodeVector); FloatMatrix preAct = classification.mmul(inputWithBias); FloatMatrix predictions = outputActivation.apply(preAct); tree.setPrediction(predictions); tree.setVector(nodeVector); }
/** * Load a collection of parse trees from a Reader. Each tree may optionally be encased in parens * to allow for Penn Treebank style trees. * * @param r The reader to read trees from. (If you want it buffered, you should already have * buffered it!) * @param id An ID for where these files come from (arbitrary, but something like a filename. Can * be <code>null</code> for none. */ public void load(Reader r, String id) { try { // could throw an IO exception? TreeReader tr = treeReaderFactory().newTreeReader(r); int sentIndex = 0; for (Tree pt; (pt = tr.readTree()) != null; ) { if (pt.label() instanceof HasIndex) { // so we can trace where this tree came from HasIndex hi = (HasIndex) pt.label(); if (id != null) { hi.setDocID(id); } hi.setSentIndex(sentIndex); } parseTrees.add(pt); sentIndex++; } } catch (IOException e) { System.err.println("load IO Exception: " + e); } }
private static void transformCC(Tree t, List<Tree> left, Tree conj, List<Tree> right) { TreeFactory tf = t.treeFactory(); LabelFactory lf = t.label().labelFactory(); Tree leftQP = tf.newTreeNode(lf.newLabel("NP"), left); Tree rightQP = tf.newTreeNode(lf.newLabel("NP"), right); List<Tree> newChildren = new ArrayList<Tree>(); newChildren.add(leftQP); newChildren.add(conj); newChildren.add(rightQP); t.setChildren(newChildren); }
private static void taggedLeafLabels(Tree t, List<CoreLabel> l) { if (t.isPreTerminal()) { CoreLabel fl = (CoreLabel) t.getChild(0).label(); fl.set(TagLabelAnnotation.class, t.label()); l.add(fl); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { taggedLeafLabels(kids[j], l); } } }
protected void tallyTree(Tree t, LinkedList<String> parents) { // traverse tree, building parent list String str = t.label().value(); boolean strIsPassive = (str.indexOf('@') == -1); if (strIsPassive) { parents.addFirst(str); } if (!t.isLeaf()) { if (!t.children()[0].isLeaf()) { tallyInternalNode(t, parents); for (int c = 0; c < t.children().length; c++) { Tree child = t.children()[c]; tallyTree(child, parents); } } else { tagNumberer.number(t.label().value()); } } if (strIsPassive) { parents.removeFirst(); } }
/** * Determine which daughter of the current parse tree is the head. * * @param t The parse tree to examine the daughters of. If this is a leaf, <code>null</code> is * returned * @param parent The parent of t * @return The daughter parse tree that is the head of <code>t</code>. Returns null for leaf * nodes. * @see Tree#percolateHeads(HeadFinder) for a routine to call this and spread heads throughout a * tree */ public Tree determineHead(Tree t, Tree parent) { if (nonTerminalInfo == null) { throw new RuntimeException( "Classes derived from AbstractCollinsHeadFinder must" + " create and fill HashMap nonTerminalInfo."); } if (DEBUG) { System.err.println("determineHead for " + t.value()); } if (t.isLeaf()) { return null; } Tree[] kids = t.children(); Tree theHead; // first check if subclass found explicitly marked head if ((theHead = findMarkedHead(t)) != null) { if (DEBUG) { System.err.println( "Find marked head method returned " + theHead.label() + " as head of " + t.label()); } return theHead; } // if the node is a unary, then that kid must be the head // it used to special case preterminal and ROOT/TOP case // but that seemed bad (especially hardcoding string "ROOT") if (kids.length == 1) { if (DEBUG) { System.err.println( "Only one child determines " + kids[0].label() + " as head of " + t.label()); } return kids[0]; } return determineNonTrivialHead(t, parent); }
protected void tallyInternalNode(Tree lt, List parents) { // form base rule String label = lt.label().value(); Rule baseR = ltToRule(lt); ruleToLabel.put(baseR, label); // act on each history depth for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size()); depth <= maxDepth; depth++) { List history = new ArrayList(parents.subList(0, depth)); // tally each history level / rewrite pair rulePairs.incrementCount(new Pair(baseR, history), 1); labelPairs.incrementCount(new Pair(label, history), 1); } }
/* Checks whether the tree t is an existential constituent * There are two cases: * -- affirmative sentences in which "there" is a left sister of the VP * -- questions in which "there" is a daughter of the SQ. * */ private boolean isExistential(Tree t, Tree parent) { if (DEBUG) { System.err.println("isExistential: " + t + ' ' + parent); } boolean toReturn = false; String motherCat = tlp.basicCategory(t.label().value()); // affirmative case if (motherCat.equals("VP") && parent != null) { // take t and the sisters Tree[] kids = parent.children(); // iterate over the sisters before t and checks if existential for (Tree kid : kids) { if (!kid.value().equals("VP")) { List<Label> tags = kid.preTerminalYield(); for (Label tag : tags) { if (tag.value().equals("EX")) { toReturn = true; } } } else { break; } } } // question case else if (motherCat.startsWith("SQ") && parent != null) { // take the daughters Tree[] kids = parent.children(); // iterate over the daughters and checks if existential for (Tree kid : kids) { if (!kid.value().startsWith("VB")) { // not necessary to look into the verb List<Label> tags = kid.preTerminalYield(); for (Label tag : tags) { if (tag.value().equals("EX")) { toReturn = true; } } } } } if (DEBUG) { System.err.println("decision " + toReturn); } return toReturn; }
/** * Takes a Tree and a collinizer and returns a Collection of {@link Constituent}s for PARSEVAL * evaluation. Some notes on this particular parseval: * * <ul> * <li>It is character-based, which allows it to be used on segmentation/parsing combination * evaluation. * <li>whether it gives you labeled or unlabeled bracketings depends on the value of the <code> * labelConstituents</code> parameter * </ul> * * (Note that I haven't checked this rigorously yet with the PARSEVAL definition -- Roger.) */ public static Collection<Constituent> parsevalObjectify( Tree t, TreeTransformer collinizer, boolean labelConstituents) { Collection<Constituent> spans = new ArrayList<Constituent>(); Tree t1 = collinizer.transformTree(t); if (t1 == null) { return spans; } for (Tree node : t1) { if (node.isLeaf() || node.isPreTerminal() || (node != t1 && node.parent(t1) == null)) { continue; } int leftEdge = t1.leftCharEdge(node); int rightEdge = t1.rightCharEdge(node); if (labelConstituents) spans.add(new LabeledConstituent(leftEdge, rightEdge, node.label())); else spans.add(new SimpleConstituent(leftEdge, rightEdge)); } return spans; }
/** * This looks to see whether any of the children is a preterminal headed by a word which is within * the set verbalSet (which in practice is either auxiliary or copula verbs). It only returns true * if it's a preterminal head, since you don't want to pick things up in phrasal daughters. That * is an error. * * @param kids The child trees * @param verbalSet The set of words * @return Returns true if one of the child trees is a preterminal verb headed by a word in * verbalSet */ private boolean hasVerbalAuxiliary(Tree[] kids, HashSet<String> verbalSet) { if (DEBUG) { System.err.println("Checking for verbal auxiliary"); } for (Tree kid : kids) { if (DEBUG) { System.err.println(" checking in " + kid); } if (kid.isPreTerminal()) { Label kidLabel = kid.label(); String tag = null; if (kidLabel instanceof HasTag) { tag = ((HasTag) kidLabel).tag(); } if (tag == null) { tag = kid.value(); } Label wordLabel = kid.firstChild().label(); String word = null; if (wordLabel instanceof HasWord) { word = ((HasWord) wordLabel).word(); } if (word == null) { word = wordLabel.value(); } if (DEBUG) { System.err.println("Checking " + kid.value() + " head is " + word + '/' + tag); } String lcWord = word.toLowerCase(); if (verbalTags.contains(tag) && verbalSet.contains(lcWord)) { if (DEBUG) { System.err.println("hasVerbalAuxiliary returns true"); } return true; } } } if (DEBUG) { System.err.println("hasVerbalAuxiliary returns false"); } return false; }
private static void transformQP(Tree t) { List<Tree> children = t.getChildrenAsList(); TreeFactory tf = t.treeFactory(); LabelFactory lf = t.label().labelFactory(); // create the new XS having the first two children of the QP Tree left = tf.newTreeNode(lf.newLabel("XS"), null); for (int i = 0; i < 2; i++) { left.addChild(children.get(i)); } // remove all the two first children of t before for (int i = 0; i < 2; i++) { t.removeChild(0); } // add XS as the first child t.addChild(0, left); }
public Tree transformTree(Tree tree) { TreeFactory tf = tree.treeFactory(); String tag = tree.label().value(); if (tree.isPreTerminal()) { String word = tree.firstChild().label().value(); List<Tree> newPreterms = new ArrayList<>(); for (int i = 0, size = word.length(); i < size; i++) { String singleCharLabel = new String(new char[] {word.charAt(i)}); Tree newLeaf = tf.newLeaf(singleCharLabel); String suffix; if (useTwoCharTags) { if (word.length() == 1 || i == 0) { suffix = "_S"; } else { suffix = "_M"; } } else { if (word.length() == 1) { suffix = "_S"; } else if (i == 0) { suffix = "_B"; } else if (i == word.length() - 1) { suffix = "_E"; } else { suffix = "_M"; } } newPreterms.add(tf.newTreeNode(tag + suffix, Collections.<Tree>singletonList(newLeaf))); } return tf.newTreeNode(tag, newPreterms); } else { List<Tree> newChildren = new ArrayList<>(); for (int i = 0; i < tree.children().length; i++) { Tree child = tree.children()[i]; newChildren.add(transformTree(child)); } return tf.newTreeNode(tag, newChildren); } }
private static int treeToLatexHelper( Tree t, StringBuilder c, StringBuilder h, int n, int nextN, int indent) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < indent; i++) sb.append(" "); h.append('\n').append(sb); h.append("{\\") .append(t.isLeaf() ? "" : "n") .append("tnode{z") .append(n) .append("}{") .append(t.label()) .append('}'); if (!t.isLeaf()) { for (int k = 0; k < t.children().length; k++) { h.append(", "); c.append("\\nodeconnect{z").append(n).append("}{z").append(nextN).append("}\n"); nextN = treeToLatexHelper(t.children()[k], c, h, nextN, nextN + 1, indent + 1); } } h.append('}'); return nextN; }
private static int treeToLatexEvenHelper( Tree t, StringBuilder c, StringBuilder h, int n, int nextN, int indent, int curDepth, int maxDepth) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < indent; i++) sb.append(" "); h.append('\n').append(sb); int tDepth = t.depth(); if (tDepth == 0 && tDepth + curDepth < maxDepth) { for (int pad = 0; pad < maxDepth - tDepth - curDepth; pad++) { h.append("{\\ntnode{pad}{}, "); } } h.append("{\\ntnode{z").append(n).append("}{").append(t.label()).append('}'); if (!t.isLeaf()) { for (int k = 0; k < t.children().length; k++) { h.append(", "); c.append("\\nodeconnect{z").append(n).append("}{z").append(nextN).append("}\n"); nextN = treeToLatexEvenHelper( t.children()[k], c, h, nextN, nextN + 1, indent + 1, curDepth + 1, maxDepth); } } if (tDepth == 0 && tDepth + curDepth < maxDepth) { for (int pad = 0; pad < maxDepth - tDepth - curDepth; pad++) { h.append('}'); } } h.append('}'); return nextN; }
private static boolean vpContainsParticiple(Tree t) { for (Tree kid : t.children()) { if (DEBUG) { System.err.println("vpContainsParticiple examining " + kid); } if (kid.isPreTerminal()) { Label kidLabel = kid.label(); String tag = null; if (kidLabel instanceof HasTag) { tag = ((HasTag) kidLabel).tag(); } if (tag == null) { tag = kid.value(); } if ("VBN".equals(tag) || "VBG".equals(tag) || "VBD".equals(tag)) { if (DEBUG) { System.err.println("vpContainsParticiple found VBN/VBG/VBD VP"); } return true; } } } return false; }
/** * Determine which daughter of the current parse tree is the head. It assumes that the daughters * already have had their heads determined. Uses special rule for VPheads * * @param t The parse tree to examine the daughters of. This is assumed to never be a leaf * @return The parse tree that is the head */ protected Tree determineNonTrivialHead(Tree t, Tree parent) { String motherCat = tlp.basicCategory(t.label().value()); if (DEBUG) { System.err.println("My parent is " + parent); } // do VPs with auxiliary as special case if (motherCat.equals("VP") || motherCat.equals("SQ") || motherCat.equals("SINV")) { Tree[] kids = t.children(); // try to find if there is an auxiliary verb if (DEBUG) { System.err.println("Semantic head finder: at VP"); System.err.println("Class is " + t.getClass().getName()); t.pennPrint(System.err); System.err.println("hasVerbalAuxiliary = " + hasVerbalAuxiliary(kids, verbalAuxiliaries)); } // looks for auxiliaries if (hasVerbalAuxiliary(kids, verbalAuxiliaries)) { // String[] how = new String[] {"left", "VP", "ADJP", "NP"}; // Including NP etc seems okay for copular sentences but is // problematic for other auxiliaries, like 'he has an answer' // But maybe doing ADJP is fine! String[] how = new String[] {"left", "VP", "ADJP"}; Tree pti = traverseLocate(kids, how, false); if (DEBUG) { System.err.println("Determined head is: " + pti); } if (pti != null) { return pti; } else { // System.err.println("------"); // System.err.println("SemanticHeadFinder failed to reassign head for"); // t.pennPrint(System.err); // System.err.println("------"); } } // looks for copular verbs if (hasVerbalAuxiliary(kids, copulars) && !isExistential(t, parent) && !isWHQ(t, parent)) { String[] how; if (motherCat.equals("SQ")) { how = new String[] {"right", "VP", "ADJP", "NP", "WHADJP", "WHNP"}; } else { how = new String[] {"left", "VP", "ADJP", "NP", "WHADJP", "WHNP"}; } Tree pti = traverseLocate(kids, how, false); if (DEBUG) { System.err.println("Determined head is: " + pti); } if (pti != null) { return pti; } else { if (DEBUG) { System.err.println("------"); System.err.println("SemanticHeadFinder failed to reassign head for"); t.pennPrint(System.err); System.err.println("------"); } } } } return super.determineNonTrivialHead(t, parent); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, FloatMatrix> binaryTD, MultiDimensionalMap<String, String, FloatMatrix> binaryCD, MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD, Map<String, FloatMatrix> unaryCD, Map<String, FloatMatrix> wordVectorD, FloatMatrix deltaUp) { if (tree.isLeaf()) { return; } FloatMatrix currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class FloatMatrix goldLabel = new FloatMatrix(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { goldLabel.put(goldClass, 1.0f); } Float nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0f; FloatMatrix predictions = tree.prediction(); // If this is an unlabeled class, set deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class FloatMatrix deltaClass = goldClass >= 0 ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel)) : new FloatMatrix(predictions.rows, predictions.columns); FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose()); float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum()); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector); FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); wordVectorD.put(word, wordVectorD.get(word).add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector); FloatMatrix deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 FloatMatrix add = binaryTD.get(leftCategory, rightCategory); FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); FloatMatrix deltaDown; if (useFloatTensors) { FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector); binaryFloatTensorTD.put( leftCategory, rightCategory, binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeFloatTensorDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryFloatTensor(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } FloatMatrix leftDerivative = activationFunction.apply(leftVector); FloatMatrix rightDerivative = activationFunction.apply(rightVector); FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1)); FloatMatrix rightDeltaDown = deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }