public static void main(String[] args) throws Exception { CompareType compareType = CompareType.MAX_COOR_ABS; if (args.length == 0) { usage(); System.exit(-1); } if ("-cosine".equals(args[0])) { compareType = CompareType.COSINE; args = (String[]) Arrays.copyOfRange(args, 1, args.length); } else if ("-sse".equals(args[0])) { compareType = CompareType.SUM_SQUARE_ERROR; args = (String[]) Arrays.copyOfRange(args, 1, args.length); } if (args.length != 2) { usage(); System.exit(-1); } Set<String> allWeights = new HashSet<String>(); Counter<String> wts1 = IOTools.readWeights(args[0]); Counter<String> wts2 = IOTools.readWeights(args[1]); allWeights.addAll(wts1.keySet()); allWeights.addAll(wts2.keySet()); if (compareType == CompareType.MAX_COOR_ABS) { double maxDiff = Double.NEGATIVE_INFINITY; Counters.multiplyInPlace(wts1, 1.0 / Counters.L1Norm(wts1)); Counters.multiplyInPlace(wts2, 1.0 / Counters.L1Norm(wts2)); for (String wt : allWeights) { double absDiff = Math.abs(wts1.getCount(wt) - wts2.getCount(wt)); if (absDiff > maxDiff) maxDiff = absDiff; } System.out.println(maxDiff); } else if (compareType == CompareType.COSINE) { double dotProd = Counters.cosine(wts1, wts2); System.out.println(dotProd); } else if (compareType == CompareType.SUM_SQUARE_ERROR) { double sse = 0; for (String wt : allWeights) { double diff = wts1.getCount(wt) - wts2.getCount(wt); sse += diff * diff; } System.out.println(sse); } }
/** * Returns a list of all modes in the Collection. (If the Collection has multiple items with the * highest frequency, all of them will be returned.) */ public static <T> Set<T> modes(Collection<T> values) { Counter<T> counter = new ClassicCounter<T>(values); List<Double> sortedCounts = CollectionUtils.sorted(counter.values()); Double highestCount = sortedCounts.get(sortedCounts.size() - 1); Counters.retainAbove(counter, highestCount); return counter.keySet(); }
/** * Add a scaled (positive) random vector to a weights vector. * * @param wts * @param scale */ public static void randomizeWeightsInPlace(Counter<String> wts, double scale) { for (String feature : wts.keySet()) { double epsilon = Math.random() * scale; double newValue = wts.getCount(feature) + epsilon; wts.setCount(feature, newValue); } }
/** * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes * as well. */ @Override @Deprecated public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) { RVFDatum<L, F> d0 = examples.get(0); int numFeatures = d0.asFeatures().size(); int[][] data = new int[examples.size()][numFeatures]; int[] labels = new int[examples.size()]; labelIndex = new HashIndex<L>(); featureIndex = new HashIndex<F>(); for (int d = 0; d < examples.size(); d++) { RVFDatum<L, F> datum = examples.get(d); Counter<F> c = datum.asFeaturesCounter(); for (F feature : c.keySet()) { if (featureIndex.add(feature)) { int fNo = featureIndex.indexOf(feature); int value = (int) c.getCount(feature); data[d][fNo] = value; } } labelIndex.add(datum.label()); labels[d] = labelIndex.indexOf(datum.label()); } int numClasses = labelIndex.size(); return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex); }
private double[] getModelProbs(Datum<L, F> datum) { double[] condDist = new double[labeledDataset.numClasses()]; Counter<L> probCounter = classifier.probabilityOf(datum); for (L label : probCounter.keySet()) { int labelID = labeledDataset.labelIndex.indexOf(label); condDist[labelID] = probCounter.getCount(label); } return condDist; }
@Override public boolean isSimpleSplit(Counter<String> feats) { for (String key : feats.keySet()) { if (key.startsWith("simple&")) { return true; } } return false; }
public Object formResult() { Set brs = new HashSet(); Set urs = new HashSet(); // scan each rule / history pair int ruleCount = 0; for (Iterator pairI = rulePairs.keySet().iterator(); pairI.hasNext(); ) { if (ruleCount % 100 == 0) { System.err.println("Rules multiplied: " + ruleCount); } ruleCount++; Pair rulePair = (Pair) pairI.next(); Rule baseRule = (Rule) rulePair.first; String baseLabel = (String) ruleToLabel.get(baseRule); List history = (List) rulePair.second; double totalProb = 0; for (int depth = 1; depth <= HISTORY_DEPTH() && depth <= history.size(); depth++) { List subHistory = history.subList(0, depth); double c_label = labelPairs.getCount(new Pair(baseLabel, subHistory)); double c_rule = rulePairs.getCount(new Pair(baseRule, subHistory)); // System.out.println("Multiplying out "+baseRule+" with history "+subHistory); // System.out.println("Count of "+baseLabel+" with "+subHistory+" is "+c_label); // System.out.println("Count of "+baseRule+" with "+subHistory+" is "+c_rule ); double prob = (1.0 / HISTORY_DEPTH()) * (c_rule) / (c_label); totalProb += prob; for (int childDepth = 0; childDepth <= Math.min(HISTORY_DEPTH() - 1, depth); childDepth++) { Rule rule = specifyRule(baseRule, subHistory, childDepth); rule.score = (float) Math.log(totalProb); // System.out.println("Created "+rule+" with score "+rule.score); if (rule instanceof UnaryRule) { urs.add(rule); } else { brs.add(rule); } } } } System.out.println("Total states: " + stateNumberer.total()); BinaryGrammar bg = new BinaryGrammar(stateNumberer.total()); UnaryGrammar ug = new UnaryGrammar(stateNumberer.total()); for (Iterator brI = brs.iterator(); brI.hasNext(); ) { BinaryRule br = (BinaryRule) brI.next(); bg.addRule(br); } for (Iterator urI = urs.iterator(); urI.hasNext(); ) { UnaryRule ur = (UnaryRule) urI.next(); ug.addRule(ur); } return new Pair(ug, bg); }
public ClassicCounter<L> scoresOf(RVFDatum<L, F> example) { ClassicCounter<L> scores = new ClassicCounter<>(); Counters.addInPlace(scores, priors); if (addZeroValued) { Counters.addInPlace(scores, priorZero); } for (L l : labels) { double score = 0.0; Counter<F> features = example.asFeaturesCounter(); for (F f : features.keySet()) { int value = (int) features.getCount(f); score += weight(l, f, Integer.valueOf(value)); if (addZeroValued) { score -= weight(l, f, zero); } } scores.incrementCount(l, score); } return scores; }
public static String[] getWeightNamesFromCounter(Counter<String> wts) { List<String> names = new ArrayList<String>(wts.keySet()); Collections.sort(names); return names.toArray(new String[0]); }
@Override public Counter<E> score() { Counter<E> currentPatternWeights4Label = new ClassicCounter<>(); Counter<E> pos_i = new ClassicCounter<>(); Counter<E> neg_i = new ClassicCounter<>(); Counter<E> unlab_i = new ClassicCounter<>(); for (Entry<E, ClassicCounter<CandidatePhrase>> en : negPatternsandWords4Label.entrySet()) { neg_i.setCount(en.getKey(), en.getValue().size()); } for (Entry<E, ClassicCounter<CandidatePhrase>> en : unLabeledPatternsandWords4Label.entrySet()) { unlab_i.setCount(en.getKey(), en.getValue().size()); } for (Entry<E, ClassicCounter<CandidatePhrase>> en : patternsandWords4Label.entrySet()) { pos_i.setCount(en.getKey(), en.getValue().size()); } Counter<E> all_i = Counters.add(pos_i, neg_i); all_i.addAll(unlab_i); // for (Entry<Integer, ClassicCounter<String>> en : allPatternsandWords4Label // .entrySet()) { // all_i.setCount(en.getKey(), en.getValue().size()); // } Counter<E> posneg_i = Counters.add(pos_i, neg_i); Counter<E> logFi = new ClassicCounter<>(pos_i); Counters.logInPlace(logFi); if (patternScoring.equals(PatternScoring.RlogF)) { currentPatternWeights4Label = Counters.product(Counters.division(pos_i, all_i), logFi); } else if (patternScoring.equals(PatternScoring.RlogFPosNeg)) { Redwood.log("extremePatDebug", "computing rlogfposneg"); currentPatternWeights4Label = Counters.product(Counters.division(pos_i, posneg_i), logFi); } else if (patternScoring.equals(PatternScoring.RlogFUnlabNeg)) { Redwood.log("extremePatDebug", "computing rlogfunlabeg"); currentPatternWeights4Label = Counters.product(Counters.division(pos_i, Counters.add(neg_i, unlab_i)), logFi); } else if (patternScoring.equals(PatternScoring.RlogFNeg)) { Redwood.log("extremePatDebug", "computing rlogfneg"); currentPatternWeights4Label = Counters.product(Counters.division(pos_i, neg_i), logFi); } else if (patternScoring.equals(PatternScoring.YanGarber02)) { Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i)); double thetaPrecision = 0.8; Counters.retainAbove(acc, thetaPrecision); Counter<E> conf = Counters.product(Counters.division(pos_i, all_i), logFi); for (E p : acc.keySet()) { currentPatternWeights4Label.setCount(p, conf.getCount(p)); } } else if (patternScoring.equals(PatternScoring.LinICML03)) { Counter<E> acc = Counters.division(pos_i, Counters.add(pos_i, neg_i)); double thetaPrecision = 0.8; Counters.retainAbove(acc, thetaPrecision); Counter<E> conf = Counters.product( Counters.division(Counters.add(pos_i, Counters.scale(neg_i, -1)), all_i), logFi); for (E p : acc.keySet()) { currentPatternWeights4Label.setCount(p, conf.getCount(p)); } } else { throw new RuntimeException("not implemented " + patternScoring + " . check spelling!"); } return currentPatternWeights4Label; }
public HashMap<String, Double> run() { HashBiMap<String, Integer> variable2id = HashBiMap.create(); for (Entry<String, Double> e : objective.entrySet()) { if (!variable2id.containsKey(e.getKey())) { variable2id.put(e.getKey(), variable2id.size()); } } for (Counter<String> counter : constraints) { for (Entry<String, Double> e : counter.entrySet()) { if (!variable2id.containsKey(e.getKey())) { variable2id.put(e.getKey(), variable2id.size()); } } } int NUMVAR = variable2id.size(); int NUMCON = constraints.size(); int NUMANZ = 0; // set up objective function double[] c = new double[NUMVAR]; for (Entry<String, Double> e : objective.entrySet()) { int vid = variable2id.get(e.getKey()); c[vid] = e.getValue(); } // set up constraints mosek.Env.boundkey[] bkc = new mosek.Env.boundkey[NUMCON]; double[] blc = new double[NUMCON]; double[] buc = new double[NUMCON]; int asub[][] = new int[NUMCON][]; double aval[][] = new double[NUMCON][]; for (int i = 0; i < NUMCON; i++) { bkc[i] = constraints_bound_type_list.get(i); blc[i] = constraints_lowerbound_list.get(i); buc[i] = constraints_upperbound_list.get(i); Counter<String> counter = constraints.get(i); asub[i] = new int[counter.keySet().size()]; aval[i] = new double[counter.keySet().size()]; int k = 0; for (Entry<String, Double> e : counter.entrySet()) { int vid = variable2id.get(e.getKey()); asub[i][k] = vid; aval[i][k] = e.getValue(); k++; NUMANZ++; } } // set up variable constraints mosek.Env.boundkey[] bkx = new mosek.Env.boundkey[NUMVAR]; double blx[] = new double[NUMVAR]; double bux[] = new double[NUMVAR]; { for (String v : variable2id.keySet()) { int vid = variable2id.get(v); double[] lowerupper = default_variable_lowerupper; if (variables_lowerupper.containsKey(v)) { lowerupper = variables_lowerupper.get(v); } bkx[vid] = mosek.Env.boundkey.ra; blx[vid] = lowerupper[0]; bux[vid] = lowerupper[1]; } } HashMap<String, Double> var2val = new HashMap<String, Double>(); double[] xx = callILP3(NUMVAR, NUMCON, NUMANZ, c, bkx, blx, bux, asub, aval, bkc, blc, buc, false); for (String v : variable2id.keySet()) { int vid = variable2id.get(v); var2val.put(v, xx[vid]); } return var2val; }
/** @param args */ public static void main(String[] args) { if (args.length != 3) { System.err.printf( "Usage: java %s language filename features%n", TreebankFactoredLexiconStats.class.getName()); System.exit(-1); } Language language = Language.valueOf(args[0]); TreebankLangParserParams tlpp = language.params; if (language.equals(Language.Arabic)) { String[] options = {"-arabicFactored"}; tlpp.setOptionFlag(options, 0); } else { String[] options = {"-frenchFactored"}; tlpp.setOptionFlag(options, 0); } Treebank tb = tlpp.diskTreebank(); tb.loadPath(args[1]); MorphoFeatureSpecification morphoSpec = language.equals(Language.Arabic) ? new ArabicMorphoFeatureSpecification() : new FrenchMorphoFeatureSpecification(); String[] features = args[2].trim().split(","); for (String feature : features) { morphoSpec.activate(MorphoFeatureType.valueOf(feature)); } // Counters Counter<String> wordTagCounter = new ClassicCounter<>(30000); Counter<String> morphTagCounter = new ClassicCounter<>(500); // Counter<String> signatureTagCounter = new ClassicCounter<String>(); Counter<String> morphCounter = new ClassicCounter<>(500); Counter<String> wordCounter = new ClassicCounter<>(30000); Counter<String> tagCounter = new ClassicCounter<>(300); Counter<String> lemmaCounter = new ClassicCounter<>(25000); Counter<String> lemmaTagCounter = new ClassicCounter<>(25000); Counter<String> richTagCounter = new ClassicCounter<>(1000); Counter<String> reducedTagCounter = new ClassicCounter<>(500); Counter<String> reducedTagLemmaCounter = new ClassicCounter<>(500); Map<String, Set<String>> wordLemmaMap = Generics.newHashMap(); TwoDimensionalIntCounter<String, String> lemmaReducedTagCounter = new TwoDimensionalIntCounter<>(30000); TwoDimensionalIntCounter<String, String> reducedTagTagCounter = new TwoDimensionalIntCounter<>(500); TwoDimensionalIntCounter<String, String> tagReducedTagCounter = new TwoDimensionalIntCounter<>(300); int numTrees = 0; for (Tree tree : tb) { for (Tree subTree : tree) { if (!subTree.isLeaf()) { tlpp.transformTree(subTree, tree); } } List<Label> pretermList = tree.preTerminalYield(); List<Label> yield = tree.yield(); assert yield.size() == pretermList.size(); int yieldLen = yield.size(); for (int i = 0; i < yieldLen; ++i) { String tag = pretermList.get(i).value(); String word = yield.get(i).value(); String morph = ((CoreLabel) yield.get(i)).originalText(); // Note: if there is no lemma, then we use the surface form. Pair<String, String> lemmaTag = MorphoFeatureSpecification.splitMorphString(word, morph); String lemma = lemmaTag.first(); String richTag = lemmaTag.second(); // WSGDEBUG if (tag.contains("MW")) lemma += "-MWE"; lemmaCounter.incrementCount(lemma); lemmaTagCounter.incrementCount(lemma + tag); richTagCounter.incrementCount(richTag); String reducedTag = morphoSpec.strToFeatures(richTag).toString(); reducedTagCounter.incrementCount(reducedTag); reducedTagLemmaCounter.incrementCount(reducedTag + lemma); wordTagCounter.incrementCount(word + tag); morphTagCounter.incrementCount(morph + tag); morphCounter.incrementCount(morph); wordCounter.incrementCount(word); tagCounter.incrementCount(tag); reducedTag = reducedTag.equals("") ? "NONE" : reducedTag; if (wordLemmaMap.containsKey(word)) { wordLemmaMap.get(word).add(lemma); } else { Set<String> lemmas = Generics.newHashSet(1); wordLemmaMap.put(word, lemmas); } lemmaReducedTagCounter.incrementCount(lemma, reducedTag); reducedTagTagCounter.incrementCount(lemma + reducedTag, tag); tagReducedTagCounter.incrementCount(tag, reducedTag); } ++numTrees; } // Barf... System.out.println("Language: " + language.toString()); System.out.printf("#trees:\t%d%n", numTrees); System.out.printf("#tokens:\t%d%n", (int) wordCounter.totalCount()); System.out.printf("#words:\t%d%n", wordCounter.keySet().size()); System.out.printf("#tags:\t%d%n", tagCounter.keySet().size()); System.out.printf("#wordTagPairs:\t%d%n", wordTagCounter.keySet().size()); System.out.printf("#lemmas:\t%d%n", lemmaCounter.keySet().size()); System.out.printf("#lemmaTagPairs:\t%d%n", lemmaTagCounter.keySet().size()); System.out.printf("#feattags:\t%d%n", reducedTagCounter.keySet().size()); System.out.printf("#feattag+lemmas:\t%d%n", reducedTagLemmaCounter.keySet().size()); System.out.printf("#richtags:\t%d%n", richTagCounter.keySet().size()); System.out.printf("#richtag+lemma:\t%d%n", morphCounter.keySet().size()); System.out.printf("#richtag+lemmaTagPairs:\t%d%n", morphTagCounter.keySet().size()); // Extra System.out.println("=================="); StringBuilder sbNoLemma = new StringBuilder(); StringBuilder sbMultLemmas = new StringBuilder(); for (Map.Entry<String, Set<String>> wordLemmas : wordLemmaMap.entrySet()) { String word = wordLemmas.getKey(); Set<String> lemmas = wordLemmas.getValue(); if (lemmas.size() == 0) { sbNoLemma.append("NO LEMMAS FOR WORD: " + word + "\n"); continue; } if (lemmas.size() > 1) { sbMultLemmas.append("MULTIPLE LEMMAS: " + word + " " + setToString(lemmas) + "\n"); continue; } String lemma = lemmas.iterator().next(); Set<String> reducedTags = lemmaReducedTagCounter.getCounter(lemma).keySet(); if (reducedTags.size() > 1) { System.out.printf("%s --> %s%n", word, lemma); for (String reducedTag : reducedTags) { int count = lemmaReducedTagCounter.getCount(lemma, reducedTag); String posTags = setToString(reducedTagTagCounter.getCounter(lemma + reducedTag).keySet()); System.out.printf("\t%s\t%d\t%s%n", reducedTag, count, posTags); } System.out.println(); } } System.out.println("=================="); System.out.println(sbNoLemma.toString()); System.out.println(sbMultLemmas.toString()); System.out.println("=================="); List<String> tags = new ArrayList<>(tagReducedTagCounter.firstKeySet()); Collections.sort(tags); for (String tag : tags) { System.out.println(tag); Set<String> reducedTags = tagReducedTagCounter.getCounter(tag).keySet(); for (String reducedTag : reducedTags) { int count = tagReducedTagCounter.getCount(tag, reducedTag); // reducedTag = reducedTag.equals("") ? "NONE" : reducedTag; System.out.printf("\t%s\t%d%n", reducedTag, count); } System.out.println(); } System.out.println("=================="); }