Esempio n. 1
0
  /**
   * This is the method to call for assigning labels and node vectors to the Tree. After calling
   * this, each of the non-leaf nodes will have the node vector and the predictions of their classes
   * assigned to that subtree's node.
   */
  public void forwardPropagateTree(Tree tree) {
    INDArray nodeVector;
    INDArray classification;

    if (tree.isLeaf()) {
      // We do nothing for the leaves.  The preterminals will
      // calculate the classification for this word/tag.  In fact, the
      // recursion should not have gotten here (unless there are
      // degenerate trees of just one leaf)
      throw new AssertionError("We should not have reached leaves in forwardPropagate");
    } else if (tree.isPreTerminal()) {
      classification = getUnaryClassification(tree.label());
      String word = tree.children().get(0).value();
      INDArray wordVector = getFeatureVector(word);
      if (wordVector == null) {
        wordVector = featureVectors.vector(Word2Vec.UNK);
      }

      nodeVector =
          Nd4j.getExecutioner()
              .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, wordVector));
    } else if (tree.children().size() == 1) {
      throw new AssertionError(
          "Non-preterminal nodes of size 1 should have already been collapsed");
    } else if (tree.children().size() == 2) {
      Tree left = tree.firstChild(), right = tree.lastChild();
      forwardPropagateTree(left);
      forwardPropagateTree(right);

      String leftCategory = tree.children().get(0).label();
      String rightCategory = tree.children().get(1).label();
      INDArray W = getBinaryTransform(leftCategory, rightCategory);
      classification = getBinaryClassification(leftCategory, rightCategory);

      INDArray leftVector = tree.children().get(0).vector();
      INDArray rightVector = tree.children().get(1).vector();

      INDArray childrenVector = Nd4j.appendBias(leftVector, rightVector);

      if (useDoubleTensors) {
        INDArray doubleT = getBinaryINDArray(leftCategory, rightCategory);
        INDArray INDArrayIn = Nd4j.concat(0, leftVector, rightVector);
        INDArray INDArrayOut = Nd4j.bilinearProducts(doubleT, INDArrayIn);
        nodeVector =
            Nd4j.getExecutioner()
                .execAndReturn(
                    Nd4j.getOpFactory()
                        .createTransform(
                            activationFunction, W.mmul(childrenVector).addi(INDArrayOut)));

      } else
        nodeVector =
            Nd4j.getExecutioner()
                .execAndReturn(
                    Nd4j.getOpFactory()
                        .createTransform(activationFunction, W.mmul(childrenVector)));

    } else {
      throw new AssertionError("Tree not correctly binarized");
    }

    INDArray inputWithBias = Nd4j.appendBias(nodeVector);
    if (inputWithBias.rows() != classification.columns()) inputWithBias = inputWithBias.transpose();
    INDArray preAct = classification.mmul(inputWithBias);
    INDArray predictions =
        Nd4j.getExecutioner()
            .execAndReturn(Nd4j.getOpFactory().createTransform(outputActivation, preAct));

    tree.setPrediction(predictions);
    tree.setVector(nodeVector);
  }
Esempio n. 2
0
  private void backpropDerivativesAndError(
      Tree tree,
      MultiDimensionalMap<String, String, INDArray> binaryTD,
      MultiDimensionalMap<String, String, INDArray> binaryCD,
      MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD,
      Map<String, INDArray> unaryCD,
      Map<String, INDArray> wordVectorD,
      INDArray deltaUp) {
    if (tree.isLeaf()) {
      return;
    }

    INDArray currentVector = tree.vector();
    String category = tree.label();
    category = basicCategory(category);

    // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
    INDArray goldLabel = Nd4j.create(numOuts, 1);
    int goldClass = tree.goldLabel();
    if (goldClass >= 0) {
      assert goldClass <= numOuts
          : "Tried adding a label that was >= to the number of configured outputs "
              + numOuts
              + " with label "
              + goldClass;
      goldLabel.putScalar(goldClass, 1.0f);
    }

    Double nodeWeight = classWeights.get(goldClass);
    if (nodeWeight == null) nodeWeight = 1.0;
    INDArray predictions = tree.prediction();

    // If this is an unlabeled class, transform deltaClass to 0.  We could
    // make this more efficient by eliminating various of the below
    // calculations, but this would be the easiest way to handle the
    // unlabeled class
    INDArray deltaClass = null;
    if (predictions.data().dataType() == DataBuffer.Type.DOUBLE) {
      deltaClass =
          goldClass >= 0
              ? Nd4j.getBlasWrapper().scal(nodeWeight, predictions.sub(goldLabel))
              : Nd4j.create(predictions.rows(), predictions.columns());

    } else {
      deltaClass =
          goldClass >= 0
              ? Nd4j.getBlasWrapper()
                  .scal((float) nodeWeight.doubleValue(), predictions.sub(goldLabel))
              : Nd4j.create(predictions.rows(), predictions.columns());
    }
    INDArray localCD = deltaClass.mmul(Nd4j.appendBias(currentVector).transpose());

    double error =
        -(Transforms.log(predictions).muli(goldLabel).sum(Integer.MAX_VALUE).getDouble(0));
    error = error * nodeWeight;
    tree.setError(error);

    if (tree.isPreTerminal()) { // below us is a word vector
      unaryCD.put(category, unaryCD.get(category).add(localCD));

      String word = tree.children().get(0).label();
      word = getVocabWord(word);

      INDArray currentVectorDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(
                  Nd4j.getOpFactory().createTransform(activationFunction, currentVector));
      INDArray deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass);
      deltaFromClass =
          deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative);
      INDArray deltaFull = deltaFromClass.add(deltaUp);
      INDArray wordVector = wordVectorD.get(word);
      wordVectorD.put(word, wordVector.add(deltaFull));

    } else {
      // Otherwise, this must be a binary node
      String leftCategory = basicCategory(tree.children().get(0).label());
      String rightCategory = basicCategory(tree.children().get(1).label());
      if (combineClassification) {
        unaryCD.put("", unaryCD.get("").add(localCD));
      } else {
        binaryCD.put(
            leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD));
      }

      INDArray currentVectorDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(
                  Nd4j.getOpFactory().createTransform(activationFunction, currentVector));
      INDArray deltaFromClass =
          getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);

      INDArray mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1));
      deltaFromClass = mult.muli(currentVectorDerivative);
      INDArray deltaFull = deltaFromClass.add(deltaUp);

      INDArray leftVector = tree.children().get(0).vector();
      INDArray rightVector = tree.children().get(1).vector();

      INDArray childrenVector = Nd4j.appendBias(leftVector, rightVector);

      // deltaFull 50 x 1, childrenVector: 50 x 2
      INDArray add = binaryTD.get(leftCategory, rightCategory);

      INDArray W_df = deltaFromClass.mmul(childrenVector.transpose());
      binaryTD.put(leftCategory, rightCategory, add.add(W_df));

      INDArray deltaDown;
      if (useDoubleTensors) {
        INDArray Wt_df = getINDArrayGradient(deltaFull, leftVector, rightVector);
        binaryINDArrayTD.put(
            leftCategory,
            rightCategory,
            binaryINDArrayTD.get(leftCategory, rightCategory).add(Wt_df));
        deltaDown =
            computeINDArrayDeltaDown(
                deltaFull,
                leftVector,
                rightVector,
                getBinaryTransform(leftCategory, rightCategory),
                getBinaryINDArray(leftCategory, rightCategory));
      } else {
        deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
      }

      INDArray leftDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, leftVector));
      INDArray rightDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, rightVector));
      INDArray leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows()), interval(0, 1));
      INDArray rightDeltaDown =
          deltaDown.get(interval(deltaFull.rows(), deltaFull.rows() * 2), interval(0, 1));
      backpropDerivativesAndError(
          tree.children().get(0),
          binaryTD,
          binaryCD,
          binaryINDArrayTD,
          unaryCD,
          wordVectorD,
          leftDerivative.mul(leftDeltaDown));
      backpropDerivativesAndError(
          tree.children().get(1),
          binaryTD,
          binaryCD,
          binaryINDArrayTD,
          unaryCD,
          wordVectorD,
          rightDerivative.mul(rightDeltaDown));
    }
  }