// fill value & derivative public void calculate(double[] theta) { dvModel.vectorToParams(theta); double localValue = 0.0; double[] localDerivative = new double[theta.length]; TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG, binaryW_dfsB; binaryW_dfsG = TwoDimensionalMap.treeMap(); binaryW_dfsB = TwoDimensionalMap.treeMap(); TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG, binaryScoreDerivativesB; binaryScoreDerivativesG = TwoDimensionalMap.treeMap(); binaryScoreDerivativesB = TwoDimensionalMap.treeMap(); Map<String, SimpleMatrix> unaryW_dfsG, unaryW_dfsB; unaryW_dfsG = new TreeMap<>(); unaryW_dfsB = new TreeMap<>(); Map<String, SimpleMatrix> unaryScoreDerivativesG, unaryScoreDerivativesB; unaryScoreDerivativesG = new TreeMap<>(); unaryScoreDerivativesB = new TreeMap<>(); Map<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<>(); Map<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<>(); for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : dvModel.binaryTransform) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); binaryW_dfsG.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols)); binaryW_dfsB.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols)); binaryScoreDerivativesG.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows)); binaryScoreDerivativesB.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows)); } for (Map.Entry<String, SimpleMatrix> entry : dvModel.unaryTransform.entrySet()) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); unaryW_dfsG.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); unaryW_dfsB.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); unaryScoreDerivativesG.put(entry.getKey(), new SimpleMatrix(1, numRows)); unaryScoreDerivativesB.put(entry.getKey(), new SimpleMatrix(1, numRows)); } if (op.trainOptions.trainWordVectors) { for (Map.Entry<String, SimpleMatrix> entry : dvModel.wordVectors.entrySet()) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); wordVectorDerivativesG.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); wordVectorDerivativesB.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); } } // Some optimization methods prints out a line without an end, so our // debugging statements are misaligned Timing scoreTiming = new Timing(); scoreTiming.doing("Scoring trees"); int treeNum = 0; MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new ScoringProcessor()); for (Tree tree : trainingBatch) { wrapper.put(tree); } wrapper.join(); scoreTiming.done(); while (wrapper.peek()) { Pair<DeepTree, DeepTree> result = wrapper.poll(); DeepTree goldTree = result.first; DeepTree bestTree = result.second; StringBuilder treeDebugLine = new StringBuilder(); Formatter formatter = new Formatter(treeDebugLine); boolean isDone = (Math.abs(bestTree.getScore() - goldTree.getScore()) <= 0.00001 || goldTree.getScore() > bestTree.getScore()); String done = isDone ? "done" : ""; formatter.format( "Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s", treeNum, bestTree.getScore(), goldTree.getScore(), done); System.err.println(treeDebugLine.toString()); if (!isDone) { // if the gold tree is better than the best hypothesis tree by // a large enough margin, then the score difference will be 0 // and we ignore the tree double valueDelta = bestTree.getScore() - goldTree.getScore(); // double valueDelta = Math.max(0.0, - scoreGold + bestScore); localValue += valueDelta; // get the context words for this tree - should be the same // for either goldTree or bestTree List<String> words = getContextWords(goldTree.getTree()); // The derivatives affected by this tree are only based on the // nodes present in this tree, eg not all matrix derivatives // will be affected by this tree backpropDerivative( goldTree.getTree(), words, goldTree.getVectors(), binaryW_dfsG, unaryW_dfsG, binaryScoreDerivativesG, unaryScoreDerivativesG, wordVectorDerivativesG); backpropDerivative( bestTree.getTree(), words, bestTree.getVectors(), binaryW_dfsB, unaryW_dfsB, binaryScoreDerivativesB, unaryScoreDerivativesB, wordVectorDerivativesB); } ++treeNum; } double[] localDerivativeGood; double[] localDerivativeB; if (op.trainOptions.trainWordVectors) { localDerivativeGood = NeuralUtils.paramsToVector( theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator(), wordVectorDerivativesG.values().iterator()); localDerivativeB = NeuralUtils.paramsToVector( theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator(), wordVectorDerivativesB.values().iterator()); } else { localDerivativeGood = NeuralUtils.paramsToVector( theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator()); localDerivativeB = NeuralUtils.paramsToVector( theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator()); } // correct - highest for (int i = 0; i < localDerivativeGood.length; i++) { localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i]; } // TODO: this is where we would combine multiple costs if we had parallelized the calculation value = localValue; derivative = localDerivative; // normalizing by training batch size value = (1.0 / trainingBatch.size()) * value; ArrayMath.multiplyInPlace(derivative, (1.0 / trainingBatch.size())); // add regularization to cost: double[] currentParams = dvModel.paramsToVector(); double regCost = 0; for (double currentParam : currentParams) { regCost += currentParam * currentParam; } regCost = op.trainOptions.regCost * 0.5 * regCost; value += regCost; // add regularization to gradient ArrayMath.multiplyInPlace(currentParams, op.trainOptions.regCost); ArrayMath.pairwiseAddInPlace(derivative, currentParams); }
public void printParamInformation(int index) { int curIndex = 0; for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : binaryTransform) { if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) { System.err.println( "Index " + index + " is element " + (index - curIndex) + " of binaryTransform \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\""); return; } else { curIndex += entry.getValue().getNumElements(); } } for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : binaryClassification) { if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) { System.err.println( "Index " + index + " is element " + (index - curIndex) + " of binaryClassification \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\""); return; } else { curIndex += entry.getValue().getNumElements(); } } for (TwoDimensionalMap.Entry<String, String, SimpleTensor> entry : binaryTensors) { if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) { System.err.println( "Index " + index + " is element " + (index - curIndex) + " of binaryTensor \"" + entry.getFirstKey() + "," + entry.getSecondKey() + "\""); return; } else { curIndex += entry.getValue().getNumElements(); } } for (Map.Entry<String, SimpleMatrix> entry : unaryClassification.entrySet()) { if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) { System.err.println( "Index " + index + " is element " + (index - curIndex) + " of unaryClassification \"" + entry.getKey() + "\""); return; } else { curIndex += entry.getValue().getNumElements(); } } for (Map.Entry<String, SimpleMatrix> entry : wordVectors.entrySet()) { if (curIndex <= index && curIndex + entry.getValue().getNumElements() > index) { System.err.println( "Index " + index + " is element " + (index - curIndex) + " of wordVector \"" + entry.getKey() + "\""); return; } else { curIndex += entry.getValue().getNumElements(); } } System.err.println( "Index " + index + " is beyond the length of the parameters; total parameter space was " + totalParamSize()); }