Пример #1
0
  public static void main(String[] args) throws Exception {
    CompareType compareType = CompareType.MAX_COOR_ABS;

    if (args.length == 0) {
      usage();
      System.exit(-1);
    }

    if ("-cosine".equals(args[0])) {
      compareType = CompareType.COSINE;
      args = (String[]) Arrays.copyOfRange(args, 1, args.length);
    } else if ("-sse".equals(args[0])) {
      compareType = CompareType.SUM_SQUARE_ERROR;
      args = (String[]) Arrays.copyOfRange(args, 1, args.length);
    }

    if (args.length != 2) {
      usage();
      System.exit(-1);
    }

    Set<String> allWeights = new HashSet<String>();
    Counter<String> wts1 = IOTools.readWeights(args[0]);
    Counter<String> wts2 = IOTools.readWeights(args[1]);
    allWeights.addAll(wts1.keySet());
    allWeights.addAll(wts2.keySet());

    if (compareType == CompareType.MAX_COOR_ABS) {
      double maxDiff = Double.NEGATIVE_INFINITY;
      Counters.multiplyInPlace(wts1, 1.0 / Counters.L1Norm(wts1));
      Counters.multiplyInPlace(wts2, 1.0 / Counters.L1Norm(wts2));
      for (String wt : allWeights) {
        double absDiff = Math.abs(wts1.getCount(wt) - wts2.getCount(wt));
        if (absDiff > maxDiff) maxDiff = absDiff;
      }
      System.out.println(maxDiff);
    } else if (compareType == CompareType.COSINE) {
      double dotProd = Counters.cosine(wts1, wts2);
      System.out.println(dotProd);
    } else if (compareType == CompareType.SUM_SQUARE_ERROR) {
      double sse = 0;
      for (String wt : allWeights) {
        double diff = wts1.getCount(wt) - wts2.getCount(wt);
        sse += diff * diff;
      }
      System.out.println(sse);
    }
  }
Пример #2
0
 /**
  * Returns a list of all modes in the Collection. (If the Collection has multiple items with the
  * highest frequency, all of them will be returned.)
  */
 public static <T> Set<T> modes(Collection<T> values) {
   Counter<T> counter = new ClassicCounter<T>(values);
   List<Double> sortedCounts = CollectionUtils.sorted(counter.values());
   Double highestCount = sortedCounts.get(sortedCounts.size() - 1);
   Counters.retainAbove(counter, highestCount);
   return counter.keySet();
 }
Пример #3
0
 /**
  * Add a scaled (positive) random vector to a weights vector.
  *
  * @param wts
  * @param scale
  */
 public static void randomizeWeightsInPlace(Counter<String> wts, double scale) {
   for (String feature : wts.keySet()) {
     double epsilon = Math.random() * scale;
     double newValue = wts.getCount(feature) + epsilon;
     wts.setCount(feature, newValue);
   }
 }
 /**
  * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes
  * as well.
  */
 @Override
 @Deprecated
 public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
   RVFDatum<L, F> d0 = examples.get(0);
   int numFeatures = d0.asFeatures().size();
   int[][] data = new int[examples.size()][numFeatures];
   int[] labels = new int[examples.size()];
   labelIndex = new HashIndex<L>();
   featureIndex = new HashIndex<F>();
   for (int d = 0; d < examples.size(); d++) {
     RVFDatum<L, F> datum = examples.get(d);
     Counter<F> c = datum.asFeaturesCounter();
     for (F feature : c.keySet()) {
       if (featureIndex.add(feature)) {
         int fNo = featureIndex.indexOf(feature);
         int value = (int) c.getCount(feature);
         data[d][fNo] = value;
       }
     }
     labelIndex.add(datum.label());
     labels[d] = labelIndex.indexOf(datum.label());
   }
   int numClasses = labelIndex.size();
   return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
 }
 private double[] getModelProbs(Datum<L, F> datum) {
   double[] condDist = new double[labeledDataset.numClasses()];
   Counter<L> probCounter = classifier.probabilityOf(datum);
   for (L label : probCounter.keySet()) {
     int labelID = labeledDataset.labelIndex.indexOf(label);
     condDist[labelID] = probCounter.getCount(label);
   }
   return condDist;
 }
 @Override
 public boolean isSimpleSplit(Counter<String> feats) {
   for (String key : feats.keySet()) {
     if (key.startsWith("simple&")) {
       return true;
     }
   }
   return false;
 }
Пример #7
0
  public Object formResult() {
    Set brs = new HashSet();
    Set urs = new HashSet();
    // scan each rule / history pair
    int ruleCount = 0;
    for (Iterator pairI = rulePairs.keySet().iterator(); pairI.hasNext(); ) {
      if (ruleCount % 100 == 0) {
        System.err.println("Rules multiplied: " + ruleCount);
      }
      ruleCount++;
      Pair rulePair = (Pair) pairI.next();
      Rule baseRule = (Rule) rulePair.first;
      String baseLabel = (String) ruleToLabel.get(baseRule);
      List history = (List) rulePair.second;
      double totalProb = 0;
      for (int depth = 1; depth <= HISTORY_DEPTH() && depth <= history.size(); depth++) {
        List subHistory = history.subList(0, depth);
        double c_label = labelPairs.getCount(new Pair(baseLabel, subHistory));
        double c_rule = rulePairs.getCount(new Pair(baseRule, subHistory));
        // System.out.println("Multiplying out "+baseRule+" with history "+subHistory);
        // System.out.println("Count of "+baseLabel+" with "+subHistory+" is "+c_label);
        // System.out.println("Count of "+baseRule+" with "+subHistory+" is "+c_rule );

        double prob = (1.0 / HISTORY_DEPTH()) * (c_rule) / (c_label);
        totalProb += prob;
        for (int childDepth = 0; childDepth <= Math.min(HISTORY_DEPTH() - 1, depth); childDepth++) {
          Rule rule = specifyRule(baseRule, subHistory, childDepth);
          rule.score = (float) Math.log(totalProb);
          // System.out.println("Created  "+rule+" with score "+rule.score);
          if (rule instanceof UnaryRule) {
            urs.add(rule);
          } else {
            brs.add(rule);
          }
        }
      }
    }
    System.out.println("Total states: " + stateNumberer.total());
    BinaryGrammar bg = new BinaryGrammar(stateNumberer.total());
    UnaryGrammar ug = new UnaryGrammar(stateNumberer.total());
    for (Iterator brI = brs.iterator(); brI.hasNext(); ) {
      BinaryRule br = (BinaryRule) brI.next();
      bg.addRule(br);
    }
    for (Iterator urI = urs.iterator(); urI.hasNext(); ) {
      UnaryRule ur = (UnaryRule) urI.next();
      ug.addRule(ur);
    }
    return new Pair(ug, bg);
  }
Пример #8
0
 public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) {
   ClassicCounter<L> scores = new ClassicCounter<>();
   Counters.addInPlace(scores, priors);
   if (addZeroValued) {
     Counters.addInPlace(scores, priorZero);
   }
   for (L l : labels) {
     double score = 0.0;
     Counter<F> features = example.asFeaturesCounter();
     for (F f : features.keySet()) {
       int value = (int) features.getCount(f);
       score += weight(l, f, Integer.valueOf(value));
       if (addZeroValued) {
         score -= weight(l, f, zero);
       }
     }
     scores.incrementCount(l, score);
   }
   return scores;
 }
Пример #9
0
 public static String[] getWeightNamesFromCounter(Counter<String> wts) {
   List<String> names = new ArrayList<String>(wts.keySet());
   Collections.sort(names);
   return names.toArray(new String[0]);
 }
Пример #10
0
  @Override
  public Counter<E> score() {

    Counter<E> currentPatternWeights4Label = new ClassicCounter<>();

    Counter<E> pos_i = new ClassicCounter<>();
    Counter<E> neg_i = new ClassicCounter<>();
    Counter<E> unlab_i = new ClassicCounter<>();

    for (Entry<E, ClassicCounter<CandidatePhrase>> en : negPatternsandWords4Label.entrySet()) {
      neg_i.setCount(en.getKey(), en.getValue().size());
    }

    for (Entry<E, ClassicCounter<CandidatePhrase>> en :
        unLabeledPatternsandWords4Label.entrySet()) {
      unlab_i.setCount(en.getKey(), en.getValue().size());
    }

    for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label.entrySet()) {
      pos_i.setCount(en.getKey(), en.getValue().size());
    }

    Counter<E> all_i = Counters.add(pos_i, neg_i);
    all_i.addAll(unlab_i);
    //    for (Entry<Integer, ClassicCounter<String>> en : allPatternsandWords4Label
    //        .entrySet()) {
    //      all_i.setCount(en.getKey(), en.getValue().size());
    //    }

    Counter<E> posneg_i = Counters.add(pos_i, neg_i);
    Counter<E> logFi = new ClassicCounter<>(pos_i);
    Counters.logInPlace(logFi);

    if (patternScoring.equals(PatternScoring.RlogF)) {
      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, all_i), logFi);
    } else if (patternScoring.equals(PatternScoring.RlogFPosNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfposneg");

      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, posneg_i), logFi);

    } else if (patternScoring.equals(PatternScoring.RlogFUnlabNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfunlabeg");

      currentPatternWeights4Label =
          Counters.product(Counters.division(pos_i, Counters.add(neg_i, unlab_i)), logFi);
    } else if (patternScoring.equals(PatternScoring.RlogFNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfneg");

      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, neg_i), logFi);
    } else if (patternScoring.equals(PatternScoring.YanGarber02)) {

      Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i));
      double thetaPrecision = 0.8;
      Counters.retainAbove(acc, thetaPrecision);
      Counter<E> conf = Counters.product(Counters.division(pos_i, all_i), logFi);
      for (E p : acc.keySet()) {
        currentPatternWeights4Label.setCount(p, conf.getCount(p));
      }
    } else if (patternScoring.equals(PatternScoring.LinICML03)) {

      Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i));
      double thetaPrecision = 0.8;
      Counters.retainAbove(acc, thetaPrecision);
      Counter<E> conf =
          Counters.product(
              Counters.division(Counters.add(pos_i, Counters.scale(neg_i, -1)), all_i), logFi);
      for (E p : acc.keySet()) {
        currentPatternWeights4Label.setCount(p, conf.getCount(p));
      }
    } else {
      throw new RuntimeException("not implemented " + patternScoring + " . check spelling!");
    }
    return currentPatternWeights4Label;
  }
  public HashMap<String, Double> run() {
    HashBiMap<String, Integer> variable2id = HashBiMap.create();
    for (Entry<String, Double> e : objective.entrySet()) {
      if (!variable2id.containsKey(e.getKey())) {
        variable2id.put(e.getKey(), variable2id.size());
      }
    }
    for (Counter<String> counter : constraints) {
      for (Entry<String, Double> e : counter.entrySet()) {
        if (!variable2id.containsKey(e.getKey())) {
          variable2id.put(e.getKey(), variable2id.size());
        }
      }
    }
    int NUMVAR = variable2id.size();
    int NUMCON = constraints.size();
    int NUMANZ = 0;

    // set up objective function
    double[] c = new double[NUMVAR];
    for (Entry<String, Double> e : objective.entrySet()) {
      int vid = variable2id.get(e.getKey());
      c[vid] = e.getValue();
    }
    // set up constraints
    mosek.Env.boundkey[] bkc = new mosek.Env.boundkey[NUMCON];
    double[] blc = new double[NUMCON];
    double[] buc = new double[NUMCON];
    int asub[][] = new int[NUMCON][];
    double aval[][] = new double[NUMCON][];
    for (int i = 0; i < NUMCON; i++) {
      bkc[i] = constraints_bound_type_list.get(i);
      blc[i] = constraints_lowerbound_list.get(i);
      buc[i] = constraints_upperbound_list.get(i);
      Counter<String> counter = constraints.get(i);
      asub[i] = new int[counter.keySet().size()];
      aval[i] = new double[counter.keySet().size()];
      int k = 0;
      for (Entry<String, Double> e : counter.entrySet()) {
        int vid = variable2id.get(e.getKey());
        asub[i][k] = vid;
        aval[i][k] = e.getValue();
        k++;
        NUMANZ++;
      }
    }

    // set up variable constraints
    mosek.Env.boundkey[] bkx = new mosek.Env.boundkey[NUMVAR];
    double blx[] = new double[NUMVAR];
    double bux[] = new double[NUMVAR];
    {
      for (String v : variable2id.keySet()) {
        int vid = variable2id.get(v);
        double[] lowerupper = default_variable_lowerupper;
        if (variables_lowerupper.containsKey(v)) {
          lowerupper = variables_lowerupper.get(v);
        }
        bkx[vid] = mosek.Env.boundkey.ra;
        blx[vid] = lowerupper[0];
        bux[vid] = lowerupper[1];
      }
    }
    HashMap<String, Double> var2val = new HashMap<String, Double>();

    double[] xx =
        callILP3(NUMVAR, NUMCON, NUMANZ, c, bkx, blx, bux, asub, aval, bkc, blc, buc, false);
    for (String v : variable2id.keySet()) {
      int vid = variable2id.get(v);
      var2val.put(v, xx[vid]);
    }
    return var2val;
  }
  /** @param args */
  public static void main(String[] args) {
    if (args.length != 3) {
      System.err.printf(
          "Usage: java %s language filename features%n",
          TreebankFactoredLexiconStats.class.getName());
      System.exit(-1);
    }

    Language language = Language.valueOf(args[0]);
    TreebankLangParserParams tlpp = language.params;
    if (language.equals(Language.Arabic)) {
      String[] options = {"-arabicFactored"};
      tlpp.setOptionFlag(options, 0);
    } else {
      String[] options = {"-frenchFactored"};
      tlpp.setOptionFlag(options, 0);
    }
    Treebank tb = tlpp.diskTreebank();
    tb.loadPath(args[1]);

    MorphoFeatureSpecification morphoSpec =
        language.equals(Language.Arabic)
            ? new ArabicMorphoFeatureSpecification()
            : new FrenchMorphoFeatureSpecification();

    String[] features = args[2].trim().split(",");
    for (String feature : features) {
      morphoSpec.activate(MorphoFeatureType.valueOf(feature));
    }

    // Counters
    Counter<String> wordTagCounter = new ClassicCounter<>(30000);
    Counter<String> morphTagCounter = new ClassicCounter<>(500);
    //    Counter<String> signatureTagCounter = new ClassicCounter<String>();
    Counter<String> morphCounter = new ClassicCounter<>(500);
    Counter<String> wordCounter = new ClassicCounter<>(30000);
    Counter<String> tagCounter = new ClassicCounter<>(300);

    Counter<String> lemmaCounter = new ClassicCounter<>(25000);
    Counter<String> lemmaTagCounter = new ClassicCounter<>(25000);

    Counter<String> richTagCounter = new ClassicCounter<>(1000);

    Counter<String> reducedTagCounter = new ClassicCounter<>(500);

    Counter<String> reducedTagLemmaCounter = new ClassicCounter<>(500);

    Map<String, Set<String>> wordLemmaMap = Generics.newHashMap();

    TwoDimensionalIntCounter<String, String> lemmaReducedTagCounter =
        new TwoDimensionalIntCounter<>(30000);
    TwoDimensionalIntCounter<String, String> reducedTagTagCounter =
        new TwoDimensionalIntCounter<>(500);
    TwoDimensionalIntCounter<String, String> tagReducedTagCounter =
        new TwoDimensionalIntCounter<>(300);

    int numTrees = 0;
    for (Tree tree : tb) {
      for (Tree subTree : tree) {
        if (!subTree.isLeaf()) {
          tlpp.transformTree(subTree, tree);
        }
      }
      List<Label> pretermList = tree.preTerminalYield();
      List<Label> yield = tree.yield();
      assert yield.size() == pretermList.size();

      int yieldLen = yield.size();
      for (int i = 0; i < yieldLen; ++i) {
        String tag = pretermList.get(i).value();

        String word = yield.get(i).value();
        String morph = ((CoreLabel) yield.get(i)).originalText();

        // Note: if there is no lemma, then we use the surface form.
        Pair<String, String> lemmaTag = MorphoFeatureSpecification.splitMorphString(word, morph);
        String lemma = lemmaTag.first();
        String richTag = lemmaTag.second();

        // WSGDEBUG
        if (tag.contains("MW")) lemma += "-MWE";

        lemmaCounter.incrementCount(lemma);
        lemmaTagCounter.incrementCount(lemma + tag);

        richTagCounter.incrementCount(richTag);

        String reducedTag = morphoSpec.strToFeatures(richTag).toString();
        reducedTagCounter.incrementCount(reducedTag);

        reducedTagLemmaCounter.incrementCount(reducedTag + lemma);

        wordTagCounter.incrementCount(word + tag);
        morphTagCounter.incrementCount(morph + tag);
        morphCounter.incrementCount(morph);
        wordCounter.incrementCount(word);
        tagCounter.incrementCount(tag);

        reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
        if (wordLemmaMap.containsKey(word)) {
          wordLemmaMap.get(word).add(lemma);
        } else {
          Set<String> lemmas = Generics.newHashSet(1);
          wordLemmaMap.put(word, lemmas);
        }
        lemmaReducedTagCounter.incrementCount(lemma, reducedTag);
        reducedTagTagCounter.incrementCount(lemma + reducedTag, tag);
        tagReducedTagCounter.incrementCount(tag, reducedTag);
      }
      ++numTrees;
    }

    // Barf...
    System.out.println("Language: " + language.toString());
    System.out.printf("#trees:\t%d%n", numTrees);
    System.out.printf("#tokens:\t%d%n", (int) wordCounter.totalCount());
    System.out.printf("#words:\t%d%n", wordCounter.keySet().size());
    System.out.printf("#tags:\t%d%n", tagCounter.keySet().size());
    System.out.printf("#wordTagPairs:\t%d%n", wordTagCounter.keySet().size());
    System.out.printf("#lemmas:\t%d%n", lemmaCounter.keySet().size());
    System.out.printf("#lemmaTagPairs:\t%d%n", lemmaTagCounter.keySet().size());
    System.out.printf("#feattags:\t%d%n", reducedTagCounter.keySet().size());
    System.out.printf("#feattag+lemmas:\t%d%n", reducedTagLemmaCounter.keySet().size());
    System.out.printf("#richtags:\t%d%n", richTagCounter.keySet().size());
    System.out.printf("#richtag+lemma:\t%d%n", morphCounter.keySet().size());
    System.out.printf("#richtag+lemmaTagPairs:\t%d%n", morphTagCounter.keySet().size());

    // Extra
    System.out.println("==================");
    StringBuilder sbNoLemma = new StringBuilder();
    StringBuilder sbMultLemmas = new StringBuilder();
    for (Map.Entry<String, Set<String>> wordLemmas : wordLemmaMap.entrySet()) {
      String word = wordLemmas.getKey();
      Set<String> lemmas = wordLemmas.getValue();
      if (lemmas.size() == 0) {
        sbNoLemma.append("NO LEMMAS FOR WORD: " + word + "\n");
        continue;
      }
      if (lemmas.size() > 1) {
        sbMultLemmas.append("MULTIPLE LEMMAS: " + word + " " + setToString(lemmas) + "\n");
        continue;
      }
      String lemma = lemmas.iterator().next();
      Set<String> reducedTags = lemmaReducedTagCounter.getCounter(lemma).keySet();
      if (reducedTags.size() > 1) {
        System.out.printf("%s --> %s%n", word, lemma);
        for (String reducedTag : reducedTags) {
          int count = lemmaReducedTagCounter.getCount(lemma, reducedTag);
          String posTags =
              setToString(reducedTagTagCounter.getCounter(lemma + reducedTag).keySet());
          System.out.printf("\t%s\t%d\t%s%n", reducedTag, count, posTags);
        }
        System.out.println();
      }
    }
    System.out.println("==================");
    System.out.println(sbNoLemma.toString());
    System.out.println(sbMultLemmas.toString());
    System.out.println("==================");
    List<String> tags = new ArrayList<>(tagReducedTagCounter.firstKeySet());
    Collections.sort(tags);
    for (String tag : tags) {
      System.out.println(tag);
      Set<String> reducedTags = tagReducedTagCounter.getCounter(tag).keySet();
      for (String reducedTag : reducedTags) {
        int count = tagReducedTagCounter.getCount(tag, reducedTag);
        //        reducedTag = reducedTag.equals("") ? "NONE" : reducedTag;
        System.out.printf("\t%s\t%d%n", reducedTag, count);
      }
      System.out.println();
    }
    System.out.println("==================");
  }