Beispiel #1
0
  double scaleAndRegularize(
      MultiDimensionalMap<String, String, INDArray> derivatives,
      MultiDimensionalMap<String, String, INDArray> currentMatrices,
      double scale,
      double regCost) {

    double cost = 0.0f; // the regularization cost
    for (MultiDimensionalMap.Entry<String, String, INDArray> entry : currentMatrices.entrySet()) {
      INDArray D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
      if (D.data().dataType() == DataBuffer.Type.DOUBLE)
        D =
            Nd4j.getBlasWrapper()
                .scal(scale, D)
                .addi(Nd4j.getBlasWrapper().scal(regCost, entry.getValue()));
      else
        D =
            Nd4j.getBlasWrapper()
                .scal((float) scale, D)
                .addi(Nd4j.getBlasWrapper().scal((float) regCost, entry.getValue()));

      derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
      cost +=
          entry.getValue().mul(entry.getValue()).sum(Integer.MAX_VALUE).getDouble(0)
              * regCost
              / 2.0;
    }
    return cost;
  }
Beispiel #2
0
 float scaleAndRegularizeFloatTensor(
     MultiDimensionalMap<String, String, FloatTensor> derivatives,
     MultiDimensionalMap<String, String, FloatTensor> currentMatrices,
     float scale,
     float regCost) {
   float cost = 0.0f; // the regularization cost
   for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry :
       currentMatrices.entrySet()) {
     FloatTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
     D = D.scale(scale).add(entry.getValue().scale(regCost));
     derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
     cost += entry.getValue().mul(entry.getValue()).sum() * regCost / 2.0;
   }
   return cost;
 }
Beispiel #3
0
 double scaleAndRegularizeINDArray(
     MultiDimensionalMap<String, String, INDArray> derivatives,
     MultiDimensionalMap<String, String, INDArray> currentMatrices,
     double scale,
     double regCost) {
   double cost = 0.0f; // the regularization cost
   for (MultiDimensionalMap.Entry<String, String, INDArray> entry : currentMatrices.entrySet()) {
     INDArray D = derivatives.get(entry.getFirstKey(), entry.getSecondKey());
     D = D.muli(scale).add(entry.getValue().muli(regCost));
     derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D);
     cost +=
         entry.getValue().mul(entry.getValue()).sum(Integer.MAX_VALUE).getDouble(0)
             * regCost
             / 2.0f;
   }
   return cost;
 }
Beispiel #4
0
  public FloatMatrix getValueGradient(int iterations) {

    // We use TreeMap for each of these so that they stay in a
    // canonical sorted order
    // TODO: factor out the initialization routines
    // binaryTD stands for Transform Derivatives
    final MultiDimensionalMap<String, String, FloatMatrix> binaryTD =
        MultiDimensionalMap.newTreeBackedMap();
    // the derivatives of the FloatTensors for the binary nodes
    final MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD =
        MultiDimensionalMap.newTreeBackedMap();
    // binaryCD stands for Classification Derivatives
    final MultiDimensionalMap<String, String, FloatMatrix> binaryCD =
        MultiDimensionalMap.newTreeBackedMap();

    // unaryCD stands for Classification Derivatives
    final Map<String, FloatMatrix> unaryCD = new TreeMap<>();

    // word vector derivatives
    final Map<String, FloatMatrix> wordVectorD = new TreeMap<>();

    for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry :
        binaryTransform.entrySet()) {
      int numRows = entry.getValue().rows;
      int numCols = entry.getValue().columns;

      binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols));
    }

    if (!combineClassification) {
      for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry :
          binaryClassification.entrySet()) {
        int numRows = entry.getValue().rows;
        int numCols = entry.getValue().columns;

        binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols));
      }
    }

    if (useFloatTensors) {
      for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry :
          binaryFloatTensors.entrySet()) {
        int numRows = entry.getValue().rows();
        int numCols = entry.getValue().columns;
        int numSlices = entry.getValue().slices();

        binaryFloatTensorTD.put(
            entry.getFirstKey(),
            entry.getSecondKey(),
            new FloatTensor(numRows, numCols, numSlices));
      }
    }

    for (Map.Entry<String, FloatMatrix> entry : unaryClassification.entrySet()) {
      int numRows = entry.getValue().rows;
      int numCols = entry.getValue().columns;
      unaryCD.put(entry.getKey(), new FloatMatrix(numRows, numCols));
    }
    for (Map.Entry<String, FloatMatrix> entry : featureVectors.entrySet()) {
      int numRows = entry.getValue().rows;
      int numCols = entry.getValue().columns;
      wordVectorD.put(entry.getKey(), new FloatMatrix(numRows, numCols));
    }

    final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>();
    Parallelization.iterateInParallel(
        trainingTrees,
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {
            Tree trainingTree = new Tree(currentItem);
            trainingTree.connect(new ArrayList<>(currentItem.children()));
            // this will attach the error vectors and the node vectors
            // to each node in the tree
            forwardPropagateTree(trainingTree);
            forwardPropTrees.add(trainingTree);
          }
        },
        rnTnActorSystem);

    // TODO: we may find a big speedup by separating the derivatives and then summing
    final AtomicDouble error = new AtomicDouble(0);
    Parallelization.iterateInParallel(
        forwardPropTrees,
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {
            backpropDerivativesAndError(
                currentItem, binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD);
            error.addAndGet(currentItem.errorSum());
          }
        },
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {}
        },
        rnTnActorSystem,
        new Object[] {binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD});

    // scale the error by the number of sentences so that the
    // regularization isn't drowned out for large training batchs
    float scale = (1.0f / trainingTrees.size());
    value = error.floatValue() * scale;

    value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix);
    value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification);
    value +=
        scaleAndRegularizeFloatTensor(
            binaryFloatTensorTD, binaryFloatTensors, scale, regTransformFloatTensor);
    value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification);
    value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector);

    FloatMatrix derivative =
        MatrixUtil.toFlattenedFloat(
            getNumParameters(),
            binaryTD.values().iterator(),
            binaryCD.values().iterator(),
            binaryFloatTensorTD.values().iterator(),
            unaryCD.values().iterator(),
            wordVectorD.values().iterator());

    if (paramAdaGrad == null) paramAdaGrad = new AdaGradFloat(1, derivative.columns);

    derivative.muli(paramAdaGrad.getLearningRates(derivative));

    return derivative;
  }
Beispiel #5
0
  private void backpropDerivativesAndError(
      Tree tree,
      MultiDimensionalMap<String, String, FloatMatrix> binaryTD,
      MultiDimensionalMap<String, String, FloatMatrix> binaryCD,
      MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD,
      Map<String, FloatMatrix> unaryCD,
      Map<String, FloatMatrix> wordVectorD,
      FloatMatrix deltaUp) {
    if (tree.isLeaf()) {
      return;
    }

    FloatMatrix currentVector = tree.vector();
    String category = tree.label();
    category = basicCategory(category);

    // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
    FloatMatrix goldLabel = new FloatMatrix(numOuts, 1);
    int goldClass = tree.goldLabel();
    if (goldClass >= 0) {
      goldLabel.put(goldClass, 1.0f);
    }

    Float nodeWeight = classWeights.get(goldClass);
    if (nodeWeight == null) nodeWeight = 1.0f;
    FloatMatrix predictions = tree.prediction();

    // If this is an unlabeled class, set deltaClass to 0.  We could
    // make this more efficient by eliminating various of the below
    // calculations, but this would be the easiest way to handle the
    // unlabeled class
    FloatMatrix deltaClass =
        goldClass >= 0
            ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel))
            : new FloatMatrix(predictions.rows, predictions.columns);
    FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose());

    float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum());
    error = error * nodeWeight;
    tree.setError(error);

    if (tree.isPreTerminal()) { // below us is a word vector
      unaryCD.put(category, unaryCD.get(category).add(localCD));

      String word = tree.children().get(0).label();
      word = getVocabWord(word);

      FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector);
      FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass);
      deltaFromClass =
          deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative);
      FloatMatrix deltaFull = deltaFromClass.add(deltaUp);
      wordVectorD.put(word, wordVectorD.get(word).add(deltaFull));

    } else {
      // Otherwise, this must be a binary node
      String leftCategory = basicCategory(tree.children().get(0).label());
      String rightCategory = basicCategory(tree.children().get(1).label());
      if (combineClassification) {
        unaryCD.put("", unaryCD.get("").add(localCD));
      } else {
        binaryCD.put(
            leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD));
      }

      FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector);
      FloatMatrix deltaFromClass =
          getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);

      FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1));
      deltaFromClass = mult.muli(currentVectorDerivative);
      FloatMatrix deltaFull = deltaFromClass.add(deltaUp);

      FloatMatrix leftVector = tree.children().get(0).vector();
      FloatMatrix rightVector = tree.children().get(1).vector();

      FloatMatrix childrenVector = appendBias(leftVector, rightVector);

      // deltaFull 50 x 1, childrenVector: 50 x 2
      FloatMatrix add = binaryTD.get(leftCategory, rightCategory);

      FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose());
      binaryTD.put(leftCategory, rightCategory, add.add(W_df));

      FloatMatrix deltaDown;
      if (useFloatTensors) {
        FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector);
        binaryFloatTensorTD.put(
            leftCategory,
            rightCategory,
            binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df));
        deltaDown =
            computeFloatTensorDeltaDown(
                deltaFull,
                leftVector,
                rightVector,
                getBinaryTransform(leftCategory, rightCategory),
                getBinaryFloatTensor(leftCategory, rightCategory));
      } else {
        deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
      }

      FloatMatrix leftDerivative = activationFunction.apply(leftVector);
      FloatMatrix rightDerivative = activationFunction.apply(rightVector);
      FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1));
      FloatMatrix rightDeltaDown =
          deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1));
      backpropDerivativesAndError(
          tree.children().get(0),
          binaryTD,
          binaryCD,
          binaryFloatTensorTD,
          unaryCD,
          wordVectorD,
          leftDerivative.mul(leftDeltaDown));
      backpropDerivativesAndError(
          tree.children().get(1),
          binaryTD,
          binaryCD,
          binaryFloatTensorTD,
          unaryCD,
          wordVectorD,
          rightDerivative.mul(rightDeltaDown));
    }
  }
Beispiel #6
0
  private void init() {

    if (rng == null) rng = new MersenneTwister(123);

    MultiDimensionalSet<String, String> binaryProductions = MultiDimensionalSet.hashSet();
    if (simplifiedModel) {
      binaryProductions.add("", "");
    } else {
      // TODO
      // figure out what binary productions we have in these trees
      // Note: the current sentiment training data does not actually
      // have any constituent labels
      throw new UnsupportedOperationException("Not yet implemented");
    }

    Set<String> unaryProductions = new HashSet<>();

    if (simplifiedModel) {
      unaryProductions.add("");
    } else {
      // TODO
      // figure out what unary productions we have in these trees (preterminals only, after the
      // collapsing)
      throw new UnsupportedOperationException("Not yet implemented");
    }

    identity = FloatMatrix.eye(numHidden);

    binaryTransform = MultiDimensionalMap.newTreeBackedMap();
    binaryFloatTensors = MultiDimensionalMap.newTreeBackedMap();
    binaryClassification = MultiDimensionalMap.newTreeBackedMap();

    // When making a flat model (no semantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix
    for (Pair<String, String> binary : binaryProductions) {
      String left = basicCategory(binary.getFirst());
      String right = basicCategory(binary.getSecond());
      if (binaryTransform.contains(left, right)) {
        continue;
      }

      binaryTransform.put(left, right, randomTransformMatrix());
      if (useFloatTensors) {
        binaryFloatTensors.put(left, right, randomBinaryFloatTensor());
      }

      if (!combineClassification) {
        binaryClassification.put(left, right, randomClassificationMatrix());
      }
    }

    numBinaryMatrices = binaryTransform.size();
    binaryTransformSize = numHidden * (2 * numHidden + 1);

    if (useFloatTensors) {
      binaryFloatTensorSize = numHidden * numHidden * numHidden * 4;
    } else {
      binaryFloatTensorSize = 0;
    }

    binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1);

    unaryClassification = new TreeMap<>();

    // When making a flat model (no semantic untying) the
    // basicCategory function will return the same basic category for
    // all labels, so all entries will map to the same matrix

    for (String unary : unaryProductions) {
      unary = basicCategory(unary);
      if (unaryClassification.containsKey(unary)) {
        continue;
      }
      unaryClassification.put(unary, randomClassificationMatrix());
    }

    binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1);

    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numOuts * (numHidden + 1);

    featureVectors.put(UNKNOWN_FEATURE, randomWordVector());
    numUnaryMatrices = unaryClassification.size();
    unaryClassificationSize = numOuts * (numHidden + 1);
    classWeights = new HashMap<>();
  }
Beispiel #7
0
  private void backpropDerivativesAndError(
      Tree tree,
      MultiDimensionalMap<String, String, INDArray> binaryTD,
      MultiDimensionalMap<String, String, INDArray> binaryCD,
      MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD,
      Map<String, INDArray> unaryCD,
      Map<String, INDArray> wordVectorD,
      INDArray deltaUp) {
    if (tree.isLeaf()) {
      return;
    }

    INDArray currentVector = tree.vector();
    String category = tree.label();
    category = basicCategory(category);

    // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class
    INDArray goldLabel = Nd4j.create(numOuts, 1);
    int goldClass = tree.goldLabel();
    if (goldClass >= 0) {
      assert goldClass <= numOuts
          : "Tried adding a label that was >= to the number of configured outputs "
              + numOuts
              + " with label "
              + goldClass;
      goldLabel.putScalar(goldClass, 1.0f);
    }

    Double nodeWeight = classWeights.get(goldClass);
    if (nodeWeight == null) nodeWeight = 1.0;
    INDArray predictions = tree.prediction();

    // If this is an unlabeled class, transform deltaClass to 0.  We could
    // make this more efficient by eliminating various of the below
    // calculations, but this would be the easiest way to handle the
    // unlabeled class
    INDArray deltaClass = null;
    if (predictions.data().dataType() == DataBuffer.Type.DOUBLE) {
      deltaClass =
          goldClass >= 0
              ? Nd4j.getBlasWrapper().scal(nodeWeight, predictions.sub(goldLabel))
              : Nd4j.create(predictions.rows(), predictions.columns());

    } else {
      deltaClass =
          goldClass >= 0
              ? Nd4j.getBlasWrapper()
                  .scal((float) nodeWeight.doubleValue(), predictions.sub(goldLabel))
              : Nd4j.create(predictions.rows(), predictions.columns());
    }
    INDArray localCD = deltaClass.mmul(Nd4j.appendBias(currentVector).transpose());

    double error =
        -(Transforms.log(predictions).muli(goldLabel).sum(Integer.MAX_VALUE).getDouble(0));
    error = error * nodeWeight;
    tree.setError(error);

    if (tree.isPreTerminal()) { // below us is a word vector
      unaryCD.put(category, unaryCD.get(category).add(localCD));

      String word = tree.children().get(0).label();
      word = getVocabWord(word);

      INDArray currentVectorDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(
                  Nd4j.getOpFactory().createTransform(activationFunction, currentVector));
      INDArray deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass);
      deltaFromClass =
          deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative);
      INDArray deltaFull = deltaFromClass.add(deltaUp);
      INDArray wordVector = wordVectorD.get(word);
      wordVectorD.put(word, wordVector.add(deltaFull));

    } else {
      // Otherwise, this must be a binary node
      String leftCategory = basicCategory(tree.children().get(0).label());
      String rightCategory = basicCategory(tree.children().get(1).label());
      if (combineClassification) {
        unaryCD.put("", unaryCD.get("").add(localCD));
      } else {
        binaryCD.put(
            leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD));
      }

      INDArray currentVectorDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(
                  Nd4j.getOpFactory().createTransform(activationFunction, currentVector));
      INDArray deltaFromClass =
          getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);

      INDArray mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1));
      deltaFromClass = mult.muli(currentVectorDerivative);
      INDArray deltaFull = deltaFromClass.add(deltaUp);

      INDArray leftVector = tree.children().get(0).vector();
      INDArray rightVector = tree.children().get(1).vector();

      INDArray childrenVector = Nd4j.appendBias(leftVector, rightVector);

      // deltaFull 50 x 1, childrenVector: 50 x 2
      INDArray add = binaryTD.get(leftCategory, rightCategory);

      INDArray W_df = deltaFromClass.mmul(childrenVector.transpose());
      binaryTD.put(leftCategory, rightCategory, add.add(W_df));

      INDArray deltaDown;
      if (useDoubleTensors) {
        INDArray Wt_df = getINDArrayGradient(deltaFull, leftVector, rightVector);
        binaryINDArrayTD.put(
            leftCategory,
            rightCategory,
            binaryINDArrayTD.get(leftCategory, rightCategory).add(Wt_df));
        deltaDown =
            computeINDArrayDeltaDown(
                deltaFull,
                leftVector,
                rightVector,
                getBinaryTransform(leftCategory, rightCategory),
                getBinaryINDArray(leftCategory, rightCategory));
      } else {
        deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
      }

      INDArray leftDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, leftVector));
      INDArray rightDerivative =
          Nd4j.getExecutioner()
              .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, rightVector));
      INDArray leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows()), interval(0, 1));
      INDArray rightDeltaDown =
          deltaDown.get(interval(deltaFull.rows(), deltaFull.rows() * 2), interval(0, 1));
      backpropDerivativesAndError(
          tree.children().get(0),
          binaryTD,
          binaryCD,
          binaryINDArrayTD,
          unaryCD,
          wordVectorD,
          leftDerivative.mul(leftDeltaDown));
      backpropDerivativesAndError(
          tree.children().get(1),
          binaryTD,
          binaryCD,
          binaryINDArrayTD,
          unaryCD,
          wordVectorD,
          rightDerivative.mul(rightDeltaDown));
    }
  }
Beispiel #8
0
  public INDArray getValueGradient(final List<Tree> trainingBatch) {

    // We use TreeMap for each of these so that they stay in a
    // canonical sorted order
    // TODO: factor out the initialization routines
    // binaryTD stands for Transform Derivatives
    final MultiDimensionalMap<String, String, INDArray> binaryTD =
        MultiDimensionalMap.newTreeBackedMap();
    // the derivatives of the INd4j for the binary nodes
    final MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD =
        MultiDimensionalMap.newTreeBackedMap();
    // binaryCD stands for Classification Derivatives
    final MultiDimensionalMap<String, String, INDArray> binaryCD =
        MultiDimensionalMap.newTreeBackedMap();

    // unaryCD stands for Classification Derivatives
    final Map<String, INDArray> unaryCD = new TreeMap<>();

    // word vector derivatives
    final Map<String, INDArray> wordVectorD = new TreeMap<>();

    for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryTransform.entrySet()) {
      int numRows = entry.getValue().rows();
      int numCols = entry.getValue().columns();

      binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols));
    }

    if (!combineClassification) {
      for (MultiDimensionalMap.Entry<String, String, INDArray> entry :
          binaryClassification.entrySet()) {
        int numRows = entry.getValue().rows();
        int numCols = entry.getValue().columns();

        binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols));
      }
    }

    if (useDoubleTensors) {
      for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryTensors.entrySet()) {
        int numRows = entry.getValue().size(1);
        int numCols = entry.getValue().size(2);
        int numSlices = entry.getValue().slices();

        binaryINDArrayTD.put(
            entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols, numSlices));
      }
    }

    for (Map.Entry<String, INDArray> entry : unaryClassification.entrySet()) {
      int numRows = entry.getValue().rows();
      int numCols = entry.getValue().columns();
      unaryCD.put(entry.getKey(), Nd4j.create(numRows, numCols));
    }

    for (String s : vocabCache.words()) {
      INDArray vector = featureVectors.vector(s);
      int numRows = vector.rows();
      int numCols = vector.columns();
      wordVectorD.put(s, Nd4j.create(numRows, numCols));
    }

    final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>();
    // if(!forwardPropTrees.isEmpty())
    Parallelization.iterateInParallel(
        trainingBatch,
        new Parallelization.RunnableWithParams<Tree>() {

          public void run(Tree currentItem, Object[] args) {
            Tree trainingTree = new Tree(currentItem);
            trainingTree.connect(new ArrayList<>(currentItem.children()));
            // this will attach the error vectors and the node vectors
            // to each node in the tree
            forwardPropagateTree(trainingTree);
            forwardPropTrees.add(trainingTree);
          }
        },
        rnTnActorSystem);

    // TODO: we may find a big speedup by separating the derivatives and then summing
    final AtomicDouble error = new AtomicDouble(0);
    if (!forwardPropTrees.isEmpty())
      Parallelization.iterateInParallel(
          forwardPropTrees,
          new Parallelization.RunnableWithParams<Tree>() {

            public void run(Tree currentItem, Object[] args) {
              backpropDerivativesAndError(
                  currentItem, binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD);
              error.addAndGet(currentItem.errorSum());
            }
          },
          new Parallelization.RunnableWithParams<Tree>() {

            public void run(Tree currentItem, Object[] args) {}
          },
          rnTnActorSystem,
          new Object[] {binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD});

    // scale the error by the number of sentences so that the
    // regularization isn't drowned out for large training batchs
    double scale =
        trainingBatch == null || trainingBatch.isEmpty() ? 1.0f : (1.0f / trainingBatch.size());
    value = error.doubleValue() * scale;

    value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix);
    value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification);
    value +=
        scaleAndRegularizeINDArray(binaryINDArrayTD, binaryTensors, scale, regTransformINDArray);
    value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification);
    value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector);

    INDArray derivative =
        Nd4j.toFlattened(
            getNumParameters(),
            binaryTD.values().iterator(),
            binaryCD.values().iterator(),
            binaryINDArrayTD.values().iterator(),
            unaryCD.values().iterator(),
            wordVectorD.values().iterator());

    if (derivative.length() != numParameters)
      throw new IllegalStateException(
          "Gradient has wrong number of parameters "
              + derivative.length()
              + " should have been "
              + numParameters);

    if (paramAdaGrad == null) paramAdaGrad = new AdaGrad(1, derivative.columns());

    derivative = paramAdaGrad.getGradient(derivative, 0);

    return derivative;
  }