private NaiveBayesClassifier<L, F> trainClassifier(
     int[][] data,
     int[] labels,
     int numFeatures,
     int numClasses,
     Index<L> labelIndex,
     Index<F> featureIndex) {
   Set<L> labelSet = Generics.newHashSet();
   NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses);
   Counter<L> priors = new ClassicCounter<L>();
   double[] pr = nbWeights.priors;
   for (int i = 0; i < pr.length; i++) {
     priors.incrementCount(labelIndex.get(i), pr[i]);
     labelSet.add(labelIndex.get(i));
   }
   Counter<Pair<Pair<L, F>, Number>> weightsCounter =
       new ClassicCounter<Pair<Pair<L, F>, Number>>();
   double[][][] wts = nbWeights.weights;
   for (int c = 0; c < numClasses; c++) {
     L label = labelIndex.get(c);
     for (int f = 0; f < numFeatures; f++) {
       F feature = featureIndex.get(f);
       Pair<L, F> p = new Pair<L, F>(label, feature);
       for (int val = 0; val < wts[c][f].length; val++) {
         Pair<Pair<L, F>, Number> key = new Pair<Pair<L, F>, Number>(p, Integer.valueOf(val));
         weightsCounter.incrementCount(key, wts[c][f][val]);
       }
     }
   }
   return new NaiveBayesClassifier<L, F>(weightsCounter, priors, labelSet);
 }
Пример #2
0
 /**
  * Returns a list of all modes in the Collection. (If the Collection has multiple items with the
  * highest frequency, all of them will be returned.)
  */
 public static <T> Set<T> modes(Collection<T> values) {
   Counter<T> counter = new ClassicCounter<T>(values);
   List<Double> sortedCounts = CollectionUtils.sorted(counter.values());
   Double highestCount = sortedCounts.get(sortedCounts.size() - 1);
   Counters.retainAbove(counter, highestCount);
   return counter.keySet();
 }
 /**
  * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes
  * as well.
  */
 @Override
 @Deprecated
 public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
   RVFDatum<L, F> d0 = examples.get(0);
   int numFeatures = d0.asFeatures().size();
   int[][] data = new int[examples.size()][numFeatures];
   int[] labels = new int[examples.size()];
   labelIndex = new HashIndex<L>();
   featureIndex = new HashIndex<F>();
   for (int d = 0; d < examples.size(); d++) {
     RVFDatum<L, F> datum = examples.get(d);
     Counter<F> c = datum.asFeaturesCounter();
     for (F feature : c.keySet()) {
       if (featureIndex.add(feature)) {
         int fNo = featureIndex.indexOf(feature);
         int value = (int) c.getCount(feature);
         data[d][fNo] = value;
       }
     }
     labelIndex.add(datum.label());
     labels[d] = labelIndex.indexOf(datum.label());
   }
   int numClasses = labelIndex.size();
   return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
 }
  /**
   * Returns a list of featured thresholded by minPrecision and sorted by their frequency of
   * occurrence. precision in this case, is defined as the frequency of majority label over total
   * frequency for that feature.
   *
   * @return list of high precision features.
   */
  private List<F> getHighPrecisionFeatures(
      GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) {
    int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()];
    for (int f = 0; f < dataset.numFeatures(); f++) Arrays.fill(feature2label[f], 0);

    int[][] data = dataset.data;
    int[] labels = dataset.labels;
    for (int d = 0; d < data.length; d++) {
      int label = labels[d];
      // System.out.println("datum id:"+d+" label id: "+label);
      if (data[d] != null) {
        // System.out.println(" number of features:"+data[d].length);
        for (int n = 0; n < data[d].length; n++) {
          feature2label[data[d][n]][label]++;
        }
      }
    }
    Counter<F> feature2freq = new ClassicCounter<F>();
    for (int f = 0; f < dataset.numFeatures(); f++) {
      int maxF = ArrayMath.max(feature2label[f]);
      int total = ArrayMath.sum(feature2label[f]);
      double precision = ((double) maxF) / total;
      F feature = dataset.featureIndex.get(f);
      if (precision >= minPrecision) {
        feature2freq.incrementCount(feature, total);
      }
    }
    if (feature2freq.size() > maxNumFeatures) {
      Counters.retainTop(feature2freq, maxNumFeatures);
    }
    // for(F feature : feature2freq.keySet())
    // System.out.println(feature+" "+feature2freq.getCount(feature));
    // System.exit(0);
    return Counters.toSortedList(feature2freq);
  }
Пример #5
0
 public static Counter<String> getWeightCounterFromArray(String[] weightNames, double[] wtsArr) {
   Counter<String> wts = new ClassicCounter<String>();
   for (int i = 0; i < weightNames.length; i++) {
     wts.setCount(weightNames[i], wtsArr[i]);
   }
   return wts;
 }
Пример #6
0
 /**
  * Add a scaled (positive) random vector to a weights vector.
  *
  * @param wts
  * @param scale
  */
 public static void randomizeWeightsInPlace(Counter<String> wts, double scale) {
   for (String feature : wts.keySet()) {
     double epsilon = Math.random() * scale;
     double newValue = wts.getCount(feature) + epsilon;
     wts.setCount(feature, newValue);
   }
 }
Пример #7
0
  /**
   * 67% of time spent in LogConditionalObjectiveFunction.rvfcalculate() 29% of time spent in
   * dataset construction (11% in RVFDataset.addFeatures(), 7% rvf incrementCount(), 11% rest)
   *
   * <p>Single threaded, 4700 ms Multi threaded, 700 ms
   *
   * <p>With same data, seed 42, 245 ms With reordered accesses for cacheing, 195 ms Down to 80% of
   * the time, not huge but a win nonetheless
   *
   * <p>with 8 cpus, a 6.7x speedup -- almost, but not quite linear, pretty good
   */
  public static void benchmarkRVFLogisticRegression() {
    RVFDataset<String, String> data = new RVFDataset<>();
    for (int i = 0; i < 10000; i++) {
      Random r = new Random(42);
      Counter<String> features = new ClassicCounter<>();

      boolean cl = r.nextBoolean();

      for (int j = 0; j < 1000; j++) {
        double value;
        if (cl && i % 2 == 0) {
          value = (r.nextDouble() * 2.0) - 0.6;
        } else {
          value = (r.nextDouble() * 2.0) - 1.4;
        }
        features.incrementCount("f" + j, value);
      }

      data.add(new RVFDatum<>(features, "target:" + cl));
    }

    LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>();

    long msStart = System.currentTimeMillis();
    factory.trainClassifier(data);
    long delay = System.currentTimeMillis() - msStart;
    System.out.println("Training took " + delay + " ms");
  }
Пример #8
0
 public Counter<K> uncompress(CompressedFeatureVector cvf) {
   Counter<K> c = new ClassicCounter<>();
   for (int i = 0; i < cvf.keys.size(); i++) {
     c.incrementCount(inverse.get(cvf.keys.get(i)), cvf.values.get(i));
   }
   return c;
 }
        @SuppressWarnings({"unchecked"})
        @Override
        protected void fillFeatures(
            Pair<Mention, ClusteredMention> input,
            Counter<Feature> inFeatures,
            Boolean output,
            Counter<Feature> outFeatures) {
          // --Input Features
          for (Object o : ACTIVE_FEATURES) {
            if (o instanceof Class) {
              // (case: singleton feature)
              Option<Double> count = new Option<Double>(1.0);
              Feature feat = feature((Class) o, input, count);
              if (count.get() > 0.0) {
                inFeatures.incrementCount(feat, count.get());
              }
            } else if (o instanceof Pair) {
              // (case: pair of features)
              Pair<Class, Class> pair = (Pair<Class, Class>) o;
              Option<Double> countA = new Option<Double>(1.0);
              Option<Double> countB = new Option<Double>(1.0);
              Feature featA = feature(pair.getFirst(), input, countA);
              Feature featB = feature(pair.getSecond(), input, countB);
              if (countA.get() * countB.get() > 0.0) {
                inFeatures.incrementCount(
                    new Feature.PairFeature(featA, featB), countA.get() * countB.get());
              }
            }
          }

          // --Output Features
          if (output != null) {
            outFeatures.incrementCount(new Feature.CoreferentIndicator(output), 1.0);
          }
        }
Пример #10
0
  public static <T> Counter<T> featureValueCollectionToCounter(Collection<FeatureValue<T>> c) {
    Counter<T> counter = new ClassicCounter<T>();

    for (FeatureValue<T> fv : c) {
      counter.incrementCount(fv.name, fv.value);
    }

    return counter;
  }
 private double[] getModelProbs(Datum<L, F> datum) {
   double[] condDist = new double[labeledDataset.numClasses()];
   Counter<L> probCounter = classifier.probabilityOf(datum);
   for (L label : probCounter.keySet()) {
     int labelID = labeledDataset.labelIndex.indexOf(label);
     condDist[labelID] = probCounter.getCount(label);
   }
   return condDist;
 }
Пример #12
0
 @Test
 public void tokenMatch() {
   String[] text = new String[] {"what", "tv", "program", "have", "hugh", "laurie", "create"};
   String[] pattern = new String[] {"program", "create"};
   Counter<String> match =
       TokenLevelMatchFeatures.extractTokenMatchFeatures(
           Arrays.asList(text), Arrays.asList(pattern), true);
   assertEquals(0.5, match.getCount("prefix"), 0.00001);
   assertEquals(0.5, match.getCount("suffix"), 0.00001);
 }
  private Counter<String> uniformRandom() {
    Counter<String> uniformRandom =
        new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory());
    for (Map<SentenceKey, EnsembleStatistics> impl : this.impl) {
      for (Map.Entry<SentenceKey, EnsembleStatistics> entry : impl.entrySet()) {
        uniformRandom.setCount(entry.getKey().sentenceHash, 1.0);
      }
    }

    return uniformRandom;
  }
Пример #14
0
  public static List<String> generateDict(List<String> str, int cutOff) {
    Counter<String> freq = new IntCounter<>();
    for (String aStr : str) freq.incrementCount(aStr);

    List<String> keys = Counters.toSortedList(freq, false);
    List<String> dict = new ArrayList<>();
    for (String word : keys) {
      if (freq.getCount(word) >= cutOff) dict.add(word);
    }
    return dict;
  }
  private Counter<String> highKLFromMean() {
    // Get confidences
    Counter<String> kl =
        new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory());
    for (Map<SentenceKey, EnsembleStatistics> impl : this.impl) {
      for (Map.Entry<SentenceKey, EnsembleStatistics> entry : impl.entrySet()) {
        kl.setCount(entry.getKey().sentenceHash, entry.getValue().averageKLFromMean());
      }
    }

    return kl;
  }
 public void addCrossNumericProximity(Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   if (args.size() > 1) {
     double multiplier = args.get(0).val / args.get(1).val;
     features.incrementCount("abs-numeric-distance-12", Math.abs(Math.log(multiplier)));
   }
   if (args.size() > 2) {
     double multiplier13 = args.get(0).val / args.get(2).val;
     double multiplier23 = args.get(1).val / args.get(2).val;
     features.incrementCount("abs-numeric-distance-13", Math.abs(Math.log(multiplier13)));
     features.incrementCount("abs-numeric-distance-23", Math.abs(Math.log(multiplier23)));
   }
 }
Пример #17
0
 /**
  * TODO(gabor) JavaDoc
  *
  * @param tokens
  * @param span
  * @return
  */
 public static String guessNER(List<CoreLabel> tokens, Span span) {
   Counter<String> nerGuesses = new ClassicCounter<>();
   for (int i : span) {
     nerGuesses.incrementCount(tokens.get(i).ner());
   }
   nerGuesses.remove("O");
   nerGuesses.remove(null);
   if (nerGuesses.size() > 0 && Counters.max(nerGuesses) >= span.size() / 2) {
     return Counters.argmax(nerGuesses);
   } else {
     return "O";
   }
 }
Пример #18
0
 @Override
 public Counter<L> scoresOf(Datum<L, F> example) {
   Counter<L> scores = new ClassicCounter<>();
   for (L label : labelIndex) {
     Map<L, String> posLabelMap = new ArrayMap<>();
     posLabelMap.put(label, POS_LABEL);
     Datum<String, F> binDatum = GeneralDataset.mapDatum(example, posLabelMap, NEG_LABEL);
     Classifier<String, F> binaryClassifier = getBinaryClassifier(label);
     Counter<String> binScores = binaryClassifier.scoresOf(binDatum);
     double score = binScores.getCount(POS_LABEL);
     scores.setCount(label, score);
   }
   return scores;
 }
 private Counter<String> lowAverageConfidence() {
   // Get confidences
   Counter<String> lowConfidence =
       new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory());
   for (Map<SentenceKey, EnsembleStatistics> impl : this.impl) {
     for (Map.Entry<SentenceKey, EnsembleStatistics> entry : impl.entrySet()) {
       SentenceStatistics average = entry.getValue().mean();
       for (double confidence : average.confidence) {
         lowConfidence.setCount(entry.getKey().sentenceHash, 1 - confidence);
       }
     }
   }
   return lowConfidence;
 }
Пример #20
0
  public Object formResult() {
    Set brs = new HashSet();
    Set urs = new HashSet();
    // scan each rule / history pair
    int ruleCount = 0;
    for (Iterator pairI = rulePairs.keySet().iterator(); pairI.hasNext(); ) {
      if (ruleCount % 100 == 0) {
        System.err.println("Rules multiplied: " + ruleCount);
      }
      ruleCount++;
      Pair rulePair = (Pair) pairI.next();
      Rule baseRule = (Rule) rulePair.first;
      String baseLabel = (String) ruleToLabel.get(baseRule);
      List history = (List) rulePair.second;
      double totalProb = 0;
      for (int depth = 1; depth <= HISTORY_DEPTH() && depth <= history.size(); depth++) {
        List subHistory = history.subList(0, depth);
        double c_label = labelPairs.getCount(new Pair(baseLabel, subHistory));
        double c_rule = rulePairs.getCount(new Pair(baseRule, subHistory));
        // System.out.println("Multiplying out "+baseRule+" with history "+subHistory);
        // System.out.println("Count of "+baseLabel+" with "+subHistory+" is "+c_label);
        // System.out.println("Count of "+baseRule+" with "+subHistory+" is "+c_rule );

        double prob = (1.0 / HISTORY_DEPTH()) * (c_rule) / (c_label);
        totalProb += prob;
        for (int childDepth = 0; childDepth <= Math.min(HISTORY_DEPTH() - 1, depth); childDepth++) {
          Rule rule = specifyRule(baseRule, subHistory, childDepth);
          rule.score = (float) Math.log(totalProb);
          // System.out.println("Created  "+rule+" with score "+rule.score);
          if (rule instanceof UnaryRule) {
            urs.add(rule);
          } else {
            brs.add(rule);
          }
        }
      }
    }
    System.out.println("Total states: " + stateNumberer.total());
    BinaryGrammar bg = new BinaryGrammar(stateNumberer.total());
    UnaryGrammar ug = new UnaryGrammar(stateNumberer.total());
    for (Iterator brI = brs.iterator(); brI.hasNext(); ) {
      BinaryRule br = (BinaryRule) brI.next();
      bg.addRule(br);
    }
    for (Iterator urI = urs.iterator(); urI.hasNext(); ) {
      UnaryRule ur = (UnaryRule) urI.next();
      ug.addRule(ur);
    }
    return new Pair(ug, bg);
  }
Пример #21
0
 protected void tallyInternalNode(Tree lt, List parents) {
   // form base rule
   String label = lt.label().value();
   Rule baseR = ltToRule(lt);
   ruleToLabel.put(baseR, label);
   // act on each history depth
   for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size());
       depth <= maxDepth;
       depth++) {
     List history = new ArrayList(parents.subList(0, depth));
     // tally each history level / rewrite pair
     rulePairs.incrementCount(new Pair(baseR, history), 1);
     labelPairs.incrementCount(new Pair(label, history), 1);
   }
 }
 protected void semanticCrossFeatures(
     Counter<String> features, String prefix, Sentence str1, Sentence str2) {
   double[] vec1 = embeddings.get(str1);
   double[] vec2 = embeddings.get(str2);
   for (int i = 0; i < vec1.length; i++)
     features.incrementCount(prefix + "wv-" + i, vec1[i] - vec2[i]);
 }
 protected void semanticCrossFeatures(
     Counter<String> features, String prefix, NumericMention mention, Sentence str2) {
   double[] vec1 = embeddings.get(mention.sentence.get(), mention.token_begin, mention.token_end);
   double[] vec2 = embeddings.get(str2);
   for (int i = 0; i < vec1.length; i++)
     features.incrementCount(prefix + "wv-" + i, vec1[i] - vec2[i]);
 }
Пример #24
0
 public static Set<String> featureWhiteList(FlatNBestList nbest, int minSegmentCount) {
   List<List<ScoredFeaturizedTranslation<IString, String>>> nbestlists = nbest.nbestLists();
   Counter<String> featureSegmentCounts = new ClassicCounter<String>();
   for (List<ScoredFeaturizedTranslation<IString, String>> nbestlist : nbestlists) {
     Set<String> segmentFeatureSet = new HashSet<String>();
     for (ScoredFeaturizedTranslation<IString, String> trans : nbestlist) {
       for (FeatureValue<String> feature : trans.features) {
         segmentFeatureSet.add(feature.name);
       }
     }
     for (String featureName : segmentFeatureSet) {
       featureSegmentCounts.incrementCount(featureName);
     }
   }
   return Counters.keysAbove(featureSegmentCounts, minSegmentCount - 1);
 }
 public void addCrossArgumentSemanticSimilarity(
     Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   double sim = 0.;
   double minSim = 0.;
   double maxSim = 0.;
   if (args.size() > 1) {
     sim =
         semanticSimilarityVW(features, "12-", args.get(0).subjSentence, args.get(1).subjSentence);
     maxSim = Math.max(maxSim, sim);
     minSim = Math.min(minSim, sim);
   }
   if (args.size() > 2) {
     sim =
         semanticSimilarityVW(features, "13-", args.get(0).subjSentence, args.get(2).subjSentence);
     maxSim = Math.max(maxSim, sim);
     minSim = Math.min(minSim, sim);
     sim =
         semanticSimilarityVW(features, "23-", args.get(1).subjSentence, args.get(2).subjSentence);
     maxSim = Math.max(maxSim, sim);
     minSim = Math.min(minSim, sim);
   }
   features.incrementCount("max-cross-semantic-similarity", maxSim);
   features.incrementCount("min-cross-semantic-similarity", minSim);
 }
Пример #26
0
 public static double[] getWeightArrayFromCounter(String[] weightNames, Counter<String> wts) {
   double[] wtsArr = new double[weightNames.length];
   for (int i = 0; i < wtsArr.length; i++) {
     wtsArr[i] = wts.getCount(weightNames[i]);
   }
   return wtsArr;
 }
Пример #27
0
 public static double scoreTranslation(
     Counter<String> wts, ScoredFeaturizedTranslation<IString, String> trans) {
   double s = 0;
   for (FeatureValue<String> fv : trans.features) {
     s += fv.value * wts.getCount(fv.name);
   }
   return s;
 }
Пример #28
0
  /**
   * This should be called after the classifier has been trained and parseAndTrain has been called
   * to accumulate test set
   *
   * <p>This will return precision,recall and F1 measure
   */
  public void runTestSet(List<List<CoreLabel>> testSet) {
    Counter<String> tp = new DefaultCounter<>();
    Counter<String> fp = new DefaultCounter<>();
    Counter<String> fn = new DefaultCounter<>();

    Counter<String> actual = new DefaultCounter<>();

    for (List<CoreLabel> labels : testSet) {
      List<CoreLabel> unannotatedLabels = new ArrayList<>();
      // create a new label without answer annotation
      for (CoreLabel label : labels) {
        CoreLabel newLabel = new CoreLabel();
        newLabel.set(annotationForWord, label.get(annotationForWord));
        newLabel.set(PartOfSpeechAnnotation.class, label.get(PartOfSpeechAnnotation.class));
        unannotatedLabels.add(newLabel);
      }

      List<CoreLabel> annotatedLabels = this.classifier.classify(unannotatedLabels);

      int ind = 0;
      for (CoreLabel expectedLabel : labels) {

        CoreLabel annotatedLabel = annotatedLabels.get(ind);
        String answer = annotatedLabel.get(AnswerAnnotation.class);
        String expectedAnswer = expectedLabel.get(AnswerAnnotation.class);

        actual.incrementCount(expectedAnswer);

        // match only non background symbols
        if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(expectedAnswer)
            && expectedAnswer.equals(answer)) {
          // true positives
          tp.incrementCount(answer);
          System.out.println("True Positive:" + annotatedLabel);
        } else if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(answer)) {
          // false positives
          fp.incrementCount(answer);
          System.out.println("False Positive:" + annotatedLabel);
        } else if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(expectedAnswer)) {
          // false negatives
          fn.incrementCount(expectedAnswer);
          System.out.println("False Negative:" + expectedLabel);
        } // else true negatives

        ind++;
      }
    }

    actual.remove(SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL);
  }
 @Override
 public boolean isSimpleSplit(Counter<String> feats) {
   for (String key : feats.keySet()) {
     if (key.startsWith("simple&")) {
       return true;
     }
   }
   return false;
 }
 @Deprecated
 protected void semanticSimilarity(
     Counter<String> features, String prefix, Sentence str1, Sentence str2) {
   Counter<String> v1 =
       new ClassicCounter<>(
           str1.lemmas().stream().map(String::toLowerCase).collect(Collectors.toList()));
   Counter<String> v2 = new ClassicCounter<>(str2.lemmas());
   // Remove any stopwords.
   for (String word : stopwords) {
     v1.remove(word);
     v2.remove(word);
   }
   // take inner product.
   double sim =
       Counters.dotProduct(v1, v2) / (Counters.saferL2Norm(v1) * Counters.saferL2Norm(v2));
   features.incrementCount(
       prefix + "semantic-similarity", 2 * sim - 1); // to make it between 0 and 1.
 }