private NaiveBayesClassifier<L, F> trainClassifier( int[][] data, int[] labels, int numFeatures, int numClasses, Index<L> labelIndex, Index<F> featureIndex) { Set<L> labelSet = Generics.newHashSet(); NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses); Counter<L> priors = new ClassicCounter<L>(); double[] pr = nbWeights.priors; for (int i = 0; i < pr.length; i++) { priors.incrementCount(labelIndex.get(i), pr[i]); labelSet.add(labelIndex.get(i)); } Counter<Pair<Pair<L, F>, Number>> weightsCounter = new ClassicCounter<Pair<Pair<L, F>, Number>>(); double[][][] wts = nbWeights.weights; for (int c = 0; c < numClasses; c++) { L label = labelIndex.get(c); for (int f = 0; f < numFeatures; f++) { F feature = featureIndex.get(f); Pair<L, F> p = new Pair<L, F>(label, feature); for (int val = 0; val < wts[c][f].length; val++) { Pair<Pair<L, F>, Number> key = new Pair<Pair<L, F>, Number>(p, Integer.valueOf(val)); weightsCounter.incrementCount(key, wts[c][f][val]); } } } return new NaiveBayesClassifier<L, F>(weightsCounter, priors, labelSet); }
/** * Returns a list of all modes in the Collection. (If the Collection has multiple items with the * highest frequency, all of them will be returned.) */ public static <T> Set<T> modes(Collection<T> values) { Counter<T> counter = new ClassicCounter<T>(values); List<Double> sortedCounts = CollectionUtils.sorted(counter.values()); Double highestCount = sortedCounts.get(sortedCounts.size() - 1); Counters.retainAbove(counter, highestCount); return counter.keySet(); }
/** * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes * as well. */ @Override @Deprecated public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) { RVFDatum<L, F> d0 = examples.get(0); int numFeatures = d0.asFeatures().size(); int[][] data = new int[examples.size()][numFeatures]; int[] labels = new int[examples.size()]; labelIndex = new HashIndex<L>(); featureIndex = new HashIndex<F>(); for (int d = 0; d < examples.size(); d++) { RVFDatum<L, F> datum = examples.get(d); Counter<F> c = datum.asFeaturesCounter(); for (F feature : c.keySet()) { if (featureIndex.add(feature)) { int fNo = featureIndex.indexOf(feature); int value = (int) c.getCount(feature); data[d][fNo] = value; } } labelIndex.add(datum.label()); labels[d] = labelIndex.indexOf(datum.label()); } int numClasses = labelIndex.size(); return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex); }
/** * Returns a list of featured thresholded by minPrecision and sorted by their frequency of * occurrence. precision in this case, is defined as the frequency of majority label over total * frequency for that feature. * * @return list of high precision features. */ private List<F> getHighPrecisionFeatures( GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) { int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()]; for (int f = 0; f < dataset.numFeatures(); f++) Arrays.fill(feature2label[f], 0); int[][] data = dataset.data; int[] labels = dataset.labels; for (int d = 0; d < data.length; d++) { int label = labels[d]; // System.out.println("datum id:"+d+" label id: "+label); if (data[d] != null) { // System.out.println(" number of features:"+data[d].length); for (int n = 0; n < data[d].length; n++) { feature2label[data[d][n]][label]++; } } } Counter<F> feature2freq = new ClassicCounter<F>(); for (int f = 0; f < dataset.numFeatures(); f++) { int maxF = ArrayMath.max(feature2label[f]); int total = ArrayMath.sum(feature2label[f]); double precision = ((double) maxF) / total; F feature = dataset.featureIndex.get(f); if (precision >= minPrecision) { feature2freq.incrementCount(feature, total); } } if (feature2freq.size() > maxNumFeatures) { Counters.retainTop(feature2freq, maxNumFeatures); } // for(F feature : feature2freq.keySet()) // System.out.println(feature+" "+feature2freq.getCount(feature)); // System.exit(0); return Counters.toSortedList(feature2freq); }
public static Counter<String> getWeightCounterFromArray(String[] weightNames, double[] wtsArr) { Counter<String> wts = new ClassicCounter<String>(); for (int i = 0; i < weightNames.length; i++) { wts.setCount(weightNames[i], wtsArr[i]); } return wts; }
/** * Add a scaled (positive) random vector to a weights vector. * * @param wts * @param scale */ public static void randomizeWeightsInPlace(Counter<String> wts, double scale) { for (String feature : wts.keySet()) { double epsilon = Math.random() * scale; double newValue = wts.getCount(feature) + epsilon; wts.setCount(feature, newValue); } }
/** * 67% of time spent in LogConditionalObjectiveFunction.rvfcalculate() 29% of time spent in * dataset construction (11% in RVFDataset.addFeatures(), 7% rvf incrementCount(), 11% rest) * * <p>Single threaded, 4700 ms Multi threaded, 700 ms * * <p>With same data, seed 42, 245 ms With reordered accesses for cacheing, 195 ms Down to 80% of * the time, not huge but a win nonetheless * * <p>with 8 cpus, a 6.7x speedup -- almost, but not quite linear, pretty good */ public static void benchmarkRVFLogisticRegression() { RVFDataset<String, String> data = new RVFDataset<>(); for (int i = 0; i < 10000; i++) { Random r = new Random(42); Counter<String> features = new ClassicCounter<>(); boolean cl = r.nextBoolean(); for (int j = 0; j < 1000; j++) { double value; if (cl && i % 2 == 0) { value = (r.nextDouble() * 2.0) - 0.6; } else { value = (r.nextDouble() * 2.0) - 1.4; } features.incrementCount("f" + j, value); } data.add(new RVFDatum<>(features, "target:" + cl)); } LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>(); long msStart = System.currentTimeMillis(); factory.trainClassifier(data); long delay = System.currentTimeMillis() - msStart; System.out.println("Training took " + delay + " ms"); }
public Counter<K> uncompress(CompressedFeatureVector cvf) { Counter<K> c = new ClassicCounter<>(); for (int i = 0; i < cvf.keys.size(); i++) { c.incrementCount(inverse.get(cvf.keys.get(i)), cvf.values.get(i)); } return c; }
@SuppressWarnings({"unchecked"}) @Override protected void fillFeatures( Pair<Mention, ClusteredMention> input, Counter<Feature> inFeatures, Boolean output, Counter<Feature> outFeatures) { // --Input Features for (Object o : ACTIVE_FEATURES) { if (o instanceof Class) { // (case: singleton feature) Option<Double> count = new Option<Double>(1.0); Feature feat = feature((Class) o, input, count); if (count.get() > 0.0) { inFeatures.incrementCount(feat, count.get()); } } else if (o instanceof Pair) { // (case: pair of features) Pair<Class, Class> pair = (Pair<Class, Class>) o; Option<Double> countA = new Option<Double>(1.0); Option<Double> countB = new Option<Double>(1.0); Feature featA = feature(pair.getFirst(), input, countA); Feature featB = feature(pair.getSecond(), input, countB); if (countA.get() * countB.get() > 0.0) { inFeatures.incrementCount( new Feature.PairFeature(featA, featB), countA.get() * countB.get()); } } } // --Output Features if (output != null) { outFeatures.incrementCount(new Feature.CoreferentIndicator(output), 1.0); } }
public static <T> Counter<T> featureValueCollectionToCounter(Collection<FeatureValue<T>> c) { Counter<T> counter = new ClassicCounter<T>(); for (FeatureValue<T> fv : c) { counter.incrementCount(fv.name, fv.value); } return counter; }
private double[] getModelProbs(Datum<L, F> datum) { double[] condDist = new double[labeledDataset.numClasses()]; Counter<L> probCounter = classifier.probabilityOf(datum); for (L label : probCounter.keySet()) { int labelID = labeledDataset.labelIndex.indexOf(label); condDist[labelID] = probCounter.getCount(label); } return condDist; }
@Test public void tokenMatch() { String[] text = new String[] {"what", "tv", "program", "have", "hugh", "laurie", "create"}; String[] pattern = new String[] {"program", "create"}; Counter<String> match = TokenLevelMatchFeatures.extractTokenMatchFeatures( Arrays.asList(text), Arrays.asList(pattern), true); assertEquals(0.5, match.getCount("prefix"), 0.00001); assertEquals(0.5, match.getCount("suffix"), 0.00001); }
private Counter<String> uniformRandom() { Counter<String> uniformRandom = new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory()); for (Map<SentenceKey, EnsembleStatistics> impl : this.impl) { for (Map.Entry<SentenceKey, EnsembleStatistics> entry : impl.entrySet()) { uniformRandom.setCount(entry.getKey().sentenceHash, 1.0); } } return uniformRandom; }
public static List<String> generateDict(List<String> str, int cutOff) { Counter<String> freq = new IntCounter<>(); for (String aStr : str) freq.incrementCount(aStr); List<String> keys = Counters.toSortedList(freq, false); List<String> dict = new ArrayList<>(); for (String word : keys) { if (freq.getCount(word) >= cutOff) dict.add(word); } return dict; }
private Counter<String> highKLFromMean() { // Get confidences Counter<String> kl = new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory()); for (Map<SentenceKey, EnsembleStatistics> impl : this.impl) { for (Map.Entry<SentenceKey, EnsembleStatistics> entry : impl.entrySet()) { kl.setCount(entry.getKey().sentenceHash, entry.getValue().averageKLFromMean()); } } return kl; }
public void addCrossNumericProximity(Counter<String> features, final NumericMentionExpression e) { List<NumericTuple> args = e.expression.arguments(); if (args.size() > 1) { double multiplier = args.get(0).val / args.get(1).val; features.incrementCount("abs-numeric-distance-12", Math.abs(Math.log(multiplier))); } if (args.size() > 2) { double multiplier13 = args.get(0).val / args.get(2).val; double multiplier23 = args.get(1).val / args.get(2).val; features.incrementCount("abs-numeric-distance-13", Math.abs(Math.log(multiplier13))); features.incrementCount("abs-numeric-distance-23", Math.abs(Math.log(multiplier23))); } }
/** * TODO(gabor) JavaDoc * * @param tokens * @param span * @return */ public static String guessNER(List<CoreLabel> tokens, Span span) { Counter<String> nerGuesses = new ClassicCounter<>(); for (int i : span) { nerGuesses.incrementCount(tokens.get(i).ner()); } nerGuesses.remove("O"); nerGuesses.remove(null); if (nerGuesses.size() > 0 && Counters.max(nerGuesses) >= span.size() / 2) { return Counters.argmax(nerGuesses); } else { return "O"; } }
@Override public Counter<L> scoresOf(Datum<L, F> example) { Counter<L> scores = new ClassicCounter<>(); for (L label : labelIndex) { Map<L, String> posLabelMap = new ArrayMap<>(); posLabelMap.put(label, POS_LABEL); Datum<String, F> binDatum = GeneralDataset.mapDatum(example, posLabelMap, NEG_LABEL); Classifier<String, F> binaryClassifier = getBinaryClassifier(label); Counter<String> binScores = binaryClassifier.scoresOf(binDatum); double score = binScores.getCount(POS_LABEL); scores.setCount(label, score); } return scores; }
private Counter<String> lowAverageConfidence() { // Get confidences Counter<String> lowConfidence = new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory()); for (Map<SentenceKey, EnsembleStatistics> impl : this.impl) { for (Map.Entry<SentenceKey, EnsembleStatistics> entry : impl.entrySet()) { SentenceStatistics average = entry.getValue().mean(); for (double confidence : average.confidence) { lowConfidence.setCount(entry.getKey().sentenceHash, 1 - confidence); } } } return lowConfidence; }
public Object formResult() { Set brs = new HashSet(); Set urs = new HashSet(); // scan each rule / history pair int ruleCount = 0; for (Iterator pairI = rulePairs.keySet().iterator(); pairI.hasNext(); ) { if (ruleCount % 100 == 0) { System.err.println("Rules multiplied: " + ruleCount); } ruleCount++; Pair rulePair = (Pair) pairI.next(); Rule baseRule = (Rule) rulePair.first; String baseLabel = (String) ruleToLabel.get(baseRule); List history = (List) rulePair.second; double totalProb = 0; for (int depth = 1; depth <= HISTORY_DEPTH() && depth <= history.size(); depth++) { List subHistory = history.subList(0, depth); double c_label = labelPairs.getCount(new Pair(baseLabel, subHistory)); double c_rule = rulePairs.getCount(new Pair(baseRule, subHistory)); // System.out.println("Multiplying out "+baseRule+" with history "+subHistory); // System.out.println("Count of "+baseLabel+" with "+subHistory+" is "+c_label); // System.out.println("Count of "+baseRule+" with "+subHistory+" is "+c_rule ); double prob = (1.0 / HISTORY_DEPTH()) * (c_rule) / (c_label); totalProb += prob; for (int childDepth = 0; childDepth <= Math.min(HISTORY_DEPTH() - 1, depth); childDepth++) { Rule rule = specifyRule(baseRule, subHistory, childDepth); rule.score = (float) Math.log(totalProb); // System.out.println("Created "+rule+" with score "+rule.score); if (rule instanceof UnaryRule) { urs.add(rule); } else { brs.add(rule); } } } } System.out.println("Total states: " + stateNumberer.total()); BinaryGrammar bg = new BinaryGrammar(stateNumberer.total()); UnaryGrammar ug = new UnaryGrammar(stateNumberer.total()); for (Iterator brI = brs.iterator(); brI.hasNext(); ) { BinaryRule br = (BinaryRule) brI.next(); bg.addRule(br); } for (Iterator urI = urs.iterator(); urI.hasNext(); ) { UnaryRule ur = (UnaryRule) urI.next(); ug.addRule(ur); } return new Pair(ug, bg); }
protected void tallyInternalNode(Tree lt, List parents) { // form base rule String label = lt.label().value(); Rule baseR = ltToRule(lt); ruleToLabel.put(baseR, label); // act on each history depth for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size()); depth <= maxDepth; depth++) { List history = new ArrayList(parents.subList(0, depth)); // tally each history level / rewrite pair rulePairs.incrementCount(new Pair(baseR, history), 1); labelPairs.incrementCount(new Pair(label, history), 1); } }
protected void semanticCrossFeatures( Counter<String> features, String prefix, Sentence str1, Sentence str2) { double[] vec1 = embeddings.get(str1); double[] vec2 = embeddings.get(str2); for (int i = 0; i < vec1.length; i++) features.incrementCount(prefix + "wv-" + i, vec1[i] - vec2[i]); }
protected void semanticCrossFeatures( Counter<String> features, String prefix, NumericMention mention, Sentence str2) { double[] vec1 = embeddings.get(mention.sentence.get(), mention.token_begin, mention.token_end); double[] vec2 = embeddings.get(str2); for (int i = 0; i < vec1.length; i++) features.incrementCount(prefix + "wv-" + i, vec1[i] - vec2[i]); }
public static Set<String> featureWhiteList(FlatNBestList nbest, int minSegmentCount) { List<List<ScoredFeaturizedTranslation<IString, String>>> nbestlists = nbest.nbestLists(); Counter<String> featureSegmentCounts = new ClassicCounter<String>(); for (List<ScoredFeaturizedTranslation<IString, String>> nbestlist : nbestlists) { Set<String> segmentFeatureSet = new HashSet<String>(); for (ScoredFeaturizedTranslation<IString, String> trans : nbestlist) { for (FeatureValue<String> feature : trans.features) { segmentFeatureSet.add(feature.name); } } for (String featureName : segmentFeatureSet) { featureSegmentCounts.incrementCount(featureName); } } return Counters.keysAbove(featureSegmentCounts, minSegmentCount - 1); }
public void addCrossArgumentSemanticSimilarity( Counter<String> features, final NumericMentionExpression e) { List<NumericTuple> args = e.expression.arguments(); double sim = 0.; double minSim = 0.; double maxSim = 0.; if (args.size() > 1) { sim = semanticSimilarityVW(features, "12-", args.get(0).subjSentence, args.get(1).subjSentence); maxSim = Math.max(maxSim, sim); minSim = Math.min(minSim, sim); } if (args.size() > 2) { sim = semanticSimilarityVW(features, "13-", args.get(0).subjSentence, args.get(2).subjSentence); maxSim = Math.max(maxSim, sim); minSim = Math.min(minSim, sim); sim = semanticSimilarityVW(features, "23-", args.get(1).subjSentence, args.get(2).subjSentence); maxSim = Math.max(maxSim, sim); minSim = Math.min(minSim, sim); } features.incrementCount("max-cross-semantic-similarity", maxSim); features.incrementCount("min-cross-semantic-similarity", minSim); }
public static double[] getWeightArrayFromCounter(String[] weightNames, Counter<String> wts) { double[] wtsArr = new double[weightNames.length]; for (int i = 0; i < wtsArr.length; i++) { wtsArr[i] = wts.getCount(weightNames[i]); } return wtsArr; }
public static double scoreTranslation( Counter<String> wts, ScoredFeaturizedTranslation<IString, String> trans) { double s = 0; for (FeatureValue<String> fv : trans.features) { s += fv.value * wts.getCount(fv.name); } return s; }
/** * This should be called after the classifier has been trained and parseAndTrain has been called * to accumulate test set * * <p>This will return precision,recall and F1 measure */ public void runTestSet(List<List<CoreLabel>> testSet) { Counter<String> tp = new DefaultCounter<>(); Counter<String> fp = new DefaultCounter<>(); Counter<String> fn = new DefaultCounter<>(); Counter<String> actual = new DefaultCounter<>(); for (List<CoreLabel> labels : testSet) { List<CoreLabel> unannotatedLabels = new ArrayList<>(); // create a new label without answer annotation for (CoreLabel label : labels) { CoreLabel newLabel = new CoreLabel(); newLabel.set(annotationForWord, label.get(annotationForWord)); newLabel.set(PartOfSpeechAnnotation.class, label.get(PartOfSpeechAnnotation.class)); unannotatedLabels.add(newLabel); } List<CoreLabel> annotatedLabels = this.classifier.classify(unannotatedLabels); int ind = 0; for (CoreLabel expectedLabel : labels) { CoreLabel annotatedLabel = annotatedLabels.get(ind); String answer = annotatedLabel.get(AnswerAnnotation.class); String expectedAnswer = expectedLabel.get(AnswerAnnotation.class); actual.incrementCount(expectedAnswer); // match only non background symbols if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(expectedAnswer) && expectedAnswer.equals(answer)) { // true positives tp.incrementCount(answer); System.out.println("True Positive:" + annotatedLabel); } else if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(answer)) { // false positives fp.incrementCount(answer); System.out.println("False Positive:" + annotatedLabel); } else if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(expectedAnswer)) { // false negatives fn.incrementCount(expectedAnswer); System.out.println("False Negative:" + expectedLabel); } // else true negatives ind++; } } actual.remove(SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL); }
@Override public boolean isSimpleSplit(Counter<String> feats) { for (String key : feats.keySet()) { if (key.startsWith("simple&")) { return true; } } return false; }
@Deprecated protected void semanticSimilarity( Counter<String> features, String prefix, Sentence str1, Sentence str2) { Counter<String> v1 = new ClassicCounter<>( str1.lemmas().stream().map(String::toLowerCase).collect(Collectors.toList())); Counter<String> v2 = new ClassicCounter<>(str2.lemmas()); // Remove any stopwords. for (String word : stopwords) { v1.remove(word); v2.remove(word); } // take inner product. double sim = Counters.dotProduct(v1, v2) / (Counters.saferL2Norm(v1) * Counters.saferL2Norm(v2)); features.incrementCount( prefix + "semantic-similarity", 2 * sim - 1); // to make it between 0 and 1. }