/** * This is the method to call for assigning labels and node vectors to the Tree. After calling * this, each of the non-leaf nodes will have the node vector and the predictions of their classes * assigned to that subtree's node. */ public void forwardPropagateTree(Tree tree) { INDArray nodeVector; INDArray classification; if (tree.isLeaf()) { // We do nothing for the leaves. The preterminals will // calculate the classification for this word/tag. In fact, the // recursion should not have gotten here (unless there are // degenerate trees of just one leaf) throw new AssertionError("We should not have reached leaves in forwardPropagate"); } else if (tree.isPreTerminal()) { classification = getUnaryClassification(tree.label()); String word = tree.children().get(0).value(); INDArray wordVector = getFeatureVector(word); if (wordVector == null) { wordVector = featureVectors.vector(Word2Vec.UNK); } nodeVector = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, wordVector)); } else if (tree.children().size() == 1) { throw new AssertionError( "Non-preterminal nodes of size 1 should have already been collapsed"); } else if (tree.children().size() == 2) { Tree left = tree.firstChild(), right = tree.lastChild(); forwardPropagateTree(left); forwardPropagateTree(right); String leftCategory = tree.children().get(0).label(); String rightCategory = tree.children().get(1).label(); INDArray W = getBinaryTransform(leftCategory, rightCategory); classification = getBinaryClassification(leftCategory, rightCategory); INDArray leftVector = tree.children().get(0).vector(); INDArray rightVector = tree.children().get(1).vector(); INDArray childrenVector = Nd4j.appendBias(leftVector, rightVector); if (useDoubleTensors) { INDArray doubleT = getBinaryINDArray(leftCategory, rightCategory); INDArray INDArrayIn = Nd4j.concat(0, leftVector, rightVector); INDArray INDArrayOut = Nd4j.bilinearProducts(doubleT, INDArrayIn); nodeVector = Nd4j.getExecutioner() .execAndReturn( Nd4j.getOpFactory() .createTransform( activationFunction, W.mmul(childrenVector).addi(INDArrayOut))); } else nodeVector = Nd4j.getExecutioner() .execAndReturn( Nd4j.getOpFactory() .createTransform(activationFunction, W.mmul(childrenVector))); } else { throw new AssertionError("Tree not correctly binarized"); } INDArray inputWithBias = Nd4j.appendBias(nodeVector); if (inputWithBias.rows() != classification.columns()) inputWithBias = inputWithBias.transpose(); INDArray preAct = classification.mmul(inputWithBias); INDArray predictions = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createTransform(outputActivation, preAct)); tree.setPrediction(predictions); tree.setVector(nodeVector); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, INDArray> binaryTD, MultiDimensionalMap<String, String, INDArray> binaryCD, MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD, Map<String, INDArray> unaryCD, Map<String, INDArray> wordVectorD, INDArray deltaUp) { if (tree.isLeaf()) { return; } INDArray currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class INDArray goldLabel = Nd4j.create(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { assert goldClass <= numOuts : "Tried adding a label that was >= to the number of configured outputs " + numOuts + " with label " + goldClass; goldLabel.putScalar(goldClass, 1.0f); } Double nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0; INDArray predictions = tree.prediction(); // If this is an unlabeled class, transform deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class INDArray deltaClass = null; if (predictions.data().dataType() == DataBuffer.Type.DOUBLE) { deltaClass = goldClass >= 0 ? Nd4j.getBlasWrapper().scal(nodeWeight, predictions.sub(goldLabel)) : Nd4j.create(predictions.rows(), predictions.columns()); } else { deltaClass = goldClass >= 0 ? Nd4j.getBlasWrapper() .scal((float) nodeWeight.doubleValue(), predictions.sub(goldLabel)) : Nd4j.create(predictions.rows(), predictions.columns()); } INDArray localCD = deltaClass.mmul(Nd4j.appendBias(currentVector).transpose()); double error = -(Transforms.log(predictions).muli(goldLabel).sum(Integer.MAX_VALUE).getDouble(0)); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); INDArray currentVectorDerivative = Nd4j.getExecutioner() .execAndReturn( Nd4j.getOpFactory().createTransform(activationFunction, currentVector)); INDArray deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); INDArray deltaFull = deltaFromClass.add(deltaUp); INDArray wordVector = wordVectorD.get(word); wordVectorD.put(word, wordVector.add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } INDArray currentVectorDerivative = Nd4j.getExecutioner() .execAndReturn( Nd4j.getOpFactory().createTransform(activationFunction, currentVector)); INDArray deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); INDArray mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); INDArray deltaFull = deltaFromClass.add(deltaUp); INDArray leftVector = tree.children().get(0).vector(); INDArray rightVector = tree.children().get(1).vector(); INDArray childrenVector = Nd4j.appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 INDArray add = binaryTD.get(leftCategory, rightCategory); INDArray W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); INDArray deltaDown; if (useDoubleTensors) { INDArray Wt_df = getINDArrayGradient(deltaFull, leftVector, rightVector); binaryINDArrayTD.put( leftCategory, rightCategory, binaryINDArrayTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeINDArrayDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryINDArray(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } INDArray leftDerivative = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, leftVector)); INDArray rightDerivative = Nd4j.getExecutioner() .execAndReturn(Nd4j.getOpFactory().createTransform(activationFunction, rightVector)); INDArray leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows()), interval(0, 1)); INDArray rightDeltaDown = deltaDown.get(interval(deltaFull.rows(), deltaFull.rows() * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }