public FloatMatrix getValueGradient(int iterations) { // We use TreeMap for each of these so that they stay in a // canonical sorted order // TODO: factor out the initialization routines // binaryTD stands for Transform Derivatives final MultiDimensionalMap<String, String, FloatMatrix> binaryTD = MultiDimensionalMap.newTreeBackedMap(); // the derivatives of the FloatTensors for the binary nodes final MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD = MultiDimensionalMap.newTreeBackedMap(); // binaryCD stands for Classification Derivatives final MultiDimensionalMap<String, String, FloatMatrix> binaryCD = MultiDimensionalMap.newTreeBackedMap(); // unaryCD stands for Classification Derivatives final Map<String, FloatMatrix> unaryCD = new TreeMap<>(); // word vector derivatives final Map<String, FloatMatrix> wordVectorD = new TreeMap<>(); for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry : binaryTransform.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols)); } if (!combineClassification) { for (MultiDimensionalMap.Entry<String, String, FloatMatrix> entry : binaryClassification.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), new FloatMatrix(numRows, numCols)); } } if (useFloatTensors) { for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry : binaryFloatTensors.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns; int numSlices = entry.getValue().slices(); binaryFloatTensorTD.put( entry.getFirstKey(), entry.getSecondKey(), new FloatTensor(numRows, numCols, numSlices)); } } for (Map.Entry<String, FloatMatrix> entry : unaryClassification.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; unaryCD.put(entry.getKey(), new FloatMatrix(numRows, numCols)); } for (Map.Entry<String, FloatMatrix> entry : featureVectors.entrySet()) { int numRows = entry.getValue().rows; int numCols = entry.getValue().columns; wordVectorD.put(entry.getKey(), new FloatMatrix(numRows, numCols)); } final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>(); Parallelization.iterateInParallel( trainingTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { Tree trainingTree = new Tree(currentItem); trainingTree.connect(new ArrayList<>(currentItem.children())); // this will attach the error vectors and the node vectors // to each node in the tree forwardPropagateTree(trainingTree); forwardPropTrees.add(trainingTree); } }, rnTnActorSystem); // TODO: we may find a big speedup by separating the derivatives and then summing final AtomicDouble error = new AtomicDouble(0); Parallelization.iterateInParallel( forwardPropTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { backpropDerivativesAndError( currentItem, binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD); error.addAndGet(currentItem.errorSum()); } }, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) {} }, rnTnActorSystem, new Object[] {binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD}); // scale the error by the number of sentences so that the // regularization isn't drowned out for large training batchs float scale = (1.0f / trainingTrees.size()); value = error.floatValue() * scale; value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix); value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification); value += scaleAndRegularizeFloatTensor( binaryFloatTensorTD, binaryFloatTensors, scale, regTransformFloatTensor); value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification); value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector); FloatMatrix derivative = MatrixUtil.toFlattenedFloat( getNumParameters(), binaryTD.values().iterator(), binaryCD.values().iterator(), binaryFloatTensorTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()); if (paramAdaGrad == null) paramAdaGrad = new AdaGradFloat(1, derivative.columns); derivative.muli(paramAdaGrad.getLearningRates(derivative)); return derivative; }
public INDArray getValueGradient(final List<Tree> trainingBatch) { // We use TreeMap for each of these so that they stay in a // canonical sorted order // TODO: factor out the initialization routines // binaryTD stands for Transform Derivatives final MultiDimensionalMap<String, String, INDArray> binaryTD = MultiDimensionalMap.newTreeBackedMap(); // the derivatives of the INd4j for the binary nodes final MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD = MultiDimensionalMap.newTreeBackedMap(); // binaryCD stands for Classification Derivatives final MultiDimensionalMap<String, String, INDArray> binaryCD = MultiDimensionalMap.newTreeBackedMap(); // unaryCD stands for Classification Derivatives final Map<String, INDArray> unaryCD = new TreeMap<>(); // word vector derivatives final Map<String, INDArray> wordVectorD = new TreeMap<>(); for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryTransform.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns(); binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols)); } if (!combineClassification) { for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryClassification.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns(); binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols)); } } if (useDoubleTensors) { for (MultiDimensionalMap.Entry<String, String, INDArray> entry : binaryTensors.entrySet()) { int numRows = entry.getValue().size(1); int numCols = entry.getValue().size(2); int numSlices = entry.getValue().slices(); binaryINDArrayTD.put( entry.getFirstKey(), entry.getSecondKey(), Nd4j.create(numRows, numCols, numSlices)); } } for (Map.Entry<String, INDArray> entry : unaryClassification.entrySet()) { int numRows = entry.getValue().rows(); int numCols = entry.getValue().columns(); unaryCD.put(entry.getKey(), Nd4j.create(numRows, numCols)); } for (String s : vocabCache.words()) { INDArray vector = featureVectors.vector(s); int numRows = vector.rows(); int numCols = vector.columns(); wordVectorD.put(s, Nd4j.create(numRows, numCols)); } final List<Tree> forwardPropTrees = new CopyOnWriteArrayList<>(); // if(!forwardPropTrees.isEmpty()) Parallelization.iterateInParallel( trainingBatch, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { Tree trainingTree = new Tree(currentItem); trainingTree.connect(new ArrayList<>(currentItem.children())); // this will attach the error vectors and the node vectors // to each node in the tree forwardPropagateTree(trainingTree); forwardPropTrees.add(trainingTree); } }, rnTnActorSystem); // TODO: we may find a big speedup by separating the derivatives and then summing final AtomicDouble error = new AtomicDouble(0); if (!forwardPropTrees.isEmpty()) Parallelization.iterateInParallel( forwardPropTrees, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) { backpropDerivativesAndError( currentItem, binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD); error.addAndGet(currentItem.errorSum()); } }, new Parallelization.RunnableWithParams<Tree>() { public void run(Tree currentItem, Object[] args) {} }, rnTnActorSystem, new Object[] {binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD}); // scale the error by the number of sentences so that the // regularization isn't drowned out for large training batchs double scale = trainingBatch == null || trainingBatch.isEmpty() ? 1.0f : (1.0f / trainingBatch.size()); value = error.doubleValue() * scale; value += scaleAndRegularize(binaryTD, binaryTransform, scale, regTransformMatrix); value += scaleAndRegularize(binaryCD, binaryClassification, scale, regClassification); value += scaleAndRegularizeINDArray(binaryINDArrayTD, binaryTensors, scale, regTransformINDArray); value += scaleAndRegularize(unaryCD, unaryClassification, scale, regClassification); value += scaleAndRegularize(wordVectorD, featureVectors, scale, regWordVector); INDArray derivative = Nd4j.toFlattened( getNumParameters(), binaryTD.values().iterator(), binaryCD.values().iterator(), binaryINDArrayTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()); if (derivative.length() != numParameters) throw new IllegalStateException( "Gradient has wrong number of parameters " + derivative.length() + " should have been " + numParameters); if (paramAdaGrad == null) paramAdaGrad = new AdaGrad(1, derivative.columns()); derivative = paramAdaGrad.getGradient(derivative, 0); return derivative; }