Ejemplo n.º 1
0
  private void forwardPropagateTree(
      Tree tree,
      List<String> words,
      IdentityHashMap<Tree, SimpleMatrix> nodeVectors,
      IdentityHashMap<Tree, Double> scores) {
    if (tree.isLeaf()) {
      return;
    }

    if (tree.isPreTerminal()) {
      Tree wordNode = tree.children()[0];
      String word = wordNode.label().value();
      SimpleMatrix wordVector = dvModel.getWordVector(word);
      wordVector = NeuralUtils.elementwiseApplyTanh(wordVector);
      nodeVectors.put(tree, wordVector);
      return;
    }

    for (Tree child : tree.children()) {
      forwardPropagateTree(child, words, nodeVectors, scores);
    }

    // at this point, nodeVectors contains the vectors for all of
    // the children of tree

    SimpleMatrix childVec;
    if (tree.children().length == 2) {
      childVec =
          NeuralUtils.concatenateWithBias(
              nodeVectors.get(tree.children()[0]), nodeVectors.get(tree.children()[1]));
    } else {
      childVec = NeuralUtils.concatenateWithBias(nodeVectors.get(tree.children()[0]));
    }
    if (op.trainOptions.useContextWords) {
      childVec = concatenateContextWords(childVec, tree.getSpan(), words);
    }

    SimpleMatrix W = dvModel.getWForNode(tree);
    if (W == null) {
      String error = "Could not find W for tree " + tree;
      if (op.testOptions.verbose) {
        System.err.println(error);
      }
      throw new NoSuchParseException(error);
    }
    SimpleMatrix currentVector = W.mult(childVec);
    currentVector = NeuralUtils.elementwiseApplyTanh(currentVector);
    nodeVectors.put(tree, currentVector);

    SimpleMatrix scoreW = dvModel.getScoreWForNode(tree);
    if (scoreW == null) {
      String error = "Could not find scoreW for tree " + tree;
      if (op.testOptions.verbose) {
        System.err.println(error);
      }
      throw new NoSuchParseException(error);
    }
    double score = scoreW.dot(currentVector);
    // score = NeuralUtils.sigmoid(score);
    scores.put(tree, score);
    // System.err.print(Double.toString(score)+" ");
  }