float scaleAndRegularize( Map<String, FloatMatrix> derivatives, Map<String, FloatMatrix> currentMatrices, float scale, float regCost) { float cost = 0.0f; // the regularization cost for (Map.Entry<String, FloatMatrix> entry : currentMatrices.entrySet()) { FloatMatrix D = derivatives.get(entry.getKey()); D = SimpleBlas.scal(scale, D).add(SimpleBlas.scal(regCost, entry.getValue())); derivatives.put(entry.getKey(), D); cost += entry.getValue().mul(entry.getValue()).sum() * regCost / 2.0; } return cost; }
/** Returns matrices of the right size for either binary or unary (terminal) classification */ FloatMatrix randomClassificationMatrix() { // Leave the bias column with 0 values float range = 1.0f / (float) (Math.sqrt((float) numHidden)); FloatMatrix ret = FloatMatrix.zeros(numOuts, numHidden + 1); FloatMatrix insert = MatrixUtil.rand(numOuts, numHidden, -range, range, rng); ret.put(interval(0, numOuts), interval(0, numHidden), insert); return SimpleBlas.scal(scalingForInit, ret); }
/** * output the top level labels for each tree * * @param trees the trees to predict * @return the prediction labels for each tree */ public List<Integer> predict(List<Tree> trees) { List<Integer> ret = new ArrayList<>(); for (Tree t : trees) { forwardPropagateTree(t); ret.add(SimpleBlas.iamax(t.prediction())); } return ret; }
private FloatTensor getFloatTensorGradient( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) { int size = deltaFull.length; FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { Wt_df.setSlice( slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose())); } return Wt_df; }
FloatMatrix randomTransformMatrix() { FloatMatrix binary = new FloatMatrix(numHidden, numHidden * 2 + 1); // bias column values are initialized zero FloatMatrix block = randomTransformBlock(); binary.put(interval(0, block.rows), interval(0, block.columns), block); binary.put( interval(0, block.rows), interval(numHidden, numHidden + block.columns), randomTransformBlock()); return SimpleBlas.scal(scalingForInit, binary); }
private FloatMatrix computeFloatTensorDeltaDown( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector, FloatMatrix W, FloatTensor Wt) { FloatMatrix WTDelta = W.transpose().mmul(deltaFull); FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2)); int size = deltaFull.length; FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector); deltaFloatTensor = deltaFloatTensor.add( Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector)); } return deltaFloatTensor.add(WTDeltaNoBias); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, FloatMatrix> binaryTD, MultiDimensionalMap<String, String, FloatMatrix> binaryCD, MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD, Map<String, FloatMatrix> unaryCD, Map<String, FloatMatrix> wordVectorD, FloatMatrix deltaUp) { if (tree.isLeaf()) { return; } FloatMatrix currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class FloatMatrix goldLabel = new FloatMatrix(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { goldLabel.put(goldClass, 1.0f); } Float nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0f; FloatMatrix predictions = tree.prediction(); // If this is an unlabeled class, set deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class FloatMatrix deltaClass = goldClass >= 0 ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel)) : new FloatMatrix(predictions.rows, predictions.columns); FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose()); float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum()); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector); FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); wordVectorD.put(word, wordVectorD.get(word).add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector); FloatMatrix deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 FloatMatrix add = binaryTD.get(leftCategory, rightCategory); FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); FloatMatrix deltaDown; if (useFloatTensors) { FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector); binaryFloatTensorTD.put( leftCategory, rightCategory, binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeFloatTensorDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryFloatTensor(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } FloatMatrix leftDerivative = activationFunction.apply(leftVector); FloatMatrix rightDerivative = activationFunction.apply(rightVector); FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1)); FloatMatrix rightDeltaDown = deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }
private int getLabel(FloatDataSet data) { return SimpleBlas.iamax(data.getSecond()); }
public int outcome() { if (this.numExamples() > 1) throw new IllegalStateException("Unable to derive outcome for dataset greater than one row"); return SimpleBlas.iamax(getSecond()); }