public double getCount(K token) { if (!lm.keySet().contains(token)) { System.err.println(lm.keySet().size()); throw new RuntimeException("token not in keyset"); } return lm.getCount(token); }
/** * GT smoothing with least squares interpolation. This follows the procedure in Jurafsky and * Martin sect. 4.5.3. */ public void smoothAndNormalize() { Counter<Integer> cntCounter = new Counter<Integer>(); for (K tok : lm.keySet()) { int cnt = (int) lm.getCount(tok); cntCounter.incrementCount(cnt); } final double[] coeffs = runLogSpaceRegression(cntCounter); UNK_PROB = cntCounter.getCount(1) / lm.totalCount(); for (K tok : lm.keySet()) { double tokCnt = lm.getCount(tok); if (tokCnt <= unkCutoff) // Treat as unknown unkTokens.add(tok); if (tokCnt <= kCutoff) { // Smooth double cSmooth = katzEstimate(cntCounter, tokCnt, coeffs); lm.setCount(tok, cSmooth); } } // Normalize // Counters.normalize(lm); // MY COUNTER IS ALWAYS NORMALIZED AND AWESOME }
private double[] runLogSpaceRegression(Counter<Integer> cntCounter) { SimpleRegression reg = new SimpleRegression(); for (int cnt : cntCounter.keySet()) { reg.addData(cnt, Math.log(cntCounter.getCount(cnt))); } // System.out.println(reg.getIntercept()); // System.out.println(reg.getSlope()); // System.out.println(regression.getSlopeStdErr()); double[] coeffs = new double[] {reg.getIntercept(), reg.getSlope()}; return coeffs; }
private double katzEstimate(Counter<Integer> cnt, double c, double[] coeffs) { double nC = cnt.getCount((int) c); double nC1 = cnt.getCount(((int) c) + 1); if (nC1 == 0.0) nC1 = Math.exp(coeffs[0] + (coeffs[1] * (c + 1.0))); double n1 = cnt.getCount(1); double nK1 = cnt.getCount(((int) kCutoff) + 1); if (nK1 == 0.0) nK1 = Math.exp(coeffs[0] + (coeffs[1] * (kCutoff + 1.0))); double kTerm = (kCutoff + 1.0) * (nK1 / n1); double cTerm = (c + 1.0) * (nC1 / nC); double cSmooth = (cTerm - (c * kTerm)) / (1.0 - kTerm); return cSmooth; }
public Set<K> getVocab() { return Collections.unmodifiableSet(lm.keySet()); }
public boolean contains(K token) { return lm.containsKey(token); }
public int vocabSize() { return lm.keySet().size(); }
public double totalMass() { return lm.totalCount(); }
public double getProb(K token) { if (unkTokens.contains(token) || !lm.containsKey(token)) return UNK_PROB; return lm.getCount(token); }
public void incrementCount(K token) { lm.incrementCount(token); }