/** * Given single matrices and sets of options, create the corresponding SentimentModel. Useful for * creating a Java version of a model trained in some other manner, such as using the original * Matlab code. */ static SentimentModel modelFromMatrices( SimpleMatrix W, SimpleMatrix Wcat, SimpleTensor Wt, Map<String, SimpleMatrix> wordVectors, RNNOptions op) { if (!op.combineClassification || !op.simplifiedModel) { throw new IllegalArgumentException( "Can only create a model using this method if combineClassification and simplifiedModel are turned on"); } TwoDimensionalMap<String, String, SimpleMatrix> binaryTransform = TwoDimensionalMap.treeMap(); binaryTransform.put("", "", W); TwoDimensionalMap<String, String, SimpleTensor> binaryTensors = TwoDimensionalMap.treeMap(); binaryTensors.put("", "", Wt); TwoDimensionalMap<String, String, SimpleMatrix> binaryClassification = TwoDimensionalMap.treeMap(); Map<String, SimpleMatrix> unaryClassification = Generics.newTreeMap(); unaryClassification.put("", Wcat); return new SentimentModel( binaryTransform, binaryTensors, binaryClassification, unaryClassification, wordVectors, op); }
public void mergeAcronymCache(CorefCluster to, CorefCluster from) { TwoDimensionalSet<Integer, Integer> replacements = TwoDimensionalSet.hashSet(); for (Integer first : acronymCache.firstKeySet()) { for (Integer second : acronymCache.get(first).keySet()) { if (acronymCache.get(first, second)) { Integer other = null; if (first == from.clusterID) { other = second; } else if (second == from.clusterID) { other = first; } if (other != null && other != to.clusterID) { int cid1 = Math.min(other, to.clusterID); int cid2 = Math.max(other, to.clusterID); replacements.add(cid1, cid2); } } } } for (Integer first : replacements.firstKeySet()) { for (Integer second : replacements.secondKeySet(first)) { acronymCache.put(first, second, true); } } }
public void backpropDerivative( Tree tree, List<String> words, IdentityHashMap<Tree, SimpleMatrix> nodeVectors, TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfs, Map<String, SimpleMatrix> unaryW_dfs, TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivatives, Map<String, SimpleMatrix> unaryScoreDerivatives, Map<String, SimpleMatrix> wordVectorDerivatives, SimpleMatrix deltaUp) { if (tree.isLeaf()) { return; } if (tree.isPreTerminal()) { if (op.trainOptions.trainWordVectors) { String word = tree.children()[0].label().value(); word = dvModel.getVocabWord(word); // SimpleMatrix currentVector = nodeVectors.get(tree); // SimpleMatrix currentVectorDerivative = // nonlinearityVectorToDerivative(currentVector); // SimpleMatrix derivative = deltaUp.elementMult(currentVectorDerivative); SimpleMatrix derivative = deltaUp; wordVectorDerivatives.put(word, wordVectorDerivatives.get(word).plus(derivative)); } return; } SimpleMatrix currentVector = nodeVectors.get(tree); SimpleMatrix currentVectorDerivative = NeuralUtils.elementwiseApplyTanhDerivative(currentVector); SimpleMatrix scoreW = dvModel.getScoreWForNode(tree); currentVectorDerivative = currentVectorDerivative.elementMult(scoreW.transpose()); // the delta that is used at the current nodes SimpleMatrix deltaCurrent = deltaUp.plus(currentVectorDerivative); SimpleMatrix W = dvModel.getWForNode(tree); SimpleMatrix WTdelta = W.transpose().mult(deltaCurrent); if (tree.children().length == 2) { // TODO: RS: Change to the nice "getWForNode" setup? String leftLabel = dvModel.basicCategory(tree.children()[0].label().value()); String rightLabel = dvModel.basicCategory(tree.children()[1].label().value()); binaryScoreDerivatives.put( leftLabel, rightLabel, binaryScoreDerivatives.get(leftLabel, rightLabel).plus(currentVector.transpose())); SimpleMatrix leftVector = nodeVectors.get(tree.children()[0]); SimpleMatrix rightVector = nodeVectors.get(tree.children()[1]); SimpleMatrix childrenVector = NeuralUtils.concatenateWithBias(leftVector, rightVector); if (op.trainOptions.useContextWords) { childrenVector = concatenateContextWords(childrenVector, tree.getSpan(), words); } SimpleMatrix W_df = deltaCurrent.mult(childrenVector.transpose()); binaryW_dfs.put(leftLabel, rightLabel, binaryW_dfs.get(leftLabel, rightLabel).plus(W_df)); // and then recurse SimpleMatrix leftDerivative = NeuralUtils.elementwiseApplyTanhDerivative(leftVector); SimpleMatrix rightDerivative = NeuralUtils.elementwiseApplyTanhDerivative(rightVector); SimpleMatrix leftWTDelta = WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1); SimpleMatrix rightWTDelta = WTdelta.extractMatrix(deltaCurrent.numRows(), deltaCurrent.numRows() * 2, 0, 1); backpropDerivative( tree.children()[0], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, leftDerivative.elementMult(leftWTDelta)); backpropDerivative( tree.children()[1], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, rightDerivative.elementMult(rightWTDelta)); } else if (tree.children().length == 1) { String childLabel = dvModel.basicCategory(tree.children()[0].label().value()); unaryScoreDerivatives.put( childLabel, unaryScoreDerivatives.get(childLabel).plus(currentVector.transpose())); SimpleMatrix childVector = nodeVectors.get(tree.children()[0]); SimpleMatrix childVectorWithBias = NeuralUtils.concatenateWithBias(childVector); if (op.trainOptions.useContextWords) { childVectorWithBias = concatenateContextWords(childVectorWithBias, tree.getSpan(), words); } SimpleMatrix W_df = deltaCurrent.mult(childVectorWithBias.transpose()); // System.out.println("unary backprop derivative for " + childLabel); // System.out.println("Old transform:"); // System.out.println(unaryW_dfs.get(childLabel)); // System.out.println(" Delta:"); // System.out.println(W_df.scale(scale)); unaryW_dfs.put(childLabel, unaryW_dfs.get(childLabel).plus(W_df)); // and then recurse SimpleMatrix childDerivative = NeuralUtils.elementwiseApplyTanhDerivative(childVector); // SimpleMatrix childDerivative = childVector; SimpleMatrix childWTDelta = WTdelta.extractMatrix(0, deltaCurrent.numRows(), 0, 1); backpropDerivative( tree.children()[0], words, nodeVectors, binaryW_dfs, unaryW_dfs, binaryScoreDerivatives, unaryScoreDerivatives, wordVectorDerivatives, childDerivative.elementMult(childWTDelta)); } }
// fill value & derivative public void calculate(double[] theta) { dvModel.vectorToParams(theta); double localValue = 0.0; double[] localDerivative = new double[theta.length]; TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG, binaryW_dfsB; binaryW_dfsG = TwoDimensionalMap.treeMap(); binaryW_dfsB = TwoDimensionalMap.treeMap(); TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG, binaryScoreDerivativesB; binaryScoreDerivativesG = TwoDimensionalMap.treeMap(); binaryScoreDerivativesB = TwoDimensionalMap.treeMap(); Map<String, SimpleMatrix> unaryW_dfsG, unaryW_dfsB; unaryW_dfsG = new TreeMap<>(); unaryW_dfsB = new TreeMap<>(); Map<String, SimpleMatrix> unaryScoreDerivativesG, unaryScoreDerivativesB; unaryScoreDerivativesG = new TreeMap<>(); unaryScoreDerivativesB = new TreeMap<>(); Map<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<>(); Map<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<>(); for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : dvModel.binaryTransform) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); binaryW_dfsG.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols)); binaryW_dfsB.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols)); binaryScoreDerivativesG.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows)); binaryScoreDerivativesB.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows)); } for (Map.Entry<String, SimpleMatrix> entry : dvModel.unaryTransform.entrySet()) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); unaryW_dfsG.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); unaryW_dfsB.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); unaryScoreDerivativesG.put(entry.getKey(), new SimpleMatrix(1, numRows)); unaryScoreDerivativesB.put(entry.getKey(), new SimpleMatrix(1, numRows)); } if (op.trainOptions.trainWordVectors) { for (Map.Entry<String, SimpleMatrix> entry : dvModel.wordVectors.entrySet()) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); wordVectorDerivativesG.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); wordVectorDerivativesB.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); } } // Some optimization methods prints out a line without an end, so our // debugging statements are misaligned Timing scoreTiming = new Timing(); scoreTiming.doing("Scoring trees"); int treeNum = 0; MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new ScoringProcessor()); for (Tree tree : trainingBatch) { wrapper.put(tree); } wrapper.join(); scoreTiming.done(); while (wrapper.peek()) { Pair<DeepTree, DeepTree> result = wrapper.poll(); DeepTree goldTree = result.first; DeepTree bestTree = result.second; StringBuilder treeDebugLine = new StringBuilder(); Formatter formatter = new Formatter(treeDebugLine); boolean isDone = (Math.abs(bestTree.getScore() - goldTree.getScore()) <= 0.00001 || goldTree.getScore() > bestTree.getScore()); String done = isDone ? "done" : ""; formatter.format( "Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s", treeNum, bestTree.getScore(), goldTree.getScore(), done); System.err.println(treeDebugLine.toString()); if (!isDone) { // if the gold tree is better than the best hypothesis tree by // a large enough margin, then the score difference will be 0 // and we ignore the tree double valueDelta = bestTree.getScore() - goldTree.getScore(); // double valueDelta = Math.max(0.0, - scoreGold + bestScore); localValue += valueDelta; // get the context words for this tree - should be the same // for either goldTree or bestTree List<String> words = getContextWords(goldTree.getTree()); // The derivatives affected by this tree are only based on the // nodes present in this tree, eg not all matrix derivatives // will be affected by this tree backpropDerivative( goldTree.getTree(), words, goldTree.getVectors(), binaryW_dfsG, unaryW_dfsG, binaryScoreDerivativesG, unaryScoreDerivativesG, wordVectorDerivativesG); backpropDerivative( bestTree.getTree(), words, bestTree.getVectors(), binaryW_dfsB, unaryW_dfsB, binaryScoreDerivativesB, unaryScoreDerivativesB, wordVectorDerivativesB); } ++treeNum; } double[] localDerivativeGood; double[] localDerivativeB; if (op.trainOptions.trainWordVectors) { localDerivativeGood = NeuralUtils.paramsToVector( theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator(), wordVectorDerivativesG.values().iterator()); localDerivativeB = NeuralUtils.paramsToVector( theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator(), wordVectorDerivativesB.values().iterator()); } else { localDerivativeGood = NeuralUtils.paramsToVector( theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator()); localDerivativeB = NeuralUtils.paramsToVector( theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator()); } // correct - highest for (int i = 0; i < localDerivativeGood.length; i++) { localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i]; } // TODO: this is where we would combine multiple costs if we had parallelized the calculation value = localValue; derivative = localDerivative; // normalizing by training batch size value = (1.0 / trainingBatch.size()) * value; ArrayMath.multiplyInPlace(derivative, (1.0 / trainingBatch.size())); // add regularization to cost: double[] currentParams = dvModel.paramsToVector(); double regCost = 0; for (double currentParam : currentParams) { regCost += currentParam * currentParam; } regCost = op.trainOptions.regCost * 0.5 * regCost; value += regCost; // add regularization to gradient ArrayMath.multiplyInPlace(currentParams, op.trainOptions.regCost); ArrayMath.pairwiseAddInPlace(derivative, currentParams); }
/** The traditional way of initializing an empty model suitable for training. */ public SentimentModel(RNNOptions op, List<Tree> trainingTrees) { this.op = op; rand = new Random(op.randomSeed); if (op.randomWordVectors) { initRandomWordVectors(trainingTrees); } else { readWordVectors(); } if (op.numHid > 0) { this.numHid = op.numHid; } else { int size = 0; for (SimpleMatrix vector : wordVectors.values()) { size = vector.getNumElements(); break; } this.numHid = size; } TwoDimensionalSet<String, String> binaryProductions = TwoDimensionalSet.hashSet(); if (op.simplifiedModel) { binaryProductions.add("", ""); } else { // TODO // figure out what binary productions we have in these trees // Note: the current sentiment training data does not actually // have any constituent labels throw new UnsupportedOperationException("Not yet implemented"); } Set<String> unaryProductions = Generics.newHashSet(); if (op.simplifiedModel) { unaryProductions.add(""); } else { // TODO // figure out what unary productions we have in these trees (preterminals only, after the // collapsing) throw new UnsupportedOperationException("Not yet implemented"); } this.numClasses = op.numClasses; identity = SimpleMatrix.identity(numHid); binaryTransform = TwoDimensionalMap.treeMap(); binaryTensors = TwoDimensionalMap.treeMap(); binaryClassification = TwoDimensionalMap.treeMap(); // When making a flat model (no symantic untying) the // basicCategory function will return the same basic category for // all labels, so all entries will map to the same matrix for (Pair<String, String> binary : binaryProductions) { String left = basicCategory(binary.first); String right = basicCategory(binary.second); if (binaryTransform.contains(left, right)) { continue; } binaryTransform.put(left, right, randomTransformMatrix()); if (op.useTensors) { binaryTensors.put(left, right, randomBinaryTensor()); } if (!op.combineClassification) { binaryClassification.put(left, right, randomClassificationMatrix()); } } numBinaryMatrices = binaryTransform.size(); binaryTransformSize = numHid * (2 * numHid + 1); if (op.useTensors) { binaryTensorSize = numHid * numHid * numHid * 4; } else { binaryTensorSize = 0; } binaryClassificationSize = (op.combineClassification) ? 0 : numClasses * (numHid + 1); unaryClassification = Generics.newTreeMap(); // When making a flat model (no symantic untying) the // basicCategory function will return the same basic category for // all labels, so all entries will map to the same matrix for (String unary : unaryProductions) { unary = basicCategory(unary); if (unaryClassification.containsKey(unary)) { continue; } unaryClassification.put(unary, randomClassificationMatrix()); } numUnaryMatrices = unaryClassification.size(); unaryClassificationSize = numClasses * (numHid + 1); // System.err.println("Binary transform matrices:"); // System.err.println(binaryTransform); // System.err.println("Binary classification matrices:"); // System.err.println(binaryClassification); // System.err.println("Unary classification matrices:"); // System.err.println(unaryClassification); }