/** * Returns a list of featured thresholded by minPrecision and sorted by their frequency of * occurrence. precision in this case, is defined as the frequency of majority label over total * frequency for that feature. * * @return list of high precision features. */ private List<F> getHighPrecisionFeatures( GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) { int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()]; for (int f = 0; f < dataset.numFeatures(); f++) Arrays.fill(feature2label[f], 0); int[][] data = dataset.data; int[] labels = dataset.labels; for (int d = 0; d < data.length; d++) { int label = labels[d]; // System.out.println("datum id:"+d+" label id: "+label); if (data[d] != null) { // System.out.println(" number of features:"+data[d].length); for (int n = 0; n < data[d].length; n++) { feature2label[data[d][n]][label]++; } } } Counter<F> feature2freq = new ClassicCounter<F>(); for (int f = 0; f < dataset.numFeatures(); f++) { int maxF = ArrayMath.max(feature2label[f]); int total = ArrayMath.sum(feature2label[f]); double precision = ((double) maxF) / total; F feature = dataset.featureIndex.get(f); if (precision >= minPrecision) { feature2freq.incrementCount(feature, total); } } if (feature2freq.size() > maxNumFeatures) { Counters.retainTop(feature2freq, maxNumFeatures); } // for(F feature : feature2freq.keySet()) // System.out.println(feature+" "+feature2freq.getCount(feature)); // System.exit(0); return Counters.toSortedList(feature2freq); }
private void computeDir(double[] dir, double[] fg) throws SQNMinimizer.SurpriseConvergence { System.arraycopy(fg, 0, dir, 0, fg.length); int mmm = sList.size(); double[] as = new double[mmm]; double[] factors = new double[dir.length]; for (int i = mmm - 1; i >= 0; i--) { as[i] = roList.get(i) * ArrayMath.innerProduct(sList.get(i), dir); plusAndConstMult(dir, yList.get(i), -as[i], dir); } // multiply by hessian approximation if (mmm != 0) { double[] y = yList.get(mmm - 1); double yDotY = ArrayMath.innerProduct(y, y); if (yDotY == 0) { throw new SQNMinimizer.SurpriseConvergence("Y is 0!!"); } double gamma = ArrayMath.innerProduct(sList.get(mmm - 1), y) / yDotY; ArrayMath.multiplyInPlace(dir, gamma); } else if (mmm == 0) { // This is a safety feature preventing too large of an initial step (see Yu Schraudolph // Gunter) ArrayMath.multiplyInPlace(dir, epsilon); } for (int i = 0; i < mmm; i++) { double b = roList.get(i) * ArrayMath.innerProduct(yList.get(i), dir); plusAndConstMult(dir, sList.get(i), cPosDef * as[i] - b, dir); plusAndConstMult(ArrayMath.pairwiseMultiply(yList.get(i), sList.get(i)), factors, 1, factors); } ArrayMath.multiplyInPlace(dir, -1); }
@Override public double[] derivativeAt(double[] flatCoefs) { double[] g = new double[model.flatIDsize()]; model.setCoefsFromFlat(flatCoefs); for (ModelSentence s : mSentences) { model.computeGradient(s, g); } ArrayMath.multiplyInPlace(g, -1); addL2regularizerGradient(g, flatCoefs); return g; }
private void computeEmpiricalStatistics(List<F> geFeatures) { // allocate memory to the containers and initialize them geFeature2EmpiricalDist = new double[geFeatures.size()][labeledDataset.labelIndex.size()]; geFeature2DatumList = new ArrayList<List<Integer>>(geFeatures.size()); Map<F, Integer> geFeatureMap = Generics.newHashMap(); Set<Integer> activeUnlabeledExamples = Generics.newHashSet(); for (int n = 0; n < geFeatures.size(); n++) { F geFeature = geFeatures.get(n); geFeature2DatumList.add(new ArrayList<Integer>()); Arrays.fill(geFeature2EmpiricalDist[n], 0); geFeatureMap.put(geFeature, n); } // compute the empirical label distribution for each GE feature for (int i = 0; i < labeledDataset.size(); i++) { Datum<L, F> datum = labeledDataset.getDatum(i); int labelID = labeledDataset.labelIndex.indexOf(datum.label()); for (F feature : datum.asFeatures()) { if (geFeatureMap.containsKey(feature)) { int geFnum = geFeatureMap.get(feature); geFeature2EmpiricalDist[geFnum][labelID]++; } } } // now normalize and smooth the label distribution for each feature. for (int n = 0; n < geFeatures.size(); n++) { ArrayMath.normalize(geFeature2EmpiricalDist[n]); smoothDistribution(geFeature2EmpiricalDist[n]); } // now build the inverted index from each GE feature to unlabeled datums that contain it. for (int i = 0; i < unlabeledDataList.size(); i++) { Datum<L, F> datum = unlabeledDataList.get(i); for (F feature : datum.asFeatures()) { if (geFeatureMap.containsKey(feature)) { int geFnum = geFeatureMap.get(feature); geFeature2DatumList.get(geFnum).add(i); activeUnlabeledExamples.add(i); } } } System.out.println("Number of active unlabeled examples:" + activeUnlabeledExamples.size()); }
/** * This method is for comparing the speed of the ArrayCoreMap family and HashMap. It tests random * access speed for a fixed number of accesses, i, for both a CoreLabel (can be swapped out for an * ArrayCoreMap) and a HashMap. Switching the order of testing (CoreLabel first or second) shows * that there's a slight advantage to the second loop, especially noticeable for small i - this is * due to some background java funky-ness, so we now run 50% each way. */ @SuppressWarnings({"StringEquality"}) public static void main(String[] args) { @SuppressWarnings("unchecked") Class<CoreAnnotation<String>>[] allKeys = new Class[] { CoreAnnotations.TextAnnotation.class, CoreAnnotations.LemmaAnnotation.class, CoreAnnotations.PartOfSpeechAnnotation.class, CoreAnnotations.ShapeAnnotation.class, CoreAnnotations.NamedEntityTagAnnotation.class, CoreAnnotations.DocIDAnnotation.class, CoreAnnotations.ValueAnnotation.class, CoreAnnotations.CategoryAnnotation.class, CoreAnnotations.BeforeAnnotation.class, CoreAnnotations.AfterAnnotation.class, CoreAnnotations.OriginalTextAnnotation.class, CoreAnnotations.ArgumentAnnotation.class, CoreAnnotations.MarkingAnnotation.class }; // how many iterations final int numBurnRounds = 10; final int numGoodRounds = 60; final int numIterations = 2000000; final int maxNumKeys = 12; double gains = 0.0; for (int numKeys = 1; numKeys <= maxNumKeys; numKeys++) { // the HashMap instance HashMap<String, String> hashmap = new HashMap<String, String>(numKeys); // the CoreMap instance CoreMap coremap = new ArrayCoreMap(numKeys); // the set of keys to use String[] hashKeys = new String[numKeys]; @SuppressWarnings("unchecked") Class<CoreAnnotation<String>>[] coreKeys = new Class[numKeys]; for (int key = 0; key < numKeys; key++) { hashKeys[key] = allKeys[key].getSimpleName(); coreKeys[key] = allKeys[key]; } // initialize with default values for (int i = 0; i < numKeys; i++) { coremap.set(coreKeys[i], String.valueOf(i)); hashmap.put(hashKeys[i], String.valueOf(i)); } assert coremap.size() == numKeys; assert hashmap.size() == numKeys; // for storing results double[] hashTimings = new double[numGoodRounds]; double[] coreTimings = new double[numGoodRounds]; final Random rand = new Random(0); boolean foundEqual = false; for (int round = 0; round < numBurnRounds + numGoodRounds; round++) { System.err.print("."); if (round % 2 == 0) { // test timings on hashmap first final long hashStart = System.nanoTime(); final int length = hashKeys.length; String last = null; for (int i = 0; i < numIterations; i++) { int key = rand.nextInt(length); String val = hashmap.get(hashKeys[key]); if (val == last) { foundEqual = true; } last = val; } if (round >= numBurnRounds) { hashTimings[round - numBurnRounds] = (System.nanoTime() - hashStart) / 1000000000.0; } } { // test timings on coremap final long coreStart = System.nanoTime(); final int length = coreKeys.length; String last = null; for (int i = 0; i < numIterations; i++) { int key = rand.nextInt(length); String val = coremap.get(coreKeys[key]); if (val == last) { foundEqual = true; } last = val; } if (round >= numBurnRounds) { coreTimings[round - numBurnRounds] = (System.nanoTime() - coreStart) / 1000000000.0; } } if (round % 2 == 1) { // test timings on hashmap second final long hashStart = System.nanoTime(); final int length = hashKeys.length; String last = null; for (int i = 0; i < numIterations; i++) { int key = rand.nextInt(length); String val = hashmap.get(hashKeys[key]); if (val == last) { foundEqual = true; } last = val; } if (round >= numBurnRounds) { hashTimings[round - numBurnRounds] = (System.nanoTime() - hashStart) / 1000000000.0; } } } if (foundEqual) { System.err.print(" [found equal]"); } System.err.println(); double hashMean = ArrayMath.mean(hashTimings); double coreMean = ArrayMath.mean(coreTimings); double percentDiff = (hashMean - coreMean) / hashMean * 100.0; NumberFormat nf = new DecimalFormat("0.00"); System.out.println("HashMap @ " + numKeys + " keys: " + hashMean + " secs/2million gets"); System.out.println( "CoreMap @ " + numKeys + " keys: " + coreMean + " secs/2million gets (" + nf.format(Math.abs(percentDiff)) + "% " + (percentDiff >= 0.0 ? "faster" : "slower") + ")"); gains += percentDiff; } System.out.println(); gains = gains / maxNumKeys; System.out.println( "Average: " + Math.abs(gains) + "% " + (gains >= 0.0 ? "faster" : "slower") + "."); }
public static void trainLibLinear() throws Exception { System.out.println("train SVMs."); Indexer<String> featureIndexer = IOUtils.readIndexer(WORD_INDEXER_FILE); List<SparseVector> trainData = SparseVector.readList(TRAIN_DATA_FILE); List<SparseVector> testData = SparseVector.readList(TEST_DATA_FILE); Collections.shuffle(trainData); Collections.shuffle(testData); // List[] lists = new List[] { trainData, testData }; // // for (int i = 0; i < lists.length; i++) { // List<SparseVector> list = lists[i]; // for (int j = 0; j < list.size(); j++) { // SparseVector sv = list.get(j); // if (sv.label() > 0) { // sv.setLabel(1); // } // } // } Problem prob = new Problem(); prob.l = trainData.size(); prob.n = featureIndexer.size() + 1; prob.y = new double[prob.l]; prob.x = new Feature[prob.l][]; prob.bias = -1; if (prob.bias >= 0) { prob.n++; } for (int i = 0; i < trainData.size(); i++) { SparseVector x = trainData.get(i); Feature[] input = new Feature[prob.bias > 0 ? x.size() + 1 : x.size()]; for (int j = 0; j < x.size(); j++) { int index = x.indexAtLoc(j) + 1; double value = x.valueAtLoc(j); assert index >= 0; input[j] = new FeatureNode(index + 1, value); } if (prob.bias >= 0) { input[input.length - 1] = new FeatureNode(prob.n, prob.bias); } prob.x[i] = input; prob.y[i] = x.label(); } Model model = Linear.train(prob, getSVMParamter()); CounterMap<Integer, Integer> cm = new CounterMap<Integer, Integer>(); for (int i = 0; i < testData.size(); i++) { SparseVector sv = testData.get(i); Feature[] input = new Feature[sv.size()]; for (int j = 0; j < sv.size(); j++) { int index = sv.indexAtLoc(j) + 1; double value = sv.valueAtLoc(j); input[j] = new FeatureNode(index + 1, value); } double[] dec_values = new double[model.getNrClass()]; Linear.predictValues(model, input, dec_values); int max_id = ArrayMath.argmax(dec_values); int pred = model.getLabels()[max_id]; int answer = sv.label(); cm.incrementCount(answer, pred, 1); } System.out.println(cm); model.save(new File(MODEL_FILE)); }
// fill value & derivative public void calculate(double[] theta) { dvModel.vectorToParams(theta); double localValue = 0.0; double[] localDerivative = new double[theta.length]; TwoDimensionalMap<String, String, SimpleMatrix> binaryW_dfsG, binaryW_dfsB; binaryW_dfsG = TwoDimensionalMap.treeMap(); binaryW_dfsB = TwoDimensionalMap.treeMap(); TwoDimensionalMap<String, String, SimpleMatrix> binaryScoreDerivativesG, binaryScoreDerivativesB; binaryScoreDerivativesG = TwoDimensionalMap.treeMap(); binaryScoreDerivativesB = TwoDimensionalMap.treeMap(); Map<String, SimpleMatrix> unaryW_dfsG, unaryW_dfsB; unaryW_dfsG = new TreeMap<>(); unaryW_dfsB = new TreeMap<>(); Map<String, SimpleMatrix> unaryScoreDerivativesG, unaryScoreDerivativesB; unaryScoreDerivativesG = new TreeMap<>(); unaryScoreDerivativesB = new TreeMap<>(); Map<String, SimpleMatrix> wordVectorDerivativesG = new TreeMap<>(); Map<String, SimpleMatrix> wordVectorDerivativesB = new TreeMap<>(); for (TwoDimensionalMap.Entry<String, String, SimpleMatrix> entry : dvModel.binaryTransform) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); binaryW_dfsG.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols)); binaryW_dfsB.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(numRows, numCols)); binaryScoreDerivativesG.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows)); binaryScoreDerivativesB.put( entry.getFirstKey(), entry.getSecondKey(), new SimpleMatrix(1, numRows)); } for (Map.Entry<String, SimpleMatrix> entry : dvModel.unaryTransform.entrySet()) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); unaryW_dfsG.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); unaryW_dfsB.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); unaryScoreDerivativesG.put(entry.getKey(), new SimpleMatrix(1, numRows)); unaryScoreDerivativesB.put(entry.getKey(), new SimpleMatrix(1, numRows)); } if (op.trainOptions.trainWordVectors) { for (Map.Entry<String, SimpleMatrix> entry : dvModel.wordVectors.entrySet()) { int numRows = entry.getValue().numRows(); int numCols = entry.getValue().numCols(); wordVectorDerivativesG.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); wordVectorDerivativesB.put(entry.getKey(), new SimpleMatrix(numRows, numCols)); } } // Some optimization methods prints out a line without an end, so our // debugging statements are misaligned Timing scoreTiming = new Timing(); scoreTiming.doing("Scoring trees"); int treeNum = 0; MulticoreWrapper<Tree, Pair<DeepTree, DeepTree>> wrapper = new MulticoreWrapper<>(op.trainOptions.trainingThreads, new ScoringProcessor()); for (Tree tree : trainingBatch) { wrapper.put(tree); } wrapper.join(); scoreTiming.done(); while (wrapper.peek()) { Pair<DeepTree, DeepTree> result = wrapper.poll(); DeepTree goldTree = result.first; DeepTree bestTree = result.second; StringBuilder treeDebugLine = new StringBuilder(); Formatter formatter = new Formatter(treeDebugLine); boolean isDone = (Math.abs(bestTree.getScore() - goldTree.getScore()) <= 0.00001 || goldTree.getScore() > bestTree.getScore()); String done = isDone ? "done" : ""; formatter.format( "Tree %6d Highest tree: %12.4f Correct tree: %12.4f %s", treeNum, bestTree.getScore(), goldTree.getScore(), done); System.err.println(treeDebugLine.toString()); if (!isDone) { // if the gold tree is better than the best hypothesis tree by // a large enough margin, then the score difference will be 0 // and we ignore the tree double valueDelta = bestTree.getScore() - goldTree.getScore(); // double valueDelta = Math.max(0.0, - scoreGold + bestScore); localValue += valueDelta; // get the context words for this tree - should be the same // for either goldTree or bestTree List<String> words = getContextWords(goldTree.getTree()); // The derivatives affected by this tree are only based on the // nodes present in this tree, eg not all matrix derivatives // will be affected by this tree backpropDerivative( goldTree.getTree(), words, goldTree.getVectors(), binaryW_dfsG, unaryW_dfsG, binaryScoreDerivativesG, unaryScoreDerivativesG, wordVectorDerivativesG); backpropDerivative( bestTree.getTree(), words, bestTree.getVectors(), binaryW_dfsB, unaryW_dfsB, binaryScoreDerivativesB, unaryScoreDerivativesB, wordVectorDerivativesB); } ++treeNum; } double[] localDerivativeGood; double[] localDerivativeB; if (op.trainOptions.trainWordVectors) { localDerivativeGood = NeuralUtils.paramsToVector( theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator(), wordVectorDerivativesG.values().iterator()); localDerivativeB = NeuralUtils.paramsToVector( theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator(), wordVectorDerivativesB.values().iterator()); } else { localDerivativeGood = NeuralUtils.paramsToVector( theta.length, binaryW_dfsG.valueIterator(), unaryW_dfsG.values().iterator(), binaryScoreDerivativesG.valueIterator(), unaryScoreDerivativesG.values().iterator()); localDerivativeB = NeuralUtils.paramsToVector( theta.length, binaryW_dfsB.valueIterator(), unaryW_dfsB.values().iterator(), binaryScoreDerivativesB.valueIterator(), unaryScoreDerivativesB.values().iterator()); } // correct - highest for (int i = 0; i < localDerivativeGood.length; i++) { localDerivative[i] = localDerivativeB[i] - localDerivativeGood[i]; } // TODO: this is where we would combine multiple costs if we had parallelized the calculation value = localValue; derivative = localDerivative; // normalizing by training batch size value = (1.0 / trainingBatch.size()) * value; ArrayMath.multiplyInPlace(derivative, (1.0 / trainingBatch.size())); // add regularization to cost: double[] currentParams = dvModel.paramsToVector(); double regCost = 0; for (double currentParam : currentParams) { regCost += currentParam * currentParam; } regCost = op.trainOptions.regCost * 0.5 * regCost; value += regCost; // add regularization to gradient ArrayMath.multiplyInPlace(currentParams, op.trainOptions.regCost); ArrayMath.pairwiseAddInPlace(derivative, currentParams); }
@Override public double[] minimize( Function f, double functionTolerance, double[] initial, int maxIterations) { if (!(f instanceof AbstractStochasticCachingDiffUpdateFunction)) { throw new UnsupportedOperationException(); } AbstractStochasticCachingDiffUpdateFunction function = (AbstractStochasticCachingDiffUpdateFunction) f; if (function instanceof LogConditionalObjectiveFunction) { if (((LogConditionalObjectiveFunction) function).parallelGradientCalculation) { System.err.println( "\n*********\nNoting that HogWild optimization requested.\nSetting batch size = data size to minimize thread creation overhead.\nResults *should* be identical on sparse problems.\nDisable parallelGradientComputation flag in LogConditionalObjectiveFunction, or run with -threads 1 to disable.\nAlso can use another Minimizer if parallel computation is desired, but HogWild isn't delivering good results.\n*********\n"); bSize = function.dataDimension(); } } int totalSamples = function.dataDimension(); int tuneSampleSize = Math.min(totalSamples, tuningSamples); if (tuneSampleSize < tuningSamples) { System.err.println( "WARNING: Total number of samples=" + totalSamples + " is smaller than requested tuning sample size=" + tuningSamples + "!!!"); } lambda = 1.0 / (sigma * totalSamples); sayln("Using sigma=" + sigma + " lambda=" + lambda + " tuning sample size " + tuneSampleSize); // tune(function, initial, tuneSampleSize, 0.1); t0 = (int) (1 / (0.1 * lambda)); x = new double[initial.length]; System.arraycopy(initial, 0, x, 0, x.length); xscale = 1; xnorm = getNorm(x); int numBatches = totalSamples / bSize; init(function); boolean have_max = (maxIterations > 0 || numPasses > 0); if (!have_max) { throw new UnsupportedOperationException( "No maximum number of iterations has been specified."); } else { maxIterations = Math.max(maxIterations, numPasses) * numBatches; } sayln(" Batch size of: " + bSize); sayln(" Data dimension of: " + totalSamples); sayln(" Batches per pass through data: " + numBatches); sayln(" Number of passes is = " + numPasses); sayln(" Max iterations is = " + maxIterations); // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // Loop // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! Timing total = new Timing(); Timing current = new Timing(); total.start(); current.start(); int t = t0; int iters = 0; for (int pass = 0; pass < numPasses; pass++) { boolean doEval = (pass > 0 && evaluateIters > 0 && pass % evaluateIters == 0); if (doEval) { rescale(); doEvaluation(x); } double totalValue = 0; double lastValue = 0; say("Iter: " + iters + " pass " + pass + " batch 1 ... "); for (int batch = 0; batch < numBatches; batch++) { iters++; // Get the next X double eta = 1 / (lambda * t); double gain = eta / xscale; lastValue = function.calculateStochasticUpdate(x, xscale, bSize, gain); totalValue += lastValue; // weight decay (for L2 regularization) xscale *= (1 - eta * lambda * bSize); t += bSize; } if (xscale < 1e-6) { rescale(); } try { ArrayMath.assertFinite(x, "x"); } catch (ArrayMath.InvalidElementException e) { System.err.println(e.toString()); for (int i = 0; i < x.length; i++) { x[i] = Double.NaN; } break; } xnorm = getNorm(x) * xscale * xscale; // Calculate loss based on L2 regularization double loss = totalValue + 0.5 * xnorm * lambda * totalSamples; say(String.valueOf(numBatches)); say("[" + (total.report()) / 1000.0 + " s "); say("{" + (current.restart() / 1000.0) + " s}] "); sayln(" " + lastValue + ' ' + totalValue + ' ' + loss); if (iters >= maxIterations) { sayln("Stochastic Optimization complete. Stopped after max iterations"); break; } if (total.report() >= maxTime) { sayln("Stochastic Optimization complete. Stopped after max time"); break; } } rescale(); if (evaluateIters > 0) { // do final evaluation doEvaluation(x); } sayln("Completed in: " + Timing.toSecondsString(total.report()) + " s"); return x; }
private static void smoothDistribution(double[] dist) { // perform Laplace smoothing double epsilon = 1e-6; for (int i = 0; i < dist.length; i++) dist[i] += epsilon; ArrayMath.normalize(dist); }
/** * Do max language model markov segmentation. Note that this algorithm inherently tags words as it * goes, but that we throw away the tags in the final result so that the segmented words are * untagged. (Note: for a couple of years till Aug 2007, a tagged result was returned, but this * messed up the parser, because it could use no tagging but the given tagging, which often wasn't * very good. Or in particular it was a subcategorized tagging which never worked with the current * forceTags option which assumes that gold taggings are inherently basic taggings.) * * @param s A String to segment * @return The list of segmented words. */ private ArrayList<HasWord> segmentWordsWithMarkov(String s) { int length = s.length(); // Set<String> POSes = (Set<String>) POSDistribution.keySet(); // 1.5 int numTags = POSes.size(); // score of span with initial word of this tag double[][][] scores = new double[length][length + 1][numTags]; // best (length of) first word for this span with this tag int[][][] splitBacktrace = new int[length][length + 1][numTags]; // best tag for second word over this span, if first is this tag int[][][] POSbacktrace = new int[length][length + 1][numTags]; for (int i = 0; i < length; i++) { for (int j = 0; j < length + 1; j++) { Arrays.fill(scores[i][j], Double.NEGATIVE_INFINITY); } } // first fill in word probabilities for (int diff = 1; diff <= 10; diff++) { for (int start = 0; start + diff <= length; start++) { int end = start + diff; StringBuilder wordBuf = new StringBuilder(); for (int pos = start; pos < end; pos++) { wordBuf.append(s.charAt(pos)); } String word = wordBuf.toString(); for (String tag : POSes) { IntTaggedWord itw = new IntTaggedWord(word, tag, wordIndex, tagIndex); double score = lex.score(itw, 0, word, null); if (start == 0) { score += Math.log(initialPOSDist.probabilityOf(tag)); } scores[start][end][itw.tag()] = score; splitBacktrace[start][end][itw.tag()] = end; } } } // now fill in word combination probabilities for (int diff = 2; diff <= length; diff++) { for (int start = 0; start + diff <= length; start++) { int end = start + diff; for (int split = start + 1; split < end && split - start <= 10; split++) { for (String tag : POSes) { int tagNum = tagIndex.indexOf(tag, true); if (splitBacktrace[start][split][tagNum] != split) { continue; } Distribution<String> rTagDist = markovPOSDists.get(tag); if (rTagDist == null) { continue; // this happens with "*" POS } for (String rTag : POSes) { int rTagNum = tagIndex.indexOf(rTag, true); double newScore = scores[start][split][tagNum] + scores[split][end][rTagNum] + Math.log(rTagDist.probabilityOf(rTag)); if (newScore > scores[start][end][tagNum]) { scores[start][end][tagNum] = newScore; splitBacktrace[start][end][tagNum] = split; POSbacktrace[start][end][tagNum] = rTagNum; } } } } } } int nextPOS = ArrayMath.argmax(scores[0][length]); ArrayList<HasWord> words = new ArrayList<HasWord>(); int start = 0; while (start < length) { int split = splitBacktrace[start][length][nextPOS]; StringBuilder wordBuf = new StringBuilder(); for (int i = start; i < split; i++) { wordBuf.append(s.charAt(i)); } String word = wordBuf.toString(); // String tag = tagIndex.get(nextPOS); // words.add(new TaggedWord(word, tag)); words.add(new Word(word)); if (split < length) { nextPOS = POSbacktrace[start][length][nextPOS]; } start = split; } return words; }