/**
   * This is the method to call for assigning labels and node vectors to the Tree. After calling
   * this, each of the non-leaf nodes will have the node vector and the predictions of their classes
   * assigned to that subtree's node.
   */
  public void forwardPropagateTree(Tree tree) {
    FloatMatrix nodeVector;
    FloatMatrix classification;

    if (tree.isLeaf()) {
      // We do nothing for the leaves.  The preterminals will
      // calculate the classification for this word/tag.  In fact, the
      // recursion should not have gotten here (unless there are
      // degenerate trees of just one leaf)
      throw new AssertionError("We should not have reached leaves in forwardPropagate");
    } else if (tree.isPreTerminal()) {
      classification = getUnaryClassification(tree.label());
      String word = tree.children().get(0).value();
      FloatMatrix wordVector = getFeatureVector(word);
      if (wordVector == null) {
        wordVector = featureVectors.get(UNKNOWN_FEATURE);
      }

      nodeVector = activationFunction.apply(wordVector);
    } else if (tree.children().size() == 1) {
      throw new AssertionError(
          "Non-preterminal nodes of size 1 should have already been collapsed");
    } else if (tree.children().size() == 2) {
      Tree left = tree.firstChild(), right = tree.lastChild();
      forwardPropagateTree(left);
      forwardPropagateTree(right);

      String leftCategory = tree.children().get(0).label();
      String rightCategory = tree.children().get(1).label();
      FloatMatrix W = getBinaryTransform(leftCategory, rightCategory);
      classification = getBinaryClassification(leftCategory, rightCategory);

      FloatMatrix leftVector = tree.children().get(0).vector();
      FloatMatrix rightVector = tree.children().get(1).vector();

      FloatMatrix childrenVector = appendBias(leftVector, rightVector);

      if (useFloatTensors) {
        FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory);
        FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector);
        FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn);
        nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut));
      } else nodeVector = activationFunction.apply(W.mmul(childrenVector));

    } else {
      throw new AssertionError("Tree not correctly binarized");
    }

    FloatMatrix inputWithBias = appendBias(nodeVector);
    FloatMatrix preAct = classification.mmul(inputWithBias);
    FloatMatrix predictions = outputActivation.apply(preAct);

    tree.setPrediction(predictions);
    tree.setVector(nodeVector);
  }
 private FloatTensor getFloatTensorGradient(
     FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) {
   int size = deltaFull.length;
   FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size);
   FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector);
   for (int slice = 0; slice < size; ++slice) {
     Wt_df.setSlice(
         slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose()));
   }
   return Wt_df;
 }
 private FloatMatrix computeFloatTensorDeltaDown(
     FloatMatrix deltaFull,
     FloatMatrix leftVector,
     FloatMatrix rightVector,
     FloatMatrix W,
     FloatTensor Wt) {
   FloatMatrix WTDelta = W.transpose().mmul(deltaFull);
   FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2));
   int size = deltaFull.length;
   FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1);
   FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector);
   for (int slice = 0; slice < size; ++slice) {
     FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector);
     deltaFloatTensor =
         deltaFloatTensor.add(
             Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector));
   }
   return deltaFloatTensor.add(WTDeltaNoBias);
 }
 /**
  * The feature to add, and the example/row number
  *
  * @param feature the feature vector to add
  * @param example the number of the example to append to
  */
 public void addFeatureVector(FloatMatrix feature, int example) {
   getFirst().putRow(example, FloatMatrix.concatHorizontally(getFirst().getRow(example), feature));
 }
 /**
  * Adds a feature for each example on to the current feature vector
  *
  * @param toAdd the feature vector to add
  */
 public void addFeatureVector(FloatMatrix toAdd) {
   setFirst(FloatMatrix.concatHorizontally(getFirst(), toAdd));
 }