/** * Returns a list of featured thresholded by minPrecision and sorted by their frequency of * occurrence. precision in this case, is defined as the frequency of majority label over total * frequency for that feature. * * @return list of high precision features. */ private List<F> getHighPrecisionFeatures( GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) { int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()]; for (int f = 0; f < dataset.numFeatures(); f++) Arrays.fill(feature2label[f], 0); int[][] data = dataset.data; int[] labels = dataset.labels; for (int d = 0; d < data.length; d++) { int label = labels[d]; // System.out.println("datum id:"+d+" label id: "+label); if (data[d] != null) { // System.out.println(" number of features:"+data[d].length); for (int n = 0; n < data[d].length; n++) { feature2label[data[d][n]][label]++; } } } Counter<F> feature2freq = new ClassicCounter<F>(); for (int f = 0; f < dataset.numFeatures(); f++) { int maxF = ArrayMath.max(feature2label[f]); int total = ArrayMath.sum(feature2label[f]); double precision = ((double) maxF) / total; F feature = dataset.featureIndex.get(f); if (precision >= minPrecision) { feature2freq.incrementCount(feature, total); } } if (feature2freq.size() > maxNumFeatures) { Counters.retainTop(feature2freq, maxNumFeatures); } // for(F feature : feature2freq.keySet()) // System.out.println(feature+" "+feature2freq.getCount(feature)); // System.exit(0); return Counters.toSortedList(feature2freq); }
/** * TODO(gabor) JavaDoc * * @param tokens * @param span * @return */ public static String guessNER(List<CoreLabel> tokens, Span span) { Counter<String> nerGuesses = new ClassicCounter<>(); for (int i : span) { nerGuesses.incrementCount(tokens.get(i).ner()); } nerGuesses.remove("O"); nerGuesses.remove(null); if (nerGuesses.size() > 0 && Counters.max(nerGuesses) >= span.size() / 2) { return Counters.argmax(nerGuesses); } else { return "O"; } }
/** Run some sanity checks on the training statistics, to make sure they look valid. */ public void validate() { for (Map<SentenceKey, EnsembleStatistics> map : impl) { for (EnsembleStatistics stats : map.values()) { for (SentenceStatistics component : stats.statisticsForClassifiers) { assert !Counters.isUniformDistribution(component.relationDistribution, 1e-5); Counters.normalize( component.relationDistribution); // TODO(gabor) this shouldn't be necessary assert (Math.abs(component.relationDistribution.totalCount() - 1.0)) < 1e-5; } assert (Math.abs(stats.mean().relationDistribution.totalCount() - 1.0)) < 1e-5; assert !Counters.isUniformDistribution(stats.mean().relationDistribution, 1e-5); } } }
public static void main(String[] args) throws Exception { CompareType compareType = CompareType.MAX_COOR_ABS; if (args.length == 0) { usage(); System.exit(-1); } if ("-cosine".equals(args[0])) { compareType = CompareType.COSINE; args = (String[]) Arrays.copyOfRange(args, 1, args.length); } else if ("-sse".equals(args[0])) { compareType = CompareType.SUM_SQUARE_ERROR; args = (String[]) Arrays.copyOfRange(args, 1, args.length); } if (args.length != 2) { usage(); System.exit(-1); } Set<String> allWeights = new HashSet<String>(); Counter<String> wts1 = IOTools.readWeights(args[0]); Counter<String> wts2 = IOTools.readWeights(args[1]); allWeights.addAll(wts1.keySet()); allWeights.addAll(wts2.keySet()); if (compareType == CompareType.MAX_COOR_ABS) { double maxDiff = Double.NEGATIVE_INFINITY; Counters.multiplyInPlace(wts1, 1.0 / Counters.L1Norm(wts1)); Counters.multiplyInPlace(wts2, 1.0 / Counters.L1Norm(wts2)); for (String wt : allWeights) { double absDiff = Math.abs(wts1.getCount(wt) - wts2.getCount(wt)); if (absDiff > maxDiff) maxDiff = absDiff; } System.out.println(maxDiff); } else if (compareType == CompareType.COSINE) { double dotProd = Counters.cosine(wts1, wts2); System.out.println(dotProd); } else if (compareType == CompareType.SUM_SQUARE_ERROR) { double sse = 0; for (String wt : allWeights) { double diff = wts1.getCount(wt) - wts2.getCount(wt); sse += diff * diff; } System.out.println(sse); } }
private static <T> void display(ClassicCounter<T> c, PrintWriter pw) { List<T> cats = new ArrayList<>(c.keySet()); Collections.sort(cats, Counters.toComparatorDescending(c)); for (T ob : cats) { pw.println(ob + " " + c.getCount(ob)); } }
/** * Returns a list of all modes in the Collection. (If the Collection has multiple items with the * highest frequency, all of them will be returned.) */ public static <T> Set<T> modes(Collection<T> values) { Counter<T> counter = new ClassicCounter<T>(values); List<Double> sortedCounts = CollectionUtils.sorted(counter.values()); Double highestCount = sortedCounts.get(sortedCounts.size() - 1); Counters.retainAbove(counter, highestCount); return counter.keySet(); }
@Deprecated protected void semanticSimilarity( Counter<String> features, String prefix, Sentence str1, Sentence str2) { Counter<String> v1 = new ClassicCounter<>( str1.lemmas().stream().map(String::toLowerCase).collect(Collectors.toList())); Counter<String> v2 = new ClassicCounter<>(str2.lemmas()); // Remove any stopwords. for (String word : stopwords) { v1.remove(word); v2.remove(word); } // take inner product. double sim = Counters.dotProduct(v1, v2) / (Counters.saferL2Norm(v1) * Counters.saferL2Norm(v2)); features.incrementCount( prefix + "semantic-similarity", 2 * sim - 1); // to make it between 0 and 1. }
@Override public L classOf(Datum<L, F> example) { Counter<L> scores = scoresOf(example); if (scores != null) { return Counters.argmax(scores); } else { return defaultLabel; } }
public static List<String> generateDict(List<String> str, int cutOff) { Counter<String> freq = new IntCounter<>(); for (String aStr : str) freq.incrementCount(aStr); List<String> keys = Counters.toSortedList(freq, false); List<String> dict = new ArrayList<>(); for (String word : keys) { if (freq.getCount(word) >= cutOff) dict.add(word); } return dict; }
private static <T> void display(ClassicCounter<T> c, int num, PrintWriter pw) { List<T> rules = new ArrayList<>(c.keySet()); Collections.sort(rules, Counters.toComparatorDescending(c)); int rSize = rules.size(); if (num > rSize) { num = rSize; } for (int i = 0; i < num; i++) { pw.println(rules.get(i) + " " + c.getCount(rules.get(i))); } }
public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) { ClassicCounter<L> scores = new ClassicCounter<>(); Counters.addInPlace(scores, priors); if (addZeroValued) { Counters.addInPlace(scores, priorZero); } for (L l : labels) { double score = 0.0; Counter<F> features = example.asFeaturesCounter(); for (F f : features.keySet()) { int value = (int) features.getCount(f); score += weight(l, f, Integer.valueOf(value)); if (addZeroValued) { score -= weight(l, f, zero); } } scores.incrementCount(l, score); } return scores; }
public SentenceStatistics mean() { double sumConfidence = 0; int countWithConfidence = 0; Counter<String> avePredictions = new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory()); // Sum for (SentenceStatistics stat : this.statisticsForClassifiers) { for (Double confidence : stat.confidence) { sumConfidence += confidence; countWithConfidence += 1; } assert Math.abs(stat.relationDistribution.totalCount() - 1.0) < 1e-5; for (Map.Entry<String, Double> entry : stat.relationDistribution.entrySet()) { assert entry.getValue() >= 0.0; assert entry.getValue() == stat.relationDistribution.getCount(entry.getKey()); avePredictions.incrementCount(entry.getKey(), entry.getValue()); assert stat.relationDistribution.getCount(entry.getKey()) == stat.relationDistribution.getCount(entry.getKey()); } } // Normalize double aveConfidence = sumConfidence / ((double) countWithConfidence); // Return if (this.statisticsForClassifiers.size() > 1) { Counters.divideInPlace(avePredictions, (double) this.statisticsForClassifiers.size()); } if (Math.abs(avePredictions.totalCount() - 1.0) > 1e-5) { throw new IllegalStateException("Mean relation distribution is not a distribution!"); } assert this.statisticsForClassifiers.size() > 1 || this.statisticsForClassifiers.size() == 0 || Counters.equals( avePredictions, statisticsForClassifiers.iterator().next().relationDistribution, 1e-5); return countWithConfidence > 0 ? new SentenceStatistics(avePredictions, aveConfidence) : new SentenceStatistics(avePredictions); }
String probsToString() { List<Pair<String, Double>> sorted = Counters.toDescendingMagnitudeSortedListWithCounts(typeProbabilities); StringBuffer os = new StringBuffer(); os.append("{"); boolean first = true; for (Pair<String, Double> lv : sorted) { if (!first) os.append("; "); os.append(lv.first + ", " + lv.second); first = false; } os.append("}"); return os.toString(); }
public static Set<String> featureWhiteList(FlatNBestList nbest, int minSegmentCount) { List<List<ScoredFeaturizedTranslation<IString, String>>> nbestlists = nbest.nbestLists(); Counter<String> featureSegmentCounts = new ClassicCounter<String>(); for (List<ScoredFeaturizedTranslation<IString, String>> nbestlist : nbestlists) { Set<String> segmentFeatureSet = new HashSet<String>(); for (ScoredFeaturizedTranslation<IString, String> trans : nbestlist) { for (FeatureValue<String> feature : trans.features) { segmentFeatureSet.add(feature.name); } } for (String featureName : segmentFeatureSet) { featureSegmentCounts.incrementCount(featureName); } } return Counters.keysAbove(featureSegmentCounts, minSegmentCount - 1); }
/** * Update an existing feature whitelist according to nbestlists. Then return the features that * appear more than minSegmentCount times. * * @param featureWhitelist * @param nbestlists * @param minSegmentCount * @return features that appear more than minSegmentCount times */ public static Set<String> updatefeatureWhiteList( Counter<String> featureWhitelist, List<List<RichTranslation<IString, String>>> nbestlists, int minSegmentCount) { for (List<RichTranslation<IString, String>> nbestlist : nbestlists) { Set<String> segmentFeatureSet = new HashSet<String>(1000); for (RichTranslation<IString, String> trans : nbestlist) { for (FeatureValue<String> feature : trans.features) { if (!segmentFeatureSet.contains(feature.name)) { segmentFeatureSet.add(feature.name); featureWhitelist.incrementCount(feature.name); } } } } return Counters.keysAbove(featureWhitelist, minSegmentCount - 1); }
/** * Returns true if it's worth saving/printing this object This happens in two cases: 1. The type * of the object is not nilLabel 2. The type of the object is nilLabel but the second ranked label * is within the given beam (0 -- 100) of the first choice * * @param beam * @param nilLabel */ public boolean printableObject(double beam, String nilLabel) { if (typeProbabilities == null) { return false; } List<Pair<String, Double>> sorted = Counters.toDescendingMagnitudeSortedListWithCounts(typeProbabilities); // first choice not nil if (sorted.size() > 0 && !sorted.get(0).first.equals(nilLabel)) { return true; } // first choice is nil, but second is within beam if (sorted.size() > 1 && sorted.get(0).first.equals(nilLabel) && beam > 0 && 100.0 * (sorted.get(0).second - sorted.get(1).second) < beam) { return true; } return false; }
public double averageKLFromMean() { Counter<String> mean = this.mean().relationDistribution; double sumKL = 0; for (SentenceStatistics stats : this.statisticsForClassifiers) { double kl = Counters.klDivergence(stats.relationDistribution, mean); if (kl < 0.0 && kl > -1e-12) { kl = 0.0; } // floating point error. assert kl >= 0.0; sumKL += kl; } double val = sumKL / ((double) this.statisticsForClassifiers.size()); if (Double.isInfinite(val) || Double.isNaN(val) || val < 0.0) { throw new AssertionError("Invalid average KL value: " + val); } assert val >= 0.0; // KL lower bound assert this.statisticsForClassifiers.size() > 1 || val < 1e-5; if (val < 1e-10) { val = 0.0; } // floating point error return val; }
public L classOf(RVFDatum<L, F> example) { Counter<L> scores = scoresOf(example); return Counters.argmax(scores); }
public List<String> selectKeys(ActiveLearningSelectionCriterion criterion) { Counter<String> weights = uncertainty(criterion); return Counters.toSortedList(weights); }
/** * The core implementation of the search. * * @param root The root word to search from. Traditionally, this is the root of the sentence. * @param candidateFragments The callback for the resulting sentence fragments. This is a * predicate of a triple of values. The return value of the predicate determines whether we * should continue searching. The triple is a triple of * <ol> * <li>The log probability of the sentence fragment, according to the featurizer and the * weights * <li>The features along the path to this fragment. The last element of this is the * features from the most recent step. * <li>The sentence fragment. Because it is relatively expensive to compute the resulting * tree, this is returned as a lazy {@link Supplier}. * </ol> * * @param classifier The classifier for whether an arc should be on the path to a clause split, a * clause split itself, or neither. * @param featurizer The featurizer to use. Make sure this matches the weights! * @param actionSpace The action space we are allowed to take. Each action defines a means of * splitting a clause on a dependency boundary. */ protected void search( // The root to search from IndexedWord root, // The output specs final Predicate<Triple<Double, List<Counter<String>>, Supplier<SentenceFragment>>> candidateFragments, // The learning specs final Classifier<ClauseSplitter.ClauseClassifierLabel, String> classifier, Map<String, ? extends List<String>> hardCodedSplits, final Function<Triple<State, Action, State>, Counter<String>> featurizer, final Collection<Action> actionSpace, final int maxTicks) { // (the fringe) PriorityQueue<Pair<State, List<Counter<String>>>> fringe = new FixedPrioritiesPriorityQueue<>(); // (avoid duplicate work) Set<IndexedWord> seenWords = new HashSet<>(); State firstState = new State(null, null, -9000, null, x -> {}, true); // First state is implicitly "done" fringe.add(Pair.makePair(firstState, new ArrayList<>(0)), -0.0); int ticks = 0; while (!fringe.isEmpty()) { if (++ticks > maxTicks) { // System.err.println("WARNING! Timed out on search with " + ticks + " ticks"); return; } // Useful variables double logProbSoFar = fringe.getPriority(); assert logProbSoFar <= 0.0; Pair<State, List<Counter<String>>> lastStatePair = fringe.removeFirst(); State lastState = lastStatePair.first; List<Counter<String>> featuresSoFar = lastStatePair.second; IndexedWord rootWord = lastState.edge == null ? root : lastState.edge.getDependent(); // Register thunk if (lastState.isDone) { if (!candidateFragments.test( Triple.makeTriple( logProbSoFar, featuresSoFar, () -> { SemanticGraph copy = new SemanticGraph(tree); lastState .thunk .andThen( x -> { // Add the extra edges back in, if they don't break the tree-ness of the // extraction for (IndexedWord newTreeRoot : x.getRoots()) { if (newTreeRoot != null) { // what a strange thing to have happen... for (SemanticGraphEdge extraEdge : extraEdgesByGovernor.get(newTreeRoot)) { assert Util.isTree(x); //noinspection unchecked addSubtree( x, newTreeRoot, extraEdge.getRelation().toString(), tree, extraEdge.getDependent(), tree.getIncomingEdgesSorted(newTreeRoot)); assert Util.isTree(x); } } } }) .accept(copy); return new SentenceFragment(copy, assumedTruth, false); }))) { break; } } // Find relevant auxilliary terms SemanticGraphEdge subjOrNull = null; SemanticGraphEdge objOrNull = null; for (SemanticGraphEdge auxEdge : tree.outgoingEdgeIterable(rootWord)) { String relString = auxEdge.getRelation().toString(); if (relString.contains("obj")) { objOrNull = auxEdge; } else if (relString.contains("subj")) { subjOrNull = auxEdge; } } // Iterate over children // For each outgoing edge... for (SemanticGraphEdge outgoingEdge : tree.outgoingEdgeIterable(rootWord)) { // Prohibit indirect speech verbs from splitting off clauses // (e.g., 'said', 'think') // This fires if the governor is an indirect speech verb, and the outgoing edge is a ccomp if (outgoingEdge.getRelation().toString().equals("ccomp") && ((outgoingEdge.getGovernor().lemma() != null && INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().lemma())) || INDIRECT_SPEECH_LEMMAS.contains(outgoingEdge.getGovernor().word()))) { continue; } // Get some variables String outgoingEdgeRelation = outgoingEdge.getRelation().toString(); List<String> forcedArcOrder = hardCodedSplits.get(outgoingEdgeRelation); if (forcedArcOrder == null && outgoingEdgeRelation.contains(":")) { forcedArcOrder = hardCodedSplits.get( outgoingEdgeRelation.substring(0, outgoingEdgeRelation.indexOf(":")) + ":*"); } boolean doneForcedArc = false; // For each action... for (Action action : (forcedArcOrder == null ? actionSpace : orderActions(actionSpace, forcedArcOrder))) { // Check the prerequisite if (!action.prerequisitesMet(tree, outgoingEdge)) { continue; } if (forcedArcOrder != null && doneForcedArc) { break; } // 1. Compute the child state Optional<State> candidate = action.applyTo(tree, lastState, outgoingEdge, subjOrNull, objOrNull); if (candidate.isPresent()) { double logProbability; ClauseClassifierLabel bestLabel; Counter<String> features = featurizer.apply(Triple.makeTriple(lastState, action, candidate.get())); if (forcedArcOrder != null && !doneForcedArc) { logProbability = 0.0; bestLabel = ClauseClassifierLabel.CLAUSE_SPLIT; doneForcedArc = true; } else if (features.containsKey("__undocumented_junit_no_classifier")) { logProbability = Double.NEGATIVE_INFINITY; bestLabel = ClauseClassifierLabel.CLAUSE_INTERM; } else { Counter<ClauseClassifierLabel> scores = classifier.scoresOf(new RVFDatum<>(features)); if (scores.size() > 0) { Counters.logNormalizeInPlace(scores); } String rel = outgoingEdge.getRelation().toString(); if ("nsubj".equals(rel) || "dobj".equals(rel)) { scores.remove( ClauseClassifierLabel.NOT_A_CLAUSE); // Always at least yield on nsubj and dobj } logProbability = Counters.max(scores, Double.NEGATIVE_INFINITY); bestLabel = Counters.argmax(scores, (x, y) -> 0, ClauseClassifierLabel.CLAUSE_SPLIT); } if (bestLabel != ClauseClassifierLabel.NOT_A_CLAUSE) { Pair<State, List<Counter<String>>> childState = Pair.makePair( candidate.get().withIsDone(bestLabel), new ArrayList<Counter<String>>(featuresSoFar) { { add(features); } }); // 2. Register the child state if (!seenWords.contains(childState.first.edge.getDependent())) { // System.err.println(" pushing " + action.signature() + " with " + // argmax.first.edge); fringe.add(childState, logProbability); } } } } } seenWords.add(rootWord); } // System.err.println("Search finished in " + ticks + " ticks and " + classifierEvals + " // classifier evaluations."); }
@Override public Counter<E> score() { Counter<E> currentPatternWeights4Label = new ClassicCounter<>(); Counter<E> pos_i = new ClassicCounter<>(); Counter<E> neg_i = new ClassicCounter<>(); Counter<E> unlab_i = new ClassicCounter<>(); for (Entry<E, ClassicCounter<CandidatePhrase>> en : negPatternsandWords4Label.entrySet()) { neg_i.setCount(en.getKey(), en.getValue().size()); } for (Entry<E, ClassicCounter<CandidatePhrase>> en : unLabeledPatternsandWords4Label.entrySet()) { unlab_i.setCount(en.getKey(), en.getValue().size()); } for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label.entrySet()) { pos_i.setCount(en.getKey(), en.getValue().size()); } Counter<E> all_i = Counters.add(pos_i, neg_i); all_i.addAll(unlab_i); // for (Entry<Integer, ClassicCounter<String>> en : allPatternsandWords4Label // .entrySet()) { // all_i.setCount(en.getKey(), en.getValue().size()); // } Counter<E> posneg_i = Counters.add(pos_i, neg_i); Counter<E> logFi = new ClassicCounter<>(pos_i); Counters.logInPlace(logFi); if (patternScoring.equals(PatternScoring.RlogF)) { currentPatternWeights4Label = Counters.product(Counters.division(pos_i, all_i), logFi); } else if (patternScoring.equals(PatternScoring.RlogFPosNeg)) { Redwood.log("extremePatDebug", "computing rlogfposneg"); currentPatternWeights4Label = Counters.product(Counters.division(pos_i, posneg_i), logFi); } else if (patternScoring.equals(PatternScoring.RlogFUnlabNeg)) { Redwood.log("extremePatDebug", "computing rlogfunlabeg"); currentPatternWeights4Label = Counters.product(Counters.division(pos_i, Counters.add(neg_i, unlab_i)), logFi); } else if (patternScoring.equals(PatternScoring.RlogFNeg)) { Redwood.log("extremePatDebug", "computing rlogfneg"); currentPatternWeights4Label = Counters.product(Counters.division(pos_i, neg_i), logFi); } else if (patternScoring.equals(PatternScoring.YanGarber02)) { Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i)); double thetaPrecision = 0.8; Counters.retainAbove(acc, thetaPrecision); Counter<E> conf = Counters.product(Counters.division(pos_i, all_i), logFi); for (E p : acc.keySet()) { currentPatternWeights4Label.setCount(p, conf.getCount(p)); } } else if (patternScoring.equals(PatternScoring.LinICML03)) { Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i)); double thetaPrecision = 0.8; Counters.retainAbove(acc, thetaPrecision); Counter<E> conf = Counters.product( Counters.division(Counters.add(pos_i, Counters.scale(neg_i, -1)), all_i), logFi); for (E p : acc.keySet()) { currentPatternWeights4Label.setCount(p, conf.getCount(p)); } } else { throw new RuntimeException("not implemented " + patternScoring + " . check spelling!"); } return currentPatternWeights4Label; }
public List<Pair<String, Double>> selectWeightedKeysWithSampling( ActiveLearningSelectionCriterion criterion, int numSamples, int seed) { List<Pair<String, Double>> result = new ArrayList<>(); forceTrack("Sampling Keys"); log("" + numSamples + " to collect"); // Get uncertainty forceTrack("Computing Uncertainties"); Counter<String> weightCounter = uncertainty(criterion); assert weightCounter.equals(uncertainty(criterion)); endTrack("Computing Uncertainties"); // Compute some statistics startTrack("Uncertainty Histogram"); // log(new Histogram(weightCounter, 50).toString()); // removed to make the release easier // (Histogram isn't in CoreNLP) endTrack("Uncertainty Histogram"); double totalCount = weightCounter.totalCount(); Random random = new Random(seed); // Flatten counter List<String> keys = new LinkedList<>(); List<Double> weights = new LinkedList<>(); List<String> zeroUncertaintyKeys = new LinkedList<>(); for (Pair<String, Double> elem : Counters.toSortedListWithCounts( weightCounter, (o1, o2) -> { int value = o1.compareTo(o2); if (value == 0) { return o1.first.compareTo(o2.first); } else { return value; } })) { if (elem.second != 0.0 || weightCounter.totalCount() == 0.0 || weightCounter.size() <= numSamples) { // ignore 0 probability weights keys.add(elem.first); weights.add(elem.second); } else { zeroUncertaintyKeys.add(elem.first); } } // Error check if (Utils.assertionsEnabled()) { for (Double elem : weights) { if (!(elem >= 0 && !Double.isInfinite(elem) && !Double.isNaN(elem))) { throw new IllegalArgumentException("Invalid weight: " + elem); } } } // Sample SAMPLE_ITER: for (int i = 1; i <= numSamples; ++i) { // For each sample if (i % 1000 == 0) { // Debug log log("sampled " + (i / 1000) + "k keys"); // Recompute total count to mitigate floating point errors totalCount = 0.0; for (double val : weights) { totalCount += val; } } if (weights.size() == 0) { continue; } assert totalCount >= 0.0; assert weights.size() == keys.size(); double target = random.nextDouble() * totalCount; Iterator<String> keyIter = keys.iterator(); Iterator<Double> weightIter = weights.iterator(); double runningTotal = 0.0; while (keyIter.hasNext()) { // For each candidate String key = keyIter.next(); double weight = weightIter.next(); runningTotal += weight; if (target <= runningTotal) { // Select that sample result.add(Pair.makePair(key, weight)); keyIter.remove(); weightIter.remove(); totalCount -= weight; continue SAMPLE_ITER; // continue sampling } } // We should get here only if the keys list is empty warn( "No more uncertain samples left to draw from! (target=" + target + " totalCount=" + totalCount + " size=" + keys.size()); assert keys.size() == 0; if (zeroUncertaintyKeys.size() > 0) { result.add(Pair.makePair(zeroUncertaintyKeys.remove(0), 0.0)); } else { break; } } endTrack("Sampling Keys"); return result; }
/** Print some statistics about this lexicon. */ public void printLexStats() { System.out.println("BaseLexicon statistics"); System.out.println("unknownLevel is " + getUnknownWordModel().getUnknownLevel()); // System.out.println("Rules size: " + rules.size()); System.out.println("Sum of rulesWithWord: " + numRules()); System.out.println("Tags size: " + tags.size()); int wsize = words.size(); System.out.println("Words size: " + wsize); // System.out.println("Unseen Sigs size: " + sigs.size() + // " [number of unknown equivalence classes]"); System.out.println( "rulesWithWord length: " + rulesWithWord.length + " [should be sum of words + unknown sigs]"); int[] lengths = new int[STATS_BINS]; ArrayList<String>[] wArr = new ArrayList[STATS_BINS]; for (int j = 0; j < STATS_BINS; j++) { wArr[j] = new ArrayList<String>(); } for (int i = 0; i < rulesWithWord.length; i++) { int num = rulesWithWord[i].size(); if (num > STATS_BINS - 1) { num = STATS_BINS - 1; } lengths[num]++; if (wsize <= 20 || num >= STATS_BINS / 2) { wArr[num].add(wordIndex.get(i)); } } System.out.println("Stats on how many taggings for how many words"); for (int j = 0; j < STATS_BINS; j++) { System.out.print(j + " taggings: " + lengths[j] + " words "); if (wsize <= 20 || j >= STATS_BINS / 2) { System.out.print(wArr[j]); } System.out.println(); } NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(0); System.out.println("Unseen counter: " + Counters.toString(uwModel.unSeenCounter(), nf)); if (wsize < 50 && tags.size() < 10) { nf.setMaximumFractionDigits(3); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); pw.println("Tagging probabilities log P(word|tag)"); for (int t = 0; t < tags.size(); t++) { pw.print('\t'); pw.print(tagIndex.get(t)); } pw.println(); for (int w = 0; w < wsize; w++) { pw.print(wordIndex.get(w)); pw.print('\t'); for (int t = 0; t < tags.size(); t++) { IntTaggedWord iTW = new IntTaggedWord(w, t); pw.print(nf.format(score(iTW, 1, wordIndex.get(w)))); if (t == tags.size() - 1) { pw.println(); } else pw.print('\t'); } } pw.close(); System.out.println(sw.toString()); } }
public List<Pair<String, Double>> selectWeightedKeys(ActiveLearningSelectionCriterion criterion) { Counter<String> weights = uncertainty(criterion); return Counters.toSortedListWithCounts(weights); }