/** * This is the method to call for assigning labels and node vectors to the Tree. After calling * this, each of the non-leaf nodes will have the node vector and the predictions of their classes * assigned to that subtree's node. */ public void forwardPropagateTree(Tree tree) { FloatMatrix nodeVector; FloatMatrix classification; if (tree.isLeaf()) { // We do nothing for the leaves. The preterminals will // calculate the classification for this word/tag. In fact, the // recursion should not have gotten here (unless there are // degenerate trees of just one leaf) throw new AssertionError("We should not have reached leaves in forwardPropagate"); } else if (tree.isPreTerminal()) { classification = getUnaryClassification(tree.label()); String word = tree.children().get(0).value(); FloatMatrix wordVector = getFeatureVector(word); if (wordVector == null) { wordVector = featureVectors.get(UNKNOWN_FEATURE); } nodeVector = activationFunction.apply(wordVector); } else if (tree.children().size() == 1) { throw new AssertionError( "Non-preterminal nodes of size 1 should have already been collapsed"); } else if (tree.children().size() == 2) { Tree left = tree.firstChild(), right = tree.lastChild(); forwardPropagateTree(left); forwardPropagateTree(right); String leftCategory = tree.children().get(0).label(); String rightCategory = tree.children().get(1).label(); FloatMatrix W = getBinaryTransform(leftCategory, rightCategory); classification = getBinaryClassification(leftCategory, rightCategory); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); if (useFloatTensors) { FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory); FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector); FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn); nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut)); } else nodeVector = activationFunction.apply(W.mmul(childrenVector)); } else { throw new AssertionError("Tree not correctly binarized"); } FloatMatrix inputWithBias = appendBias(nodeVector); FloatMatrix preAct = classification.mmul(inputWithBias); FloatMatrix predictions = outputActivation.apply(preAct); tree.setPrediction(predictions); tree.setVector(nodeVector); }
private FloatTensor getFloatTensorGradient( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) { int size = deltaFull.length; FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { Wt_df.setSlice( slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose())); } return Wt_df; }
float scaleAndRegularizeFloatTensor( MultiDimensionalMap<String, String, FloatTensor> derivatives, MultiDimensionalMap<String, String, FloatTensor> currentMatrices, float scale, float regCost) { float cost = 0.0f; // the regularization cost for (MultiDimensionalMap.Entry<String, String, FloatTensor> entry : currentMatrices.entrySet()) { FloatTensor D = derivatives.get(entry.getFirstKey(), entry.getSecondKey()); D = D.scale(scale).add(entry.getValue().scale(regCost)); derivatives.put(entry.getFirstKey(), entry.getSecondKey(), D); cost += entry.getValue().mul(entry.getValue()).sum() * regCost / 2.0; } return cost; }
private FloatMatrix computeFloatTensorDeltaDown( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector, FloatMatrix W, FloatTensor Wt) { FloatMatrix WTDelta = W.transpose().mmul(deltaFull); FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2)); int size = deltaFull.length; FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector); deltaFloatTensor = deltaFloatTensor.add( Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector)); } return deltaFloatTensor.add(WTDeltaNoBias); }
FloatTensor randomBinaryFloatTensor() { float range = 1.0f / (4.0f * numHidden); FloatTensor floatTensor = FloatTensor.rand(numHidden * 2, numHidden * 2, numHidden, -range, range, rng); return floatTensor.scale(scalingForInit); }