@Override public Object compute(FloatMatrix params, int flag) { x = params.getRange(0, rows * features); FloatMatrix theta = params.getRange(rows * features, params.length); x = x.reshape(rows, features); theta = theta.reshape(columns, features); if (flag == 1 || flag == 3) { FloatMatrix M = MatrixFunctions.pow(x.mmul(theta.transpose()).sub(y), 2); this.cost = M.mul(r).columnSums().rowSums().get(0) / 2; if (lambda != 0) { float cost1 = (lambda / 2) * (MatrixFunctions.pow(theta, 2).columnSums().rowSums().get(0) + MatrixFunctions.pow(x, 2).columnSums().rowSums().get(0)); this.cost += cost1; } } if (flag == 2 || flag == 3) { FloatMatrix xGrad = FloatMatrix.zeros(x.rows, x.columns); FloatMatrix thetaGrad = FloatMatrix.zeros(theta.rows, theta.columns); int[] indices; FloatMatrix thetaTemp; FloatMatrix xTemp; FloatMatrix yTemp; for (int i = 0; i < rows; i++) { indices = r.getRow(i).eq(1).findIndices(); if (indices.length == 0) continue; thetaTemp = theta.getRows(indices); yTemp = y.getRow(i).get(indices); xGrad.putRow(i, x.getRow(i).mmul(thetaTemp.transpose()).sub(yTemp).mmul(thetaTemp)); } xGrad = xGrad.add(x.mmul(lambda)); for (int i = 0; i < columns; i++) { indices = r.getColumn(i).eq(1).findIndices(); if (indices.length == 0) continue; xTemp = x.getRows(indices); yTemp = y.getColumn(i).get(indices); thetaGrad.putRow( i, xTemp.mmul(theta.getRow(i).transpose()).sub(yTemp).transpose().mmul(xTemp)); } thetaGrad = thetaGrad.add(theta.mmul(lambda)); this.gradient = MatrixUtil.merge(xGrad.data, thetaGrad.data); } return flag == 1 ? cost : gradient; }
private FloatTensor getFloatTensorGradient( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector) { int size = deltaFull.length; FloatTensor Wt_df = new FloatTensor(size * 2, size * 2, size); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { Wt_df.setSlice( slice, SimpleBlas.scal(deltaFull.get(slice), fullVector).mmul(fullVector.transpose())); } return Wt_df; }
private FloatMatrix computeFloatTensorDeltaDown( FloatMatrix deltaFull, FloatMatrix leftVector, FloatMatrix rightVector, FloatMatrix W, FloatTensor Wt) { FloatMatrix WTDelta = W.transpose().mmul(deltaFull); FloatMatrix WTDeltaNoBias = WTDelta.get(interval(0, 1), interval(0, deltaFull.rows * 2)); int size = deltaFull.length; FloatMatrix deltaFloatTensor = new FloatMatrix(size * 2, 1); FloatMatrix fullVector = FloatMatrix.concatHorizontally(leftVector, rightVector); for (int slice = 0; slice < size; ++slice) { FloatMatrix scaledFullVector = SimpleBlas.scal(deltaFull.get(slice), fullVector); deltaFloatTensor = deltaFloatTensor.add( Wt.getSlice(slice).add(Wt.getSlice(slice).transpose()).mmul(scaledFullVector)); } return deltaFloatTensor.add(WTDeltaNoBias); }
private void backpropDerivativesAndError( Tree tree, MultiDimensionalMap<String, String, FloatMatrix> binaryTD, MultiDimensionalMap<String, String, FloatMatrix> binaryCD, MultiDimensionalMap<String, String, FloatTensor> binaryFloatTensorTD, Map<String, FloatMatrix> unaryCD, Map<String, FloatMatrix> wordVectorD, FloatMatrix deltaUp) { if (tree.isLeaf()) { return; } FloatMatrix currentVector = tree.vector(); String category = tree.label(); category = basicCategory(category); // Build a vector that looks like 0,0,1,0,0 with an indicator for the correct class FloatMatrix goldLabel = new FloatMatrix(numOuts, 1); int goldClass = tree.goldLabel(); if (goldClass >= 0) { goldLabel.put(goldClass, 1.0f); } Float nodeWeight = classWeights.get(goldClass); if (nodeWeight == null) nodeWeight = 1.0f; FloatMatrix predictions = tree.prediction(); // If this is an unlabeled class, set deltaClass to 0. We could // make this more efficient by eliminating various of the below // calculations, but this would be the easiest way to handle the // unlabeled class FloatMatrix deltaClass = goldClass >= 0 ? SimpleBlas.scal(nodeWeight, predictions.sub(goldLabel)) : new FloatMatrix(predictions.rows, predictions.columns); FloatMatrix localCD = deltaClass.mmul(appendBias(currentVector).transpose()); float error = -(MatrixFunctions.log(predictions).muli(goldLabel).sum()); error = error * nodeWeight; tree.setError(error); if (tree.isPreTerminal()) { // below us is a word vector unaryCD.put(category, unaryCD.get(category).add(localCD)); String word = tree.children().get(0).label(); word = getVocabWord(word); FloatMatrix currentVectorDerivative = activationFunction.apply(currentVector); FloatMatrix deltaFromClass = getUnaryClassification(category).transpose().mmul(deltaClass); deltaFromClass = deltaFromClass.get(interval(0, numHidden), interval(0, 1)).mul(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); wordVectorD.put(word, wordVectorD.get(word).add(deltaFull)); } else { // Otherwise, this must be a binary node String leftCategory = basicCategory(tree.children().get(0).label()); String rightCategory = basicCategory(tree.children().get(1).label()); if (combineClassification) { unaryCD.put("", unaryCD.get("").add(localCD)); } else { binaryCD.put( leftCategory, rightCategory, binaryCD.get(leftCategory, rightCategory).add(localCD)); } FloatMatrix currentVectorDerivative = activationFunction.applyDerivative(currentVector); FloatMatrix deltaFromClass = getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass); FloatMatrix mult = deltaFromClass.get(interval(0, numHidden), interval(0, 1)); deltaFromClass = mult.muli(currentVectorDerivative); FloatMatrix deltaFull = deltaFromClass.add(deltaUp); FloatMatrix leftVector = tree.children().get(0).vector(); FloatMatrix rightVector = tree.children().get(1).vector(); FloatMatrix childrenVector = appendBias(leftVector, rightVector); // deltaFull 50 x 1, childrenVector: 50 x 2 FloatMatrix add = binaryTD.get(leftCategory, rightCategory); FloatMatrix W_df = deltaFromClass.mmul(childrenVector.transpose()); binaryTD.put(leftCategory, rightCategory, add.add(W_df)); FloatMatrix deltaDown; if (useFloatTensors) { FloatTensor Wt_df = getFloatTensorGradient(deltaFull, leftVector, rightVector); binaryFloatTensorTD.put( leftCategory, rightCategory, binaryFloatTensorTD.get(leftCategory, rightCategory).add(Wt_df)); deltaDown = computeFloatTensorDeltaDown( deltaFull, leftVector, rightVector, getBinaryTransform(leftCategory, rightCategory), getBinaryFloatTensor(leftCategory, rightCategory)); } else { deltaDown = getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull); } FloatMatrix leftDerivative = activationFunction.apply(leftVector); FloatMatrix rightDerivative = activationFunction.apply(rightVector); FloatMatrix leftDeltaDown = deltaDown.get(interval(0, deltaFull.rows), interval(0, 1)); FloatMatrix rightDeltaDown = deltaDown.get(interval(deltaFull.rows, deltaFull.rows * 2), interval(0, 1)); backpropDerivativesAndError( tree.children().get(0), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown)); backpropDerivativesAndError( tree.children().get(1), binaryTD, binaryCD, binaryFloatTensorTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown)); } }