/**
   * Returns a list of featured thresholded by minPrecision and sorted by their frequency of
   * occurrence. precision in this case, is defined as the frequency of majority label over total
   * frequency for that feature.
   *
   * @return list of high precision features.
   */
  private List<F> getHighPrecisionFeatures(
      GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) {
    int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()];
    for (int f = 0; f < dataset.numFeatures(); f++) Arrays.fill(feature2label[f], 0);

    int[][] data = dataset.data;
    int[] labels = dataset.labels;
    for (int d = 0; d < data.length; d++) {
      int label = labels[d];
      // System.out.println("datum id:"+d+" label id: "+label);
      if (data[d] != null) {
        // System.out.println(" number of features:"+data[d].length);
        for (int n = 0; n < data[d].length; n++) {
          feature2label[data[d][n]][label]++;
        }
      }
    }
    Counter<F> feature2freq = new ClassicCounter<F>();
    for (int f = 0; f < dataset.numFeatures(); f++) {
      int maxF = ArrayMath.max(feature2label[f]);
      int total = ArrayMath.sum(feature2label[f]);
      double precision = ((double) maxF) / total;
      F feature = dataset.featureIndex.get(f);
      if (precision >= minPrecision) {
        feature2freq.incrementCount(feature, total);
      }
    }
    if (feature2freq.size() > maxNumFeatures) {
      Counters.retainTop(feature2freq, maxNumFeatures);
    }
    // for(F feature : feature2freq.keySet())
    // System.out.println(feature+" "+feature2freq.getCount(feature));
    // System.exit(0);
    return Counters.toSortedList(feature2freq);
  }
Пример #2
0
 /**
  * TODO(gabor) JavaDoc
  *
  * @param tokens
  * @param span
  * @return
  */
 public static String guessNER(List<CoreLabel> tokens, Span span) {
   Counter<String> nerGuesses = new ClassicCounter<>();
   for (int i : span) {
     nerGuesses.incrementCount(tokens.get(i).ner());
   }
   nerGuesses.remove("O");
   nerGuesses.remove(null);
   if (nerGuesses.size() > 0 && Counters.max(nerGuesses) >= span.size() / 2) {
     return Counters.argmax(nerGuesses);
   } else {
     return "O";
   }
 }
 /** Run some sanity checks on the training statistics, to make sure they look valid. */
 public void validate() {
   for (Map<SentenceKey, EnsembleStatistics> map : impl) {
     for (EnsembleStatistics stats : map.values()) {
       for (SentenceStatistics component : stats.statisticsForClassifiers) {
         assert !Counters.isUniformDistribution(component.relationDistribution, 1e-5);
         Counters.normalize(
             component.relationDistribution); // TODO(gabor) this shouldn't be necessary
         assert (Math.abs(component.relationDistribution.totalCount() - 1.0)) < 1e-5;
       }
       assert (Math.abs(stats.mean().relationDistribution.totalCount() - 1.0)) < 1e-5;
       assert !Counters.isUniformDistribution(stats.mean().relationDistribution, 1e-5);
     }
   }
 }
Пример #4
0
  public static void main(String[] args) throws Exception {
    CompareType compareType = CompareType.MAX_COOR_ABS;

    if (args.length == 0) {
      usage();
      System.exit(-1);
    }

    if ("-cosine".equals(args[0])) {
      compareType = CompareType.COSINE;
      args = (String[]) Arrays.copyOfRange(args, 1, args.length);
    } else if ("-sse".equals(args[0])) {
      compareType = CompareType.SUM_SQUARE_ERROR;
      args = (String[]) Arrays.copyOfRange(args, 1, args.length);
    }

    if (args.length != 2) {
      usage();
      System.exit(-1);
    }

    Set<String> allWeights = new HashSet<String>();
    Counter<String> wts1 = IOTools.readWeights(args[0]);
    Counter<String> wts2 = IOTools.readWeights(args[1]);
    allWeights.addAll(wts1.keySet());
    allWeights.addAll(wts2.keySet());

    if (compareType == CompareType.MAX_COOR_ABS) {
      double maxDiff = Double.NEGATIVE_INFINITY;
      Counters.multiplyInPlace(wts1, 1.0 / Counters.L1Norm(wts1));
      Counters.multiplyInPlace(wts2, 1.0 / Counters.L1Norm(wts2));
      for (String wt : allWeights) {
        double absDiff = Math.abs(wts1.getCount(wt) - wts2.getCount(wt));
        if (absDiff > maxDiff) maxDiff = absDiff;
      }
      System.out.println(maxDiff);
    } else if (compareType == CompareType.COSINE) {
      double dotProd = Counters.cosine(wts1, wts2);
      System.out.println(dotProd);
    } else if (compareType == CompareType.SUM_SQUARE_ERROR) {
      double sse = 0;
      for (String wt : allWeights) {
        double diff = wts1.getCount(wt) - wts2.getCount(wt);
        sse += diff * diff;
      }
      System.out.println(sse);
    }
  }
Пример #5
0
 private static <T> void display(ClassicCounter<T> c, PrintWriter pw) {
   List<T> cats = new ArrayList<>(c.keySet());
   Collections.sort(cats, Counters.toComparatorDescending(c));
   for (T ob : cats) {
     pw.println(ob + " " + c.getCount(ob));
   }
 }
Пример #6
0
 /**
  * Returns a list of all modes in the Collection. (If the Collection has multiple items with the
  * highest frequency, all of them will be returned.)
  */
 public static <T> Set<T> modes(Collection<T> values) {
   Counter<T> counter = new ClassicCounter<T>(values);
   List<Double> sortedCounts = CollectionUtils.sorted(counter.values());
   Double highestCount = sortedCounts.get(sortedCounts.size() - 1);
   Counters.retainAbove(counter, highestCount);
   return counter.keySet();
 }
 @Deprecated
 protected void semanticSimilarity(
     Counter<String> features, String prefix, Sentence str1, Sentence str2) {
   Counter<String> v1 =
       new ClassicCounter<>(
           str1.lemmas().stream().map(String::toLowerCase).collect(Collectors.toList()));
   Counter<String> v2 = new ClassicCounter<>(str2.lemmas());
   // Remove any stopwords.
   for (String word : stopwords) {
     v1.remove(word);
     v2.remove(word);
   }
   // take inner product.
   double sim =
       Counters.dotProduct(v1, v2) / (Counters.saferL2Norm(v1) * Counters.saferL2Norm(v2));
   features.incrementCount(
       prefix + "semantic-similarity", 2 * sim - 1); // to make it between 0 and 1.
 }
Пример #8
0
 @Override
 public L classOf(Datum<L, F> example) {
   Counter<L> scores = scoresOf(example);
   if (scores != null) {
     return Counters.argmax(scores);
   } else {
     return defaultLabel;
   }
 }
Пример #9
0
  public static List<String> generateDict(List<String> str, int cutOff) {
    Counter<String> freq = new IntCounter<>();
    for (String aStr : str) freq.incrementCount(aStr);

    List<String> keys = Counters.toSortedList(freq, false);
    List<String> dict = new ArrayList<>();
    for (String word : keys) {
      if (freq.getCount(word) >= cutOff) dict.add(word);
    }
    return dict;
  }
Пример #10
0
 private static <T> void display(ClassicCounter<T> c, int num, PrintWriter pw) {
   List<T> rules = new ArrayList<>(c.keySet());
   Collections.sort(rules, Counters.toComparatorDescending(c));
   int rSize = rules.size();
   if (num > rSize) {
     num = rSize;
   }
   for (int i = 0; i < num; i++) {
     pw.println(rules.get(i) + " " + c.getCount(rules.get(i)));
   }
 }
Пример #11
0
 public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) {
   ClassicCounter<L> scores = new ClassicCounter<>();
   Counters.addInPlace(scores, priors);
   if (addZeroValued) {
     Counters.addInPlace(scores, priorZero);
   }
   for (L l : labels) {
     double score = 0.0;
     Counter<F> features = example.asFeaturesCounter();
     for (F f : features.keySet()) {
       int value = (int) features.getCount(f);
       score += weight(l, f, Integer.valueOf(value));
       if (addZeroValued) {
         score -= weight(l, f, zero);
       }
     }
     scores.incrementCount(l, score);
   }
   return scores;
 }
 public SentenceStatistics mean() {
   double sumConfidence = 0;
   int countWithConfidence = 0;
   Counter<String> avePredictions =
       new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory());
   // Sum
   for (SentenceStatistics stat : this.statisticsForClassifiers) {
     for (Double confidence : stat.confidence) {
       sumConfidence += confidence;
       countWithConfidence += 1;
     }
     assert Math.abs(stat.relationDistribution.totalCount() - 1.0) < 1e-5;
     for (Map.Entry<String, Double> entry : stat.relationDistribution.entrySet()) {
       assert entry.getValue() >= 0.0;
       assert entry.getValue() == stat.relationDistribution.getCount(entry.getKey());
       avePredictions.incrementCount(entry.getKey(), entry.getValue());
       assert stat.relationDistribution.getCount(entry.getKey())
           == stat.relationDistribution.getCount(entry.getKey());
     }
   }
   // Normalize
   double aveConfidence = sumConfidence / ((double) countWithConfidence);
   // Return
   if (this.statisticsForClassifiers.size() > 1) {
     Counters.divideInPlace(avePredictions, (double) this.statisticsForClassifiers.size());
   }
   if (Math.abs(avePredictions.totalCount() - 1.0) > 1e-5) {
     throw new IllegalStateException("Mean relation distribution is not a distribution!");
   }
   assert this.statisticsForClassifiers.size() > 1
       || this.statisticsForClassifiers.size() == 0
       || Counters.equals(
           avePredictions,
           statisticsForClassifiers.iterator().next().relationDistribution,
           1e-5);
   return countWithConfidence > 0
       ? new SentenceStatistics(avePredictions, aveConfidence)
       : new SentenceStatistics(avePredictions);
 }
 String probsToString() {
   List<Pair<String, Double>> sorted =
       Counters.toDescendingMagnitudeSortedListWithCounts(typeProbabilities);
   StringBuffer os = new StringBuffer();
   os.append("{");
   boolean first = true;
   for (Pair<String, Double> lv : sorted) {
     if (!first) os.append("; ");
     os.append(lv.first + ", " + lv.second);
     first = false;
   }
   os.append("}");
   return os.toString();
 }
Пример #14
0
 public static Set<String> featureWhiteList(FlatNBestList nbest, int minSegmentCount) {
   List<List<ScoredFeaturizedTranslation<IString, String>>> nbestlists = nbest.nbestLists();
   Counter<String> featureSegmentCounts = new ClassicCounter<String>();
   for (List<ScoredFeaturizedTranslation<IString, String>> nbestlist : nbestlists) {
     Set<String> segmentFeatureSet = new HashSet<String>();
     for (ScoredFeaturizedTranslation<IString, String> trans : nbestlist) {
       for (FeatureValue<String> feature : trans.features) {
         segmentFeatureSet.add(feature.name);
       }
     }
     for (String featureName : segmentFeatureSet) {
       featureSegmentCounts.incrementCount(featureName);
     }
   }
   return Counters.keysAbove(featureSegmentCounts, minSegmentCount - 1);
 }
Пример #15
0
 /**
  * Update an existing feature whitelist according to nbestlists. Then return the features that
  * appear more than minSegmentCount times.
  *
  * @param featureWhitelist
  * @param nbestlists
  * @param minSegmentCount
  * @return features that appear more than minSegmentCount times
  */
 public static Set<String> updatefeatureWhiteList(
     Counter<String> featureWhitelist,
     List<List<RichTranslation<IString, String>>> nbestlists,
     int minSegmentCount) {
   for (List<RichTranslation<IString, String>> nbestlist : nbestlists) {
     Set<String> segmentFeatureSet = new HashSet<String>(1000);
     for (RichTranslation<IString, String> trans : nbestlist) {
       for (FeatureValue<String> feature : trans.features) {
         if (!segmentFeatureSet.contains(feature.name)) {
           segmentFeatureSet.add(feature.name);
           featureWhitelist.incrementCount(feature.name);
         }
       }
     }
   }
   return Counters.keysAbove(featureWhitelist, minSegmentCount - 1);
 }
  /**
   * Returns true if it's worth saving/printing this object This happens in two cases: 1. The type
   * of the object is not nilLabel 2. The type of the object is nilLabel but the second ranked label
   * is within the given beam (0 -- 100) of the first choice
   *
   * @param beam
   * @param nilLabel
   */
  public boolean printableObject(double beam, String nilLabel) {
    if (typeProbabilities == null) {
      return false;
    }
    List<Pair<String, Double>> sorted =
        Counters.toDescendingMagnitudeSortedListWithCounts(typeProbabilities);

    // first choice not nil
    if (sorted.size() > 0 && !sorted.get(0).first.equals(nilLabel)) {
      return true;
    }

    // first choice is nil, but second is within beam
    if (sorted.size() > 1
        && sorted.get(0).first.equals(nilLabel)
        && beam > 0
        && 100.0 * (sorted.get(0).second - sorted.get(1).second) < beam) {
      return true;
    }

    return false;
  }
 public double averageKLFromMean() {
   Counter<String> mean = this.mean().relationDistribution;
   double sumKL = 0;
   for (SentenceStatistics stats : this.statisticsForClassifiers) {
     double kl = Counters.klDivergence(stats.relationDistribution, mean);
     if (kl < 0.0 && kl > -1e-12) {
       kl = 0.0;
     } // floating point error.
     assert kl >= 0.0;
     sumKL += kl;
   }
   double val = sumKL / ((double) this.statisticsForClassifiers.size());
   if (Double.isInfinite(val) || Double.isNaN(val) || val < 0.0) {
     throw new AssertionError("Invalid average KL value: " + val);
   }
   assert val >= 0.0; // KL lower bound
   assert this.statisticsForClassifiers.size() > 1 || val < 1e-5;
   if (val < 1e-10) {
     val = 0.0;
   } // floating point error
   return val;
 }
Пример #18
0
 public L classOf(RVFDatum<L, F> example) {
   Counter<L> scores = scoresOf(example);
   return Counters.argmax(scores);
 }
 public List<String> selectKeys(ActiveLearningSelectionCriterion criterion) {
   Counter<String> weights = uncertainty(criterion);
   return Counters.toSortedList(weights);
 }
  /**
   * The core implementation of the search.
   *
   * @param root The root word to search from. Traditionally, this is the root of the sentence.
   * @param candidateFragments The callback for the resulting sentence fragments. This is a
   *     predicate of a triple of values. The return value of the predicate determines whether we
   *     should continue searching. The triple is a triple of
   *     <ol>
   *       <li>The log probability of the sentence fragment, according to the featurizer and the
   *           weights
   *       <li>The features along the path to this fragment. The last element of this is the
   *           features from the most recent step.
   *       <li>The sentence fragment. Because it is relatively expensive to compute the resulting
   *           tree, this is returned as a lazy {@link Supplier}.
   *     </ol>
   *
   * @param classifier The classifier for whether an arc should be on the path to a clause split, a
   *     clause split itself, or neither.
   * @param featurizer The featurizer to use. Make sure this matches the weights!
   * @param actionSpace The action space we are allowed to take. Each action defines a means of
   *     splitting a clause on a dependency boundary.
   */
  protected void search(
      // The root to search from
      IndexedWord root,
      // The output specs
      final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>>
          candidateFragments,
      // The learning specs
      final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier,
      Map<String, ? extends List<String>> hardCodedSplits,
      final Function<Triple<State, Action, State>, Counter<String>> featurizer,
      final Collection<Action> actionSpace,
      final int maxTicks) {
    // (the fringe)
    PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>();
    // (avoid duplicate work)
    Set<IndexedWord> seenWords = new HashSet<>();

    State firstState =
        new State(null, null, -9000, null, x -> {}, true); // First state is implicitly "done"
    fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0);
    int ticks = 0;

    while (!fringe.isEmpty()) {
      if (++ticks > maxTicks) {
        //        System.err.println("WARNING! Timed out on search with " + ticks + " ticks");
        return;
      }
      // Useful variables
      double logProbSoFar = fringe.getPriority();
      assert logProbSoFar <= 0.0;
      Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst();
      State lastState = lastStatePair.first;
      List<Counter<String>> featuresSoFar = lastStatePair.second;
      IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent();

      // Register thunk
      if (lastState.isDone) {
        if (!candidateFragments.test(
            Triple.makeTriple(
                logProbSoFar,
                featuresSoFar,
                () -> {
                  SemanticGraph copy = new SemanticGraph(tree);
                  lastState
                      .thunk
                      .andThen(
                          x -> {
                            // Add the extra edges back in, if they don't break the tree-ness of the
                            // extraction
                            for (IndexedWord newTreeRoot : x.getRoots()) {
                              if (newTreeRoot != null) { // what a strange thing to have happen...
                                for (SemanticGraphEdge extraEdge :
                                    extraEdgesByGovernor.get(newTreeRoot)) {
                                  assert Util.isTree(x);
                                  //noinspection unchecked
                                  addSubtree(
                                      x,
                                      newTreeRoot,
                                      extraEdge.getRelation().toString(),
                                      tree,
                                      extraEdge.getDependent(),
                                      tree.getIncomingEdgesSorted(newTreeRoot));
                                  assert Util.isTree(x);
                                }
                              }
                            }
                          })
                      .accept(copy);
                  return new SentenceFragment(copy, assumedTruth, false);
                }))) {
          break;
        }
      }

      // Find relevant auxilliary terms
      SemanticGraphEdge subjOrNull = null;
      SemanticGraphEdge objOrNull = null;
      for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) {
        String relString = auxEdge.getRelation().toString();
        if (relString.contains("obj")) {
          objOrNull = auxEdge;
        } else if (relString.contains("subj")) {
          subjOrNull = auxEdge;
        }
      }

      // Iterate over children
      // For each outgoing edge...
      for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) {
        // Prohibit indirect speech verbs from splitting off clauses
        // (e.g., 'said', 'think')
        // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp
        if (outgoingEdge.getRelation().toString().equals("ccomp")
            && ((outgoingEdge.getGovernor().lemma() != null
                    && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma()))
                || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) {
          continue;
        }
        // Get some variables
        String outgoingEdgeRelation = outgoingEdge.getRelation().toString();
        List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation);
        if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) {
          forcedArcOrder =
              hardCodedSplits.get(
                  outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*");
        }
        boolean doneForcedArc = false;
        // For each action...
        for (Action action :
            (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) {
          // Check the prerequisite
          if (!action.prerequisitesMet(tree, outgoingEdge)) {
            continue;
          }
          if (forcedArcOrder != null && doneForcedArc) {
            break;
          }
          // 1. Compute the child state
          Optional<State> candidate =
              action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull);
          if (candidate.isPresent()) {
            double logProbability;
            ClauseClassifierLabel bestLabel;
            Counter<String> features =
                featurizer.apply(Triple.makeTriple(lastState, action, candidate.get()));
            if (forcedArcOrder != null && !doneForcedArc) {
              logProbability = 0.0;
              bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT;
              doneForcedArc = true;
            } else if (features.containsKey("__undocumented_junit_no_classifier")) {
              logProbability = Double.NEGATIVE_INFINITY;
              bestLabel = ClauseClassifierLabel.CLAUSE_INTERM;
            } else {
              Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features));
              if (scores.size() > 0) {
                Counters.logNormalizeInPlace(scores);
              }
              String rel = outgoingEdge.getRelation().toString();
              if ("nsubj".equals(rel) || "dobj".equals(rel)) {
                scores.remove(
                    ClauseClassifierLabel.NOT_A_CLAUSE); // Always at least yield on nsubj and dobj
              }
              logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY);
              bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT);
            }

            if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) {
              Pair<State, List<Counter<String>>> childState =
                  Pair.makePair(
                      candidate.get().withIsDone(bestLabel),
                      new ArrayList<Counter<String>>(featuresSoFar) {
                        {
                          add(features);
                        }
                      });
              // 2. Register the child state
              if (!seenWords.contains(childState.first.edge.getDependent())) {
                //            System.err.println("  pushing " + action.signature() + " with " +
                // argmax.first.edge);
                fringe.add(childState, logProbability);
              }
            }
          }
        }
      }

      seenWords.add(rootWord);
    }
    //    System.err.println("Search finished in " + ticks + " ticks and " + classifierEvals + "
    // classifier evaluations.");
  }
Пример #21
0
  @Override
  public Counter<E> score() {

    Counter<E> currentPatternWeights4Label = new ClassicCounter<>();

    Counter<E> pos_i = new ClassicCounter<>();
    Counter<E> neg_i = new ClassicCounter<>();
    Counter<E> unlab_i = new ClassicCounter<>();

    for (Entry<E, ClassicCounter<CandidatePhrase>> en : negPatternsandWords4Label.entrySet()) {
      neg_i.setCount(en.getKey(), en.getValue().size());
    }

    for (Entry<E, ClassicCounter<CandidatePhrase>> en :
        unLabeledPatternsandWords4Label.entrySet()) {
      unlab_i.setCount(en.getKey(), en.getValue().size());
    }

    for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label.entrySet()) {
      pos_i.setCount(en.getKey(), en.getValue().size());
    }

    Counter<E> all_i = Counters.add(pos_i, neg_i);
    all_i.addAll(unlab_i);
    //    for (Entry<Integer, ClassicCounter<String>> en : allPatternsandWords4Label
    //        .entrySet()) {
    //      all_i.setCount(en.getKey(), en.getValue().size());
    //    }

    Counter<E> posneg_i = Counters.add(pos_i, neg_i);
    Counter<E> logFi = new ClassicCounter<>(pos_i);
    Counters.logInPlace(logFi);

    if (patternScoring.equals(PatternScoring.RlogF)) {
      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, all_i), logFi);
    } else if (patternScoring.equals(PatternScoring.RlogFPosNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfposneg");

      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, posneg_i), logFi);

    } else if (patternScoring.equals(PatternScoring.RlogFUnlabNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfunlabeg");

      currentPatternWeights4Label =
          Counters.product(Counters.division(pos_i, Counters.add(neg_i, unlab_i)), logFi);
    } else if (patternScoring.equals(PatternScoring.RlogFNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfneg");

      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, neg_i), logFi);
    } else if (patternScoring.equals(PatternScoring.YanGarber02)) {

      Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i));
      double thetaPrecision = 0.8;
      Counters.retainAbove(acc, thetaPrecision);
      Counter<E> conf = Counters.product(Counters.division(pos_i, all_i), logFi);
      for (E p : acc.keySet()) {
        currentPatternWeights4Label.setCount(p, conf.getCount(p));
      }
    } else if (patternScoring.equals(PatternScoring.LinICML03)) {

      Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i));
      double thetaPrecision = 0.8;
      Counters.retainAbove(acc, thetaPrecision);
      Counter<E> conf =
          Counters.product(
              Counters.division(Counters.add(pos_i, Counters.scale(neg_i, -1)), all_i), logFi);
      for (E p : acc.keySet()) {
        currentPatternWeights4Label.setCount(p, conf.getCount(p));
      }
    } else {
      throw new RuntimeException("not implemented " + patternScoring + " . check spelling!");
    }
    return currentPatternWeights4Label;
  }
  public List<Pair<String, Double>> selectWeightedKeysWithSampling(
      ActiveLearningSelectionCriterion criterion, int numSamples, int seed) {
    List<Pair<String, Double>> result = new ArrayList<>();
    forceTrack("Sampling Keys");
    log("" + numSamples + " to collect");

    // Get uncertainty
    forceTrack("Computing Uncertainties");
    Counter<String> weightCounter = uncertainty(criterion);
    assert weightCounter.equals(uncertainty(criterion));
    endTrack("Computing Uncertainties");
    // Compute some statistics
    startTrack("Uncertainty Histogram");
    //    log(new Histogram(weightCounter, 50).toString());  // removed to make the release easier
    // (Histogram isn't in CoreNLP)
    endTrack("Uncertainty Histogram");
    double totalCount = weightCounter.totalCount();
    Random random = new Random(seed);

    // Flatten counter
    List<String> keys = new LinkedList<>();
    List<Double> weights = new LinkedList<>();
    List<String> zeroUncertaintyKeys = new LinkedList<>();
    for (Pair<String, Double> elem :
        Counters.toSortedListWithCounts(
            weightCounter,
            (o1, o2) -> {
              int value = o1.compareTo(o2);
              if (value == 0) {
                return o1.first.compareTo(o2.first);
              } else {
                return value;
              }
            })) {
      if (elem.second != 0.0
          || weightCounter.totalCount() == 0.0
          || weightCounter.size() <= numSamples) { // ignore 0 probability weights
        keys.add(elem.first);
        weights.add(elem.second);
      } else {
        zeroUncertaintyKeys.add(elem.first);
      }
    }

    // Error check
    if (Utils.assertionsEnabled()) {
      for (Double elem : weights) {
        if (!(elem >= 0 && !Double.isInfinite(elem) && !Double.isNaN(elem))) {
          throw new IllegalArgumentException("Invalid weight: " + elem);
        }
      }
    }

    // Sample
    SAMPLE_ITER:
    for (int i = 1; i <= numSamples; ++i) { // For each sample
      if (i % 1000 == 0) {
        // Debug log
        log("sampled " + (i / 1000) + "k keys");
        // Recompute total count to mitigate floating point errors
        totalCount = 0.0;
        for (double val : weights) {
          totalCount += val;
        }
      }
      if (weights.size() == 0) {
        continue;
      }
      assert totalCount >= 0.0;
      assert weights.size() == keys.size();
      double target = random.nextDouble() * totalCount;
      Iterator<String> keyIter = keys.iterator();
      Iterator<Double> weightIter = weights.iterator();
      double runningTotal = 0.0;
      while (keyIter.hasNext()) { // For each candidate
        String key = keyIter.next();
        double weight = weightIter.next();
        runningTotal += weight;
        if (target <= runningTotal) { // Select that sample
          result.add(Pair.makePair(key, weight));
          keyIter.remove();
          weightIter.remove();
          totalCount -= weight;
          continue SAMPLE_ITER; // continue sampling
        }
      }
      // We should get here only if the keys list is empty
      warn(
          "No more uncertain samples left to draw from! (target="
              + target
              + " totalCount="
              + totalCount
              + " size="
              + keys.size());
      assert keys.size() == 0;
      if (zeroUncertaintyKeys.size() > 0) {
        result.add(Pair.makePair(zeroUncertaintyKeys.remove(0), 0.0));
      } else {
        break;
      }
    }

    endTrack("Sampling Keys");
    return result;
  }
Пример #23
0
  /** Print some statistics about this lexicon. */
  public void printLexStats() {
    System.out.println("BaseLexicon statistics");
    System.out.println("unknownLevel is " + getUnknownWordModel().getUnknownLevel());
    // System.out.println("Rules size: " + rules.size());
    System.out.println("Sum of rulesWithWord: " + numRules());
    System.out.println("Tags size: " + tags.size());
    int wsize = words.size();
    System.out.println("Words size: " + wsize);
    // System.out.println("Unseen Sigs size: " + sigs.size() +
    // " [number of unknown equivalence classes]");
    System.out.println(
        "rulesWithWord length: "
            + rulesWithWord.length
            + " [should be sum of words + unknown sigs]");
    int[] lengths = new int[STATS_BINS];
    ArrayList<String>[] wArr = new ArrayList[STATS_BINS];
    for (int j = 0; j < STATS_BINS; j++) {
      wArr[j] = new ArrayList<String>();
    }
    for (int i = 0; i < rulesWithWord.length; i++) {
      int num = rulesWithWord[i].size();
      if (num > STATS_BINS - 1) {
        num = STATS_BINS - 1;
      }
      lengths[num]++;
      if (wsize <= 20 || num >= STATS_BINS / 2) {
        wArr[num].add(wordIndex.get(i));
      }
    }
    System.out.println("Stats on how many taggings for how many words");
    for (int j = 0; j < STATS_BINS; j++) {
      System.out.print(j + " taggings: " + lengths[j] + " words ");
      if (wsize <= 20 || j >= STATS_BINS / 2) {
        System.out.print(wArr[j]);
      }
      System.out.println();
    }
    NumberFormat nf = NumberFormat.getNumberInstance();
    nf.setMaximumFractionDigits(0);
    System.out.println("Unseen counter: " + Counters.toString(uwModel.unSeenCounter(), nf));

    if (wsize < 50 && tags.size() < 10) {
      nf.setMaximumFractionDigits(3);
      StringWriter sw = new StringWriter();
      PrintWriter pw = new PrintWriter(sw);
      pw.println("Tagging probabilities log P(word|tag)");
      for (int t = 0; t < tags.size(); t++) {
        pw.print('\t');
        pw.print(tagIndex.get(t));
      }
      pw.println();
      for (int w = 0; w < wsize; w++) {
        pw.print(wordIndex.get(w));
        pw.print('\t');
        for (int t = 0; t < tags.size(); t++) {
          IntTaggedWord iTW = new IntTaggedWord(w, t);
          pw.print(nf.format(score(iTW, 1, wordIndex.get(w))));
          if (t == tags.size() - 1) {
            pw.println();
          } else pw.print('\t');
        }
      }
      pw.close();
      System.out.println(sw.toString());
    }
  }
 public List<Pair<String, Double>> selectWeightedKeys(ActiveLearningSelectionCriterion criterion) {
   Counter<String> weights = uncertainty(criterion);
   return Counters.toSortedListWithCounts(weights);
 }