/** * Return information about the objects in this Tree. * * @param t The tree to examine. * @return A human-readable String */ public static String toDebugStructureString(Tree t) { StringBuilder sb = new StringBuilder(); String tCl = StringUtils.getShortClassName(t); String tfCl = StringUtils.getShortClassName(t.treeFactory()); String lCl = StringUtils.getShortClassName(t.label()); String lfCl = StringUtils.getShortClassName(t.label().labelFactory()); Set<String> otherClasses = new HashSet<String>(); for (Tree st : t) { String stCl = StringUtils.getShortClassName(st); String stfCl = StringUtils.getShortClassName(st.treeFactory()); String slCl = StringUtils.getShortClassName(st.label()); String slfCl = StringUtils.getShortClassName(st.label().labelFactory()); if (!tCl.equals(stCl)) { otherClasses.add(stCl); } if (!tfCl.equals(stfCl)) { otherClasses.add(stfCl); } if (!lCl.equals(slCl)) { otherClasses.add(slCl); } if (!lfCl.equals(slfCl)) { otherClasses.add(slfCl); } } sb.append("Tree with root of class ").append(tCl).append(" and factory ").append(tfCl); sb.append(" with label class ").append(lCl).append(" and factory ").append(lfCl); if (!otherClasses.isEmpty()) { sb.append(" with the following classes also found within the tree: ").append(otherClasses); } return sb.toString(); }
Tree prune(Tree tree, int start) { if (tree.isLeaf() || tree.isPreTerminal()) { return tree; } // check each node's children for deletion List<Tree> children = helper(tree.getChildrenAsList(), start); children = prune(children, tree.label(), start, start + tree.yield().size()); return tree.treeFactory().newTreeNode(tree.label(), children); }
public static Tree normalizeTree(Tree tree, TreeNormalizer tn, TreeFactory tf) { for (Tree node : tree) { if (node.isLeaf()) { node.label().setValue(tn.normalizeTerminal(node.label().value())); } else { node.label().setValue(tn.normalizeNonterminal(node.label().value())); } } return tn.normalizeWholeTree(tree, tf); }
/** * returns a list of categories that is the path from Tree from to Tree to within Tree root. If * either from or to is not in root, returns null. Otherwise includes both from and to in the * list. */ public static List<String> pathNodeToNode(Tree from, Tree to, Tree root) { List<Tree> fromPath = pathFromRoot(from, root); // System.out.println(treeListToCatList(fromPath)); if (fromPath == null) return null; List<Tree> toPath = pathFromRoot(to, root); // System.out.println(treeListToCatList(toPath)); if (toPath == null) return null; // System.out.println(treeListToCatList(fromPath)); // System.out.println(treeListToCatList(toPath)); int last = 0; int min = fromPath.size() <= toPath.size() ? fromPath.size() : toPath.size(); Tree lastNode = null; // while((! (fromPath.isEmpty() || toPath.isEmpty())) && // fromPath.get(0).equals(toPath.get(0))) { // lastNode = (Tree) fromPath.remove(0); // toPath.remove(0); // } while (last < min && fromPath.get(last).equals(toPath.get(last))) { lastNode = fromPath.get(last); last++; } // System.out.println(treeListToCatList(fromPath)); // System.out.println(treeListToCatList(toPath)); List<String> totalPath = new ArrayList<String>(); for (int i = fromPath.size() - 1; i >= last; i--) { Tree t = fromPath.get(i); totalPath.add("up-" + t.label().value()); } if (lastNode != null) totalPath.add("up-" + lastNode.label().value()); for (Tree t : toPath) totalPath.add("down-" + t.label().value()); // for(ListIterator i = fromPath.listIterator(fromPath.size()); i.hasPrevious(); ){ // Tree t = (Tree) i.previous(); // totalPath.add("up-" + t.label().value()); // } // if(lastNode != null) // totalPath.add("up-" + lastNode.label().value()); // for(ListIterator j = toPath.listIterator(); j.hasNext(); ){ // Tree t = (Tree) j.next(); // totalPath.add("down-" + t.label().value()); // } return totalPath; }
List<Tree> prune(List<Tree> treeList, Label label, int start, int end) { // get reference tree if (treeList.size() == 1) { return treeList; } Tree testTree = treeList.get(0).treeFactory().newTreeNode(label, treeList); int goal = Numberer.getGlobalNumberer("states").number(label.value()); Tree tempTree = parser.extractBestParse(goal, start, end); // parser.restoreUnaries(tempTree); Tree pcfgTree = debinarizer.transformTree(tempTree); Set<Constituent> pcfgConstituents = pcfgTree.constituents(new LabeledScoredConstituentFactory()); // delete child labels that are not in reference but do not cross reference List<Tree> prunedChildren = new ArrayList<Tree>(); int childStart = 0; for (int c = 0, numCh = testTree.numChildren(); c < numCh; c++) { Tree child = testTree.getChild(c); boolean isExtra = true; int childEnd = childStart + child.yield().size(); Constituent childConstituent = new LabeledScoredConstituent(childStart, childEnd, child.label(), 0); if (pcfgConstituents.contains(childConstituent)) { isExtra = false; } if (childConstituent.crosses(pcfgConstituents)) { isExtra = false; } if (child.isLeaf() || child.isPreTerminal()) { isExtra = false; } if (pcfgTree.yield().size() != testTree.yield().size()) { isExtra = false; } if (!label.value().startsWith("NP^NP")) { isExtra = false; } if (isExtra) { System.err.println( "Pruning: " + child.label() + " from " + (childStart + start) + " to " + (childEnd + start)); System.err.println("Was: " + testTree + " vs " + pcfgTree); prunedChildren.addAll(child.getChildrenAsList()); } else { prunedChildren.add(child); } childStart = childEnd; } return prunedChildren; }
protected Rule ltToRule(Tree lt) { if (lt.children().length == 1) { UnaryRule ur = new UnaryRule(); ur.parent = stateNumberer.number(lt.label().value()); ur.child = stateNumberer.number(lt.children()[0].label().value()); return ur; } else { BinaryRule br = new BinaryRule(); br.parent = stateNumberer.number(lt.label().value()); br.leftChild = stateNumberer.number(lt.children()[0].label().value()); br.rightChild = stateNumberer.number(lt.children()[1].label().value()); return br; } }
/** * returns the syntactic category of the tree as a list of the syntactic categories of the mother * and the daughters */ public static List<String> localTreeAsCatList(Tree t) { List<String> l = new ArrayList<String>(t.children().length + 1); l.add(t.label().value()); for (int i = 0; i < t.children().length; i++) { l.add(t.children()[i].label().value()); } return l; }
public static Tree copyHelper(Tree t, Map<Tree, Tree> newToOld, Map<Tree, Tree> oldToNew) { Tree[] kids = t.children(); Tree[] newKids = new Tree[kids.length]; for (int i = 0, n = kids.length; i < n; i++) { newKids[i] = copyHelper(kids[i], newToOld, oldToNew); } TreeFactory tf = t.treeFactory(); if (kids.length == 0) { Tree newLeaf = tf.newLeaf(t.label()); newToOld.put(newLeaf, t); oldToNew.put(newLeaf, t); return newLeaf; } Tree newNode = tf.newTreeNode(t.label(), Arrays.asList(newKids)); newToOld.put(newNode, t); oldToNew.put(t, newNode); return newNode; }
private static void leafLabels(Tree t, List<Label> l) { if (t.isLeaf()) { l.add(t.label()); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { leafLabels(kids[j], l); } } }
/** * This is the method to call for assigning labels and node vectors to the Tree. After calling * this, each of the non-leaf nodes will have the node vector and the predictions of their classes * assigned to that subtree's node. */ public void forwardPropagateTree(Tree tree) { FloatMatrix nodeVector; FloatMatrix classification; if (tree.isLeaf()) { // We do nothing for the leaves. The preterminals will // calculate the classification for this word/tag. In fact, the // recursion should not have gotten here (unless there are // degenerate trees of just one leaf) throw new AssertionError("We should not have reached leaves in forwardPropagate"); } else if (tree.isPreTerminal()) { classification = getUnaryClassification(tree.label()); String word = tree.children().get(0).value(); FloatMatrix wordVector = getFeatureVector(word); if (wordVector == null) { wordVector = featureVectors.get(UNKNOWN_FEATURE); } nodeVector = activationFunction.apply(wordVector); } else if (tree.children().size() == 1) { throw new AssertionError( "Non-preterminal nodes of size 1 should have already been collapsed"); } else if (tree.children().size() == 2) { Tree left = tree.firstChild(), right = tree.lastChild(); forwardPropagateTree(left); forwardPropagateTree(right); String leftCategory = tree.children().get(0).label(); String rightCategory = tree.children().get(1).label(); FloatMatrix W = getBinaryTransform(leftCategory, rightCategory); classification = getBinaryClassification(leftCategory, rightCategory); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); if (useFloatTensors) { FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory); FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector); FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn); nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut)); } else nodeVector = activationFunction.apply(W.mmul(childrenVector)); } else { throw new AssertionError("Tree not correctly binarized"); } FloatMatrix inputWithBias = appendBias(nodeVector); FloatMatrix preAct = classification.mmul(inputWithBias); FloatMatrix predictions = outputActivation.apply(preAct); tree.setPrediction(predictions); tree.setVector(nodeVector); }
private static void taggedLeafLabels(Tree t, List<CoreLabel> l) { if (t.isPreTerminal()) { CoreLabel fl = (CoreLabel) t.getChild(0).label(); fl.set(TagLabelAnnotation.class, t.label()); l.add(fl); } else { Tree[] kids = t.children(); for (int j = 0, n = kids.length; j < n; j++) { taggedLeafLabels(kids[j], l); } } }
protected void tallyTree(Tree t, LinkedList<String> parents) { // traverse tree, building parent list String str = t.label().value(); boolean strIsPassive = (str.indexOf('@') == -1); if (strIsPassive) { parents.addFirst(str); } if (!t.isLeaf()) { if (!t.children()[0].isLeaf()) { tallyInternalNode(t, parents); for (int c = 0; c < t.children().length; c++) { Tree child = t.children()[c]; tallyTree(child, parents); } } else { tagNumberer.number(t.label().value()); } } if (strIsPassive) { parents.removeFirst(); } }
protected void tallyInternalNode(Tree lt, List parents) { // form base rule String label = lt.label().value(); Rule baseR = ltToRule(lt); ruleToLabel.put(baseR, label); // act on each history depth for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size()); depth <= maxDepth; depth++) { List history = new ArrayList(parents.subList(0, depth)); // tally each history level / rewrite pair rulePairs.incrementCount(new Pair(baseR, history), 1); labelPairs.incrementCount(new Pair(label, history), 1); } }
private static int treeToLatexHelper( Tree t, StringBuilder c, StringBuilder h, int n, int nextN, int indent) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < indent; i++) sb.append(" "); h.append('\n').append(sb); h.append("{\\") .append(t.isLeaf() ? "" : "n") .append("tnode{z") .append(n) .append("}{") .append(t.label()) .append('}'); if (!t.isLeaf()) { for (int k = 0; k < t.children().length; k++) { h.append(", "); c.append("\\nodeconnect{z").append(n).append("}{z").append(nextN).append("}\n"); nextN = treeToLatexHelper(t.children()[k], c, h, nextN, nextN + 1, indent + 1); } } h.append('}'); return nextN; }
private static int treeToLatexEvenHelper( Tree t, StringBuilder c, StringBuilder h, int n, int nextN, int indent, int curDepth, int maxDepth) { StringBuilder sb = new StringBuilder(); for (int i = 0; i < indent; i++) sb.append(" "); h.append('\n').append(sb); int tDepth = t.depth(); if (tDepth == 0 && tDepth + curDepth < maxDepth) { for (int pad = 0; pad < maxDepth - tDepth - curDepth; pad++) { h.append("{\\ntnode{pad}{}, "); } } h.append("{\\ntnode{z").append(n).append("}{").append(t.label()).append('}'); if (!t.isLeaf()) { for (int k = 0; k < t.children().length; k++) { h.append(", "); c.append("\\nodeconnect{z").append(n).append("}{z").append(nextN).append("}\n"); nextN = treeToLatexEvenHelper( t.children()[k], c, h, nextN, nextN + 1, indent + 1, curDepth + 1, maxDepth); } } if (tDepth == 0 && tDepth + curDepth < maxDepth) { for (int pad = 0; pad < maxDepth - tDepth - curDepth; pad++) { h.append('}'); } } h.append('}'); return nextN; }
/** * Normalize a whole tree -- one can assume that this is the root. This implementation deletes * empty elements (ones with nonterminal tag label '-NONE-') from the tree. */ @Override public Tree normalizeWholeTree(Tree tree, TreeFactory tf) { TreeTransformer transformer1 = new TreeTransformer() { @Override public Tree transformTree(Tree t) { if (doSGappedStuff) { String lab = t.label().value(); if (lab.equals("S") && includesEmptyNPSubj(t)) { LabelFactory lf = t.label().labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! t.setLabel(lf.newLabel(t.label().value() + "-G")); } } return t; } }; Filter<Tree> subtreeFilter = new Filter<Tree>() { private static final long serialVersionUID = -7250433816896327901L; @Override public boolean accept(Tree t) { Tree[] kids = t.children(); Label l = t.label(); // The special Switchboard non-terminals clause. // Note that it deletes IP which other Treebanks might use! if ("RS".equals(t.label().value()) || "RM".equals(t.label().value()) || "IP".equals(t.label().value()) || "CODE".equals(t.label().value())) { return false; } if ((l != null) && l.value() != null && (l.value().equals("-NONE-")) && !t.isLeaf() && kids.length == 1 && kids[0].isLeaf()) { // Delete empty/trace nodes (ones marked '-NONE-') return false; } return true; } }; Filter<Tree> nodeFilter = new Filter<Tree>() { private static final long serialVersionUID = 9000955019205336311L; @Override public boolean accept(Tree t) { if (t.isLeaf() || t.isPreTerminal()) { return true; } // The special switchboard non-terminals clause. Try keeping EDITED for now.... // if ("EDITED".equals(t.label().value())) { // return false; // } if (t.numChildren() != 1) { return true; } if (t.label() != null && t.label().value() != null && t.label().value().equals(t.children()[0].label().value())) { return false; } return true; } }; TreeTransformer transformer2 = new TreeTransformer() { @Override public Tree transformTree(Tree t) { if (temporalAnnotation == TEMPORAL_ANY_TMP_PERCOLATED) { String lab = t.label().value(); if (TmpPattern.matcher(lab).matches()) { Tree oldT = t; Tree ht; do { ht = headFinder.determineHead(oldT); // special fix for possessives! -- make noun before head if (ht.label().value().equals("POS")) { int j = oldT.objectIndexOf(ht); if (j > 0) { ht = oldT.getChild(j - 1); } } LabelFactory lf = ht.label().labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! ht.setLabel(lf.newLabel(ht.label().value() + "-TMP")); oldT = ht; } while (!ht.isPreTerminal()); if (lab.startsWith("PP")) { ht = headFinder.determineHead(t); // look to right int j = t.objectIndexOf(ht); int sz = t.children().length; if (j + 1 < sz) { ht = t.getChild(j + 1); } if (ht.label().value().startsWith("NP")) { while (!ht.isLeaf()) { LabelFactory lf = ht.label().labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! ht.setLabel(lf.newLabel(ht.label().value() + "-TMP")); ht = headFinder.determineHead(ht); } } } } } else if (temporalAnnotation == TEMPORAL_ALL_TERMINALS) { String lab = t.label().value(); if (NPTmpPattern.matcher(lab).matches()) { Tree ht; ht = headFinder.determineHead(t); if (ht.isPreTerminal()) { // change all tags to -TMP LabelFactory lf = ht.label().labelFactory(); Tree[] kids = t.children(); for (Tree kid : kids) { if (kid.isPreTerminal()) { // Note: this changes the tree label, rather // than creating a new tree node. Beware! kid.setLabel(lf.newLabel(kid.value() + "-TMP")); } } } else { Tree oldT = t; do { ht = headFinder.determineHead(oldT); oldT = ht; } while (!ht.isPreTerminal()); LabelFactory lf = ht.label().labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! ht.setLabel(lf.newLabel(ht.label().value() + "-TMP")); } } } else if (temporalAnnotation == TEMPORAL_ALL_NP) { String lab = t.label().value(); if (NPTmpPattern.matcher(lab).matches()) { Tree oldT = t; Tree ht; do { ht = headFinder.determineHead(oldT); // special fix for possessives! -- make noun before head if (ht.label().value().equals("POS")) { int j = oldT.objectIndexOf(ht); if (j > 0) { ht = oldT.getChild(j - 1); } } if (ht.isPreTerminal() || ht.value().startsWith("NP")) { LabelFactory lf = ht.labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! ht.setLabel(lf.newLabel(ht.label().value() + "-TMP")); oldT = ht; } } while (ht.value().startsWith("NP")); } } else if (temporalAnnotation == TEMPORAL_ALL_NP_AND_PP || temporalAnnotation == TEMPORAL_NP_AND_PP_WITH_NP_HEAD || temporalAnnotation == TEMPORAL_ALL_NP_EVEN_UNDER_PP) { // also allow chain to start with PP String lab = t.value(); if (NPTmpPattern.matcher(lab).matches() || PPTmpPattern.matcher(lab).matches()) { Tree oldT = t; do { Tree ht = headFinder.determineHead(oldT); // special fix for possessives! -- make noun before head if (ht.value().equals("POS")) { int j = oldT.objectIndexOf(ht); if (j > 0) { ht = oldT.getChild(j - 1); } } else if ((temporalAnnotation == TEMPORAL_NP_AND_PP_WITH_NP_HEAD || temporalAnnotation == TEMPORAL_ALL_NP_EVEN_UNDER_PP) && (ht.value().equals("IN") || ht.value().equals("TO"))) { // change the head to be NP if possible Tree[] kidlets = oldT.children(); for (int k = kidlets.length - 1; k > 0; k--) { if (kidlets[k].value().startsWith("NP")) { ht = kidlets[k]; } } } LabelFactory lf = ht.labelFactory(); // Note: this next bit changes the tree label, rather // than creating a new tree node. Beware! if (ht.isPreTerminal() || ht.value().startsWith("NP")) { ht.setLabel(lf.newLabel(ht.value() + "-TMP")); } if (temporalAnnotation == TEMPORAL_ALL_NP_EVEN_UNDER_PP && oldT.value().startsWith("PP")) { oldT.setLabel(lf.newLabel(tlp.basicCategory(oldT.value()))); } oldT = ht; } while (oldT.value().startsWith("NP") || oldT.value().startsWith("PP")); } } else if (temporalAnnotation == TEMPORAL_ALL_NP_PP_ADVP) { // also allow chain to start with PP or ADVP String lab = t.value(); if (NPTmpPattern.matcher(lab).matches() || PPTmpPattern.matcher(lab).matches() || ADVPTmpPattern.matcher(lab).matches()) { Tree oldT = t; do { Tree ht = headFinder.determineHead(oldT); // special fix for possessives! -- make noun before head if (ht.value().equals("POS")) { int j = oldT.objectIndexOf(ht); if (j > 0) { ht = oldT.getChild(j - 1); } } // Note: this next bit changes the tree label, rather // than creating a new tree node. Beware! if (ht.isPreTerminal() || ht.value().startsWith("NP")) { LabelFactory lf = ht.labelFactory(); ht.setLabel(lf.newLabel(ht.value() + "-TMP")); } oldT = ht; } while (oldT.value().startsWith("NP")); } } else if (temporalAnnotation == TEMPORAL_9) { // also allow chain to start with PP or ADVP String lab = t.value(); if (NPTmpPattern.matcher(lab).matches() || PPTmpPattern.matcher(lab).matches() || ADVPTmpPattern.matcher(lab).matches()) { // System.err.println("TMP: Annotating " + t); addTMP9(t); } } else if (temporalAnnotation == TEMPORAL_ACL03PCFG) { String lab = t.label().value(); if (lab != null && NPTmpPattern.matcher(lab).matches()) { Tree oldT = t; Tree ht; do { ht = headFinder.determineHead(oldT); // special fix for possessives! -- make noun before head if (ht.label().value().equals("POS")) { int j = oldT.objectIndexOf(ht); if (j > 0) { ht = oldT.getChild(j - 1); } } oldT = ht; } while (!ht.isPreTerminal()); if (!onlyTagAnnotateNstar || ht.label().value().startsWith("N")) { LabelFactory lf = ht.label().labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! ht.setLabel(lf.newLabel(ht.label().value() + "-TMP")); } } } if (doAdverbialNP) { String lab = t.value(); if (NPAdvPattern.matcher(lab).matches()) { Tree oldT = t; Tree ht; do { ht = headFinder.determineHead(oldT); // special fix for possessives! -- make noun before head if (ht.label().value().equals("POS")) { int j = oldT.objectIndexOf(ht); if (j > 0) { ht = oldT.getChild(j - 1); } } if (ht.isPreTerminal() || ht.value().startsWith("NP")) { LabelFactory lf = ht.labelFactory(); // Note: this changes the tree label, rather than // creating a new tree node. Beware! ht.setLabel(lf.newLabel(ht.label().value() + "-ADV")); oldT = ht; } } while (ht.value().startsWith("NP")); } } return t; } }; // if there wasn't an empty nonterminal at the top, but an S, wrap it. if (tree.label().value().equals("S")) { tree = tf.newTreeNode("ROOT", Collections.singletonList(tree)); } // repair for the phrasal VB in Switchboard (PTB version 3) that should be a VP for (Tree subtree : tree) { if (subtree.isPhrasal() && "VB".equals(subtree.label().value())) { subtree.setValue("VP"); } } tree = tree.transform(transformer1); if (tree == null) { return null; } tree = tree.prune(subtreeFilter, tf); if (tree == null) { return null; } tree = tree.spliceOut(nodeFilter, tf); if (tree == null) { return null; } return tree.transform(transformer2, tf); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, FloatMatrix> binaryTD, MultiDimensionalMap<String, String, FloatMatrix> binaryCD, MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD, Map<String, FloatMatrix> unaryCD, Map<String, FloatMatrix> wordVectorD, FloatMatrix deltaUp) { if (tree.isLeaf()) { return; } FloatMatrix currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class FloatMatrix goldLabel = new FloatMatrix(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { goldLabel.put(goldClass, 1.0f); } Float nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0f; FloatMatrix predictions = tree.prediction(); // If this is an unlabeled class, set deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class FloatMatrix deltaClass = goldClass >= 0 ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel)) : new FloatMatrix(predictions.rows, predictions.columns); FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose()); float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum()); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector); FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); wordVectorD.put(word, wordVectorD.get(word).add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector); FloatMatrix deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 FloatMatrix add = binaryTD.get(leftCategory, rightCategory); FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); FloatMatrix deltaDown; if (useFloatTensors) { FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector); binaryFloatTensorTD.put( leftCategory, rightCategory, binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeFloatTensorDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryFloatTensor(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } FloatMatrix leftDerivative = activationFunction.apply(leftVector); FloatMatrix rightDerivative = activationFunction.apply(rightVector); FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1)); FloatMatrix rightDeltaDown = deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }
/** * transformTree does all language-specific tree transformations. Any parameterizations should be * inside the specific TreebankLangParserParams class. */ @Override public Tree transformTree(Tree t, Tree root) { if (t == null || t.isLeaf()) { return t; } String parentStr; String grandParentStr; Tree parent; Tree grandParent; if (root == null || t.equals(root)) { parent = null; parentStr = ""; } else { parent = t.parent(root); parentStr = parent.label().value(); } if (parent == null || parent.equals(root)) { grandParent = null; grandParentStr = ""; } else { grandParent = parent.parent(root); grandParentStr = grandParent.label().value(); } String baseParentStr = ctlp.basicCategory(parentStr); String baseGrandParentStr = ctlp.basicCategory(grandParentStr); CoreLabel lab = (CoreLabel) t.label(); String word = lab.word(); String tag = lab.tag(); String baseTag = ctlp.basicCategory(tag); String category = lab.value(); String baseCategory = ctlp.basicCategory(category); if (t.isPreTerminal()) { // it's a POS tag List<String> leftAunts = listBasicCategories(SisterAnnotationStats.leftSisterLabels(parent, grandParent)); List<String> rightAunts = listBasicCategories(SisterAnnotationStats.rightSisterLabels(parent, grandParent)); // Chinese-specific punctuation splits if (chineseSplitPunct && baseTag.equals("PU")) { if (ChineseTreebankLanguagePack.chineseDouHaoAcceptFilter().accept(word)) { tag = tag + "-DOU"; // System.out.println("Punct: Split dou hao"); // debugging } else if (ChineseTreebankLanguagePack.chineseCommaAcceptFilter().accept(word)) { tag = tag + "-COMMA"; // System.out.println("Punct: Split comma"); // debugging } else if (ChineseTreebankLanguagePack.chineseColonAcceptFilter().accept(word)) { tag = tag + "-COLON"; // System.out.println("Punct: Split colon"); // debugging } else if (ChineseTreebankLanguagePack.chineseQuoteMarkAcceptFilter().accept(word)) { if (chineseSplitPunctLR) { if (ChineseTreebankLanguagePack.chineseLeftQuoteMarkAcceptFilter().accept(word)) { tag += "-LQUOTE"; } else { tag += "-RQUOTE"; } } else { tag = tag + "-QUOTE"; } // System.out.println("Punct: Split quote"); // debugging } else if (ChineseTreebankLanguagePack.chineseEndSentenceAcceptFilter().accept(word)) { tag = tag + "-ENDSENT"; // System.out.println("Punct: Split end sent"); // debugging } else if (ChineseTreebankLanguagePack.chineseParenthesisAcceptFilter().accept(word)) { if (chineseSplitPunctLR) { if (ChineseTreebankLanguagePack.chineseLeftParenthesisAcceptFilter().accept(word)) { tag += "-LPAREN"; } else { tag += "-RPAREN"; } } else { tag += "-PAREN"; // printlnErr("Just used -PAREN annotation"); // printlnErr(word); // throw new RuntimeException(); } // System.out.println("Punct: Split paren"); // debugging } else if (ChineseTreebankLanguagePack.chineseDashAcceptFilter().accept(word)) { tag = tag + "-DASH"; // System.out.println("Punct: Split dash"); // debugging } else if (ChineseTreebankLanguagePack.chineseOtherAcceptFilter().accept(word)) { tag = tag + "-OTHER"; } else { printlnErr("Unknown punct (you should add it to CTLP): " + tag + " |" + word + "|"); } } else if (chineseSplitDouHao) { // only split DouHao if (ChineseTreebankLanguagePack.chineseDouHaoAcceptFilter().accept(word) && baseTag.equals("PU")) { tag = tag + "-DOU"; } } // Chinese-specific POS tag splits (non-punctuation) if (tagWordSize) { int l = word.length(); tag += "-" + l + "CHARS"; } if (mergeNNVV && baseTag.equals("NN")) { tag = "VV"; } if ((chineseSelectiveTagPA || chineseVerySelectiveTagPA) && (baseTag.equals("CC") || baseTag.equals("P"))) { tag += "-" + baseParentStr; } if (chineseSelectiveTagPA && (baseTag.equals("VV"))) { tag += "-" + baseParentStr; } if (markMultiNtag && tag.startsWith("N")) { for (int i = 0; i < parent.numChildren(); i++) { if (parent.children()[i].label().value().startsWith("N") && parent.children()[i] != t) { tag += "=N"; // System.out.println("Found multi=N rewrite"); } } } if (markVVsisterIP && baseTag.equals("VV")) { boolean seenIP = false; for (int i = 0; i < parent.numChildren(); i++) { if (parent.children()[i].label().value().startsWith("IP")) { seenIP = true; } } if (seenIP) { tag += "-IP"; // System.out.println("Found VV with IP sister"); // testing } } if (markPsisterIP && baseTag.equals("P")) { boolean seenIP = false; for (int i = 0; i < parent.numChildren(); i++) { if (parent.children()[i].label().value().startsWith("IP")) { seenIP = true; } } if (seenIP) { tag += "-IP"; } } if (markADgrandchildOfIP && baseTag.equals("AD") && baseGrandParentStr.equals("IP")) { tag += "~IP"; // System.out.println("Found AD with IP grandparent"); // testing } if (gpaAD && baseTag.equals("AD")) { tag += "~" + baseGrandParentStr; // System.out.println("Found AD with grandparent " + grandParentStr); // testing } if (markPostverbalP && leftAunts.contains("VV") && baseTag.equals("P")) { // System.out.println("Found post-verbal P"); tag += "^=lVV"; } // end Chinese-specific tag splits Label label = new CategoryWordTag(tag, word, tag); t.setLabel(label); } else { // it's a phrasal category Tree[] kids = t.children(); // Chinese-specific category splits List<String> leftSis = listBasicCategories(SisterAnnotationStats.leftSisterLabels(t, parent)); List<String> rightSis = listBasicCategories(SisterAnnotationStats.rightSisterLabels(t, parent)); if (paRootDtr && baseParentStr.equals("ROOT")) { category += "^ROOT"; } if (markIPsisterBA && baseCategory.equals("IP")) { if (leftSis.contains("BA")) { category += "=BA"; // System.out.println("Found IP sister of BA"); } } if (dominatesV && hasV(t.preTerminalYield())) { // mark categories containing a verb category += "-v"; } if (markIPsisterVVorP && baseCategory.equals("IP")) { // todo: cdm: is just looking for "P" here selective enough?? if (leftSis.contains("VV") || leftSis.contains("P")) { category += "=VVP"; } } if (markIPsisDEC && baseCategory.equals("IP")) { if (rightSis.contains("DEC")) { category += "=DEC"; // System.out.println("Found prenominal IP"); } } if (baseCategory.equals("VP")) { // cdm 2008: this used to just check that it startsWith("VP"), but // I think that was bad because it also matched VPT verb compounds if (chineseSplitVP == 3) { boolean hasCC = false; boolean hasPU = false; boolean hasLexV = false; for (Tree kid : kids) { if (kid.label().value().startsWith("CC")) { hasCC = true; } else if (kid.label().value().startsWith("PU")) { hasPU = true; } else if (StringUtils.lookingAt( kid.label().value(), "(V[ACEV]|VCD|VCP|VNV|VPT|VRD|VSB)")) { hasLexV = true; } } if (hasCC || (hasPU && !hasLexV)) { category += "-CRD"; // System.out.println("Found coordinate VP"); // testing } else if (hasLexV) { category += "-COMP"; // System.out.println("Found complementing VP"); // testing } else { category += "-ADJT"; // System.out.println("Found adjoining VP"); // testing } } else if (chineseSplitVP >= 1) { boolean hasBA = false; for (Tree kid : kids) { if (kid.label().value().startsWith("BA")) { hasBA = true; } else if (chineseSplitVP == 2 && tlp.basicCategory(kid.label().value()).equals("VP")) { for (Tree kidkid : kid.children()) { if (kidkid.label().value().startsWith("BA")) { hasBA = true; } } } } if (hasBA) { category += "-BA"; } } } if (markVPadjunct && baseParentStr.equals("VP")) { // cdm 2008: This used to use startsWith("VP") but changed to baseCat Tree[] sisters = parent.children(); boolean hasVPsister = false; boolean hasCC = false; boolean hasPU = false; boolean hasLexV = false; for (Tree sister : sisters) { if (tlp.basicCategory(sister.label().value()).equals("VP")) { hasVPsister = true; } if (sister.label().value().startsWith("CC")) { hasCC = true; } if (sister.label().value().startsWith("PU")) { hasPU = true; } if (StringUtils.lookingAt(sister.label().value(), "(V[ACEV]|VCD|VCP|VNV|VPT|VRD|VSB)")) { hasLexV = true; } } if (hasVPsister && !(hasCC || hasPU || hasLexV)) { category += "-VPADJ"; // System.out.println("Found adjunct of VP"); // testing } } if (markNPmodNP && baseCategory.equals("NP") && baseParentStr.equals("NP")) { if (rightSis.contains("NP")) { category += "=MODIFIERNP"; // System.out.println("Found NP modifier of NP"); // testing } } if (markModifiedNP && baseCategory.equals("NP") && baseParentStr.equals("NP")) { if (rightSis.isEmpty() && (leftSis.contains("ADJP") || leftSis.contains("NP") || leftSis.contains("DNP") || leftSis.contains("QP") || leftSis.contains("CP") || leftSis.contains("PP"))) { category += "=MODIFIEDNP"; // System.out.println("Found modified NP"); // testing } } if (markNPconj && baseCategory.equals("NP") && baseParentStr.equals("NP")) { if (rightSis.contains("CC") || rightSis.contains("PU") || leftSis.contains("CC") || leftSis.contains("PU")) { category += "=CONJ"; // System.out.println("Found NP conjunct"); // testing } } if (markIPconj && baseCategory.equals("IP") && baseParentStr.equals("IP")) { Tree[] sisters = parent.children(); boolean hasCommaSis = false; boolean hasIPSis = false; for (Tree sister : sisters) { if (ctlp.basicCategory(sister.label().value()).equals("PU") && ChineseTreebankLanguagePack.chineseCommaAcceptFilter() .accept(sister.children()[0].label().toString())) { hasCommaSis = true; // System.out.println("Found CommaSis"); // testing } if (ctlp.basicCategory(sister.label().value()).equals("IP") && sister != t) { hasIPSis = true; } } if (hasCommaSis && hasIPSis) { category += "-CONJ"; // System.out.println("Found IP conjunct"); // testing } } if (unaryIP && baseCategory.equals("IP") && t.numChildren() == 1) { category += "-U"; // System.out.println("Found unary IP"); //testing } if (unaryCP && baseCategory.equals("CP") && t.numChildren() == 1) { category += "-U"; // System.out.println("Found unary CP"); //testing } if (splitBaseNP && baseCategory.equals("NP")) { if (t.isPrePreTerminal()) { category = category + "-B"; } } // if (Test.verbose) printlnErr(baseCategory + " " + leftSis.toString()); //debugging if (markPostverbalPP && leftSis.contains("VV") && baseCategory.equals("PP")) { // System.out.println("Found post-verbal PP"); category += "=lVV"; } if ((markADgrandchildOfIP || gpaAD) && listBasicCategories(SisterAnnotationStats.kidLabels(t)).contains("AD")) { category += "^ADVP"; } if (markCC) { // was: for (int i = 0; i < kids.length; i++) { // This second version takes an idea from Collins: don't count // marginal conjunctions which don't conjoin 2 things. for (int i = 1; i < kids.length - 1; i++) { String cat2 = kids[i].label().value(); if (cat2.startsWith("CC")) { category += "-CC"; } } } Label label = new CategoryWordTag(category, word, tag); t.setLabel(label); } return t; }