예제 #1
0
  /**
   * Return the probability (as a real number between 0 and 1) of stopping rather than generating
   * another argument at this position.
   *
   * @param dependency The dependency used as the basis for stopping on. Tags are assumed to be in
   *     the TagProjection space.
   * @return The probability of generating this stop probability
   */
  protected double getStopProb(IntDependency dependency) {
    short binDistance = distanceBin(dependency.distance);
    IntTaggedWord unknownHead = new IntTaggedWord(-1, dependency.head.tag);
    IntTaggedWord anyHead = new IntTaggedWord(ANY_WORD_INT, dependency.head.tag);

    IntDependency temp =
        new IntDependency(dependency.head, stopTW, dependency.leftHeaded, binDistance);
    double c_stop_hTWds = stopCounter.getCount(temp);
    temp = new IntDependency(unknownHead, stopTW, dependency.leftHeaded, binDistance);
    double c_stop_hTds = stopCounter.getCount(temp);
    temp = new IntDependency(dependency.head, wildTW, dependency.leftHeaded, binDistance);
    double c_hTWds = stopCounter.getCount(temp);
    temp = new IntDependency(anyHead, wildTW, dependency.leftHeaded, binDistance);
    double c_hTds = stopCounter.getCount(temp);

    double p_stop_hTds = (c_hTds > 0.0 ? c_stop_hTds / c_hTds : 1.0);

    double pb_stop_hTWds = (c_stop_hTWds + smooth_stop * p_stop_hTds) / (c_hTWds + smooth_stop);

    if (verbose) {
      System.out.println(
          "  c_stop_hTWds: "
              + c_stop_hTWds
              + "; c_hTWds: "
              + c_hTWds
              + "; c_stop_hTds: "
              + c_stop_hTds
              + "; c_hTds: "
              + c_hTds);
      System.out.println("  Generate STOP prob: " + pb_stop_hTWds);
    }
    return pb_stop_hTWds;
  }
예제 #2
0
 public static <T> Counter<T> toCounter(List<FeatureValue<T>> featureValues) {
   ClassicCounter<T> counter = new ClassicCounter<T>();
   for (FeatureValue<T> fv : featureValues) {
     counter.incrementCount(fv.name, fv.value);
   }
   return counter;
 }
 private static <L, F> BasicDatum<L, F> newDatum(L label, F[] features, Double[] counts) {
   ClassicCounter<F> counter = new ClassicCounter<F>();
   for (int i = 0; i < features.length; i++) {
     counter.setCount(features[i], counts[i]);
   }
   return new BasicDatum<L, F>(counter.keySet(), label);
 }
예제 #4
0
  /**
   * Trains this lexicon on the Collection of trees. Also trains the unknown word model pointed to
   * by this lexicon.
   */
  public void train(Collection<Tree> trees, double weight, boolean keepTagsAsLabels) {
    getUnknownWordModel().train(trees);

    // scan data
    for (Tree tree : trees) {
      List<IntTaggedWord> taggedWords = treeToEvents(tree, keepTagsAsLabels);
      for (int w = 0, sz = taggedWords.size(); w < sz; w++) {
        IntTaggedWord iTW = taggedWords.get(w);
        seenCounter.incrementCount(iTW, weight);
        IntTaggedWord iT = new IntTaggedWord(nullWord, iTW.tag);
        seenCounter.incrementCount(iT, weight);
        IntTaggedWord iW = new IntTaggedWord(iTW.word, nullTag);
        seenCounter.incrementCount(iW, weight);
        IntTaggedWord i = new IntTaggedWord(nullWord, nullTag);
        seenCounter.incrementCount(i, weight);
        // rules.add(iTW);
        tags.add(iT);
        words.add(iW);
      }
    }

    tune(trees);

    // index the possible tags for each word
    initRulesWithWord();

    if (DEBUG_LEXICON) {
      printLexStats();
    }
  }
예제 #5
0
 private static <T> void display(ClassicCounter<T> c, PrintWriter pw) {
   List<T> cats = new ArrayList<>(c.keySet());
   Collections.sort(cats, Counters.toComparatorDescending(c));
   for (T ob : cats) {
     pw.println(ob + " " + c.getCount(ob));
   }
 }
예제 #6
0
  /** Writes out data from this Object to the Writer w. */
  @Override
  public void writeData(PrintWriter out) throws IOException {
    // all lines have one rule per line

    for (IntDependency dependency : argCounter.keySet()) {
      if (dependency.head != wildTW
          && dependency.arg != wildTW
          && dependency.head.word != -1
          && dependency.arg.word != -1) {
        double count = argCounter.getCount(dependency);
        out.println(dependency.toString(wordIndex, tagIndex) + " " + count);
      }
    }

    out.println("BEGIN_STOP");

    for (IntDependency dependency : stopCounter.keySet()) {
      if (dependency.head.word != -1) {
        double count = stopCounter.getCount(dependency);
        out.println(dependency.toString(wordIndex, tagIndex) + " " + count);
      }
    }

    out.flush();
  }
 /** Make a copy of the array of counters. */
 public ClassicCounter<Integer>[] cloneCounter(ClassicCounter<Integer>[] counter) {
   ClassicCounter<Integer>[] newcount =
       ErasureUtils.<ClassicCounter<Integer>>mkTArray(ClassicCounter.class, counter.length);
   for (int xx = 0; xx < counter.length; xx++) {
     ClassicCounter<Integer> cc = new ClassicCounter<Integer>();
     newcount[xx] = cc;
     for (Integer key : counter[xx].keySet()) cc.incrementCount(key, counter[xx].getCount(key));
   }
   return newcount;
 }
예제 #8
0
 private static <T> void display(ClassicCounter<T> c, int num, PrintWriter pw) {
   List<T> rules = new ArrayList<>(c.keySet());
   Collections.sort(rules, Counters.toComparatorDescending(c));
   int rSize = rules.size();
   if (num > rSize) {
     num = rSize;
   }
   for (int i = 0; i < num; i++) {
     pw.println(rules.get(i) + " " + c.getCount(rules.get(i)));
   }
 }
예제 #9
0
 /** @param <T> */
 public static <T> List<FeatureValue<T>> combine(Collection<FeatureValue<T>> featureValues) {
   ClassicCounter<T> counter = new ClassicCounter<T>();
   for (FeatureValue<T> fv : featureValues) {
     counter.incrementCount(fv.name, fv.value);
   }
   Set<T> keys = new TreeSet<T>(counter.keySet());
   List<FeatureValue<T>> featureList = new ArrayList<FeatureValue<T>>(keys.size());
   for (T key : keys) {
     featureList.add(new FeatureValue<T>(key, counter.getCount(key)));
   }
   return featureList;
 }
예제 #10
0
 @Override
 public void evaluate(Tree t1, Tree t2, PrintWriter pw) {
   Set<String> s1 = makeObjects(t1);
   Set<String> s2 = makeObjects(t2);
   for (String o1 : s1) {
     if (!s2.contains(o1)) {
       over.incrementCount(o1);
     }
   }
   for (String o2 : s2) {
     if (!s1.contains(o2)) {
       under.incrementCount(o2);
     }
   }
 }
예제 #11
0
  /**
   * Evaluates how many words (= terminals) in a collection of trees are covered by the lexicon.
   * First arg is the collection of trees; second through fourth args get the results. Currently
   * unused; this probably only works if train and test at same time so tags and words variables are
   * initialized.
   */
  public double evaluateCoverage(
      Collection<Tree> trees,
      Set<String> missingWords,
      Set<String> missingTags,
      Set<IntTaggedWord> missingTW) {

    List<IntTaggedWord> iTW1 = new ArrayList<IntTaggedWord>();
    for (Tree t : trees) {
      iTW1.addAll(treeToEvents(t));
    }

    int total = 0;
    int unseen = 0;

    for (IntTaggedWord itw : iTW1) {
      total++;
      if (!words.contains(new IntTaggedWord(itw.word(), nullTag))) {
        missingWords.add(wordIndex.get(itw.word()));
      }
      if (!tags.contains(new IntTaggedWord(nullWord, itw.tag()))) {
        missingTags.add(tagIndex.get(itw.tag()));
      }
      // if (!rules.contains(itw)) {
      if (seenCounter.getCount(itw) == 0.0) {
        unseen++;
        missingTW.add(itw);
      }
    }
    return (double) unseen / total;
  }
예제 #12
0
 /**
  * This records how likely it is for a word with one tag to also have another tag. This won't work
  * after serialization/deserialization, but that is how it is currently called....
  */
 void buildPT_T() {
   int numTags = tagIndex.size();
   m_TT = new double[numTags][numTags];
   m_T = new double[numTags];
   double[] tmp = new double[numTags];
   for (IntTaggedWord word : words) {
     double tot = 0.0;
     for (int t = 0; t < numTags; t++) {
       IntTaggedWord iTW = new IntTaggedWord(word.word, t);
       tmp[t] = seenCounter.getCount(iTW);
       tot += tmp[t];
     }
     if (tot < 10) {
       continue;
     }
     for (int t = 0; t < numTags; t++) {
       for (int t2 = 0; t2 < numTags; t2++) {
         if (tmp[t2] > 0.0) {
           double c = tmp[t] / tot;
           m_T[t] += c;
           m_TT[t2][t] += c;
         }
       }
     }
   }
 }
예제 #13
0
  /**
   * Writes out data from this Object to the Writer w. Rules are separated by newline, and rule
   * elements are delimited by \t.
   */
  public void writeData(Writer w) throws IOException {
    PrintWriter out = new PrintWriter(w);

    for (IntTaggedWord itw : seenCounter.keySet()) {
      out.println(itw.toLexicalEntry(wordIndex, tagIndex) + " SEEN " + seenCounter.getCount(itw));
    }
    for (IntTaggedWord itw : getUnknownWordModel().unSeenCounter().keySet()) {
      out.println(
          itw.toLexicalEntry(wordIndex, tagIndex)
              + " UNSEEN "
              + getUnknownWordModel().unSeenCounter().getCount(itw));
    }
    for (int i = 0; i < smooth.length; i++) {
      out.println("smooth[" + i + "] = " + smooth[i]);
    }
    out.flush();
  }
 @Override
 protected void tallyRoot(Tree lt, double weight) {
   // this list is in full (not reduced) tag space
   List<IntDependency> deps = MLEDependencyGrammar.treeToDependencyList(lt, wordIndex, tagIndex);
   for (IntDependency dependency : deps) {
     dependencyCounter.incrementCount(dependency, weight);
   }
 }
 private Distribution<Integer> getSegmentedWordLengthDistribution(Treebank tb) {
   // CharacterLevelTagExtender ext = new CharacterLevelTagExtender();
   ClassicCounter<Integer> c = new ClassicCounter<Integer>();
   for (Iterator iterator = tb.iterator(); iterator.hasNext(); ) {
     Tree gold = (Tree) iterator.next();
     StringBuilder goldChars = new StringBuilder();
     ArrayList goldYield = gold.yield();
     for (Iterator wordIter = goldYield.iterator(); wordIter.hasNext(); ) {
       Word word = (Word) wordIter.next();
       goldChars.append(word);
     }
     List<HasWord> ourWords = segment(goldChars.toString());
     for (int i = 0; i < ourWords.size(); i++) {
       c.incrementCount(Integer.valueOf(ourWords.get(i).word().length()));
     }
   }
   return Distribution.getDistribution(c);
 }
 @Override
 public DependencyGrammar formResult() {
   wordIndex.indexOf(Lexicon.UNKNOWN_WORD, true);
   MLEDependencyGrammar dg =
       new MLEDependencyGrammar(
           tlpParams,
           directional,
           useDistance,
           useCoarseDistance,
           basicCategoryTagsInDependencyGrammar,
           op,
           wordIndex,
           tagIndex);
   for (IntDependency dependency : dependencyCounter.keySet()) {
     dg.addRule(dependency, dependencyCounter.getCount(dependency));
   }
   return dg;
 }
예제 #17
0
 @Override
 public void evaluate(Tree t1, Tree t2, PrintWriter pw) {
   List<String> s1 = myMakeObjects(t1);
   List<String> s2 = myMakeObjects(t2);
   List<String> del2 = new LinkedList<>(s2);
   // we delete out as we find them so we can score correctly a cat with
   // a certain cardinality in a tree.
   for (String o1 : s1) {
     if (!del2.remove(o1)) {
       over.incrementCount(o1);
     }
   }
   for (String o2 : s2) {
     if (!s1.remove(o2)) {
       under.incrementCount(o2);
     }
   }
 }
예제 #18
0
  private void expandStop(
      IntDependency dependency, short distBinDist, double count, boolean wildForStop) {
    IntTaggedWord headT = getCachedITW(dependency.head.tag);
    IntTaggedWord head =
        new IntTaggedWord(dependency.head.word, tagBin(dependency.head.tag)); // dependency.head;
    IntTaggedWord arg =
        new IntTaggedWord(dependency.arg.word, tagBin(dependency.arg.tag)); // dependency.arg;

    boolean leftHeaded = dependency.leftHeaded;

    if (arg.word == STOP_WORD_INT) {
      stopCounter.incrementCount(intern(head, arg, leftHeaded, distBinDist), count);
      stopCounter.incrementCount(intern(headT, arg, leftHeaded, distBinDist), count);
    }
    if (wildForStop || arg.word != STOP_WORD_INT) {
      stopCounter.incrementCount(intern(head, wildTW, leftHeaded, distBinDist), count);
      stopCounter.incrementCount(intern(headT, wildTW, leftHeaded, distBinDist), count);
    }
  }
예제 #19
0
 public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) {
   ClassicCounter<L> scores = new ClassicCounter<>();
   Counters.addInPlace(scores, priors);
   if (addZeroValued) {
     Counters.addInPlace(scores, priorZero);
   }
   for (L l : labels) {
     double score = 0.0;
     Counter<F> features = example.asFeaturesCounter();
     for (F f : features.keySet()) {
       int value = (int) features.getCount(f);
       score += weight(l, f, Integer.valueOf(value));
       if (addZeroValued) {
         score -= weight(l, f, zero);
       }
     }
     scores.incrementCount(l, score);
   }
   return scores;
 }
예제 #20
0
  public double countHistory(IntDependency dependency) {
    IntDependency temp =
        new IntDependency(
            dependency.head.word,
            tagBin(dependency.head.tag),
            wildTW.word,
            wildTW.tag,
            dependency.leftHeaded,
            valenceBin(dependency.distance));

    return argCounter.getCount(temp);
  }
예제 #21
0
  private void readObject(ObjectInputStream stream) throws IOException, ClassNotFoundException {
    stream.defaultReadObject();
    //    System.err.println("Before decompression:");
    //    System.err.println("arg size: " + argCounter.size() + "  total: " +
    // argCounter.totalCount());
    //    System.err.println("stop size: " + stopCounter.size() + "  total: " +
    // stopCounter.totalCount());

    ClassicCounter<IntDependency> compressedArgC = argCounter;
    argCounter = new ClassicCounter<IntDependency>();
    ClassicCounter<IntDependency> compressedStopC = stopCounter;
    stopCounter = new ClassicCounter<IntDependency>();
    for (IntDependency d : compressedArgC.keySet()) {
      double count = compressedArgC.getCount(d);
      expandArg(d, d.distance, count);
    }

    for (IntDependency d : compressedStopC.keySet()) {
      double count = compressedStopC.getCount(d);
      expandStop(d, d.distance, count, false);
    }

    //    System.err.println("After decompression:");
    //    System.err.println("arg size: " + argCounter.size() + "  total: " +
    // argCounter.totalCount());
    //    System.err.println("stop size: " + stopCounter.size() + "  total: " +
    // stopCounter.totalCount());

    expandDependencyMap = null;
  }
  public UnknownWordModel finishTraining() {
    // make sure the unseen counter isn't empty!  If it is, put in
    // a uniform unseen over tags
    if (unSeenCounter.isEmpty()) {
      int numTags = tagIndex.size();
      for (int tt = 0; tt < numTags; tt++) {
        if (!Lexicon.BOUNDARY_TAG.equals(tagIndex.get(tt))) {
          IntTaggedWord iT = new IntTaggedWord(nullWord, tt);
          IntTaggedWord i = NULL_ITW;
          unSeenCounter.incrementCount(iT);
          unSeenCounter.incrementCount(i);
        }
      }
    }

    // index the possible tags for each word
    // numWords = wordIndex.size();
    // unknownWordIndex = wordIndex.indexOf(Lexicon.UNKNOWN_WORD, true);
    // initRulesWithWord();

    return model;
  }
예제 #23
0
 /**
  * Generate the possible taggings for a word at a sentence position. This may either be based on a
  * strict lexicon or an expanded generous set of possible taggings.
  *
  * <p><i>Implementation note:</i> Expanded sets of possible taggings are calculated dynamically at
  * runtime, so as to reduce the memory used by the lexicon (a space/time tradeoff).
  *
  * @param word The word (as an int)
  * @param loc Its index in the sentence (usually only relevant for unknown words)
  * @return A list of possible taggings
  */
 public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {
   // if (rulesWithWord == null) { // tested in isKnown already
   // initRulesWithWord();
   // }
   List<IntTaggedWord> wordTaggings;
   if (isKnown(word)) {
     if (!flexiTag) {
       // Strict lexical tagging for seen items
       wordTaggings = rulesWithWord[word];
     } else {
       /* Allow all tags with same basicCategory */
       /* Allow all scored taggings, unless very common */
       IntTaggedWord iW = new IntTaggedWord(word, nullTag);
       if (seenCounter.getCount(iW) > smoothInUnknownsThreshold) {
         return rulesWithWord[word].iterator();
       } else {
         // give it flexible tagging not just lexicon
         wordTaggings = new ArrayList<IntTaggedWord>(40);
         for (IntTaggedWord iTW2 : tags) {
           IntTaggedWord iTW = new IntTaggedWord(word, iTW2.tag);
           if (score(iTW, loc, wordIndex.get(word)) > Float.NEGATIVE_INFINITY) {
             wordTaggings.add(iTW);
           }
         }
       }
     }
   } else {
     // we copy list so we can insert correct word in each item
     wordTaggings = new ArrayList<IntTaggedWord>(40);
     for (IntTaggedWord iTW : rulesWithWord[wordIndex.indexOf(UNKNOWN_WORD)]) {
       wordTaggings.add(new IntTaggedWord(word, iTW.tag));
     }
   }
   if (DEBUG_LEXICON) {
     EncodingPrintWriter.err.println(
         "Lexicon: "
             + wordIndex.get(word)
             + " ("
             + (isKnown(word) ? "known" : "unknown")
             + ", loc="
             + loc
             + ", n="
             + (isKnown(word) ? word : wordIndex.indexOf(UNKNOWN_WORD))
             + ") "
             + (flexiTag ? "flexi" : "lexicon")
             + " taggings: "
             + wordTaggings,
         "UTF-8");
   }
   return wordTaggings.iterator();
 }
  /** Trains this UWM on the Collection of trees. */
  public void train(TaggedWord tw, int loc, double weight) {
    IntTaggedWord iTW = new IntTaggedWord(tw.word(), tw.tag(), wordIndex, tagIndex);
    IntTaggedWord iT = new IntTaggedWord(nullWord, iTW.tag);
    IntTaggedWord iW = new IntTaggedWord(iTW.word, nullTag);
    seenCounter.incrementCount(iW, weight);
    IntTaggedWord i = NULL_ITW;

    if (treesRead > indexToStartUnkCounting) {
      // start doing this once some way through trees;
      // treesRead is 1 based counting
      if (seenCounter.getCount(iW) < 1.5) {
        // it's an entirely unknown word
        int s = model.getSignatureIndex(iTW.word, loc, wordIndex.get(iTW.word));
        if (DOCUMENT_UNKNOWNS) {
          String wStr = wordIndex.get(iTW.word);
          String tStr = tagIndex.get(iTW.tag);
          String sStr = wordIndex.get(s);
          EncodingPrintWriter.err.println(
              "Unknown word/tag/sig:\t" + wStr + '\t' + tStr + '\t' + sStr, "UTF-8");
        }
        IntTaggedWord iTS = new IntTaggedWord(s, iTW.tag);
        IntTaggedWord iS = new IntTaggedWord(s, nullTag);
        unSeenCounter.incrementCount(iTS, weight);
        unSeenCounter.incrementCount(iT, weight);
        unSeenCounter.incrementCount(iS, weight);
        unSeenCounter.incrementCount(i, weight);
        // rules.add(iTS);
        // sigs.add(iS);
      } // else {
      // if (seenCounter.getCount(iTW) < 2) {
      // it's a new tag for a known word
      // do nothing for now
      // }
      // }
    }
  }
  @Override
  public void train(List<TaggedWord> sentence) {
    lex.train(sentence, 1.0);

    String last = null;
    for (TaggedWord tagLabel : sentence) {
      String tag = tagLabel.tag();
      tagIndex.add(tag);
      if (last == null) {
        initial.incrementCount(tag);
      } else {
        ruleCounter.incrementCount2D(last, tag);
      }
      last = tag;
    }
  }
예제 #26
0
 /** Adds the tagging with count to the data structures in this Lexicon. */
 protected void addTagging(boolean seen, IntTaggedWord itw, double count) {
   if (seen) {
     seenCounter.incrementCount(itw, count);
     if (itw.tag() == nullTag) {
       words.add(itw);
     } else if (itw.word() == nullWord) {
       tags.add(itw);
     } else {
       // rules.add(itw);
     }
   } else {
     uwModel.addTagging(seen, itw, count);
     // if (itw.tag() == nullTag) {
     // sigs.add(itw);
     // }
   }
 }
예제 #27
0
  private void writeObject(ObjectOutputStream stream) throws IOException {
    //    System.err.println("\nBefore compression:");
    //    System.err.println("arg size: " + argCounter.size() + "  total: " +
    // argCounter.totalCount());
    //    System.err.println("stop size: " + stopCounter.size() + "  total: " +
    // stopCounter.totalCount());

    ClassicCounter<IntDependency> fullArgCounter = argCounter;
    argCounter = new ClassicCounter<IntDependency>();
    for (IntDependency dependency : fullArgCounter.keySet()) {
      if (dependency.head != wildTW
          && dependency.arg != wildTW
          && dependency.head.word != -1
          && dependency.arg.word != -1) {
        argCounter.incrementCount(dependency, fullArgCounter.getCount(dependency));
      }
    }

    ClassicCounter<IntDependency> fullStopCounter = stopCounter;
    stopCounter = new ClassicCounter<IntDependency>();
    for (IntDependency dependency : fullStopCounter.keySet()) {
      if (dependency.head.word != -1) {
        stopCounter.incrementCount(dependency, fullStopCounter.getCount(dependency));
      }
    }

    //    System.err.println("After compression:");
    //    System.err.println("arg size: " + argCounter.size() + "  total: " +
    // argCounter.totalCount());
    //    System.err.println("stop size: " + stopCounter.size() + "  total: " +
    // stopCounter.totalCount());

    stream.defaultWriteObject();

    argCounter = fullArgCounter;
    stopCounter = fullStopCounter;
  }
예제 #28
0
  /**
   * Collect counts for a non-STOP dependent. The dependency arg is still in the full tag space.
   *
   * @param dependency A non-stop dependency
   * @param valBinDist A binned distance
   * @param count The weight with which to add this dependency
   */
  private void expandArg(IntDependency dependency, short valBinDist, double count) {
    IntTaggedWord headT = getCachedITW(dependency.head.tag);
    IntTaggedWord argT = getCachedITW(dependency.arg.tag);
    IntTaggedWord head =
        new IntTaggedWord(dependency.head.word, tagBin(dependency.head.tag)); // dependency.head;
    IntTaggedWord arg =
        new IntTaggedWord(dependency.arg.word, tagBin(dependency.arg.tag)); // dependency.arg;
    boolean leftHeaded = dependency.leftHeaded;

    // argCounter stores stuff in both the original and the reduced tag space???
    argCounter.incrementCount(intern(head, arg, leftHeaded, valBinDist), count);
    argCounter.incrementCount(intern(headT, arg, leftHeaded, valBinDist), count);
    argCounter.incrementCount(intern(head, argT, leftHeaded, valBinDist), count);
    argCounter.incrementCount(intern(headT, argT, leftHeaded, valBinDist), count);

    argCounter.incrementCount(intern(head, wildTW, leftHeaded, valBinDist), count);
    argCounter.incrementCount(intern(headT, wildTW, leftHeaded, valBinDist), count);

    // the WILD head stats are always directionless and not useDistance!
    argCounter.incrementCount(intern(wildTW, arg, false, (short) -1), count);
    argCounter.incrementCount(intern(wildTW, argT, false, (short) -1), count);

    if (useSmoothTagProjection) {
      // added stuff to do more smoothing.  CDM Jan 2007
      IntTaggedWord headP =
          new IntTaggedWord(dependency.head.word, tagProject(dependency.head.tag));
      IntTaggedWord headTP = new IntTaggedWord(ANY_WORD_INT, tagProject(dependency.head.tag));
      IntTaggedWord argP = new IntTaggedWord(dependency.arg.word, tagProject(dependency.arg.tag));
      IntTaggedWord argTP = new IntTaggedWord(ANY_WORD_INT, tagProject(dependency.arg.tag));

      argCounter.incrementCount(intern(headP, argP, leftHeaded, valBinDist), count);
      argCounter.incrementCount(intern(headTP, argP, leftHeaded, valBinDist), count);
      argCounter.incrementCount(intern(headP, argTP, leftHeaded, valBinDist), count);
      argCounter.incrementCount(intern(headTP, argTP, leftHeaded, valBinDist), count);

      argCounter.incrementCount(intern(headP, wildTW, leftHeaded, valBinDist), count);
      argCounter.incrementCount(intern(headTP, wildTW, leftHeaded, valBinDist), count);

      // the WILD head stats are always directionless and not useDistance!
      argCounter.incrementCount(intern(wildTW, argP, false, (short) -1), count);
      argCounter.incrementCount(intern(wildTW, argTP, false, (short) -1), count);
      argCounter.incrementCount(
          intern(wildTW, new IntTaggedWord(dependency.head.word, ANY_TAG_INT), false, (short) -1),
          count);
    }
    numWordTokens++;
  }
예제 #29
0
  /**
   * Calculate the probability of a dependency as a real probability between 0 and 1 inclusive.
   *
   * @param dependency The dependency for which the probability is to be calculated. The tags in
   *     this dependency are in the reduced TagProjection space.
   * @return The probability of the dependency
   */
  protected double probTB(IntDependency dependency) {
    if (verbose) {
      // System.out.println("tagIndex: " + tagIndex);
      System.err.println("Generating " + dependency);
    }

    boolean leftHeaded = dependency.leftHeaded && directional;

    int hW = dependency.head.word;
    int aW = dependency.arg.word;
    short hT = dependency.head.tag;
    short aT = dependency.arg.tag;

    IntTaggedWord aTW = dependency.arg;
    IntTaggedWord hTW = dependency.head;

    boolean isRoot = rootTW(dependency.head);
    double pb_stop_hTWds;
    if (isRoot) {
      pb_stop_hTWds = 0.0;
    } else {
      pb_stop_hTWds = getStopProb(dependency);
    }

    if (dependency.arg.word == STOP_WORD_INT) {
      // did we generate stop?
      return pb_stop_hTWds;
    }

    double pb_go_hTWds = 1.0 - pb_stop_hTWds;

    // generate the argument

    short binDistance = valenceBin(dependency.distance);

    // KEY:
    // c_     count of (read as joint count of first and second)
    // p_     MLE prob of (or MAP if useSmoothTagProjection)
    // pb_    MAP prob of (read as prob of first given second thing)
    // a      arg
    // h      head
    // T      tag
    // PT     projected tag
    // W      word
    // d      direction
    // ds     distance (implicit: there when direction is mentioned!)

    IntTaggedWord anyHead = new IntTaggedWord(ANY_WORD_INT, dependency.head.tag);
    IntTaggedWord anyArg = new IntTaggedWord(ANY_WORD_INT, dependency.arg.tag);
    IntTaggedWord anyTagArg = new IntTaggedWord(dependency.arg.word, ANY_TAG_INT);

    IntDependency temp =
        new IntDependency(dependency.head, dependency.arg, leftHeaded, binDistance);
    double c_aTW_hTWd = argCounter.getCount(temp);
    temp = new IntDependency(dependency.head, anyArg, leftHeaded, binDistance);
    double c_aT_hTWd = argCounter.getCount(temp);
    temp = new IntDependency(dependency.head, wildTW, leftHeaded, binDistance);
    double c_hTWd = argCounter.getCount(temp);

    temp = new IntDependency(anyHead, dependency.arg, leftHeaded, binDistance);
    double c_aTW_hTd = argCounter.getCount(temp);
    temp = new IntDependency(anyHead, anyArg, leftHeaded, binDistance);
    double c_aT_hTd = argCounter.getCount(temp);
    temp = new IntDependency(anyHead, wildTW, leftHeaded, binDistance);
    double c_hTd = argCounter.getCount(temp);

    // for smooth tag projection
    short aPT = Short.MIN_VALUE;
    double c_aPTW_hPTd = Double.NaN;
    double c_aPT_hPTd = Double.NaN;
    double c_hPTd = Double.NaN;
    double c_aPTW_aPT = Double.NaN;
    double c_aPT = Double.NaN;

    if (useSmoothTagProjection) {
      aPT = tagProject(dependency.arg.tag);
      short hPT = tagProject(dependency.head.tag);

      IntTaggedWord projectedArg = new IntTaggedWord(dependency.arg.word, aPT);
      IntTaggedWord projectedAnyHead = new IntTaggedWord(ANY_WORD_INT, hPT);
      IntTaggedWord projectedAnyArg = new IntTaggedWord(ANY_WORD_INT, aPT);

      temp = new IntDependency(projectedAnyHead, projectedArg, leftHeaded, binDistance);
      c_aPTW_hPTd = argCounter.getCount(temp);
      temp = new IntDependency(projectedAnyHead, projectedAnyArg, leftHeaded, binDistance);
      c_aPT_hPTd = argCounter.getCount(temp);
      temp = new IntDependency(projectedAnyHead, wildTW, leftHeaded, binDistance);
      c_hPTd = argCounter.getCount(temp);

      temp = new IntDependency(wildTW, projectedArg, false, ANY_DISTANCE_INT);
      c_aPTW_aPT = argCounter.getCount(temp);
      temp = new IntDependency(wildTW, projectedAnyArg, false, ANY_DISTANCE_INT);
      c_aPT = argCounter.getCount(temp);
    }

    // wild head is always directionless and no use distance
    temp = new IntDependency(wildTW, dependency.arg, false, ANY_DISTANCE_INT);
    double c_aTW = argCounter.getCount(temp);
    temp = new IntDependency(wildTW, anyArg, false, ANY_DISTANCE_INT);
    double c_aT = argCounter.getCount(temp);
    temp = new IntDependency(wildTW, anyTagArg, false, ANY_DISTANCE_INT);
    double c_aW = argCounter.getCount(temp);

    // do the Bayesian magic
    // MLE probs
    double p_aTW_hTd;
    double p_aT_hTd;
    double p_aTW_aT;
    double p_aW;
    double p_aPTW_aPT;
    double p_aPTW_hPTd;
    double p_aPT_hPTd;

    // backoffs either mle or themselves bayesian smoothed depending on useSmoothTagProjection
    if (useSmoothTagProjection) {
      if (useUnigramWordSmoothing) {
        p_aW = c_aW > 0.0 ? (c_aW / numWordTokens) : 1.0; // NEED this 1.0 for unknown words!!!
        p_aPTW_aPT = (c_aPTW_aPT + smooth_aPTW_aPT * p_aW) / (c_aPT + smooth_aPTW_aPT);
      } else {
        p_aPTW_aPT =
            c_aPTW_aPT > 0.0 ? (c_aPTW_aPT / c_aPT) : 1.0; // NEED this 1.0 for unknown words!!!
      }
      p_aTW_aT = (c_aTW + smooth_aTW_aT * p_aPTW_aPT) / (c_aT + smooth_aTW_aT);

      p_aPTW_hPTd = c_hPTd > 0.0 ? (c_aPTW_hPTd / c_hPTd) : 0.0;
      p_aTW_hTd = (c_aTW_hTd + smooth_aTW_hTd * p_aPTW_hPTd) / (c_hTd + smooth_aTW_hTd);

      p_aPT_hPTd = c_hPTd > 0.0 ? (c_aPT_hPTd / c_hPTd) : 0.0;
      p_aT_hTd = (c_aT_hTd + smooth_aT_hTd * p_aPT_hPTd) / (c_hTd + smooth_aT_hTd);
    } else {
      // here word generation isn't smoothed - can't get previously unseen word with tag.  Ugh.
      if (op.testOptions.useLexiconToScoreDependencyPwGt) {
        // We don't know the position.  Now -1 means average over 0 and 1.
        p_aTW_aT =
            dependency.leftHeaded
                ? Math.exp(lex.score(dependency.arg, 1, wordIndex.get(dependency.arg.word)))
                : Math.exp(lex.score(dependency.arg, -1, wordIndex.get(dependency.arg.word)));
        // double oldScore = c_aTW > 0.0 ? (c_aTW / c_aT) : 1.0;
        // if (oldScore == 1.0) {
        //  System.err.println("#### arg=" + dependency.arg + " score=" + p_aTW_aT +
        //                      " oldScore=" + oldScore + " c_aTW=" + c_aTW + " c_aW=" + c_aW);
        // }
      } else {
        p_aTW_aT = c_aTW > 0.0 ? (c_aTW / c_aT) : 1.0;
      }
      p_aTW_hTd = c_hTd > 0.0 ? (c_aTW_hTd / c_hTd) : 0.0;
      p_aT_hTd = c_hTd > 0.0 ? (c_aT_hTd / c_hTd) : 0.0;
    }

    double pb_aTW_hTWd = (c_aTW_hTWd + smooth_aTW_hTWd * p_aTW_hTd) / (c_hTWd + smooth_aTW_hTWd);
    double pb_aT_hTWd = (c_aT_hTWd + smooth_aT_hTWd * p_aT_hTd) / (c_hTWd + smooth_aT_hTWd);

    double score = (interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd) * pb_go_hTWds;

    if (verbose) {
      NumberFormat nf = NumberFormat.getNumberInstance();
      nf.setMaximumFractionDigits(2);
      if (useSmoothTagProjection) {
        if (useUnigramWordSmoothing) {
          System.err.println(
              "  c_aW=" + c_aW + ", numWordTokens=" + numWordTokens + ", p(aW)=" + nf.format(p_aW));
        }
        System.err.println(
            "  c_aPTW_aPT="
                + c_aPTW_aPT
                + ", c_aPT="
                + c_aPT
                + ", smooth_aPTW_aPT="
                + smooth_aPTW_aPT
                + ", p(aPTW|aPT)="
                + nf.format(p_aPTW_aPT));
      }
      System.err.println(
          "  c_aTW="
              + c_aTW
              + ", c_aT="
              + c_aT
              + ", smooth_aTW_aT="
              + smooth_aTW_aT
              + ", ## p(aTW|aT)="
              + nf.format(p_aTW_aT));

      if (useSmoothTagProjection) {
        System.err.println(
            "  c_aPTW_hPTd="
                + c_aPTW_hPTd
                + ", c_hPTd="
                + c_hPTd
                + ", p(aPTW|hPTd)="
                + nf.format(p_aPTW_hPTd));
      }
      System.err.println(
          "  c_aTW_hTd="
              + c_aTW_hTd
              + ", c_hTd="
              + c_hTd
              + ", smooth_aTW_hTd="
              + smooth_aTW_hTd
              + ", p(aTW|hTd)="
              + nf.format(p_aTW_hTd));

      if (useSmoothTagProjection) {
        System.err.println(
            "  c_aPT_hPTd="
                + c_aPT_hPTd
                + ", c_hPTd="
                + c_hPTd
                + ", p(aPT|hPTd)="
                + nf.format(p_aPT_hPTd));
      }
      System.err.println(
          "  c_aT_hTd="
              + c_aT_hTd
              + ", c_hTd="
              + c_hTd
              + ", smooth_aT_hTd="
              + smooth_aT_hTd
              + ", p(aT|hTd)="
              + nf.format(p_aT_hTd));

      System.err.println(
          "  c_aTW_hTWd="
              + c_aTW_hTWd
              + ", c_hTWd="
              + c_hTWd
              + ", smooth_aTW_hTWd="
              + smooth_aTW_hTWd
              + ", ## p(aTW|hTWd)="
              + nf.format(pb_aTW_hTWd));
      System.err.println(
          "  c_aT_hTWd="
              + c_aT_hTWd
              + ", c_hTWd="
              + c_hTWd
              + ", smooth_aT_hTWd="
              + smooth_aT_hTWd
              + ", ## p(aT|hTWd)="
              + nf.format(pb_aT_hTWd));

      System.err.println(
          "  interp="
              + interp
              + ", prescore="
              + nf.format(interp * pb_aTW_hTWd + (1.0 - interp) * p_aTW_aT * pb_aT_hTWd)
              + ", P(go|hTWds)="
              + nf.format(pb_go_hTWds)
              + ", score="
              + nf.format(score));
    }

    if (op.testOptions.prunePunc && pruneTW(aTW)) {
      return 1.0;
    }

    if (Double.isNaN(score)) {
      score = 0.0;
    }

    // if (op.testOptions.rightBonus && ! dependency.leftHeaded)
    //  score -= 0.2;

    if (score < MIN_PROBABILITY) {
      score = 0.0;
    }

    return score;
  }
예제 #30
0
 public void dumpSizes() {
   //    System.out.println("core dep " + coreDependencies.size());
   System.out.println("arg counter " + argCounter.size());
   System.out.println("stop counter " + stopCounter.size());
 }