Beispiel #1
0
  /**
   * This is the method to call for assigning labels and node vectors to the Tree. After calling
   * this, each of the non-leaf nodes will have the node vector and the predictions of their classes
   * assigned to that subtree's node.
   */
  public void forwardPropagateTree(Tree tree) {
    FloatMatrix nodeVector;
    FloatMatrix classification;

    if (tree.isLeaf()) {
      // We do nothing for the leaves.  The preterminals will
      // calculate the classification for this word/tag.  In fact, the
      // recursion should not have gotten here (unless there are
      // degenerate trees of just one leaf)
      throw new AssertionError("We should not have reached leaves in forwardPropagate");
    } else if (tree.isPreTerminal()) {
      classification = getUnaryClassification(tree.label());
      String word = tree.children().get(0).value();
      FloatMatrix wordVector = getFeatureVector(word);
      if (wordVector == null) {
        wordVector = featureVectors.get(UNKNOWN_FEATURE);
      }

      nodeVector = activationFunction.apply(wordVector);
    } else if (tree.children().size() == 1) {
      throw new AssertionError(
          "Non-preterminal nodes of size 1 should have already been collapsed");
    } else if (tree.children().size() == 2) {
      Tree left = tree.firstChild(), right = tree.lastChild();
      forwardPropagateTree(left);
      forwardPropagateTree(right);

      String leftCategory = tree.children().get(0).label();
      String rightCategory = tree.children().get(1).label();
      FloatMatrix W = getBinaryTransform(leftCategory, rightCategory);
      classification = getBinaryClassification(leftCategory, rightCategory);

      FloatMatrix leftVector = tree.children().get(0).vector();
      FloatMatrix rightVector = tree.children().get(1).vector();

      FloatMatrix childrenVector = appendBias(leftVector, rightVector);

      if (useFloatTensors) {
        FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory);
        FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector);
        FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn);
        nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut));
      } else nodeVector = activationFunction.apply(W.mmul(childrenVector));

    } else {
      throw new AssertionError("Tree not correctly binarized");
    }

    FloatMatrix inputWithBias = appendBias(nodeVector);
    FloatMatrix preAct = classification.mmul(inputWithBias);
    FloatMatrix predictions = outputActivation.apply(preAct);

    tree.setPrediction(predictions);
    tree.setVector(nodeVector);
  }
Beispiel #2
0
 private FloatTensor getFloatTensorGradient(
     FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) {
   int size = deltaFull.length;
   FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size);
   FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector);
   for (int slice = 0; slice < size; ++slice) {
     Wt_df.setSlice(
         slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose()));
   }
   return Wt_df;
 }
Beispiel #3
0
 float scaleAndRegularizeFloatTensor(
     MultiDimensionalMap<String, String, FloatTensor> derivatives,
     MultiDimensionalMap<String, String, FloatTensor> currentMatrices,
     float scale,
     float regCost) {
   float cost = 0.0f; // the regularization cost
   for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry :
       currentMatrices.entrySet()) {
     FloatTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
     D = D.scale(scale).add(entry.getValue().scale(regCost));
     derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
     cost += entry.getValue().mul(entry.getValue()).sum() * regCost / 2.0;
   }
   return cost;
 }
Beispiel #4
0
 private FloatMatrix computeFloatTensorDeltaDown(
     FloatMatrix deltaFull,
     FloatMatrix leftVector,
     FloatMatrix rightVector,
     FloatMatrix W,
     FloatTensor Wt) {
   FloatMatrix WTDelta = W.transpose().mmul(deltaFull);
   FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2));
   int size = deltaFull.length;
   FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1);
   FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector);
   for (int slice = 0; slice < size; ++slice) {
     FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector);
     deltaFloatTensor =
         deltaFloatTensor.add(
             Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector));
   }
   return deltaFloatTensor.add(WTDeltaNoBias);
 }
Beispiel #5
0
 FloatTensor randomBinaryFloatTensor() {
   float range = 1.0f / (4.0f * numHidden);
   FloatTensor floatTensor =
       FloatTensor.rand(numHidden * 2, numHidden * 2, numHidden, -range, range, rng);
   return floatTensor.scale(scalingForInit);
 }