コード例 #1
0
        @SuppressWarnings({"unchecked"})
        @Override
        protected void fillFeatures(
            Pair<Mention, ClusteredMention> input,
            Counter<Feature> inFeatures,
            Boolean output,
            Counter<Feature> outFeatures) {
          // --Input Features
          for (Object o : ACTIVE_FEATURES) {
            if (o instanceof Class) {
              // (case: singleton feature)
              Option<Double> count = new Option<Double>(1.0);
              Feature feat = feature((Class) o, input, count);
              if (count.get() > 0.0) {
                inFeatures.incrementCount(feat, count.get());
              }
            } else if (o instanceof Pair) {
              // (case: pair of features)
              Pair<Class, Class> pair = (Pair<Class, Class>) o;
              Option<Double> countA = new Option<Double>(1.0);
              Option<Double> countB = new Option<Double>(1.0);
              Feature featA = feature(pair.getFirst(), input, countA);
              Feature featB = feature(pair.getSecond(), input, countB);
              if (countA.get() * countB.get() > 0.0) {
                inFeatures.incrementCount(
                    new Feature.PairFeature(featA, featB), countA.get() * countB.get());
              }
            }
          }

          // --Output Features
          if (output != null) {
            outFeatures.incrementCount(new Feature.CoreferentIndicator(output), 1.0);
          }
        }
コード例 #2
0
 public void addCrossNumericProximity(Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   if (args.size() > 1) {
     double multiplier = args.get(0).val / args.get(1).val;
     features.incrementCount("abs-numeric-distance-12", Math.abs(Math.log(multiplier)));
   }
   if (args.size() > 2) {
     double multiplier13 = args.get(0).val / args.get(2).val;
     double multiplier23 = args.get(1).val / args.get(2).val;
     features.incrementCount("abs-numeric-distance-13", Math.abs(Math.log(multiplier13)));
     features.incrementCount("abs-numeric-distance-23", Math.abs(Math.log(multiplier23)));
   }
 }
コード例 #3
0
  /**
   * This should be called after the classifier has been trained and parseAndTrain has been called
   * to accumulate test set
   *
   * <p>This will return precision,recall and F1 measure
   */
  public void runTestSet(List<List<CoreLabel>> testSet) {
    Counter<String> tp = new DefaultCounter<>();
    Counter<String> fp = new DefaultCounter<>();
    Counter<String> fn = new DefaultCounter<>();

    Counter<String> actual = new DefaultCounter<>();

    for (List<CoreLabel> labels : testSet) {
      List<CoreLabel> unannotatedLabels = new ArrayList<>();
      // create a new label without answer annotation
      for (CoreLabel label : labels) {
        CoreLabel newLabel = new CoreLabel();
        newLabel.set(annotationForWord, label.get(annotationForWord));
        newLabel.set(PartOfSpeechAnnotation.class, label.get(PartOfSpeechAnnotation.class));
        unannotatedLabels.add(newLabel);
      }

      List<CoreLabel> annotatedLabels = this.classifier.classify(unannotatedLabels);

      int ind = 0;
      for (CoreLabel expectedLabel : labels) {

        CoreLabel annotatedLabel = annotatedLabels.get(ind);
        String answer = annotatedLabel.get(AnswerAnnotation.class);
        String expectedAnswer = expectedLabel.get(AnswerAnnotation.class);

        actual.incrementCount(expectedAnswer);

        // match only non background symbols
        if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(expectedAnswer)
            && expectedAnswer.equals(answer)) {
          // true positives
          tp.incrementCount(answer);
          System.out.println("True Positive:" + annotatedLabel);
        } else if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(answer)) {
          // false positives
          fp.incrementCount(answer);
          System.out.println("False Positive:" + annotatedLabel);
        } else if (!SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL.equals(expectedAnswer)) {
          // false negatives
          fn.incrementCount(expectedAnswer);
          System.out.println("False Negative:" + expectedLabel);
        } // else true negatives

        ind++;
      }
    }

    actual.remove(SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL);
  }
コード例 #4
0
ファイル: FactoredParser.java プロジェクト: renaud/maven_repo
 protected void tallyInternalNode(Tree lt, List parents) {
   // form base rule
   String label = lt.label().value();
   Rule baseR = ltToRule(lt);
   ruleToLabel.put(baseR, label);
   // act on each history depth
   for (int depth = 0, maxDepth = Math.min(HISTORY_DEPTH(), parents.size());
       depth <= maxDepth;
       depth++) {
     List history = new ArrayList(parents.subList(0, depth));
     // tally each history level / rewrite pair
     rulePairs.incrementCount(new Pair(baseR, history), 1);
     labelPairs.incrementCount(new Pair(label, history), 1);
   }
 }
コード例 #5
0
 public void addCrossArgumentSemanticSimilarity(
     Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   double sim = 0.;
   double minSim = 0.;
   double maxSim = 0.;
   if (args.size() > 1) {
     sim =
         semanticSimilarityVW(features, "12-", args.get(0).subjSentence, args.get(1).subjSentence);
     maxSim = Math.max(maxSim, sim);
     minSim = Math.min(minSim, sim);
   }
   if (args.size() > 2) {
     sim =
         semanticSimilarityVW(features, "13-", args.get(0).subjSentence, args.get(2).subjSentence);
     maxSim = Math.max(maxSim, sim);
     minSim = Math.min(minSim, sim);
     sim =
         semanticSimilarityVW(features, "23-", args.get(1).subjSentence, args.get(2).subjSentence);
     maxSim = Math.max(maxSim, sim);
     minSim = Math.min(minSim, sim);
   }
   features.incrementCount("max-cross-semantic-similarity", maxSim);
   features.incrementCount("min-cross-semantic-similarity", minSim);
 }
コード例 #6
0
 protected void semanticCrossFeatures(
     Counter<String> features, String prefix, NumericMention mention, Sentence str2) {
   double[] vec1 = embeddings.get(mention.sentence.get(), mention.token_begin, mention.token_end);
   double[] vec2 = embeddings.get(str2);
   for (int i = 0; i < vec1.length; i++)
     features.incrementCount(prefix + "wv-" + i, vec1[i] - vec2[i]);
 }
コード例 #7
0
  /**
   * Returns a list of featured thresholded by minPrecision and sorted by their frequency of
   * occurrence. precision in this case, is defined as the frequency of majority label over total
   * frequency for that feature.
   *
   * @return list of high precision features.
   */
  private List<F> getHighPrecisionFeatures(
      GeneralDataset<L, F> dataset, double minPrecision, int maxNumFeatures) {
    int[][] feature2label = new int[dataset.numFeatures()][dataset.numClasses()];
    for (int f = 0; f < dataset.numFeatures(); f++) Arrays.fill(feature2label[f], 0);

    int[][] data = dataset.data;
    int[] labels = dataset.labels;
    for (int d = 0; d < data.length; d++) {
      int label = labels[d];
      // System.out.println("datum id:"+d+" label id: "+label);
      if (data[d] != null) {
        // System.out.println(" number of features:"+data[d].length);
        for (int n = 0; n < data[d].length; n++) {
          feature2label[data[d][n]][label]++;
        }
      }
    }
    Counter<F> feature2freq = new ClassicCounter<F>();
    for (int f = 0; f < dataset.numFeatures(); f++) {
      int maxF = ArrayMath.max(feature2label[f]);
      int total = ArrayMath.sum(feature2label[f]);
      double precision = ((double) maxF) / total;
      F feature = dataset.featureIndex.get(f);
      if (precision >= minPrecision) {
        feature2freq.incrementCount(feature, total);
      }
    }
    if (feature2freq.size() > maxNumFeatures) {
      Counters.retainTop(feature2freq, maxNumFeatures);
    }
    // for(F feature : feature2freq.keySet())
    // System.out.println(feature+" "+feature2freq.getCount(feature));
    // System.exit(0);
    return Counters.toSortedList(feature2freq);
  }
コード例 #8
0
 protected void semanticCrossFeatures(
     Counter<String> features, String prefix, Sentence str1, Sentence str2) {
   double[] vec1 = embeddings.get(str1);
   double[] vec2 = embeddings.get(str2);
   for (int i = 0; i < vec1.length; i++)
     features.incrementCount(prefix + "wv-" + i, vec1[i] - vec2[i]);
 }
コード例 #9
0
ファイル: Benchmarks.java プロジェクト: jayantam/CoreNLP
  /**
   * 67% of time spent in LogConditionalObjectiveFunction.rvfcalculate() 29% of time spent in
   * dataset construction (11% in RVFDataset.addFeatures(), 7% rvf incrementCount(), 11% rest)
   *
   * <p>Single threaded, 4700 ms Multi threaded, 700 ms
   *
   * <p>With same data, seed 42, 245 ms With reordered accesses for cacheing, 195 ms Down to 80% of
   * the time, not huge but a win nonetheless
   *
   * <p>with 8 cpus, a 6.7x speedup -- almost, but not quite linear, pretty good
   */
  public static void benchmarkRVFLogisticRegression() {
    RVFDataset<String, String> data = new RVFDataset<>();
    for (int i = 0; i < 10000; i++) {
      Random r = new Random(42);
      Counter<String> features = new ClassicCounter<>();

      boolean cl = r.nextBoolean();

      for (int j = 0; j < 1000; j++) {
        double value;
        if (cl && i % 2 == 0) {
          value = (r.nextDouble() * 2.0) - 0.6;
        } else {
          value = (r.nextDouble() * 2.0) - 1.4;
        }
        features.incrementCount("f" + j, value);
      }

      data.add(new RVFDatum<>(features, "target:" + cl));
    }

    LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>();

    long msStart = System.currentTimeMillis();
    factory.trainClassifier(data);
    long delay = System.currentTimeMillis() - msStart;
    System.out.println("Training took " + delay + " ms");
  }
コード例 #10
0
 private NaiveBayesClassifier<L, F> trainClassifier(
     int[][] data,
     int[] labels,
     int numFeatures,
     int numClasses,
     Index<L> labelIndex,
     Index<F> featureIndex) {
   Set<L> labelSet = Generics.newHashSet();
   NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses);
   Counter<L> priors = new ClassicCounter<L>();
   double[] pr = nbWeights.priors;
   for (int i = 0; i < pr.length; i++) {
     priors.incrementCount(labelIndex.get(i), pr[i]);
     labelSet.add(labelIndex.get(i));
   }
   Counter<Pair<Pair<L, F>, Number>> weightsCounter =
       new ClassicCounter<Pair<Pair<L, F>, Number>>();
   double[][][] wts = nbWeights.weights;
   for (int c = 0; c < numClasses; c++) {
     L label = labelIndex.get(c);
     for (int f = 0; f < numFeatures; f++) {
       F feature = featureIndex.get(f);
       Pair<L, F> p = new Pair<L, F>(label, feature);
       for (int val = 0; val < wts[c][f].length; val++) {
         Pair<Pair<L, F>, Number> key = new Pair<Pair<L, F>, Number>(p, Integer.valueOf(val));
         weightsCounter.incrementCount(key, wts[c][f][val]);
       }
     }
   }
   return new NaiveBayesClassifier<L, F>(weightsCounter, priors, labelSet);
 }
コード例 #11
0
ファイル: Compressor.java プロジェクト: wayzou/CoreNLP
 public Counter<K> uncompress(CompressedFeatureVector cvf) {
   Counter<K> c = new ClassicCounter<>();
   for (int i = 0; i < cvf.keys.size(); i++) {
     c.incrementCount(inverse.get(cvf.keys.get(i)), cvf.values.get(i));
   }
   return c;
 }
コード例 #12
0
 public static void main(String[] args) {
   IntegerLinearProgramming ilp = new IntegerLinearProgramming();
   Counter<String> obj = new ClassicCounter<String>();
   Counter<String> c1 = new ClassicCounter<String>();
   Counter<String> c2 = new ClassicCounter<String>();
   Counter<String> c3 = new ClassicCounter<String>();
   obj.incrementCount("x0", 3);
   obj.incrementCount("x1", 1);
   obj.incrementCount("x2", 5);
   obj.incrementCount("x3", 1);
   c1.incrementCount("x0", 3);
   c1.incrementCount("x1", 1);
   c1.incrementCount("x2", 2);
   c2.incrementCount("x0", 2);
   c2.incrementCount("x1", 1);
   c2.incrementCount("x2", 3);
   c2.incrementCount("x3", 1);
   c3.incrementCount("x1", 2);
   c3.incrementCount("x3", 3);
   ilp.setObjective(obj, true);
   ilp.addConstraint(c1, true, 30, true, 30);
   ilp.addConstraint(c2, true, 15, false, 0);
   ilp.addConstraint(c3, false, 0, true, 25);
   ilp.setupVariableLowerUpper("x0", 0, 10000);
   ilp.setupVariableLowerUpper("x1", 0, 10);
   ilp.setupVariableLowerUpper("x2", 0, 10000);
   ilp.setupVariableLowerUpper("x3", 0, 10000);
   D.p(ilp.run());
 }
コード例 #13
0
ファイル: OptimizerUtils.java プロジェクト: wentaouc/phrasal
  public static <T> Counter<T> featureValueCollectionToCounter(Collection<FeatureValue<T>> c) {
    Counter<T> counter = new ClassicCounter<T>();

    for (FeatureValue<T> fv : c) {
      counter.incrementCount(fv.name, fv.value);
    }

    return counter;
  }
コード例 #14
0
ファイル: Util.java プロジェクト: foxlf823/ade
  public static List<String> generateDict(List<String> str, int cutOff) {
    Counter<String> freq = new IntCounter<>();
    for (String aStr : str) freq.incrementCount(aStr);

    List<String> keys = Counters.toSortedList(freq, false);
    List<String> dict = new ArrayList<>();
    for (String word : keys) {
      if (freq.getCount(word) >= cutOff) dict.add(word);
    }
    return dict;
  }
コード例 #15
0
 public void addFactFeatures(Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   {
     for (NumericTuple tuple : args) {
       //        String id1 =
       // tuple.id.map(Object::toString).orElse(tuple.subj.replace("[^a-zA-Z0-9]","_"));
       String id1 = tuple.subj.replace("[^a-zA-Z0-9]", "_");
       features.incrementCount("fact-" + id1);
     }
   }
 }
コード例 #16
0
ファイル: Util.java プロジェクト: Eagles2F/CoreNLP
 /**
  * TODO(gabor) JavaDoc
  *
  * @param tokens
  * @param span
  * @return
  */
 public static String guessNER(List<CoreLabel> tokens, Span span) {
   Counter<String> nerGuesses = new ClassicCounter<>();
   for (int i : span) {
     nerGuesses.incrementCount(tokens.get(i).ner());
   }
   nerGuesses.remove("O");
   nerGuesses.remove(null);
   if (nerGuesses.size() > 0 && Counters.max(nerGuesses) >= span.size() / 2) {
     return Counters.argmax(nerGuesses);
   } else {
     return "O";
   }
 }
コード例 #17
0
 public void addFactMentionFeatures(Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   {
     for (NumericTuple tuple : args) {
       String id1 = tuple.subj.replace("[^a-zA-Z0-9]", "_");
       //        // Add non-stopword bigrams.
       List<String> words =
           e.mention_sentence
               .lemmas()
               .stream()
               .filter(w -> !stopwords.contains(w))
               .collect(Collectors.toList());
       for (int i = 0; i < words.size(); i++) {
         String word = words.get(i);
         features.incrementCount("fact-" + id1 + "-" + word);
         if (i < words.size() - 2) {
           String word_ = words.get(i + 1);
           features.incrementCount("fact-" + id1 + "-" + word + "-" + word_);
         }
       }
     }
   }
 }
コード例 #18
0
  public void addFactCrossFeatures(Counter<String> features, final NumericMentionExpression e) {
    List<NumericTuple> args = e.expression.arguments();
    {
      for (NumericTuple tuple : args) {
        //        String id1 =
        // tuple.id.map(Object::toString).orElse(tuple.subj.replace("[^a-zA-Z0-9]","_"));
        String id1 = tuple.subj.replace("[^a-zA-Z0-9]", "_");

        for (NumericTuple tuple_ : args) {
          if (tuple.equals(tuple_)) continue;
          // Make sure units are compatible.
          //          if (!(tuple.unit.canMul(tuple_.unit) || tuple_.unit.canMul(tuple.unit)))
          // continue;

          //          String id2 =
          // tuple_.id.map(Object::toString).orElse(tuple_.subj.replace("[^a-zA-Z0-9]","_"));
          String id2 = tuple_.subj.replace("[^a-zA-Z0-9]", "_");

          if (id1.compareTo(id2) > 0) features.incrementCount("fact-" + id1 + "-" + id2);
          else features.incrementCount("fact-" + id2 + "-" + id1);
        }
      }
    }
  }
コード例 #19
0
ファイル: OptimizerUtils.java プロジェクト: wentaouc/phrasal
 public static Set<String> featureWhiteList(FlatNBestList nbest, int minSegmentCount) {
   List<List<ScoredFeaturizedTranslation<IString, String>>> nbestlists = nbest.nbestLists();
   Counter<String> featureSegmentCounts = new ClassicCounter<String>();
   for (List<ScoredFeaturizedTranslation<IString, String>> nbestlist : nbestlists) {
     Set<String> segmentFeatureSet = new HashSet<String>();
     for (ScoredFeaturizedTranslation<IString, String> trans : nbestlist) {
       for (FeatureValue<String> feature : trans.features) {
         segmentFeatureSet.add(feature.name);
       }
     }
     for (String featureName : segmentFeatureSet) {
       featureSegmentCounts.incrementCount(featureName);
     }
   }
   return Counters.keysAbove(featureSegmentCounts, minSegmentCount - 1);
 }
コード例 #20
0
ファイル: OptimizerUtils.java プロジェクト: wentaouc/phrasal
 /**
  * Update an existing feature whitelist according to nbestlists. Then return the features that
  * appear more than minSegmentCount times.
  *
  * @param featureWhitelist
  * @param nbestlists
  * @param minSegmentCount
  * @return features that appear more than minSegmentCount times
  */
 public static Set<String> updatefeatureWhiteList(
     Counter<String> featureWhitelist,
     List<List<RichTranslation<IString, String>>> nbestlists,
     int minSegmentCount) {
   for (List<RichTranslation<IString, String>> nbestlist : nbestlists) {
     Set<String> segmentFeatureSet = new HashSet<String>(1000);
     for (RichTranslation<IString, String> trans : nbestlist) {
       for (FeatureValue<String> feature : trans.features) {
         if (!segmentFeatureSet.contains(feature.name)) {
           segmentFeatureSet.add(feature.name);
           featureWhitelist.incrementCount(feature.name);
         }
       }
     }
   }
   return Counters.keysAbove(featureWhitelist, minSegmentCount - 1);
 }
コード例 #21
0
 @Deprecated
 protected void semanticSimilarity(
     Counter<String> features, String prefix, Sentence str1, Sentence str2) {
   Counter<String> v1 =
       new ClassicCounter<>(
           str1.lemmas().stream().map(String::toLowerCase).collect(Collectors.toList()));
   Counter<String> v2 = new ClassicCounter<>(str2.lemmas());
   // Remove any stopwords.
   for (String word : stopwords) {
     v1.remove(word);
     v2.remove(word);
   }
   // take inner product.
   double sim =
       Counters.dotProduct(v1, v2) / (Counters.saferL2Norm(v1) * Counters.saferL2Norm(v2));
   features.incrementCount(
       prefix + "semantic-similarity", 2 * sim - 1); // to make it between 0 and 1.
 }
コード例 #22
0
 public SentenceStatistics mean() {
   double sumConfidence = 0;
   int countWithConfidence = 0;
   Counter<String> avePredictions =
       new ClassicCounter<>(MapFactory.<String, MutableDouble>linkedHashMapFactory());
   // Sum
   for (SentenceStatistics stat : this.statisticsForClassifiers) {
     for (Double confidence : stat.confidence) {
       sumConfidence += confidence;
       countWithConfidence += 1;
     }
     assert Math.abs(stat.relationDistribution.totalCount() - 1.0) < 1e-5;
     for (Map.Entry<String, Double> entry : stat.relationDistribution.entrySet()) {
       assert entry.getValue() >= 0.0;
       assert entry.getValue() == stat.relationDistribution.getCount(entry.getKey());
       avePredictions.incrementCount(entry.getKey(), entry.getValue());
       assert stat.relationDistribution.getCount(entry.getKey())
           == stat.relationDistribution.getCount(entry.getKey());
     }
   }
   // Normalize
   double aveConfidence = sumConfidence / ((double) countWithConfidence);
   // Return
   if (this.statisticsForClassifiers.size() > 1) {
     Counters.divideInPlace(avePredictions, (double) this.statisticsForClassifiers.size());
   }
   if (Math.abs(avePredictions.totalCount() - 1.0) > 1e-5) {
     throw new IllegalStateException("Mean relation distribution is not a distribution!");
   }
   assert this.statisticsForClassifiers.size() > 1
       || this.statisticsForClassifiers.size() == 0
       || Counters.equals(
           avePredictions,
           statisticsForClassifiers.iterator().next().relationDistribution,
           1e-5);
   return countWithConfidence > 0
       ? new SentenceStatistics(avePredictions, aveConfidence)
       : new SentenceStatistics(avePredictions);
 }
コード例 #23
0
 public void addPerLengthNumericProximity(
     Counter<String> features, final NumericMentionExpression e) {
   List<NumericTuple> args = e.expression.arguments();
   {
     double multiplier = e.mention_normalized_value / args.get(0).val;
     features.incrementCount("abs-numeric-distance-1", Math.abs(Math.log(multiplier)));
     features.incrementCount("sign-numeric-distance-1", Math.signum(Math.log(multiplier)));
   }
   if (args.size() > 1) {
     double multiplier = e.mention_normalized_value / args.get(1).val;
     features.incrementCount("abs-numeric-distance-2", Math.abs(Math.log(multiplier)));
     features.incrementCount("sign-numeric-distance-2", Math.signum(Math.log(multiplier)));
   }
   if (args.size() > 2) {
     double multiplier = e.mention_normalized_value / args.get(2).val;
     features.incrementCount("abs-numeric-distance-3", Math.abs(Math.log(multiplier)));
     features.incrementCount("sign-numeric-distance-3", Math.signum(Math.log(multiplier)));
   }
 }
コード例 #24
0
  /** @param args */
  public static void main(String[] args) {
    if (args.length != 3) {
      System.err.printf(
          "Usage: java %s language filename features%n",
          TreebankFactoredLexiconStats.class.getName());
      System.exit(-1);
    }

    Language language = Language.valueOf(args[0]);
    TreebankLangParserParams tlpp = language.params;
    if (language.equals(Language.Arabic)) {
      String[] options = {"-arabicFactored"};
      tlpp.setOptionFlag(options, 0);
    } else {
      String[] options = {"-frenchFactored"};
      tlpp.setOptionFlag(options, 0);
    }
    Treebank tb = tlpp.diskTreebank();
    tb.loadPath(args[1]);

    MorphoFeatureSpecification morphoSpec =
        language.equals(Language.Arabic)
            ? new ArabicMorphoFeatureSpecification()
            : new FrenchMorphoFeatureSpecification();

    String[] features = args[2].trim().split(",");
    for (String feature : features) {
      morphoSpec.activate(MorphoFeatureType.valueOf(feature));
    }

    // Counters
    Counter<String> wordTagCounter = new ClassicCounter<>(30000);
    Counter<String> morphTagCounter = new ClassicCounter<>(500);
    //    Counter<String> signatureTagCounter = new ClassicCounter<String>();
    Counter<String> morphCounter = new ClassicCounter<>(500);
    Counter<String> wordCounter = new ClassicCounter<>(30000);
    Counter<String> tagCounter = new ClassicCounter<>(300);

    Counter<String> lemmaCounter = new ClassicCounter<>(25000);
    Counter<String> lemmaTagCounter = new ClassicCounter<>(25000);

    Counter<String> richTagCounter = new ClassicCounter<>(1000);

    Counter<String> reducedTagCounter = new ClassicCounter<>(500);

    Counter<String> reducedTagLemmaCounter = new ClassicCounter<>(500);

    Map<String, Set<String>> wordLemmaMap = Generics.newHashMap();

    TwoDimensionalIntCounter<String, String> lemmaReducedTagCounter =
        new TwoDimensionalIntCounter<>(30000);
    TwoDimensionalIntCounter<String, String> reducedTagTagCounter =
        new TwoDimensionalIntCounter<>(500);
    TwoDimensionalIntCounter<String, String> tagReducedTagCounter =
        new TwoDimensionalIntCounter<>(300);

    int numTrees = 0;
    for (Tree tree : tb) {
      for (Tree subTree : tree) {
        if (!subTree.isLeaf()) {
          tlpp.transformTree(subTree, tree);
        }
      }
      List<Label> pretermList = tree.preTerminalYield();
      List<Label> yield = tree.yield();
      assert yield.size() == pretermList.size();

      int yieldLen = yield.size();
      for (int i = 0; i < yieldLen; ++i) {
        String tag = pretermList.get(i).value();

        String word = yield.get(i).value();
        String morph = ((CoreLabel) yield.get(i)).originalText();

        // Note: if there is no lemma, then we use the surface form.
        Pair<String, String> lemmaTag = MorphoFeatureSpecification.splitMorphString(word, morph);
        String lemma = lemmaTag.first();
        String richTag = lemmaTag.second();

        // WSGDEBUG
        if (tag.contains("MW")) lemma += "-MWE";

        lemmaCounter.incrementCount(lemma);
        lemmaTagCounter.incrementCount(lemma + tag);

        richTagCounter.incrementCount(richTag);

        String reducedTag = morphoSpec.strToFeatures(richTag).toString();
        reducedTagCounter.incrementCount(reducedTag);

        reducedTagLemmaCounter.incrementCount(reducedTag + lemma);

        wordTagCounter.incrementCount(word + tag);
        morphTagCounter.incrementCount(morph + tag);
        morphCounter.incrementCount(morph);
        wordCounter.incrementCount(word);
        tagCounter.incrementCount(tag);

        reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
        if (wordLemmaMap.containsKey(word)) {
          wordLemmaMap.get(word).add(lemma);
        } else {
          Set<String> lemmas = Generics.newHashSet(1);
          wordLemmaMap.put(word, lemmas);
        }
        lemmaReducedTagCounter.incrementCount(lemma, reducedTag);
        reducedTagTagCounter.incrementCount(lemma + reducedTag, tag);
        tagReducedTagCounter.incrementCount(tag, reducedTag);
      }
      ++numTrees;
    }

    // Barf...
    System.out.println("Language: " + language.toString());
    System.out.printf("#trees:\t%d%n", numTrees);
    System.out.printf("#tokens:\t%d%n", (int) wordCounter.totalCount());
    System.out.printf("#words:\t%d%n", wordCounter.keySet().size());
    System.out.printf("#tags:\t%d%n", tagCounter.keySet().size());
    System.out.printf("#wordTagPairs:\t%d%n", wordTagCounter.keySet().size());
    System.out.printf("#lemmas:\t%d%n", lemmaCounter.keySet().size());
    System.out.printf("#lemmaTagPairs:\t%d%n", lemmaTagCounter.keySet().size());
    System.out.printf("#feattags:\t%d%n", reducedTagCounter.keySet().size());
    System.out.printf("#feattag+lemmas:\t%d%n", reducedTagLemmaCounter.keySet().size());
    System.out.printf("#richtags:\t%d%n", richTagCounter.keySet().size());
    System.out.printf("#richtag+lemma:\t%d%n", morphCounter.keySet().size());
    System.out.printf("#richtag+lemmaTagPairs:\t%d%n", morphTagCounter.keySet().size());

    // Extra
    System.out.println("==================");
    StringBuilder sbNoLemma = new StringBuilder();
    StringBuilder sbMultLemmas = new StringBuilder();
    for (Map.Entry<String, Set<String>> wordLemmas : wordLemmaMap.entrySet()) {
      String word = wordLemmas.getKey();
      Set<String> lemmas = wordLemmas.getValue();
      if (lemmas.size() == 0) {
        sbNoLemma.append("NO LEMMAS FOR WORD: " + word + "\n");
        continue;
      }
      if (lemmas.size() > 1) {
        sbMultLemmas.append("MULTIPLE LEMMAS: " + word + " " + setToString(lemmas) + "\n");
        continue;
      }
      String lemma = lemmas.iterator().next();
      Set<String> reducedTags = lemmaReducedTagCounter.getCounter(lemma).keySet();
      if (reducedTags.size() > 1) {
        System.out.printf("%s --> %s%n", word, lemma);
        for (String reducedTag : reducedTags) {
          int count = lemmaReducedTagCounter.getCount(lemma, reducedTag);
          String posTags =
              setToString(reducedTagTagCounter.getCounter(lemma + reducedTag).keySet());
          System.out.printf("\t%s\t%d\t%s%n", reducedTag, count, posTags);
        }
        System.out.println();
      }
    }
    System.out.println("==================");
    System.out.println(sbNoLemma.toString());
    System.out.println(sbMultLemmas.toString());
    System.out.println("==================");
    List<String> tags = new ArrayList<>(tagReducedTagCounter.firstKeySet());
    Collections.sort(tags);
    for (String tag : tags) {
      System.out.println(tag);
      Set<String> reducedTags = tagReducedTagCounter.getCounter(tag).keySet();
      for (String reducedTag : reducedTags) {
        int count = tagReducedTagCounter.getCount(tag, reducedTag);
        //        reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
        System.out.printf("\t%s\t%d%n", reducedTag, count);
      }
      System.out.println();
    }
    System.out.println("==================");
  }
コード例 #25
0
 public void addLengthBias(Counter<String> features, final NumericMentionExpression e) {
   features.incrementCount(e.expression.arguments().size() + "-narguments");
 }
コード例 #26
0
 public void addNumericProximity(Counter<String> features, final NumericMentionExpression e) {
   double multiplier = e.mention_normalized_value / e.expression_value;
   features.incrementCount("numeric-distance", Math.log(multiplier));
   features.incrementCount("abs-numeric-distance", Math.abs(Math.log(multiplier)));
   features.incrementCount("sign-numeric-distance", Math.signum(Math.log(multiplier)));
 }
コード例 #27
0
 protected void semanticFeatures(Counter<String> features, String prefix, Sentence str) {
   double[] vec = embeddings.get(str);
   for (int i = 0; i < vec.length; i++) features.incrementCount(prefix + "wv-" + i, vec[i]);
 }
コード例 #28
0
 protected void semanticFeatures(Counter<String> features, String prefix, NumericMention mention) {
   double[] vec = embeddings.get(mention.sentence.get(), mention.token_begin, mention.token_end);
   for (int i = 0; i < vec.length; i++) features.incrementCount(prefix + "wv-" + i, vec[i]);
 }
コード例 #29
0
 public void addBias(final Counter<String> features, final NumericMentionExpression e) {
   features.incrementCount("bias", 1);
 }
コード例 #30
0
        @Override
        public Counter<String> apply(Triple<State, Action, State> triple) {
          // Variables
          State from = triple.first;
          Action action = triple.second;
          State to = triple.third;
          String signature = action.signature();
          String edgeRelTaken = to.edge == null ? "root" : to.edge.getRelation().toString();
          String edgeRelShort = to.edge == null ? "root" : to.edge.getRelation().getShortName();
          if (edgeRelShort.contains("_")) {
            edgeRelShort = edgeRelShort.substring(0, edgeRelShort.indexOf("_"));
          }

          // -- Featurize --
          // Variables to aggregate
          boolean parentHasSubj = false;
          boolean parentHasObj = false;
          boolean childHasSubj = false;
          boolean childHasObj = false;
          Counter<String> feats = new ClassicCounter<>();

          // 1. edge taken
          feats.incrementCount(signature + "&edge:" + edgeRelTaken);
          feats.incrementCount(signature + "&edge_type:" + edgeRelShort);

          // 2. last edge taken
          if (from.edge == null) {
            assert to.edge == null || to.originalTree().getRoots().contains(to.edge.getGovernor());
            feats.incrementCount(signature + "&at_root");
            feats.incrementCount(
                signature + "&at_root&root_pos:" + to.originalTree().getFirstRoot().tag());
          } else {
            feats.incrementCount(signature + "&not_root");
            String lastRelShort = from.edge.getRelation().getShortName();
            if (lastRelShort.contains("_")) {
              lastRelShort = lastRelShort.substring(0, lastRelShort.indexOf("_"));
            }
            feats.incrementCount(signature + "&last_edge:" + lastRelShort);
          }

          if (to.edge != null) {
            // 3. other edges at parent
            for (SemanticGraphEdge parentNeighbor :
                from.originalTree().outgoingEdgeIterable(to.edge.getGovernor())) {
              if (parentNeighbor != to.edge) {
                String parentNeighborRel = parentNeighbor.getRelation().toString();
                if (parentNeighborRel.contains("subj")) {
                  parentHasSubj = true;
                }
                if (parentNeighborRel.contains("obj")) {
                  parentHasObj = true;
                }
                // (add feature)
                feats.incrementCount(signature + "&parent_neighbor:" + parentNeighborRel);
                feats.incrementCount(
                    signature
                        + "&edge_type:"
                        + edgeRelShort
                        + "&parent_neighbor:"
                        + parentNeighborRel);
              }
            }

            // 4. Other edges at child
            int childNeighborCount = 0;
            for (SemanticGraphEdge childNeighbor :
                from.originalTree().outgoingEdgeIterable(to.edge.getDependent())) {
              String childNeighborRel = childNeighbor.getRelation().toString();
              if (childNeighborRel.contains("subj")) {
                childHasSubj = true;
              }
              if (childNeighborRel.contains("obj")) {
                childHasObj = true;
              }
              childNeighborCount += 1;
              // (add feature)
              feats.incrementCount(signature + "&child_neighbor:" + childNeighborRel);
              feats.incrementCount(
                  signature + "&edge_type:" + edgeRelShort + "&child_neighbor:" + childNeighborRel);
            }
            // 4.1 Number of other edges at child
            feats.incrementCount(
                signature
                    + "&child_neighbor_count:"
                    + (childNeighborCount < 3 ? childNeighborCount : ">2"));
            feats.incrementCount(
                signature
                    + "&edge_type:"
                    + edgeRelShort
                    + "&child_neighbor_count:"
                    + (childNeighborCount < 3 ? childNeighborCount : ">2"));

            // 5. Subject/Object stats
            feats.incrementCount(signature + "&parent_neighbor_subj:" + parentHasSubj);
            feats.incrementCount(signature + "&parent_neighbor_obj:" + parentHasObj);
            feats.incrementCount(signature + "&child_neighbor_subj:" + childHasSubj);
            feats.incrementCount(signature + "&child_neighbor_obj:" + childHasObj);

            // 6. POS tag info
            feats.incrementCount(signature + "&parent_pos:" + to.edge.getGovernor().tag());
            feats.incrementCount(signature + "&child_pos:" + to.edge.getDependent().tag());
            feats.incrementCount(
                signature
                    + "&pos_signature:"
                    + to.edge.getGovernor().tag()
                    + "_"
                    + to.edge.getDependent().tag());
            feats.incrementCount(
                signature
                    + "&edge_type:"
                    + edgeRelShort
                    + "&pos_signature:"
                    + to.edge.getGovernor().tag()
                    + "_"
                    + to.edge.getDependent().tag());
          }
          return feats;
        }