예제 #1
0
  public Object formResult() {
    Set brs = new HashSet();
    Set urs = new HashSet();
    // scan each rule / history pair
    int ruleCount = 0;
    for (Iterator pairI = rulePairs.keySet().iterator(); pairI.hasNext(); ) {
      if (ruleCount % 100 == 0) {
        System.err.println("Rules multiplied: " + ruleCount);
      }
      ruleCount++;
      Pair rulePair = (Pair) pairI.next();
      Rule baseRule = (Rule) rulePair.first;
      String baseLabel = (String) ruleToLabel.get(baseRule);
      List history = (List) rulePair.second;
      double totalProb = 0;
      for (int depth = 1; depth <= HISTORY_DEPTH() && depth <= history.size(); depth++) {
        List subHistory = history.subList(0, depth);
        double c_label = labelPairs.getCount(new Pair(baseLabel, subHistory));
        double c_rule = rulePairs.getCount(new Pair(baseRule, subHistory));
        // System.out.println("Multiplying out "+baseRule+" with history "+subHistory);
        // System.out.println("Count of "+baseLabel+" with "+subHistory+" is "+c_label);
        // System.out.println("Count of "+baseRule+" with "+subHistory+" is "+c_rule );

        double prob = (1.0 / HISTORY_DEPTH()) * (c_rule) / (c_label);
        totalProb += prob;
        for (int childDepth = 0; childDepth <= Math.min(HISTORY_DEPTH() - 1, depth); childDepth++) {
          Rule rule = specifyRule(baseRule, subHistory, childDepth);
          rule.score = (float) Math.log(totalProb);
          // System.out.println("Created  "+rule+" with score "+rule.score);
          if (rule instanceof UnaryRule) {
            urs.add(rule);
          } else {
            brs.add(rule);
          }
        }
      }
    }
    System.out.println("Total states: " + stateNumberer.total());
    BinaryGrammar bg = new BinaryGrammar(stateNumberer.total());
    UnaryGrammar ug = new UnaryGrammar(stateNumberer.total());
    for (Iterator brI = brs.iterator(); brI.hasNext(); ) {
      BinaryRule br = (BinaryRule) brI.next();
      bg.addRule(br);
    }
    for (Iterator urI = urs.iterator(); urI.hasNext(); ) {
      UnaryRule ur = (UnaryRule) urI.next();
      ug.addRule(ur);
    }
    return new Pair(ug, bg);
  }
예제 #2
0
  public static void main(String[] args) throws Exception {
    CompareType compareType = CompareType.MAX_COOR_ABS;

    if (args.length == 0) {
      usage();
      System.exit(-1);
    }

    if ("-cosine".equals(args[0])) {
      compareType = CompareType.COSINE;
      args = (String[]) Arrays.copyOfRange(args, 1, args.length);
    } else if ("-sse".equals(args[0])) {
      compareType = CompareType.SUM_SQUARE_ERROR;
      args = (String[]) Arrays.copyOfRange(args, 1, args.length);
    }

    if (args.length != 2) {
      usage();
      System.exit(-1);
    }

    Set<String> allWeights = new HashSet<String>();
    Counter<String> wts1 = IOTools.readWeights(args[0]);
    Counter<String> wts2 = IOTools.readWeights(args[1]);
    allWeights.addAll(wts1.keySet());
    allWeights.addAll(wts2.keySet());

    if (compareType == CompareType.MAX_COOR_ABS) {
      double maxDiff = Double.NEGATIVE_INFINITY;
      Counters.multiplyInPlace(wts1, 1.0 / Counters.L1Norm(wts1));
      Counters.multiplyInPlace(wts2, 1.0 / Counters.L1Norm(wts2));
      for (String wt : allWeights) {
        double absDiff = Math.abs(wts1.getCount(wt) - wts2.getCount(wt));
        if (absDiff > maxDiff) maxDiff = absDiff;
      }
      System.out.println(maxDiff);
    } else if (compareType == CompareType.COSINE) {
      double dotProd = Counters.cosine(wts1, wts2);
      System.out.println(dotProd);
    } else if (compareType == CompareType.SUM_SQUARE_ERROR) {
      double sse = 0;
      for (String wt : allWeights) {
        double diff = wts1.getCount(wt) - wts2.getCount(wt);
        sse += diff * diff;
      }
      System.out.println(sse);
    }
  }
예제 #3
0
 /**
  * Add a scaled (positive) random vector to a weights vector.
  *
  * @param wts
  * @param scale
  */
 public static void randomizeWeightsInPlace(Counter<String> wts, double scale) {
   for (String feature : wts.keySet()) {
     double epsilon = Math.random() * scale;
     double newValue = wts.getCount(feature) + epsilon;
     wts.setCount(feature, newValue);
   }
 }
예제 #4
0
 public static double[] getWeightArrayFromCounter(String[] weightNames, Counter<String> wts) {
   double[] wtsArr = new double[weightNames.length];
   for (int i = 0; i < wtsArr.length; i++) {
     wtsArr[i] = wts.getCount(weightNames[i]);
   }
   return wtsArr;
 }
 /**
  * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes
  * as well.
  */
 @Override
 @Deprecated
 public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
   RVFDatum<L, F> d0 = examples.get(0);
   int numFeatures = d0.asFeatures().size();
   int[][] data = new int[examples.size()][numFeatures];
   int[] labels = new int[examples.size()];
   labelIndex = new HashIndex<L>();
   featureIndex = new HashIndex<F>();
   for (int d = 0; d < examples.size(); d++) {
     RVFDatum<L, F> datum = examples.get(d);
     Counter<F> c = datum.asFeaturesCounter();
     for (F feature : c.keySet()) {
       if (featureIndex.add(feature)) {
         int fNo = featureIndex.indexOf(feature);
         int value = (int) c.getCount(feature);
         data[d][fNo] = value;
       }
     }
     labelIndex.add(datum.label());
     labels[d] = labelIndex.indexOf(datum.label());
   }
   int numClasses = labelIndex.size();
   return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
 }
예제 #6
0
 public static double scoreTranslation(
     Counter<String> wts, ScoredFeaturizedTranslation<IString, String> trans) {
   double s = 0;
   for (FeatureValue<String> fv : trans.features) {
     s += fv.value * wts.getCount(fv.name);
   }
   return s;
 }
 private double[] getModelProbs(Datum<L, F> datum) {
   double[] condDist = new double[labeledDataset.numClasses()];
   Counter<L> probCounter = classifier.probabilityOf(datum);
   for (L label : probCounter.keySet()) {
     int labelID = labeledDataset.labelIndex.indexOf(label);
     condDist[labelID] = probCounter.getCount(label);
   }
   return condDist;
 }
예제 #8
0
 @Test
 public void tokenMatch() {
   String[] text = new String[] {"what", "tv", "program", "have", "hugh", "laurie", "create"};
   String[] pattern = new String[] {"program", "create"};
   Counter<String> match =
       TokenLevelMatchFeatures.extractTokenMatchFeatures(
           Arrays.asList(text), Arrays.asList(pattern), true);
   assertEquals(0.5, match.getCount("prefix"), 0.00001);
   assertEquals(0.5, match.getCount("suffix"), 0.00001);
 }
예제 #9
0
파일: Util.java 프로젝트: foxlf823/ade
  public static List<String> generateDict(List<String> str, int cutOff) {
    Counter<String> freq = new IntCounter<>();
    for (String aStr : str) freq.incrementCount(aStr);

    List<String> keys = Counters.toSortedList(freq, false);
    List<String> dict = new ArrayList<>();
    for (String word : keys) {
      if (freq.getCount(word) >= cutOff) dict.add(word);
    }
    return dict;
  }
예제 #10
0
 @Override
 public Counter<L> scoresOf(Datum<L, F> example) {
   Counter<L> scores = new ClassicCounter<>();
   for (L label : labelIndex) {
     Map<L, String> posLabelMap = new ArrayMap<>();
     posLabelMap.put(label, POS_LABEL);
     Datum<String, F> binDatum = GeneralDataset.mapDatum(example, posLabelMap, NEG_LABEL);
     Classifier<String, F> binaryClassifier = getBinaryClassifier(label);
     Counter<String> binScores = binaryClassifier.scoresOf(binDatum);
     double score = binScores.getCount(POS_LABEL);
     scores.setCount(label, score);
   }
   return scores;
 }
예제 #11
0
 public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) {
   ClassicCounter<L> scores = new ClassicCounter<>();
   Counters.addInPlace(scores, priors);
   if (addZeroValued) {
     Counters.addInPlace(scores, priorZero);
   }
   for (L l : labels) {
     double score = 0.0;
     Counter<F> features = example.asFeaturesCounter();
     for (F f : features.keySet()) {
       int value = (int) features.getCount(f);
       score += weight(l, f, Integer.valueOf(value));
       if (addZeroValued) {
         score -= weight(l, f, zero);
       }
     }
     scores.incrementCount(l, score);
   }
   return scores;
 }
예제 #12
0
  @Override
  public Counter<E> score() {

    Counter<E> currentPatternWeights4Label = new ClassicCounter<>();

    Counter<E> pos_i = new ClassicCounter<>();
    Counter<E> neg_i = new ClassicCounter<>();
    Counter<E> unlab_i = new ClassicCounter<>();

    for (Entry<E, ClassicCounter<CandidatePhrase>> en : negPatternsandWords4Label.entrySet()) {
      neg_i.setCount(en.getKey(), en.getValue().size());
    }

    for (Entry<E, ClassicCounter<CandidatePhrase>> en :
        unLabeledPatternsandWords4Label.entrySet()) {
      unlab_i.setCount(en.getKey(), en.getValue().size());
    }

    for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label.entrySet()) {
      pos_i.setCount(en.getKey(), en.getValue().size());
    }

    Counter<E> all_i = Counters.add(pos_i, neg_i);
    all_i.addAll(unlab_i);
    //    for (Entry<Integer, ClassicCounter<String>> en : allPatternsandWords4Label
    //        .entrySet()) {
    //      all_i.setCount(en.getKey(), en.getValue().size());
    //    }

    Counter<E> posneg_i = Counters.add(pos_i, neg_i);
    Counter<E> logFi = new ClassicCounter<>(pos_i);
    Counters.logInPlace(logFi);

    if (patternScoring.equals(PatternScoring.RlogF)) {
      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, all_i), logFi);
    } else if (patternScoring.equals(PatternScoring.RlogFPosNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfposneg");

      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, posneg_i), logFi);

    } else if (patternScoring.equals(PatternScoring.RlogFUnlabNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfunlabeg");

      currentPatternWeights4Label =
          Counters.product(Counters.division(pos_i, Counters.add(neg_i, unlab_i)), logFi);
    } else if (patternScoring.equals(PatternScoring.RlogFNeg)) {
      Redwood.log("extremePatDebug", "computing rlogfneg");

      currentPatternWeights4Label = Counters.product(Counters.division(pos_i, neg_i), logFi);
    } else if (patternScoring.equals(PatternScoring.YanGarber02)) {

      Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i));
      double thetaPrecision = 0.8;
      Counters.retainAbove(acc, thetaPrecision);
      Counter<E> conf = Counters.product(Counters.division(pos_i, all_i), logFi);
      for (E p : acc.keySet()) {
        currentPatternWeights4Label.setCount(p, conf.getCount(p));
      }
    } else if (patternScoring.equals(PatternScoring.LinICML03)) {

      Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i));
      double thetaPrecision = 0.8;
      Counters.retainAbove(acc, thetaPrecision);
      Counter<E> conf =
          Counters.product(
              Counters.division(Counters.add(pos_i, Counters.scale(neg_i, -1)), all_i), logFi);
      for (E p : acc.keySet()) {
        currentPatternWeights4Label.setCount(p, conf.getCount(p));
      }
    } else {
      throw new RuntimeException("not implemented " + patternScoring + " . check spelling!");
    }
    return currentPatternWeights4Label;
  }