double scaleAndRegularize( MultiDimensionalMap<String, String, INDArray> derivatives, MultiDimensionalMap<String, String, INDArray> currentMatrices, double scale, double regCost) { double cost = 0.0f; // the regularization cost for (MultiDimensionalMap.Entry<String, String, INDArray> entry : currentMatrices.entrySet()) { INDArray D = derivatives.get(entry.getFirstKey(), entry.getSecondKey()); if (D.data().dataType() == DataBuffer.Type.DOUBLE) D = Nd4j.getBlasWrapper() .scal(scale, D) .addi(Nd4j.getBlasWrapper().scal(regCost, entry.getValue())); else D = Nd4j.getBlasWrapper() .scal((float) scale, D) .addi(Nd4j.getBlasWrapper().scal((float) regCost, entry.getValue())); derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D); cost += entry.getValue().mul(entry.getValue()).sum(Integer.MAX_VALUE).getDouble(0) * regCost / 2.0; } return cost; }
public INDArray getParameters() { return Nd4j.toFlattened( getNumParameters(), binaryTransform.values().iterator(), binaryClassification.values().iterator(), binaryTensors.values().iterator(), unaryClassification.values().iterator(), featureVectors.vectors()); }
public FloatMatrix getParameters() { return MatrixUtil.toFlattenedFloat( getNumParameters(), binaryTransform.values().iterator(), binaryClassification.values().iterator(), binaryFloatTensors.values().iterator(), unaryClassification.values().iterator(), featureVectors.values().iterator()); }
public void setParameters(FloatMatrix params) { setParams( params, binaryTransform.values().iterator(), binaryClassification.values().iterator(), binaryFloatTensors.values().iterator(), unaryClassification.values().iterator(), featureVectors.values().iterator()); }
float scaleAndRegularizeFloatTensor( MultiDimensionalMap<String, String, FloatTensor> derivatives, MultiDimensionalMap<String, String, FloatTensor> currentMatrices, float scale, float regCost) { float cost = 0.0f; // the regularization cost for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry : currentMatrices.entrySet()) { FloatTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey()); D = D.scale(scale).add(entry.getValue().scale(regCost)); derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D); cost += entry.getValue().mul(entry.getValue()).sum() * regCost / 2.0; } return cost; }
/** * Set the parameters for the model * * @param params */ public void setParams(INDArray params) { if (params.length() != getNumParameters()) throw new IllegalStateException( "Unable to set parameters of length " + params.length() + " must be of length " + numParameters); Nd4j.setParams( params, binaryTransform.values().iterator(), binaryClassification.values().iterator(), binaryTensors.values().iterator(), unaryClassification.values().iterator(), featureVectors.vectors()); computeGradientAndScore(); }
double scaleAndRegularizeINDArray( MultiDimensionalMap<String, String, INDArray> derivatives, MultiDimensionalMap<String, String, INDArray> currentMatrices, double scale, double regCost) { double cost = 0.0f; // the regularization cost for (MultiDimensionalMap.Entry<String, String, INDArray> entry : currentMatrices.entrySet()) { INDArray D = derivatives.get(entry.getFirstKey(), entry.getSecondKey()); D = D.muli(scale).add(entry.getValue().muli(regCost)); derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D); cost += entry.getValue().mul(entry.getValue()).sum(Integer.MAX_VALUE).getDouble(0) * regCost / 2.0f; } return cost; }
public FloatMatrix getBinaryClassification(String left, String right) { if (combineClassification) { return unaryClassification.get(""); } else { left = basicCategory(left); right = basicCategory(right); return binaryClassification.get(left, right); } }
public int getNumParameters() { int totalSize; // binaryFloatTensorSize was set to 0 if useFloatTensors=false totalSize = numBinaryMatrices * (binaryTransform.size() + binaryClassificationSize) + binaryFloatTensorSize; totalSize += numUnaryMatrices * unaryClassification.size(); totalSize += featureVectors.size() * numHidden; return totalSize; }
public int getNumParameters() { if (numParameters < 0) { int totalSize = 0; List<Iterator<INDArray>> list = Arrays.asList( binaryTransform.values().iterator(), binaryClassification.values().iterator(), binaryTensors.values().iterator(), unaryClassification.values().iterator(), featureVectors.vectors()); for (Iterator<INDArray> iter : list) { while (iter.hasNext()) totalSize += iter.next().length(); } numParameters = totalSize; } return numParameters; }
public FloatMatrix getWForNode(Tree node) { if (node.children().size() == 2) { String leftLabel = node.children().get(0).value(); String leftBasic = basicCategory(leftLabel); String rightLabel = node.children().get(1).value(); String rightBasic = basicCategory(rightLabel); return binaryTransform.get(leftBasic, rightBasic); } else if (node.children().size() == 1) { throw new AssertionError("No unary transform matrices, only unary classification"); } else { throw new AssertionError("Unexpected tree children size of " + node.children().size()); } }
public INDArray getINDArrayForNode(Tree node) { if (!useDoubleTensors) { throw new AssertionError("Not using tensors"); } if (node.children().size() == 2) { String leftLabel = node.children().get(0).value(); String leftBasic = basicCategory(leftLabel); String rightLabel = node.children().get(1).value(); String rightBasic = basicCategory(rightLabel); return binaryTensors.get(leftBasic, rightBasic); } else if (node.children().size() == 1) { throw new AssertionError("No unary transform matrices, only unary classification"); } else { throw new AssertionError("Unexpected tree children size of " + node.children().size()); } }
public FloatMatrix getClassWForNode(Tree node) { if (combineClassification) { return unaryClassification.get(""); } else if (node.children().size() == 2) { String leftLabel = node.children().get(0).value(); String leftBasic = basicCategory(leftLabel); String rightLabel = node.children().get(1).value(); String rightBasic = basicCategory(rightLabel); return binaryClassification.get(leftBasic, rightBasic); } else if (node.children().size() == 1) { String unaryLabel = node.children().get(0).value(); String unaryBasic = basicCategory(unaryLabel); return unaryClassification.get(unaryBasic); } else { throw new AssertionError("Unexpected tree children size of " + node.children().size()); } }
public FloatTensor getBinaryFloatTensor(String left, String right) { left = basicCategory(left); right = basicCategory(right); return binaryFloatTensors.get(left, right); }
public FloatMatrix getBinaryTransform(String left, String right) { left = basicCategory(left); right = basicCategory(right); return binaryTransform.get(left, right); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, INDArray> binaryTD, MultiDimensionalMap<String, String, INDArray> binaryCD, MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD, Map<String, INDArray> unaryCD, Map<String, INDArray> wordVectorD, INDArray deltaUp) { if (tree.isLeaf()) { return; } INDArray currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class INDArray goldLabel = Nd4j.create(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { assert goldClass <= numOuts : "Tried adding a label that was >= to the number of configured outputs " + numOuts + " with label " + goldClass; goldLabel.putScalar(goldClass, 1.0f); } Double nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0; INDArray predictions = tree.prediction(); // If this is an unlabeled class, transform deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class INDArray deltaClass = null; if (predictions.data().dataType() == DataBuffer.Type.DOUBLE) { deltaClass = goldClass >= 0 ? Nd4j.getBlasWrapper().scal(nodeWeight, predictions.sub(goldLabel)) : Nd4j.create(predictions.rows(), predictions.columns()); } else { deltaClass = goldClass >= 0 ? Nd4j.getBlasWrapper() .scal((float) nodeWeight.doubleValue(), predictions.sub(goldLabel)) : Nd4j.create(predictions.rows(), predictions.columns()); } INDArray localCD = deltaClass.mmul(Nd4j.appendBias(currentVector).transpose()); double error = -(Transforms.log(predictions).muli(goldLabel).sum(Integer.MAX_VALUE).getDouble(0)); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); INDArray currentVectorDerivative = Nd4j.getExecutioner() .execAndReturn( Nd4j.getOpFactory().createTransform(activationFunction, currentVector)); INDArray deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); INDArray deltaFull = deltaFromClass.add(deltaUp); INDArray wordVector = wordVectorD.get(word); wordVectorD.put(word, wordVector.add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } INDArray currentVectorDerivative = Nd4j.getExecutioner() .execAndReturn( Nd4j.getOpFactory().createTransform(activationFunction, currentVector)); INDArray deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); INDArray mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); INDArray deltaFull = deltaFromClass.add(deltaUp); INDArray leftVector = tree.children().get(0).vector(); INDArray rightVector = tree.children().get(1).vector(); INDArray childrenVector = Nd4j.appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 INDArray add = binaryTD.get(leftCategory, rightCategory); INDArray W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); INDArray deltaDown; if (useDoubleTensors) { INDArray Wt_df = getINDArrayGradient(deltaFull, leftVector, rightVector); binaryINDArrayTD.put( leftCategory, rightCategory, binaryINDArrayTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeINDArrayDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryINDArray(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } INDArray leftDerivative = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, leftVector)); INDArray rightDerivative = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, rightVector)); INDArray leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows()), interval(0, 1)); INDArray rightDeltaDown = deltaDown.get(interval(deltaFull.rows(), deltaFull.rows() * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }
public INDArray getBinaryTransform(String left, String right) { left = basicCategory(left); right = basicCategory(right); return binaryTransform.get(left, right); }
public INDArray getBinaryINDArray(String left, String right) { left = basicCategory(left); right = basicCategory(right); return binaryTensors.get(left, right); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, FloatMatrix> binaryTD, MultiDimensionalMap<String, String, FloatMatrix> binaryCD, MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD, Map<String, FloatMatrix> unaryCD, Map<String, FloatMatrix> wordVectorD, FloatMatrix deltaUp) { if (tree.isLeaf()) { return; } FloatMatrix currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class FloatMatrix goldLabel = new FloatMatrix(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { goldLabel.put(goldClass, 1.0f); } Float nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0f; FloatMatrix predictions = tree.prediction(); // If this is an unlabeled class, set deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class FloatMatrix deltaClass = goldClass >= 0 ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel)) : new FloatMatrix(predictions.rows, predictions.columns); FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose()); float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum()); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector); FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); wordVectorD.put(word, wordVectorD.get(word).add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector); FloatMatrix deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 FloatMatrix add = binaryTD.get(leftCategory, rightCategory); FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); FloatMatrix deltaDown; if (useFloatTensors) { FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector); binaryFloatTensorTD.put( leftCategory, rightCategory, binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeFloatTensorDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryFloatTensor(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } FloatMatrix leftDerivative = activationFunction.apply(leftVector); FloatMatrix rightDerivative = activationFunction.apply(rightVector); FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1)); FloatMatrix rightDeltaDown = deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }
private void init() { if (rng == null) rng = new MersenneTwister(123); MultiDimensionalSet<String, String> binaryProductions = MultiDimensionalSet.hashSet(); if (simplifiedModel) { binaryProductions.add("", ""); } else { // TODO // figure out what binary productions we have in these trees // Note: the current sentiment training data does not actually // have any constituent labels throw new UnsupportedOperationException("Not yet implemented"); } Set<String> unaryProductions = new HashSet<>(); if (simplifiedModel) { unaryProductions.add(""); } else { // TODO // figure out what unary productions we have in these trees (preterminals only, after the // collapsing) throw new UnsupportedOperationException("Not yet implemented"); } identity = FloatMatrix.eye(numHidden); binaryTransform = MultiDimensionalMap.newTreeBackedMap(); binaryFloatTensors = MultiDimensionalMap.newTreeBackedMap(); binaryClassification = MultiDimensionalMap.newTreeBackedMap(); // When making a flat model (no semantic untying) the // basicCategory function will return the same basic category for // all labels, so all entries will map to the same matrix for (Pair<String, String> binary : binaryProductions) { String left = basicCategory(binary.getFirst()); String right = basicCategory(binary.getSecond()); if (binaryTransform.contains(left, right)) { continue; } binaryTransform.put(left, right, randomTransformMatrix()); if (useFloatTensors) { binaryFloatTensors.put(left, right, randomBinaryFloatTensor()); } if (!combineClassification) { binaryClassification.put(left, right, randomClassificationMatrix()); } } numBinaryMatrices = binaryTransform.size(); binaryTransformSize = numHidden * (2 * numHidden + 1); if (useFloatTensors) { binaryFloatTensorSize = numHidden * numHidden * numHidden * 4; } else { binaryFloatTensorSize = 0; } binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1); unaryClassification = new TreeMap<>(); // When making a flat model (no semantic untying) the // basicCategory function will return the same basic category for // all labels, so all entries will map to the same matrix for (String unary : unaryProductions) { unary = basicCategory(unary); if (unaryClassification.containsKey(unary)) { continue; } unaryClassification.put(unary, randomClassificationMatrix()); } binaryClassificationSize = (combineClassification) ? 0 : numOuts * (numHidden + 1); numUnaryMatrices = unaryClassification.size(); unaryClassificationSize = numOuts * (numHidden + 1); featureVectors.put(UNKNOWN_FEATURE, randomWordVector()); numUnaryMatrices = unaryClassification.size(); unaryClassificationSize = numOuts * (numHidden + 1); classWeights = new HashMap<>(); }
public FloatMatrix getValueGradient(int iterations) { // We use TreeMap for each of these so that they stay in a // canonical sorted order // TODO: factor out the initialization routines // binaryTD stands for Transform Derivatives final MultiDimensionalMap<String, String, FloatMatrix> binaryTD = MultiDimensionalMap.newTreeBackedMap(); // the derivatives of the FloatTensors for the binary nodes final MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD = MultiDimensionalMap.newTreeBackedMap(); // binaryCD stands for Classification Derivatives final MultiDimensionalMap<String, String, FloatMatrix> binaryCD = MultiDimensionalMap.newTreeBackedMap(); // unaryCD stands for Classification Derivatives final Map<String, FloatMatrix> unaryCD = new TreeMap<>(); // word vector derivatives final Map<String, FloatMatrix> wordVectorD = new TreeMap<>(); for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry : binaryTransform.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols)); } if (!combineClassification) { for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry : binaryClassification.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols)); } } if (useFloatTensors) { for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry : binaryFloatTensors.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns; int numSlices = entry.getValue().slices(); binaryFloatTensorTD.put( entry.getFirstKey(), entry.getSecondKey(), new FloatTensor(numRows, numCols, numSlices)); } } for (Map.Entry<String, FloatMatrix> entry : unaryClassification.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; unaryCD.put(entry.getKey(), new FloatMatrix(numRows, numCols)); } for (Map.Entry<String, FloatMatrix> entry : featureVectors.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; wordVectorD.put(entry.getKey(), new FloatMatrix(numRows, numCols)); } final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>(); Parallelization.iterateInParallel( trainingTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { Tree trainingTree = new Tree(currentItem); trainingTree.connect(new ArrayList<>(currentItem.children())); // this will attach the error vectors and the node vectors // to each node in the tree forwardPropagateTree(trainingTree); forwardPropTrees.add(trainingTree); } }, rnTnActorSystem); // TODO: we may find a big speedup by separating the derivatives and then summing final AtomicDouble error = new AtomicDouble(0); Parallelization.iterateInParallel( forwardPropTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { backpropDerivativesAndError( currentItem, binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD); error.addAndGet(currentItem.errorSum()); } }, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) {} }, rnTnActorSystem, new Object[] {binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD}); // scale the error by the number of sentences so that the // regularization isn't drowned out for large training batchs float scale = (1.0f / trainingTrees.size()); value = error.floatValue() * scale; value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix); value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification); value += scaleAndRegularizeFloatTensor( binaryFloatTensorTD, binaryFloatTensors, scale, regTransformFloatTensor); value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification); value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector); FloatMatrix derivative = MatrixUtil.toFlattenedFloat( getNumParameters(), binaryTD.values().iterator(), binaryCD.values().iterator(), binaryFloatTensorTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()); if (paramAdaGrad == null) paramAdaGrad = new AdaGradFloat(1, derivative.columns); derivative.muli(paramAdaGrad.getLearningRates(derivative)); return derivative; }
public INDArray getValueGradient(final List<Tree> trainingBatch) { // We use TreeMap for each of these so that they stay in a // canonical sorted order // TODO: factor out the initialization routines // binaryTD stands for Transform Derivatives final MultiDimensionalMap<String, String, INDArray> binaryTD = MultiDimensionalMap.newTreeBackedMap(); // the derivatives of the INd4j for the binary nodes final MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD = MultiDimensionalMap.newTreeBackedMap(); // binaryCD stands for Classification Derivatives final MultiDimensionalMap<String, String, INDArray> binaryCD = MultiDimensionalMap.newTreeBackedMap(); // unaryCD stands for Classification Derivatives final Map<String, INDArray> unaryCD = new TreeMap<>(); // word vector derivatives final Map<String, INDArray> wordVectorD = new TreeMap<>(); for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryTransform.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns(); binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols)); } if (!combineClassification) { for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryClassification.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns(); binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols)); } } if (useDoubleTensors) { for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryTensors.entrySet()) { int numRows = entry.getValue().size(1); int numCols = entry.getValue().size(2); int numSlices = entry.getValue().slices(); binaryINDArrayTD.put( entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols, numSlices)); } } for (Map.Entry<String, INDArray> entry : unaryClassification.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns(); unaryCD.put(entry.getKey(), Nd4j.create(numRows, numCols)); } for (String s : vocabCache.words()) { INDArray vector = featureVectors.vector(s); int numRows = vector.rows(); int numCols = vector.columns(); wordVectorD.put(s, Nd4j.create(numRows, numCols)); } final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>(); // if(!forwardPropTrees.isEmpty()) Parallelization.iterateInParallel( trainingBatch, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { Tree trainingTree = new Tree(currentItem); trainingTree.connect(new ArrayList<>(currentItem.children())); // this will attach the error vectors and the node vectors // to each node in the tree forwardPropagateTree(trainingTree); forwardPropTrees.add(trainingTree); } }, rnTnActorSystem); // TODO: we may find a big speedup by separating the derivatives and then summing final AtomicDouble error = new AtomicDouble(0); if (!forwardPropTrees.isEmpty()) Parallelization.iterateInParallel( forwardPropTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { backpropDerivativesAndError( currentItem, binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD); error.addAndGet(currentItem.errorSum()); } }, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) {} }, rnTnActorSystem, new Object[] {binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD}); // scale the error by the number of sentences so that the // regularization isn't drowned out for large training batchs double scale = trainingBatch == null || trainingBatch.isEmpty() ? 1.0f : (1.0f / trainingBatch.size()); value = error.doubleValue() * scale; value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix); value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification); value += scaleAndRegularizeINDArray(binaryINDArrayTD, binaryTensors, scale, regTransformINDArray); value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification); value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector); INDArray derivative = Nd4j.toFlattened( getNumParameters(), binaryTD.values().iterator(), binaryCD.values().iterator(), binaryINDArrayTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()); if (derivative.length() != numParameters) throw new IllegalStateException( "Gradient has wrong number of parameters " + derivative.length() + " should have been " + numParameters); if (paramAdaGrad == null) paramAdaGrad = new AdaGrad(1, derivative.columns()); derivative = paramAdaGrad.getGradient(derivative, 0); return derivative; }