Esempio n. 1
0
  /**
   * Return information about the objects in this Tree.
   *
   * @param t The tree to examine.
   * @return A human-readable String
   */
  public static String toDebugStructureString(Tree t) {
    StringBuilder sb = new StringBuilder();
    String tCl = StringUtils.getShortClassName(t);
    String tfCl = StringUtils.getShortClassName(t.treeFactory());
    String lCl = StringUtils.getShortClassName(t.label());
    String lfCl = StringUtils.getShortClassName(t.label().labelFactory());
    Set<String> otherClasses = new HashSet<String>();
    for (Tree st : t) {
      String stCl = StringUtils.getShortClassName(st);
      String stfCl = StringUtils.getShortClassName(st.treeFactory());
      String slCl = StringUtils.getShortClassName(st.label());
      String slfCl = StringUtils.getShortClassName(st.label().labelFactory());

      if (!tCl.equals(stCl)) {
        otherClasses.add(stCl);
      }
      if (!tfCl.equals(stfCl)) {
        otherClasses.add(stfCl);
      }
      if (!lCl.equals(slCl)) {
        otherClasses.add(slCl);
      }
      if (!lfCl.equals(slfCl)) {
        otherClasses.add(slfCl);
      }
    }
    sb.append("Tree with root of class ").append(tCl).append(" and factory ").append(tfCl);
    sb.append(" with label class ").append(lCl).append(" and factory ").append(lfCl);
    if (!otherClasses.isEmpty()) {
      sb.append(" with the following classes also found within the tree: ").append(otherClasses);
    }
    return sb.toString();
  }
Esempio n. 2
0
 Tree prune(Tree tree, int start) {
   if (tree.isLeaf() || tree.isPreTerminal()) {
     return tree;
   }
   // check each node's children for deletion
   List<Tree> children = helper(tree.getChildrenAsList(), start);
   children = prune(children, tree.label(), start, start + tree.yield().size());
   return tree.treeFactory().newTreeNode(tree.label(), children);
 }
Esempio n. 3
0
 public static Tree normalizeTree(Tree tree, TreeNormalizer tn, TreeFactory tf) {
   for (Tree node : tree) {
     if (node.isLeaf()) {
       node.label().setValue(tn.normalizeTerminal(node.label().value()));
     } else {
       node.label().setValue(tn.normalizeNonterminal(node.label().value()));
     }
   }
   return tn.normalizeWholeTree(tree, tf);
 }
Esempio n. 4
0
  /**
   * returns a list of categories that is the path from Tree from to Tree to within Tree root. If
   * either from or to is not in root, returns null. Otherwise includes both from and to in the
   * list.
   */
  public static List<String> pathNodeToNode(Tree from, Tree to, Tree root) {
    List<Tree> fromPath = pathFromRoot(from, root);
    // System.out.println(treeListToCatList(fromPath));
    if (fromPath == null) return null;

    List<Tree> toPath = pathFromRoot(to, root);
    // System.out.println(treeListToCatList(toPath));
    if (toPath == null) return null;

    // System.out.println(treeListToCatList(fromPath));
    // System.out.println(treeListToCatList(toPath));

    int last = 0;
    int min = fromPath.size() <= toPath.size() ? fromPath.size() : toPath.size();

    Tree lastNode = null;
    //     while((! (fromPath.isEmpty() || toPath.isEmpty())) &&
    // fromPath.get(0).equals(toPath.get(0))) {
    //       lastNode = (Tree) fromPath.remove(0);
    //       toPath.remove(0);
    //     }
    while (last < min && fromPath.get(last).equals(toPath.get(last))) {
      lastNode = fromPath.get(last);
      last++;
    }

    // System.out.println(treeListToCatList(fromPath));
    // System.out.println(treeListToCatList(toPath));
    List<String> totalPath = new ArrayList<String>();

    for (int i = fromPath.size() - 1; i >= last; i--) {
      Tree t = fromPath.get(i);
      totalPath.add("up-" + t.label().value());
    }

    if (lastNode != null) totalPath.add("up-" + lastNode.label().value());

    for (Tree t : toPath) totalPath.add("down-" + t.label().value());

    //     for(ListIterator i = fromPath.listIterator(fromPath.size()); i.hasPrevious(); ){
    //       Tree t = (Tree) i.previous();
    //       totalPath.add("up-" + t.label().value());
    //     }

    //     if(lastNode != null)
    //     totalPath.add("up-" + lastNode.label().value());

    //     for(ListIterator j = toPath.listIterator(); j.hasNext(); ){
    //       Tree t = (Tree) j.next();
    //       totalPath.add("down-" + t.label().value());
    //     }

    return totalPath;
  }
Esempio n. 5
0
 List<Tree> prune(List<Tree> treeList, Label label, int start, int end) {
   // get reference tree
   if (treeList.size() == 1) {
     return treeList;
   }
   Tree testTree = treeList.get(0).treeFactory().newTreeNode(label, treeList);
   int goal = Numberer.getGlobalNumberer("states").number(label.value());
   Tree tempTree = parser.extractBestParse(goal, start, end);
   // parser.restoreUnaries(tempTree);
   Tree pcfgTree = debinarizer.transformTree(tempTree);
   Set<Constituent> pcfgConstituents =
       pcfgTree.constituents(new LabeledScoredConstituentFactory());
   // delete child labels that are not in reference but do not cross reference
   List<Tree> prunedChildren = new ArrayList<Tree>();
   int childStart = 0;
   for (int c = 0, numCh = testTree.numChildren(); c < numCh; c++) {
     Tree child = testTree.getChild(c);
     boolean isExtra = true;
     int childEnd = childStart + child.yield().size();
     Constituent childConstituent =
         new LabeledScoredConstituent(childStart, childEnd, child.label(), 0);
     if (pcfgConstituents.contains(childConstituent)) {
       isExtra = false;
     }
     if (childConstituent.crosses(pcfgConstituents)) {
       isExtra = false;
     }
     if (child.isLeaf() || child.isPreTerminal()) {
       isExtra = false;
     }
     if (pcfgTree.yield().size() != testTree.yield().size()) {
       isExtra = false;
     }
     if (!label.value().startsWith("NP^NP")) {
       isExtra = false;
     }
     if (isExtra) {
       System.err.println(
           "Pruning: "
               + child.label()
               + " from "
               + (childStart + start)
               + " to "
               + (childEnd + start));
       System.err.println("Was: " + testTree + " vs " + pcfgTree);
       prunedChildren.addAll(child.getChildrenAsList());
     } else {
       prunedChildren.add(child);
     }
     childStart = childEnd;
   }
   return prunedChildren;
 }
Esempio n. 6
0
 protected Rule ltToRule(Tree lt) {
   if (lt.children().length == 1) {
     UnaryRule ur = new UnaryRule();
     ur.parent = stateNumberer.number(lt.label().value());
     ur.child = stateNumberer.number(lt.children()[0].label().value());
     return ur;
   } else {
     BinaryRule br = new BinaryRule();
     br.parent = stateNumberer.number(lt.label().value());
     br.leftChild = stateNumberer.number(lt.children()[0].label().value());
     br.rightChild = stateNumberer.number(lt.children()[1].label().value());
     return br;
   }
 }
Esempio n. 7
0
 /**
  * returns the syntactic category of the tree as a list of the syntactic categories of the mother
  * and the daughters
  */
 public static List<String> localTreeAsCatList(Tree t) {
   List<String> l = new ArrayList<String>(t.children().length + 1);
   l.add(t.label().value());
   for (int i = 0; i < t.children().length; i++) {
     l.add(t.children()[i].label().value());
   }
   return l;
 }
Esempio n. 8
0
 public static Tree copyHelper(Tree t, Map<Tree, Tree> newToOld, Map<Tree, Tree> oldToNew) {
   Tree[] kids = t.children();
   Tree[] newKids = new Tree[kids.length];
   for (int i = 0, n = kids.length; i < n; i++) {
     newKids[i] = copyHelper(kids[i], newToOld, oldToNew);
   }
   TreeFactory tf = t.treeFactory();
   if (kids.length == 0) {
     Tree newLeaf = tf.newLeaf(t.label());
     newToOld.put(newLeaf, t);
     oldToNew.put(newLeaf, t);
     return newLeaf;
   }
   Tree newNode = tf.newTreeNode(t.label(), Arrays.asList(newKids));
   newToOld.put(newNode, t);
   oldToNew.put(t, newNode);
   return newNode;
 }
Esempio n. 9
0
 private static void leafLabels(Tree t, List<Label> l) {
   if (t.isLeaf()) {
     l.add(t.label());
   } else {
     Tree[] kids = t.children();
     for (int j = 0, n = kids.length; j < n; j++) {
       leafLabels(kids[j], l);
     }
   }
 }
Esempio n. 10
0
  /**
   * This is the method to call for assigning labels and node vectors to the Tree. After calling
   * this, each of the non-leaf nodes will have the node vector and the predictions of their classes
   * assigned to that subtree's node.
   */
  public void forwardPropagateTree(Tree tree) {
    FloatMatrix nodeVector;
    FloatMatrix classification;

    if (tree.isLeaf()) {
      // We do nothing for the leaves.  The preterminals will
      // calculate the classification for this word/tag.  In fact, the
      // recursion should not have gotten here (unless there are
      // degenerate trees of just one leaf)
      throw new AssertionError("We should not have reached leaves in forwardPropagate");
    } else if (tree.isPreTerminal()) {
      classification = getUnaryClassification(tree.label());
      String word = tree.children().get(0).value();
      FloatMatrix wordVector = getFeatureVector(word);
      if (wordVector == null) {
        wordVector = featureVectors.get(UNKNOWN_FEATURE);
      }

      nodeVector = activationFunction.apply(wordVector);
    } else if (tree.children().size() == 1) {
      throw new AssertionError(
          "Non-preterminal nodes of size 1 should have already been collapsed");
    } else if (tree.children().size() == 2) {
      Tree left = tree.firstChild(), right = tree.lastChild();
      forwardPropagateTree(left);
      forwardPropagateTree(right);

      String leftCategory = tree.children().get(0).label();
      String rightCategory = tree.children().get(1).label();
      FloatMatrix W = getBinaryTransform(leftCategory, rightCategory);
      classification = getBinaryClassification(leftCategory, rightCategory);

      FloatMatrix leftVector = tree.children().get(0).vector();
      FloatMatrix rightVector = tree.children().get(1).vector();

      FloatMatrix childrenVector = appendBias(leftVector, rightVector);

      if (useFloatTensors) {
        FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory);
        FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector);
        FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn);
        nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut));
      } else nodeVector = activationFunction.apply(W.mmul(childrenVector));

    } else {
      throw new AssertionError("Tree not correctly binarized");
    }

    FloatMatrix inputWithBias = appendBias(nodeVector);
    FloatMatrix preAct = classification.mmul(inputWithBias);
    FloatMatrix predictions = outputActivation.apply(preAct);

    tree.setPrediction(predictions);
    tree.setVector(nodeVector);
  }
Esempio n. 11
0
 private static void taggedLeafLabels(Tree t, List<CoreLabel> l) {
   if (t.isPreTerminal()) {
     CoreLabel fl = (CoreLabel) t.getChild(0).label();
     fl.set(TagLabelAnnotation.class, t.label());
     l.add(fl);
   } else {
     Tree[] kids = t.children();
     for (int j = 0, n = kids.length; j < n; j++) {
       taggedLeafLabels(kids[j], l);
     }
   }
 }
Esempio n. 12
0
 protected void tallyTree(Tree t, LinkedList<String> parents) {
   // traverse tree, building parent list
   String str = t.label().value();
   boolean strIsPassive = (str.indexOf('@') == -1);
   if (strIsPassive) {
     parents.addFirst(str);
   }
   if (!t.isLeaf()) {
     if (!t.children()[0].isLeaf()) {
       tallyInternalNode(t, parents);
       for (int c = 0; c < t.children().length; c++) {
         Tree child = t.children()[c];
         tallyTree(child, parents);
       }
     } else {
       tagNumberer.number(t.label().value());
     }
   }
   if (strIsPassive) {
     parents.removeFirst();
   }
 }
Esempio n. 13
0
 protected void tallyInternalNode(Tree lt, List parents) {
   // form base rule
   String label = lt.label().value();
   Rule baseR = ltToRule(lt);
   ruleToLabel.put(baseR, label);
   // act on each history depth
   for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size());
       depth <= maxDepth;
       depth++) {
     List history = new ArrayList(parents.subList(0, depth));
     // tally each history level / rewrite pair
     rulePairs.incrementCount(new Pair(baseR, history), 1);
     labelPairs.incrementCount(new Pair(label, history), 1);
   }
 }
Esempio n. 14
0
 private static int treeToLatexHelper(
     Tree t, StringBuilder c, StringBuilder h, int n, int nextN, int indent) {
   StringBuilder sb = new StringBuilder();
   for (int i = 0; i < indent; i++) sb.append("  ");
   h.append('\n').append(sb);
   h.append("{\\")
       .append(t.isLeaf() ? "" : "n")
       .append("tnode{z")
       .append(n)
       .append("}{")
       .append(t.label())
       .append('}');
   if (!t.isLeaf()) {
     for (int k = 0; k < t.children().length; k++) {
       h.append(", ");
       c.append("\\nodeconnect{z").append(n).append("}{z").append(nextN).append("}\n");
       nextN = treeToLatexHelper(t.children()[k], c, h, nextN, nextN + 1, indent + 1);
     }
   }
   h.append('}');
   return nextN;
 }
Esempio n. 15
0
 private static int treeToLatexEvenHelper(
     Tree t,
     StringBuilder c,
     StringBuilder h,
     int n,
     int nextN,
     int indent,
     int curDepth,
     int maxDepth) {
   StringBuilder sb = new StringBuilder();
   for (int i = 0; i < indent; i++) sb.append("  ");
   h.append('\n').append(sb);
   int tDepth = t.depth();
   if (tDepth == 0 && tDepth + curDepth < maxDepth) {
     for (int pad = 0; pad < maxDepth - tDepth - curDepth; pad++) {
       h.append("{\\ntnode{pad}{}, ");
     }
   }
   h.append("{\\ntnode{z").append(n).append("}{").append(t.label()).append('}');
   if (!t.isLeaf()) {
     for (int k = 0; k < t.children().length; k++) {
       h.append(", ");
       c.append("\\nodeconnect{z").append(n).append("}{z").append(nextN).append("}\n");
       nextN =
           treeToLatexEvenHelper(
               t.children()[k], c, h, nextN, nextN + 1, indent + 1, curDepth + 1, maxDepth);
     }
   }
   if (tDepth == 0 && tDepth + curDepth < maxDepth) {
     for (int pad = 0; pad < maxDepth - tDepth - curDepth; pad++) {
       h.append('}');
     }
   }
   h.append('}');
   return nextN;
 }
  /**
   * Normalize a whole tree -- one can assume that this is the root. This implementation deletes
   * empty elements (ones with nonterminal tag label '-NONE-') from the tree.
   */
  @Override
  public Tree normalizeWholeTree(Tree tree, TreeFactory tf) {
    TreeTransformer transformer1 =
        new TreeTransformer() {
          @Override
          public Tree transformTree(Tree t) {
            if (doSGappedStuff) {
              String lab = t.label().value();
              if (lab.equals("S") && includesEmptyNPSubj(t)) {
                LabelFactory lf = t.label().labelFactory();
                // Note: this changes the tree label, rather than
                // creating a new tree node.  Beware!
                t.setLabel(lf.newLabel(t.label().value() + "-G"));
              }
            }
            return t;
          }
        };
    Filter<Tree> subtreeFilter =
        new Filter<Tree>() {

          private static final long serialVersionUID = -7250433816896327901L;

          @Override
          public boolean accept(Tree t) {
            Tree[] kids = t.children();
            Label l = t.label();
            // The special Switchboard non-terminals clause.
            // Note that it deletes IP which other Treebanks might use!
            if ("RS".equals(t.label().value())
                || "RM".equals(t.label().value())
                || "IP".equals(t.label().value())
                || "CODE".equals(t.label().value())) {
              return false;
            }
            if ((l != null)
                && l.value() != null
                && (l.value().equals("-NONE-"))
                && !t.isLeaf()
                && kids.length == 1
                && kids[0].isLeaf()) {
              // Delete empty/trace nodes (ones marked '-NONE-')
              return false;
            }
            return true;
          }
        };
    Filter<Tree> nodeFilter =
        new Filter<Tree>() {

          private static final long serialVersionUID = 9000955019205336311L;

          @Override
          public boolean accept(Tree t) {
            if (t.isLeaf() || t.isPreTerminal()) {
              return true;
            }
            // The special switchboard non-terminals clause. Try keeping EDITED for now....
            // if ("EDITED".equals(t.label().value())) {
            //   return false;
            // }
            if (t.numChildren() != 1) {
              return true;
            }
            if (t.label() != null
                && t.label().value() != null
                && t.label().value().equals(t.children()[0].label().value())) {
              return false;
            }
            return true;
          }
        };
    TreeTransformer transformer2 =
        new TreeTransformer() {
          @Override
          public Tree transformTree(Tree t) {
            if (temporalAnnotation == TEMPORAL_ANY_TMP_PERCOLATED) {
              String lab = t.label().value();
              if (TmpPattern.matcher(lab).matches()) {
                Tree oldT = t;
                Tree ht;
                do {
                  ht = headFinder.determineHead(oldT);
                  // special fix for possessives! -- make noun before head
                  if (ht.label().value().equals("POS")) {
                    int j = oldT.objectIndexOf(ht);
                    if (j > 0) {
                      ht = oldT.getChild(j - 1);
                    }
                  }
                  LabelFactory lf = ht.label().labelFactory();
                  // Note: this changes the tree label, rather than
                  // creating a new tree node.  Beware!
                  ht.setLabel(lf.newLabel(ht.label().value() + "-TMP"));
                  oldT = ht;
                } while (!ht.isPreTerminal());
                if (lab.startsWith("PP")) {
                  ht = headFinder.determineHead(t);
                  // look to right
                  int j = t.objectIndexOf(ht);
                  int sz = t.children().length;
                  if (j + 1 < sz) {
                    ht = t.getChild(j + 1);
                  }
                  if (ht.label().value().startsWith("NP")) {
                    while (!ht.isLeaf()) {
                      LabelFactory lf = ht.label().labelFactory();
                      // Note: this changes the tree label, rather than
                      // creating a new tree node.  Beware!
                      ht.setLabel(lf.newLabel(ht.label().value() + "-TMP"));
                      ht = headFinder.determineHead(ht);
                    }
                  }
                }
              }
            } else if (temporalAnnotation == TEMPORAL_ALL_TERMINALS) {
              String lab = t.label().value();
              if (NPTmpPattern.matcher(lab).matches()) {
                Tree ht;
                ht = headFinder.determineHead(t);
                if (ht.isPreTerminal()) {
                  // change all tags to -TMP
                  LabelFactory lf = ht.label().labelFactory();
                  Tree[] kids = t.children();
                  for (Tree kid : kids) {
                    if (kid.isPreTerminal()) {
                      // Note: this changes the tree label, rather
                      // than creating a new tree node.  Beware!
                      kid.setLabel(lf.newLabel(kid.value() + "-TMP"));
                    }
                  }
                } else {
                  Tree oldT = t;
                  do {
                    ht = headFinder.determineHead(oldT);
                    oldT = ht;
                  } while (!ht.isPreTerminal());
                  LabelFactory lf = ht.label().labelFactory();
                  // Note: this changes the tree label, rather than
                  // creating a new tree node.  Beware!
                  ht.setLabel(lf.newLabel(ht.label().value() + "-TMP"));
                }
              }
            } else if (temporalAnnotation == TEMPORAL_ALL_NP) {
              String lab = t.label().value();
              if (NPTmpPattern.matcher(lab).matches()) {
                Tree oldT = t;
                Tree ht;
                do {
                  ht = headFinder.determineHead(oldT);
                  // special fix for possessives! -- make noun before head
                  if (ht.label().value().equals("POS")) {
                    int j = oldT.objectIndexOf(ht);
                    if (j > 0) {
                      ht = oldT.getChild(j - 1);
                    }
                  }
                  if (ht.isPreTerminal() || ht.value().startsWith("NP")) {
                    LabelFactory lf = ht.labelFactory();
                    // Note: this changes the tree label, rather than
                    // creating a new tree node.  Beware!
                    ht.setLabel(lf.newLabel(ht.label().value() + "-TMP"));
                    oldT = ht;
                  }
                } while (ht.value().startsWith("NP"));
              }
            } else if (temporalAnnotation == TEMPORAL_ALL_NP_AND_PP
                || temporalAnnotation == TEMPORAL_NP_AND_PP_WITH_NP_HEAD
                || temporalAnnotation == TEMPORAL_ALL_NP_EVEN_UNDER_PP) {
              // also allow chain to start with PP
              String lab = t.value();
              if (NPTmpPattern.matcher(lab).matches() || PPTmpPattern.matcher(lab).matches()) {
                Tree oldT = t;
                do {
                  Tree ht = headFinder.determineHead(oldT);
                  // special fix for possessives! -- make noun before head
                  if (ht.value().equals("POS")) {
                    int j = oldT.objectIndexOf(ht);
                    if (j > 0) {
                      ht = oldT.getChild(j - 1);
                    }
                  } else if ((temporalAnnotation == TEMPORAL_NP_AND_PP_WITH_NP_HEAD
                          || temporalAnnotation == TEMPORAL_ALL_NP_EVEN_UNDER_PP)
                      && (ht.value().equals("IN") || ht.value().equals("TO"))) {
                    // change the head to be NP if possible
                    Tree[] kidlets = oldT.children();
                    for (int k = kidlets.length - 1; k > 0; k--) {
                      if (kidlets[k].value().startsWith("NP")) {
                        ht = kidlets[k];
                      }
                    }
                  }
                  LabelFactory lf = ht.labelFactory();
                  // Note: this next bit changes the tree label, rather
                  // than creating a new tree node.  Beware!
                  if (ht.isPreTerminal() || ht.value().startsWith("NP")) {
                    ht.setLabel(lf.newLabel(ht.value() + "-TMP"));
                  }
                  if (temporalAnnotation == TEMPORAL_ALL_NP_EVEN_UNDER_PP
                      && oldT.value().startsWith("PP")) {
                    oldT.setLabel(lf.newLabel(tlp.basicCategory(oldT.value())));
                  }
                  oldT = ht;
                } while (oldT.value().startsWith("NP") || oldT.value().startsWith("PP"));
              }
            } else if (temporalAnnotation == TEMPORAL_ALL_NP_PP_ADVP) {
              // also allow chain to start with PP or ADVP
              String lab = t.value();
              if (NPTmpPattern.matcher(lab).matches()
                  || PPTmpPattern.matcher(lab).matches()
                  || ADVPTmpPattern.matcher(lab).matches()) {
                Tree oldT = t;
                do {
                  Tree ht = headFinder.determineHead(oldT);
                  // special fix for possessives! -- make noun before head
                  if (ht.value().equals("POS")) {
                    int j = oldT.objectIndexOf(ht);
                    if (j > 0) {
                      ht = oldT.getChild(j - 1);
                    }
                  }
                  // Note: this next bit changes the tree label, rather
                  // than creating a new tree node.  Beware!
                  if (ht.isPreTerminal() || ht.value().startsWith("NP")) {
                    LabelFactory lf = ht.labelFactory();
                    ht.setLabel(lf.newLabel(ht.value() + "-TMP"));
                  }
                  oldT = ht;
                } while (oldT.value().startsWith("NP"));
              }
            } else if (temporalAnnotation == TEMPORAL_9) {
              // also allow chain to start with PP or ADVP
              String lab = t.value();
              if (NPTmpPattern.matcher(lab).matches()
                  || PPTmpPattern.matcher(lab).matches()
                  || ADVPTmpPattern.matcher(lab).matches()) {
                // System.err.println("TMP: Annotating " + t);
                addTMP9(t);
              }
            } else if (temporalAnnotation == TEMPORAL_ACL03PCFG) {
              String lab = t.label().value();
              if (lab != null && NPTmpPattern.matcher(lab).matches()) {
                Tree oldT = t;
                Tree ht;
                do {
                  ht = headFinder.determineHead(oldT);
                  // special fix for possessives! -- make noun before head
                  if (ht.label().value().equals("POS")) {
                    int j = oldT.objectIndexOf(ht);
                    if (j > 0) {
                      ht = oldT.getChild(j - 1);
                    }
                  }
                  oldT = ht;
                } while (!ht.isPreTerminal());
                if (!onlyTagAnnotateNstar || ht.label().value().startsWith("N")) {
                  LabelFactory lf = ht.label().labelFactory();
                  // Note: this changes the tree label, rather than
                  // creating a new tree node.  Beware!
                  ht.setLabel(lf.newLabel(ht.label().value() + "-TMP"));
                }
              }
            }
            if (doAdverbialNP) {
              String lab = t.value();
              if (NPAdvPattern.matcher(lab).matches()) {
                Tree oldT = t;
                Tree ht;
                do {
                  ht = headFinder.determineHead(oldT);
                  // special fix for possessives! -- make noun before head
                  if (ht.label().value().equals("POS")) {
                    int j = oldT.objectIndexOf(ht);
                    if (j > 0) {
                      ht = oldT.getChild(j - 1);
                    }
                  }
                  if (ht.isPreTerminal() || ht.value().startsWith("NP")) {
                    LabelFactory lf = ht.labelFactory();
                    // Note: this changes the tree label, rather than
                    // creating a new tree node.  Beware!
                    ht.setLabel(lf.newLabel(ht.label().value() + "-ADV"));
                    oldT = ht;
                  }
                } while (ht.value().startsWith("NP"));
              }
            }
            return t;
          }
        };
    // if there wasn't an empty nonterminal at the top, but an S, wrap it.
    if (tree.label().value().equals("S")) {
      tree = tf.newTreeNode("ROOT", Collections.singletonList(tree));
    }
    // repair for the phrasal VB in Switchboard (PTB version 3) that should be a VP
    for (Tree subtree : tree) {
      if (subtree.isPhrasal() && "VB".equals(subtree.label().value())) {
        subtree.setValue("VP");
      }
    }
    tree = tree.transform(transformer1);
    if (tree == null) {
      return null;
    }
    tree = tree.prune(subtreeFilter, tf);
    if (tree == null) {
      return null;
    }
    tree = tree.spliceOut(nodeFilter, tf);
    if (tree == null) {
      return null;
    }
    return tree.transform(transformer2, tf);
  }
Esempio n. 17
0
  private void backpropDerivativesAndError(
      Tree tree,
      MultiDimensionalMap<String, String, FloatMatrix> binaryTD,
      MultiDimensionalMap<String, String, FloatMatrix> binaryCD,
      MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD,
      Map<String, FloatMatrix> unaryCD,
      Map<String, FloatMatrix> wordVectorD,
      FloatMatrix deltaUp) {
    if (tree.isLeaf()) {
      return;
    }

    FloatMatrix currentVector = tree.vector();
    String category = tree.label();
    category = basicCategory(category);

    // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
    FloatMatrix goldLabel = new FloatMatrix(numOuts, 1);
    int goldClass = tree.goldLabel();
    if (goldClass >= 0) {
      goldLabel.put(goldClass, 1.0f);
    }

    Float nodeWeight = classWeights.get(goldClass);
    if (nodeWeight == null) nodeWeight = 1.0f;
    FloatMatrix predictions = tree.prediction();

    // If this is an unlabeled class, set deltaClass to 0.  We could
    // make this more efficient by eliminating various of the below
    // calculations, but this would be the easiest way to handle the
    // unlabeled class
    FloatMatrix deltaClass =
        goldClass >= 0
            ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel))
            : new FloatMatrix(predictions.rows, predictions.columns);
    FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose());

    float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum());
    error = error * nodeWeight;
    tree.setError(error);

    if (tree.isPreTerminal()) { // below us is a word vector
      unaryCD.put(category, unaryCD.get(category).add(localCD));

      String word = tree.children().get(0).label();
      word = getVocabWord(word);

      FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector);
      FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass);
      deltaFromClass =
          deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative);
      FloatMatrix deltaFull = deltaFromClass.add(deltaUp);
      wordVectorD.put(word, wordVectorD.get(word).add(deltaFull));

    } else {
      // Otherwise, this must be a binary node
      String leftCategory = basicCategory(tree.children().get(0).label());
      String rightCategory = basicCategory(tree.children().get(1).label());
      if (combineClassification) {
        unaryCD.put("", unaryCD.get("").add(localCD));
      } else {
        binaryCD.put(
            leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD));
      }

      FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector);
      FloatMatrix deltaFromClass =
          getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);

      FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1));
      deltaFromClass = mult.muli(currentVectorDerivative);
      FloatMatrix deltaFull = deltaFromClass.add(deltaUp);

      FloatMatrix leftVector = tree.children().get(0).vector();
      FloatMatrix rightVector = tree.children().get(1).vector();

      FloatMatrix childrenVector = appendBias(leftVector, rightVector);

      // deltaFull 50 x 1, childrenVector: 50 x 2
      FloatMatrix add = binaryTD.get(leftCategory, rightCategory);

      FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose());
      binaryTD.put(leftCategory, rightCategory, add.add(W_df));

      FloatMatrix deltaDown;
      if (useFloatTensors) {
        FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector);
        binaryFloatTensorTD.put(
            leftCategory,
            rightCategory,
            binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df));
        deltaDown =
            computeFloatTensorDeltaDown(
                deltaFull,
                leftVector,
                rightVector,
                getBinaryTransform(leftCategory, rightCategory),
                getBinaryFloatTensor(leftCategory, rightCategory));
      } else {
        deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
      }

      FloatMatrix leftDerivative = activationFunction.apply(leftVector);
      FloatMatrix rightDerivative = activationFunction.apply(rightVector);
      FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1));
      FloatMatrix rightDeltaDown =
          deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1));
      backpropDerivativesAndError(
          tree.children().get(0),
          binaryTD,
          binaryCD,
          binaryFloatTensorTD,
          unaryCD,
          wordVectorD,
          leftDerivative.mul(leftDeltaDown));
      backpropDerivativesAndError(
          tree.children().get(1),
          binaryTD,
          binaryCD,
          binaryFloatTensorTD,
          unaryCD,
          wordVectorD,
          rightDerivative.mul(rightDeltaDown));
    }
  }
  /**
   * transformTree does all language-specific tree transformations. Any parameterizations should be
   * inside the specific TreebankLangParserParams class.
   */
  @Override
  public Tree transformTree(Tree t, Tree root) {
    if (t == null || t.isLeaf()) {
      return t;
    }

    String parentStr;
    String grandParentStr;
    Tree parent;
    Tree grandParent;
    if (root == null || t.equals(root)) {
      parent = null;
      parentStr = "";
    } else {
      parent = t.parent(root);
      parentStr = parent.label().value();
    }
    if (parent == null || parent.equals(root)) {
      grandParent = null;
      grandParentStr = "";
    } else {
      grandParent = parent.parent(root);
      grandParentStr = grandParent.label().value();
    }

    String baseParentStr = ctlp.basicCategory(parentStr);
    String baseGrandParentStr = ctlp.basicCategory(grandParentStr);

    CoreLabel lab = (CoreLabel) t.label();
    String word = lab.word();
    String tag = lab.tag();
    String baseTag = ctlp.basicCategory(tag);
    String category = lab.value();
    String baseCategory = ctlp.basicCategory(category);

    if (t.isPreTerminal()) { // it's a POS tag
      List<String> leftAunts =
          listBasicCategories(SisterAnnotationStats.leftSisterLabels(parent, grandParent));
      List<String> rightAunts =
          listBasicCategories(SisterAnnotationStats.rightSisterLabels(parent, grandParent));

      // Chinese-specific punctuation splits
      if (chineseSplitPunct && baseTag.equals("PU")) {
        if (ChineseTreebankLanguagePack.chineseDouHaoAcceptFilter().accept(word)) {
          tag = tag + "-DOU";
          // System.out.println("Punct: Split dou hao"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseCommaAcceptFilter().accept(word)) {
          tag = tag + "-COMMA";
          // System.out.println("Punct: Split comma"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseColonAcceptFilter().accept(word)) {
          tag = tag + "-COLON";
          // System.out.println("Punct: Split colon"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseQuoteMarkAcceptFilter().accept(word)) {
          if (chineseSplitPunctLR) {
            if (ChineseTreebankLanguagePack.chineseLeftQuoteMarkAcceptFilter().accept(word)) {
              tag += "-LQUOTE";
            } else {
              tag += "-RQUOTE";
            }
          } else {
            tag = tag + "-QUOTE";
          }
          // System.out.println("Punct: Split quote"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseEndSentenceAcceptFilter().accept(word)) {
          tag = tag + "-ENDSENT";
          // System.out.println("Punct: Split end sent"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseParenthesisAcceptFilter().accept(word)) {
          if (chineseSplitPunctLR) {
            if (ChineseTreebankLanguagePack.chineseLeftParenthesisAcceptFilter().accept(word)) {
              tag += "-LPAREN";
            } else {
              tag += "-RPAREN";
            }
          } else {
            tag += "-PAREN";
            // printlnErr("Just used -PAREN annotation");
            // printlnErr(word);
            // throw new RuntimeException();
          }
          // System.out.println("Punct: Split paren"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseDashAcceptFilter().accept(word)) {
          tag = tag + "-DASH";
          // System.out.println("Punct: Split dash"); // debugging
        } else if (ChineseTreebankLanguagePack.chineseOtherAcceptFilter().accept(word)) {
          tag = tag + "-OTHER";
        } else {
          printlnErr("Unknown punct (you should add it to CTLP): " + tag + " |" + word + "|");
        }
      } else if (chineseSplitDouHao) { // only split DouHao
        if (ChineseTreebankLanguagePack.chineseDouHaoAcceptFilter().accept(word)
            && baseTag.equals("PU")) {
          tag = tag + "-DOU";
        }
      }

      // Chinese-specific POS tag splits (non-punctuation)

      if (tagWordSize) {
        int l = word.length();
        tag += "-" + l + "CHARS";
      }

      if (mergeNNVV && baseTag.equals("NN")) {
        tag = "VV";
      }

      if ((chineseSelectiveTagPA || chineseVerySelectiveTagPA)
          && (baseTag.equals("CC") || baseTag.equals("P"))) {
        tag += "-" + baseParentStr;
      }
      if (chineseSelectiveTagPA && (baseTag.equals("VV"))) {
        tag += "-" + baseParentStr;
      }

      if (markMultiNtag && tag.startsWith("N")) {
        for (int i = 0; i < parent.numChildren(); i++) {
          if (parent.children()[i].label().value().startsWith("N") && parent.children()[i] != t) {
            tag += "=N";
            // System.out.println("Found multi=N rewrite");
          }
        }
      }

      if (markVVsisterIP && baseTag.equals("VV")) {
        boolean seenIP = false;
        for (int i = 0; i < parent.numChildren(); i++) {
          if (parent.children()[i].label().value().startsWith("IP")) {
            seenIP = true;
          }
        }
        if (seenIP) {
          tag += "-IP";
          // System.out.println("Found VV with IP sister"); // testing
        }
      }

      if (markPsisterIP && baseTag.equals("P")) {
        boolean seenIP = false;
        for (int i = 0; i < parent.numChildren(); i++) {
          if (parent.children()[i].label().value().startsWith("IP")) {
            seenIP = true;
          }
        }
        if (seenIP) {
          tag += "-IP";
        }
      }

      if (markADgrandchildOfIP && baseTag.equals("AD") && baseGrandParentStr.equals("IP")) {
        tag += "~IP";
        // System.out.println("Found AD with IP grandparent"); // testing
      }

      if (gpaAD && baseTag.equals("AD")) {
        tag += "~" + baseGrandParentStr;
        // System.out.println("Found AD with grandparent " + grandParentStr); // testing
      }

      if (markPostverbalP && leftAunts.contains("VV") && baseTag.equals("P")) {
        // System.out.println("Found post-verbal P");
        tag += "^=lVV";
      }

      // end Chinese-specific tag splits

      Label label = new CategoryWordTag(tag, word, tag);
      t.setLabel(label);
    } else {
      // it's a phrasal category
      Tree[] kids = t.children();

      // Chinese-specific category splits
      List<String> leftSis = listBasicCategories(SisterAnnotationStats.leftSisterLabels(t, parent));
      List<String> rightSis =
          listBasicCategories(SisterAnnotationStats.rightSisterLabels(t, parent));

      if (paRootDtr && baseParentStr.equals("ROOT")) {
        category += "^ROOT";
      }

      if (markIPsisterBA && baseCategory.equals("IP")) {
        if (leftSis.contains("BA")) {
          category += "=BA";
          // System.out.println("Found IP sister of BA");
        }
      }

      if (dominatesV && hasV(t.preTerminalYield())) {
        // mark categories containing a verb
        category += "-v";
      }

      if (markIPsisterVVorP && baseCategory.equals("IP")) {
        // todo: cdm: is just looking for "P" here selective enough??
        if (leftSis.contains("VV") || leftSis.contains("P")) {
          category += "=VVP";
        }
      }

      if (markIPsisDEC && baseCategory.equals("IP")) {
        if (rightSis.contains("DEC")) {
          category += "=DEC";
          // System.out.println("Found prenominal IP");
        }
      }

      if (baseCategory.equals("VP")) {
        // cdm 2008: this used to just check that it startsWith("VP"), but
        // I think that was bad because it also matched VPT verb compounds
        if (chineseSplitVP == 3) {
          boolean hasCC = false;
          boolean hasPU = false;
          boolean hasLexV = false;
          for (Tree kid : kids) {
            if (kid.label().value().startsWith("CC")) {
              hasCC = true;
            } else if (kid.label().value().startsWith("PU")) {
              hasPU = true;
            } else if (StringUtils.lookingAt(
                kid.label().value(), "(V[ACEV]|VCD|VCP|VNV|VPT|VRD|VSB)")) {
              hasLexV = true;
            }
          }
          if (hasCC || (hasPU && !hasLexV)) {
            category += "-CRD";
            // System.out.println("Found coordinate VP"); // testing
          } else if (hasLexV) {
            category += "-COMP";
            // System.out.println("Found complementing VP"); // testing
          } else {
            category += "-ADJT";
            // System.out.println("Found adjoining VP"); // testing
          }
        } else if (chineseSplitVP >= 1) {
          boolean hasBA = false;
          for (Tree kid : kids) {
            if (kid.label().value().startsWith("BA")) {
              hasBA = true;
            } else if (chineseSplitVP == 2 && tlp.basicCategory(kid.label().value()).equals("VP")) {
              for (Tree kidkid : kid.children()) {
                if (kidkid.label().value().startsWith("BA")) {
                  hasBA = true;
                }
              }
            }
          }
          if (hasBA) {
            category += "-BA";
          }
        }
      }

      if (markVPadjunct && baseParentStr.equals("VP")) {
        // cdm 2008: This used to use startsWith("VP") but changed to baseCat
        Tree[] sisters = parent.children();
        boolean hasVPsister = false;
        boolean hasCC = false;
        boolean hasPU = false;
        boolean hasLexV = false;
        for (Tree sister : sisters) {
          if (tlp.basicCategory(sister.label().value()).equals("VP")) {
            hasVPsister = true;
          }
          if (sister.label().value().startsWith("CC")) {
            hasCC = true;
          }
          if (sister.label().value().startsWith("PU")) {
            hasPU = true;
          }
          if (StringUtils.lookingAt(sister.label().value(), "(V[ACEV]|VCD|VCP|VNV|VPT|VRD|VSB)")) {
            hasLexV = true;
          }
        }
        if (hasVPsister && !(hasCC || hasPU || hasLexV)) {
          category += "-VPADJ";
          // System.out.println("Found adjunct of VP"); // testing
        }
      }

      if (markNPmodNP && baseCategory.equals("NP") && baseParentStr.equals("NP")) {
        if (rightSis.contains("NP")) {
          category += "=MODIFIERNP";
          // System.out.println("Found NP modifier of NP"); // testing
        }
      }

      if (markModifiedNP && baseCategory.equals("NP") && baseParentStr.equals("NP")) {
        if (rightSis.isEmpty()
            && (leftSis.contains("ADJP")
                || leftSis.contains("NP")
                || leftSis.contains("DNP")
                || leftSis.contains("QP")
                || leftSis.contains("CP")
                || leftSis.contains("PP"))) {
          category += "=MODIFIEDNP";
          // System.out.println("Found modified NP"); // testing
        }
      }

      if (markNPconj && baseCategory.equals("NP") && baseParentStr.equals("NP")) {
        if (rightSis.contains("CC")
            || rightSis.contains("PU")
            || leftSis.contains("CC")
            || leftSis.contains("PU")) {
          category += "=CONJ";
          // System.out.println("Found NP conjunct"); // testing
        }
      }

      if (markIPconj && baseCategory.equals("IP") && baseParentStr.equals("IP")) {
        Tree[] sisters = parent.children();
        boolean hasCommaSis = false;
        boolean hasIPSis = false;
        for (Tree sister : sisters) {
          if (ctlp.basicCategory(sister.label().value()).equals("PU")
              && ChineseTreebankLanguagePack.chineseCommaAcceptFilter()
                  .accept(sister.children()[0].label().toString())) {
            hasCommaSis = true;
            // System.out.println("Found CommaSis"); // testing
          }
          if (ctlp.basicCategory(sister.label().value()).equals("IP") && sister != t) {
            hasIPSis = true;
          }
        }
        if (hasCommaSis && hasIPSis) {
          category += "-CONJ";
          // System.out.println("Found IP conjunct"); // testing
        }
      }

      if (unaryIP && baseCategory.equals("IP") && t.numChildren() == 1) {
        category += "-U";
        // System.out.println("Found unary IP"); //testing
      }
      if (unaryCP && baseCategory.equals("CP") && t.numChildren() == 1) {
        category += "-U";
        // System.out.println("Found unary CP"); //testing
      }

      if (splitBaseNP && baseCategory.equals("NP")) {
        if (t.isPrePreTerminal()) {
          category = category + "-B";
        }
      }

      // if (Test.verbose) printlnErr(baseCategory + " " + leftSis.toString()); //debugging

      if (markPostverbalPP && leftSis.contains("VV") && baseCategory.equals("PP")) {
        // System.out.println("Found post-verbal PP");
        category += "=lVV";
      }

      if ((markADgrandchildOfIP || gpaAD)
          && listBasicCategories(SisterAnnotationStats.kidLabels(t)).contains("AD")) {
        category += "^ADVP";
      }

      if (markCC) {
        // was: for (int i = 0; i < kids.length; i++) {
        // This second version takes an idea from Collins: don't count
        // marginal conjunctions which don't conjoin 2 things.
        for (int i = 1; i < kids.length - 1; i++) {
          String cat2 = kids[i].label().value();
          if (cat2.startsWith("CC")) {
            category += "-CC";
          }
        }
      }

      Label label = new CategoryWordTag(category, word, tag);
      t.setLabel(label);
    }
    return t;
  }