示例#1
0
 public double getCount(K token) {
   if (!lm.keySet().contains(token)) {
     System.err.println(lm.keySet().size());
     throw new RuntimeException("token not in keyset");
   }
   return lm.getCount(token);
 }
示例#2
0
  /**
   * GT smoothing with least squares interpolation. This follows the procedure in Jurafsky and
   * Martin sect. 4.5.3.
   */
  public void smoothAndNormalize() {
    Counter<Integer> cntCounter = new Counter<Integer>();
    for (K tok : lm.keySet()) {
      int cnt = (int) lm.getCount(tok);
      cntCounter.incrementCount(cnt);
    }

    final double[] coeffs = runLogSpaceRegression(cntCounter);

    UNK_PROB = cntCounter.getCount(1) / lm.totalCount();

    for (K tok : lm.keySet()) {
      double tokCnt = lm.getCount(tok);
      if (tokCnt <= unkCutoff) // Treat as unknown
      unkTokens.add(tok);
      if (tokCnt <= kCutoff) { // Smooth
        double cSmooth = katzEstimate(cntCounter, tokCnt, coeffs);
        lm.setCount(tok, cSmooth);
      }
    }

    // Normalize
    // Counters.normalize(lm);
    // MY COUNTER IS ALWAYS NORMALIZED AND AWESOME
  }
示例#3
0
  private double[] runLogSpaceRegression(Counter<Integer> cntCounter) {
    SimpleRegression reg = new SimpleRegression();

    for (int cnt : cntCounter.keySet()) {
      reg.addData(cnt, Math.log(cntCounter.getCount(cnt)));
    }

    // System.out.println(reg.getIntercept());
    // System.out.println(reg.getSlope());
    // System.out.println(regression.getSlopeStdErr());

    double[] coeffs = new double[] {reg.getIntercept(), reg.getSlope()};

    return coeffs;
  }
示例#4
0
  private double katzEstimate(Counter<Integer> cnt, double c, double[] coeffs) {
    double nC = cnt.getCount((int) c);
    double nC1 = cnt.getCount(((int) c) + 1);
    if (nC1 == 0.0) nC1 = Math.exp(coeffs[0] + (coeffs[1] * (c + 1.0)));

    double n1 = cnt.getCount(1);
    double nK1 = cnt.getCount(((int) kCutoff) + 1);
    if (nK1 == 0.0) nK1 = Math.exp(coeffs[0] + (coeffs[1] * (kCutoff + 1.0)));

    double kTerm = (kCutoff + 1.0) * (nK1 / n1);
    double cTerm = (c + 1.0) * (nC1 / nC);

    double cSmooth = (cTerm - (c * kTerm)) / (1.0 - kTerm);

    return cSmooth;
  }
示例#5
0
 public Set<K> getVocab() {
   return Collections.unmodifiableSet(lm.keySet());
 }
示例#6
0
 public boolean contains(K token) {
   return lm.containsKey(token);
 }
示例#7
0
 public int vocabSize() {
   return lm.keySet().size();
 }
示例#8
0
 public double totalMass() {
   return lm.totalCount();
 }
示例#9
0
 public double getProb(K token) {
   if (unkTokens.contains(token) || !lm.containsKey(token)) return UNK_PROB;
   return lm.getCount(token);
 }
示例#10
0
 public void incrementCount(K token) {
   lm.incrementCount(token);
 }