/** * This is the method to call for assigning labels and node vectors to the Tree. After calling * this, each of the non-leaf nodes will have the node vector and the predictions of their classes * assigned to that subtree's node. */ public void forwardPropagateTree(Tree tree) { FloatMatrix nodeVector; FloatMatrix classification; if (tree.isLeaf()) { // We do nothing for the leaves. The preterminals will // calculate the classification for this word/tag. In fact, the // recursion should not have gotten here (unless there are // degenerate trees of just one leaf) throw new AssertionError("We should not have reached leaves in forwardPropagate"); } else if (tree.isPreTerminal()) { classification = getUnaryClassification(tree.label()); String word = tree.children().get(0).value(); FloatMatrix wordVector = getFeatureVector(word); if (wordVector == null) { wordVector = featureVectors.get(UNKNOWN_FEATURE); } nodeVector = activationFunction.apply(wordVector); } else if (tree.children().size() == 1) { throw new AssertionError( "Non-preterminal nodes of size 1 should have already been collapsed"); } else if (tree.children().size() == 2) { Tree left = tree.firstChild(), right = tree.lastChild(); forwardPropagateTree(left); forwardPropagateTree(right); String leftCategory = tree.children().get(0).label(); String rightCategory = tree.children().get(1).label(); FloatMatrix W = getBinaryTransform(leftCategory, rightCategory); classification = getBinaryClassification(leftCategory, rightCategory); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); if (useFloatTensors) { FloatTensor floatT = getBinaryFloatTensor(leftCategory, rightCategory); FloatMatrix floatTensorIn = FloatMatrix.concatHorizontally(leftVector, rightVector); FloatMatrix floatTensorOut = floatT.bilinearProducts(floatTensorIn); nodeVector = activationFunction.apply(W.mmul(childrenVector).add(floatTensorOut)); } else nodeVector = activationFunction.apply(W.mmul(childrenVector)); } else { throw new AssertionError("Tree not correctly binarized"); } FloatMatrix inputWithBias = appendBias(nodeVector); FloatMatrix preAct = classification.mmul(inputWithBias); FloatMatrix predictions = outputActivation.apply(preAct); tree.setPrediction(predictions); tree.setVector(nodeVector); }
private FloatTensor getFloatTensorGradient( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) { int size = deltaFull.length; FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { Wt_df.setSlice( slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose())); } return Wt_df; }
private FloatMatrix computeFloatTensorDeltaDown( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector, FloatMatrix W, FloatTensor Wt) { FloatMatrix WTDelta = W.transpose().mmul(deltaFull); FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2)); int size = deltaFull.length; FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector); deltaFloatTensor = deltaFloatTensor.add( Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector)); } return deltaFloatTensor.add(WTDeltaNoBias); }
/** * The feature to add, and the example/row number * * @param feature the feature vector to add * @param example the number of the example to append to */ public void addFeatureVector(FloatMatrix feature, int example) { getFirst().putRow(example, FloatMatrix.concatHorizontally(getFirst().getRow(example), feature)); }
/** * Adds a feature for each example on to the current feature vector * * @param toAdd the feature vector to add */ public void addFeatureVector(FloatMatrix toAdd) { setFirst(FloatMatrix.concatHorizontally(getFirst(), toAdd)); }