/** * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes * as well. */ @Override @Deprecated public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) { RVFDatum<L, F> d0 = examples.get(0); int numFeatures = d0.asFeatures().size(); int[][] data = new int[examples.size()][numFeatures]; int[] labels = new int[examples.size()]; labelIndex = new HashIndex<L>(); featureIndex = new HashIndex<F>(); for (int d = 0; d < examples.size(); d++) { RVFDatum<L, F> datum = examples.get(d); Counter<F> c = datum.asFeaturesCounter(); for (F feature : c.keySet()) { if (featureIndex.add(feature)) { int fNo = featureIndex.indexOf(feature); int value = (int) c.getCount(feature); data[d][fNo] = value; } } labelIndex.add(datum.label()); labels[d] = labelIndex.indexOf(datum.label()); } int numClasses = labelIndex.size(); return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex); }
private int add(AmbiguityClass a) { if (classes.contains(a)) { return classes.indexOf(a); } classes.add(a); return classes.indexOf(a); }
/** * Evaluates how many words (= terminals) in a collection of trees are covered by the lexicon. * First arg is the collection of trees; second through fourth args get the results. Currently * unused; this probably only works if train and test at same time so tags and words variables are * initialized. */ public double evaluateCoverage( Collection<Tree> trees, Set<String> missingWords, Set<String> missingTags, Set<IntTaggedWord> missingTW) { List<IntTaggedWord> iTW1 = new ArrayList<IntTaggedWord>(); for (Tree t : trees) { iTW1.addAll(treeToEvents(t)); } int total = 0; int unseen = 0; for (IntTaggedWord itw : iTW1) { total++; if (!words.contains(new IntTaggedWord(itw.word(), nullTag))) { missingWords.add(wordIndex.get(itw.word())); } if (!tags.contains(new IntTaggedWord(nullWord, itw.tag()))) { missingTags.add(tagIndex.get(itw.tag())); } // if (!rules.contains(itw)) { if (seenCounter.getCount(itw) == 0.0) { unseen++; missingTW.add(itw); } } return (double) unseen / total; }
private NaiveBayesClassifier<L, F> trainClassifier( int[][] data, int[] labels, int numFeatures, int numClasses, Index<L> labelIndex, Index<F> featureIndex) { Set<L> labelSet = Generics.newHashSet(); NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses); Counter<L> priors = new ClassicCounter<L>(); double[] pr = nbWeights.priors; for (int i = 0; i < pr.length; i++) { priors.incrementCount(labelIndex.get(i), pr[i]); labelSet.add(labelIndex.get(i)); } Counter<Pair<Pair<L, F>, Number>> weightsCounter = new ClassicCounter<Pair<Pair<L, F>, Number>>(); double[][][] wts = nbWeights.weights; for (int c = 0; c < numClasses; c++) { L label = labelIndex.get(c); for (int f = 0; f < numFeatures; f++) { F feature = featureIndex.get(f); Pair<L, F> p = new Pair<L, F>(label, feature); for (int val = 0; val < wts[c][f].length; val++) { Pair<Pair<L, F>, Number> key = new Pair<Pair<L, F>, Number>(p, Integer.valueOf(val)); weightsCounter.incrementCount(key, wts[c][f][val]); } } } return new NaiveBayesClassifier<L, F>(weightsCounter, priors, labelSet); }
/** * Provides some testing and opportunities for exploration of the probabilities of a BaseLexicon. * What's here currently probably only works for the English Penn Treeebank, as it uses default * constructors. Of the words given to test on, the first is treated as sentence initial, and the * rest as not sentence initial. * * @param args The command line arguments: java BaseLexicon treebankPath fileRange * unknownWordModel words* */ public static void main(String[] args) { if (args.length < 3) { System.err.println("java BaseLexicon treebankPath fileRange unknownWordModel words*"); return; } System.out.print("Training BaseLexicon from " + args[0] + ' ' + args[1] + " ... "); Treebank tb = new DiskTreebank(); tb.loadPath(args[0], new NumberRangesFileFilter(args[1], true)); // TODO: change this interface so the lexicon creates its own indices? Index<String> wordIndex = new HashIndex<String>(); Index<String> tagIndex = new HashIndex<String>(); BaseLexicon lex = new BaseLexicon(wordIndex, tagIndex); lex.getUnknownWordModel().setUnknownLevel(Integer.parseInt(args[2])); lex.train(tb); System.out.println("done."); System.out.println(); NumberFormat nf = NumberFormat.getNumberInstance(); nf.setMaximumFractionDigits(4); List<String> impos = new ArrayList<String>(); for (int i = 3; i < args.length; i++) { if (lex.isKnown(args[i])) { System.out.println( args[i] + " is a known word. Log probabilities [log P(w|t)] for its taggings are:"); for (Iterator<IntTaggedWord> it = lex.ruleIteratorByWord(wordIndex.indexOf(args[i], true), i - 3, null); it.hasNext(); ) { IntTaggedWord iTW = it.next(); System.out.println( StringUtils.pad(iTW, 24) + nf.format(lex.score(iTW, i - 3, wordIndex.get(iTW.word)))); } } else { String sig = lex.getUnknownWordModel().getSignature(args[i], i - 3); System.out.println( args[i] + " is an unknown word. Signature with uwm " + lex.getUnknownWordModel().getUnknownLevel() + ((i == 3) ? " init" : "non-init") + " is: " + sig); impos.clear(); List<String> lis = new ArrayList<String>(tagIndex.objectsList()); Collections.sort(lis); for (String tStr : lis) { IntTaggedWord iTW = new IntTaggedWord(args[i], tStr, wordIndex, tagIndex); double score = lex.score(iTW, 1, args[i]); if (score == Float.NEGATIVE_INFINITY) { impos.add(tStr); } else { System.out.println(StringUtils.pad(iTW, 24) + nf.format(score)); } } if (impos.size() > 0) { System.out.println(args[i] + " impossible tags: " + impos); } } System.out.println(); } }
public LinearClassifier createLinearClassifier(double[] weights) { double[][] weights2D; if (objective != null) { weights2D = objective.to2D(weights); } else { weights2D = ArrayUtils.to2D(weights, featureIndex.size(), labelIndex.size()); } return new LinearClassifier<L, F>(weights2D, featureIndex, labelIndex); }
public Index<IntPair> createIndex() { Index<IntPair> index = new HashIndex<>(); for (int x = 0; x < px.length; x++) { int numberY = numY(x); for (int y = 0; y < numberY; y++) { index.add(new IntPair(x, y)); } } return index; }
private void populateTagsToBaseTags(TreebankLanguagePack tlp) { int total = tagIndex.size(); tagsToBaseTags = new int[total]; for (int i = 0; i < total; i++) { String tag = tagIndex.get(i); String baseTag = tlp.basicCategory(tag); int j = tagIndex.indexOf(baseTag, true); tagsToBaseTags[i] = j; } }
private short tagProject(short tag) { if (smoothTPIndex == null) { smoothTPIndex = new HashIndex<String>(tagIndex); } if (tag < 0) { return tag; } else { String tagStr = smoothTPIndex.get(tag); String binStr = TP_PREFIX + smoothTP.project(tagStr); return (short) smoothTPIndex.indexOf(binStr, true); } }
/** * Generate the possible taggings for a word at a sentence position. This may either be based on a * strict lexicon or an expanded generous set of possible taggings. * * <p><i>Implementation note:</i> Expanded sets of possible taggings are calculated dynamically at * runtime, so as to reduce the memory used by the lexicon (a space/time tradeoff). * * @param word The word (as an int) * @param loc Its index in the sentence (usually only relevant for unknown words) * @return A list of possible taggings */ public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) { // if (rulesWithWord == null) { // tested in isKnown already // initRulesWithWord(); // } List<IntTaggedWord> wordTaggings; if (isKnown(word)) { if (!flexiTag) { // Strict lexical tagging for seen items wordTaggings = rulesWithWord[word]; } else { /* Allow all tags with same basicCategory */ /* Allow all scored taggings, unless very common */ IntTaggedWord iW = new IntTaggedWord(word, nullTag); if (seenCounter.getCount(iW) > smoothInUnknownsThreshold) { return rulesWithWord[word].iterator(); } else { // give it flexible tagging not just lexicon wordTaggings = new ArrayList<IntTaggedWord>(40); for (IntTaggedWord iTW2 : tags) { IntTaggedWord iTW = new IntTaggedWord(word, iTW2.tag); if (score(iTW, loc, wordIndex.get(word)) > Float.NEGATIVE_INFINITY) { wordTaggings.add(iTW); } } } } } else { // we copy list so we can insert correct word in each item wordTaggings = new ArrayList<IntTaggedWord>(40); for (IntTaggedWord iTW : rulesWithWord[wordIndex.indexOf(UNKNOWN_WORD)]) { wordTaggings.add(new IntTaggedWord(word, iTW.tag)); } } if (DEBUG_LEXICON) { EncodingPrintWriter.err.println( "Lexicon: " + wordIndex.get(word) + " (" + (isKnown(word) ? "known" : "unknown") + ", loc=" + loc + ", n=" + (isKnown(word) ? word : wordIndex.indexOf(UNKNOWN_WORD)) + ") " + (flexiTag ? "flexi" : "lexicon") + " taggings: " + wordTaggings, "UTF-8"); } return wordTaggings.iterator(); }
/** * This records how likely it is for a word with one tag to also have another tag. This won't work * after serialization/deserialization, but that is how it is currently called.... */ void buildPT_T() { int numTags = tagIndex.size(); m_TT = new double[numTags][numTags]; m_T = new double[numTags]; double[] tmp = new double[numTags]; for (IntTaggedWord word : words) { double tot = 0.0; for (int t = 0; t < numTags; t++) { IntTaggedWord iTW = new IntTaggedWord(word.word, t); tmp[t] = seenCounter.getCount(iTW); tot += tmp[t]; } if (tot < 10) { continue; } for (int t = 0; t < numTags; t++) { for (int t2 = 0; t2 < numTags; t2++) { if (tmp[t2] > 0.0) { double c = tmp[t] / tot; m_T[t] += c; m_TT[t2][t] += c; } } } } }
/** Returns the current precision: <tt>tp/(tp+fp)</tt>. Returns 1.0 if tp and fp are both 0. */ public Triple<Double, Integer, Integer> getPrecisionInfo(L label) { int i = labelIndex.indexOf(label); if (tpCount[i] == 0 && fpCount[i] == 0) { return new Triple<Double, Integer, Integer>(1.0, tpCount[i], fpCount[i]); } return new Triple<Double, Integer, Integer>( (((double) tpCount[i]) / (tpCount[i] + fpCount[i])), tpCount[i], fpCount[i]); }
public SimpleSequence(int[] intElements, Index<T> index) { elements = new Object[intElements.length]; for (int i = 0; i < intElements.length; i++) { elements[i] = index.get(intElements[i]); } start = 0; end = intElements.length; }
/** * Given the path to a file representing the text based serialization of a Linear Classifier, * reconstitutes and returns that LinearClassifier. * * <p>TODO: Leverage Index */ public static LinearClassifier<String, String> loadFromFilename(String file) { try { BufferedReader in = IOUtils.readerFromString(file); // Format: read indices first, weights, then thresholds Index<String> labelIndex = HashIndex.loadFromReader(in); Index<String> featureIndex = HashIndex.loadFromReader(in); double[][] weights = new double[featureIndex.size()][labelIndex.size()]; int currLine = 1; String line = in.readLine(); while (line != null && line.length() > 0) { String[] tuples = line.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER); if (tuples.length != 3) { throw new Exception( "Error: incorrect number of tokens in weight specifier, line=" + currLine + " in file " + file); } currLine++; int feature = Integer.valueOf(tuples[0]); int label = Integer.valueOf(tuples[1]); double value = Double.valueOf(tuples[2]); weights[feature][label] = value; line = in.readLine(); } // First line in thresholds is the number of thresholds int numThresholds = Integer.valueOf(in.readLine()); double[] thresholds = new double[numThresholds]; int curr = 0; while ((line = in.readLine()) != null) { double tval = Double.valueOf(line.trim()); thresholds[curr++] = tval; } in.close(); LinearClassifier<String, String> classifier = new LinearClassifier<String, String>(weights, featureIndex, labelIndex); return classifier; } catch (Exception e) { System.err.println("Error in LinearClassifierFactory, loading from file=" + file); e.printStackTrace(); return null; } }
@Override public void finishTraining() { lex.finishTraining(); int numTags = tagIndex.size(); POSes = new HashSet<String>(tagIndex.objectsList()); initialPOSDist = Distribution.laplaceSmoothedDistribution(initial, numTags, 0.5); markovPOSDists = new HashMap<String, Distribution>(); Set entries = ruleCounter.lowestLevelCounterEntrySet(); for (Iterator iter = entries.iterator(); iter.hasNext(); ) { Map.Entry entry = (Map.Entry) iter.next(); // Map.Entry<List<String>, Counter> entry = (Map.Entry<List<String>, Counter>) // iter.next(); Distribution d = Distribution.laplaceSmoothedDistribution((ClassicCounter) entry.getValue(), numTags, 0.5); markovPOSDists.put(((List<String>) entry.getKey()).get(0), d); } }
public Triple<Double, Integer, Integer> getPrecisionInfo() { int tp = 0, fp = 0; for (int i = 0; i < labelIndex.size(); i++) { if (i == negIndex) { continue; } tp += tpCount[i]; fp += fpCount[i]; } return new Triple<Double, Integer, Integer>((((double) tp) / (tp + fp)), tp, fp); }
public Triple<Double, Integer, Integer> getRecallInfo() { int tp = 0, fn = 0; for (int i = 0; i < labelIndex.size(); i++) { if (i == negIndex) { continue; } tp += tpCount[i]; fn += fnCount[i]; } return new Triple<Double, Integer, Integer>((((double) tp) / (tp + fn)), tp, fn); }
@Override public String toString() { StringBuilder s = new StringBuilder(); s.append(index.toString()); s.append(' '); if (openFixed) { s.append(" OPEN:").append(getOpenTags()); } else { s.append(" open:").append(getOpenTags()).append(" CLOSED:").append(closed); } return s.toString(); }
public static <L, F> OneVsAllClassifier<L, F> train( ClassifierFactory<String, F, Classifier<String, F>> classifierFactory, GeneralDataset<L, F> dataset, Collection<L> trainLabels) { Index<L> labelIndex = dataset.labelIndex(); Index<F> featureIndex = dataset.featureIndex(); Map<L, Classifier<String, F>> classifiers = Generics.newHashMap(); for (L label : trainLabels) { int i = labelIndex.indexOf(label); logger.info("Training " + label + " = " + i + ", posIndex = " + posIndex); // Create training data for training this classifier Map<L, String> posLabelMap = new ArrayMap<>(); posLabelMap.put(label, POS_LABEL); GeneralDataset<String, F> binaryDataset = dataset.mapDataset(dataset, binaryIndex, posLabelMap, NEG_LABEL); Classifier<String, F> binaryClassifier = classifierFactory.trainClassifier(binaryDataset); classifiers.put(label, binaryClassifier); } OneVsAllClassifier<L, F> classifier = new OneVsAllClassifier<>(featureIndex, labelIndex, classifiers); return classifier; }
protected void read(DataInputStream file) { try { int size = file.readInt(); index = new HashIndex<String>(); for (int i = 0; i < size; i++) { String tag = file.readUTF(); boolean inClosed = file.readBoolean(); index.add(tag); if (inClosed) closed.add(tag); } } catch (IOException e) { e.printStackTrace(); } }
protected void save(DataOutputStream file, Map<String, Set<String>> tagTokens) { try { file.writeInt(index.size()); for (String item : index) { file.writeUTF(item); if (learnClosedTags) { if (tagTokens.get(item).size() < closedTagThreshold) { markClosed(item); } } file.writeBoolean(isClosed(item)); } } catch (IOException e) { throw new RuntimeIOException(e); } }
public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) { List<L> guesses = new ArrayList<L>(); List<L> labels = new ArrayList<L>(); for (int i = 0; i < data.size(); i++) { Datum<L, F> d = data.getRVFDatum(i); L guess = classifier.classOf(d); guesses.add(guess); } int[] labelsArr = data.getLabelsArray(); labelIndex = data.labelIndex; for (int i = 0; i < data.size(); i++) { labels.add(labelIndex.get(labelsArr[i])); } labelIndex = new HashIndex<L>(); labelIndex.addAll(data.labelIndex().objectsList()); labelIndex.addAll(classifier.labels()); int numClasses = labelIndex.size(); tpCount = new int[numClasses]; fpCount = new int[numClasses]; fnCount = new int[numClasses]; negIndex = labelIndex.indexOf(negLabel); for (int i = 0; i < guesses.size(); ++i) { L guess = guesses.get(i); int guessIndex = labelIndex.indexOf(guess); L label = labels.get(i); int trueIndex = labelIndex.indexOf(label); if (guessIndex == trueIndex) { if (guessIndex != negIndex) { tpCount[guessIndex]++; } } else { if (guessIndex != negIndex) { fpCount[guessIndex]++; } if (trueIndex != negIndex) { fnCount[trueIndex]++; } } } return getFMeasure(); }
@Override public void train(List<TaggedWord> sentence) { lex.train(sentence, 1.0); String last = null; for (TaggedWord tagLabel : sentence) { String tag = tagLabel.tag(); tagIndex.add(tag); if (last == null) { initial.incrementCount(tag); } else { ruleCounter.incrementCount2D(last, tag); } last = tag; } }
@Override public DependencyGrammar formResult() { wordIndex.indexOf(Lexicon.UNKNOWN_WORD, true); MLEDependencyGrammar dg = new MLEDependencyGrammar( tlpParams, directional, useDistance, useCoarseDistance, basicCategoryTagsInDependencyGrammar, op, wordIndex, tagIndex); for (IntDependency dependency : dependencyCounter.keySet()) { dg.addRule(dependency, dependencyCounter.getCount(dependency)); } return dg; }
public Classifier<L, F> trainClassifier(Iterable<Datum<L, F>> dataIterable) { Minimizer<DiffFunction> minimizer = getMinimizer(); Index<F> featureIndex = Generics.newIndex(); Index<L> labelIndex = Generics.newIndex(); for (Datum<L, F> d : dataIterable) { labelIndex.add(d.label()); featureIndex.addAll(d.asFeatures()); // If there are duplicates, it doesn't add them again. } System.err.println( String.format( "Training linear classifier with %d features and %d labels", featureIndex.size(), labelIndex.size())); LogConditionalObjectiveFunction<L, F> objective = new LogConditionalObjectiveFunction<L, F>(dataIterable, logPrior, featureIndex, labelIndex); objective.setPrior(new LogPrior(LogPrior.LogPriorType.QUADRATIC)); double[] initial = objective.initial(); double[] weights = minimizer.minimize(objective, TOL, initial); LinearClassifier<L, F> classifier = new LinearClassifier<L, F>(objective.to2D(weights), featureIndex, labelIndex); return classifier; }
public static void main(String[] args) { Options op = new Options(new EnglishTreebankParserParams()); // op.tlpParams may be changed to something else later, so don't use it till // after options are parsed. System.out.println(StringUtils.toInvocationString("FactoredParser", args)); String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj"; int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219; String serializeFile = null; int i = 0; while (i < args.length && args[i].startsWith("-")) { if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) { path = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) { trainLow = Integer.parseInt(args[i + 1]); trainHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) { testLow = Integer.parseInt(args[i + 1]); testHigh = Integer.parseInt(args[i + 2]); i += 3; } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) { serializeFile = args[i + 1]; i += 2; } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) { try { op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance(); } catch (ClassNotFoundException e) { System.err.println("Class not found: " + args[i + 1]); throw new RuntimeException(e); } catch (InstantiationException e) { System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString()); throw new RuntimeException(e); } catch (IllegalAccessException e) { System.err.println("illegal access" + e); throw new RuntimeException(e); } i += 2; } else if (args[i].equals("-encoding")) { // sets encoding for TreebankLangParserParams op.tlpParams.setInputEncoding(args[i + 1]); op.tlpParams.setOutputEncoding(args[i + 1]); i += 2; } else { i = op.setOptionOrWarn(args, i); } } // System.out.println(tlpParams.getClass()); TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack(); op.trainOptions.sisterSplitters = new HashSet<String>(Arrays.asList(op.tlpParams.sisterSplitters())); // BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams); PrintWriter pw = op.tlpParams.pw(); op.testOptions.display(); op.trainOptions.display(); op.display(); op.tlpParams.display(); // setup tree transforms Treebank trainTreebank = op.tlpParams.memoryTreebank(); MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank(); // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank(); // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/"; // blippTreebank.loadPath(blippPath, "", true); Timing.startTime(); System.err.print("Reading trees..."); testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true)); if (op.testOptions.increasingLength) { Collections.sort(testTreebank, new TreeLengthComparator()); } trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true)); Timing.tick("done."); System.err.print("Binarizing trees..."); TreeAnnotatorAndBinarizer binarizer; if (!op.trainOptions.leftToRight) { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op); } else { binarizer = new TreeAnnotatorAndBinarizer( op.tlpParams.headFinder(), new LeftHeadFinder(), op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op); } CollinsPuncTransformer collinsPuncTransformer = null; if (op.trainOptions.collinsPunc) { collinsPuncTransformer = new CollinsPuncTransformer(tlp); } TreeTransformer debinarizer = new Debinarizer(op.forceCNF); List<Tree> binaryTrainTrees = new ArrayList<Tree>(); if (op.trainOptions.selectiveSplit) { op.trainOptions.splitters = ParentAnnotationStats.getSplitCategories( trainTreebank, op.trainOptions.tagSelectiveSplit, 0, op.trainOptions.selectiveSplitCutOff, op.trainOptions.tagSelectiveSplitCutOff, op.tlpParams.treebankLanguagePack()); if (op.trainOptions.deleteSplitters != null) { List<String> deleted = new ArrayList<String>(); for (String del : op.trainOptions.deleteSplitters) { String baseDel = tlp.basicCategory(del); boolean checkBasic = del.equals(baseDel); for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) { String elem = it.next(); String baseElem = tlp.basicCategory(elem); boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del); if (delStr) { it.remove(); deleted.add(elem); } } } System.err.println("Removed from vertical splitters: " + deleted); } } if (op.trainOptions.selectivePostSplit) { TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams, op); Treebank annotatedTB = trainTreebank.transform(myTransformer); op.trainOptions.postSplitters = ParentAnnotationStats.getSplitCategories( annotatedTB, true, 0, op.trainOptions.selectivePostSplitCutOff, op.trainOptions.tagSelectivePostSplitCutOff, op.tlpParams.treebankLanguagePack()); } if (op.trainOptions.hSelSplit) { binarizer.setDoSelectiveSplit(false); for (Tree tree : trainTreebank) { if (op.trainOptions.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } // tree.pennPrint(tlpParams.pw()); tree = binarizer.transformTree(tree); // binaryTrainTrees.add(tree); } binarizer.setDoSelectiveSplit(true); } for (Tree tree : trainTreebank) { if (op.trainOptions.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTrainTrees.add(tree); } if (op.testOptions.verbose) { binarizer.dumpStats(); } List<Tree> binaryTestTrees = new ArrayList<Tree>(); for (Tree tree : testTreebank) { if (op.trainOptions.collinsPunc) { tree = collinsPuncTransformer.transformTree(tree); } tree = binarizer.transformTree(tree); binaryTestTrees.add(tree); } Timing.tick("done."); // binarization BinaryGrammar bg = null; UnaryGrammar ug = null; DependencyGrammar dg = null; // DependencyGrammar dgBLIPP = null; Lexicon lex = null; Index<String> stateIndex = new HashIndex<String>(); // extract grammars Extractor<Pair<UnaryGrammar, BinaryGrammar>> bgExtractor = new BinaryGrammarExtractor(op, stateIndex); // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor(); // Extractor lexExtractor = new LexiconExtractor(); // Extractor dgExtractor = new DependencyMemGrammarExtractor(); if (op.doPCFG) { System.err.print("Extracting PCFG..."); Pair<UnaryGrammar, BinaryGrammar> bgug = null; if (op.trainOptions.cheatPCFG) { List<Tree> allTrees = new ArrayList<Tree>(binaryTrainTrees); allTrees.addAll(binaryTestTrees); bgug = bgExtractor.extract(allTrees); } else { bgug = bgExtractor.extract(binaryTrainTrees); } bg = bgug.second; bg.splitRules(); ug = bgug.first; ug.purgeRules(); Timing.tick("done."); } System.err.print("Extracting Lexicon..."); Index<String> wordIndex = new HashIndex<String>(); Index<String> tagIndex = new HashIndex<String>(); lex = op.tlpParams.lex(op, wordIndex, tagIndex); lex.train(binaryTrainTrees); Timing.tick("done."); if (op.doDep) { System.err.print("Extracting Dependencies..."); binaryTrainTrees.clear(); Extractor<DependencyGrammar> dgExtractor = new MLEDependencyGrammarExtractor(op, wordIndex, tagIndex); // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams,true)); // DependencyGrammar dg1 = dgExtractor.extract(trainTreebank.iterator(), new // TransformTreeDependency(op.tlpParams, true)); // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new // TransformTreeDependency(tlpParams)); // dg = (DependencyGrammar) dgExtractor.extract(new // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new // TransformTreeDependency(tlpParams)); // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2); dg = dgExtractor.extract( binaryTrainTrees); // uses information whether the words are known or not, discards // unknown words Timing.tick("done."); // System.out.print("Extracting Unknown Word Model..."); // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees); // Timing.tick("done."); System.out.print("Tuning Dependency Model..."); dg.tune(binaryTestTrees); // System.out.println("TUNE DEPS: "+tuneDeps); Timing.tick("done."); } BinaryGrammar boundBG = bg; UnaryGrammar boundUG = ug; GrammarProjection gp = new NullGrammarProjection(bg, ug); // serialization if (serializeFile != null) { System.err.print("Serializing parser..."); LexicalizedParser.saveParserDataToSerialized( new ParserData(lex, bg, ug, dg, stateIndex, wordIndex, tagIndex, op), serializeFile); Timing.tick("done."); } // test: pcfg-parse and output ExhaustivePCFGParser parser = null; if (op.doPCFG) { parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op, stateIndex, wordIndex, tagIndex); } ExhaustiveDependencyParser dparser = ((op.doDep && !op.testOptions.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op, wordIndex, tagIndex) : null); Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp, op), dparser) : null); // Scorer scorer = parser; BiLexPCFGParser bparser = null; if (op.doPCFG && op.doDep) { bparser = (op.testOptions.useN5) ? new BiLexPCFGParser.N5BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex) : new BiLexPCFGParser( scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex); } Evalb pcfgPE = new Evalb("pcfg PE", true); Evalb comboPE = new Evalb("combo PE", true); AbstractEval pcfgCB = new Evalb.CBEval("pcfg CB", true); AbstractEval pcfgTE = new TaggingEval("pcfg TE"); AbstractEval comboTE = new TaggingEval("combo TE"); AbstractEval pcfgTEnoPunct = new TaggingEval("pcfg nopunct TE"); AbstractEval comboTEnoPunct = new TaggingEval("combo nopunct TE"); AbstractEval depTE = new TaggingEval("depnd TE"); AbstractEval depDE = new UnlabeledAttachmentEval("depnd DE", true, null, tlp.punctuationWordRejectFilter()); AbstractEval comboDE = new UnlabeledAttachmentEval("combo DE", true, null, tlp.punctuationWordRejectFilter()); if (op.testOptions.evalb) { EvalbFormatWriter.initEVALBfiles(op.tlpParams); } // int[] countByLength = new int[op.testOptions.maxLength+1]; // Use a reflection ruse, so one can run this without needing the // tagger. Using a function rather than a MaxentTagger means we // can distribute a version of the parser that doesn't include the // entire tagger. Function<List<? extends HasWord>, ArrayList<TaggedWord>> tagger = null; if (op.testOptions.preTag) { try { Class[] argsClass = {String.class}; Object[] arguments = new Object[] {op.testOptions.taggerSerializedFile}; tagger = (Function<List<? extends HasWord>, ArrayList<TaggedWord>>) Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger") .getConstructor(argsClass) .newInstance(arguments); } catch (Exception e) { System.err.println(e); System.err.println("Warning: No pretagging of sentences will be done."); } } for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) { Tree tree = testTreebank.get(tNum); int testTreeLen = tree.yield().size(); if (testTreeLen > op.testOptions.maxLength) { continue; } Tree binaryTree = binaryTestTrees.get(tNum); // countByLength[testTreeLen]++; System.out.println("-------------------------------------"); System.out.println("Number: " + (tNum + 1)); System.out.println("Length: " + testTreeLen); // tree.pennPrint(pw); // System.out.println("XXXX The binary tree is"); // binaryTree.pennPrint(pw); // System.out.println("Here are the tags in the lexicon:"); // System.out.println(lex.showTags()); // System.out.println("Here's the tagnumberer:"); // System.out.println(Numberer.getGlobalNumberer("tags").toString()); long timeMil1 = System.currentTimeMillis(); Timing.tick("Starting parse."); if (op.doPCFG) { // System.err.println(op.testOptions.forceTags); if (op.testOptions.forceTags) { if (tagger != null) { // System.out.println("Using a tagger to set tags"); // System.out.println("Tagged sentence as: " + // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false)); parser.parse(addLast(tagger.apply(cutLast(wordify(binaryTree.yield()))))); } else { // System.out.println("Forcing tags to match input."); parser.parse(cleanTags(binaryTree.taggedYield(), tlp)); } } else { // System.out.println("XXXX Parsing " + binaryTree.yield()); parser.parse(binaryTree.yieldHasWord()); } // Timing.tick("Done with pcfg phase."); } if (op.doDep) { dparser.parse(binaryTree.yieldHasWord()); // Timing.tick("Done with dependency phase."); } boolean bothPassed = false; if (op.doPCFG && op.doDep) { bothPassed = bparser.parse(binaryTree.yieldHasWord()); // Timing.tick("Done with combination phase."); } long timeMil2 = System.currentTimeMillis(); long elapsed = timeMil2 - timeMil1; System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec."); // System.out.println("PCFG Best Parse:"); Tree tree2b = null; Tree tree2 = null; // System.out.println("Got full best parse..."); if (op.doPCFG) { tree2b = parser.getBestParse(); tree2 = debinarizer.transformTree(tree2b); } // System.out.println("Debinarized parse..."); // tree2.pennPrint(); // System.out.println("DepG Best Parse:"); Tree tree3 = null; Tree tree3db = null; if (op.doDep) { tree3 = dparser.getBestParse(); // was: but wrong Tree tree3db = debinarizer.transformTree(tree2); tree3db = debinarizer.transformTree(tree3); tree3.pennPrint(pw); } // tree.pennPrint(); // ((Tree)binaryTrainTrees.get(tNum)).pennPrint(); // System.out.println("Combo Best Parse:"); Tree tree4 = null; if (op.doPCFG && op.doDep) { try { tree4 = bparser.getBestParse(); if (tree4 == null) { tree4 = tree2b; } } catch (NullPointerException e) { System.err.println("Blocked, using PCFG parse!"); tree4 = tree2b; } } if (op.doPCFG && !bothPassed) { tree4 = tree2b; } // tree4.pennPrint(); if (op.doDep) { depDE.evaluate(tree3, binaryTree, pw); depTE.evaluate(tree3db, tree, pw); } TreeTransformer tc = op.tlpParams.collinizer(); TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb(); if (op.doPCFG) { // System.out.println("XXXX Best PCFG was: "); // tree2.pennPrint(); // System.out.println("XXXX Transformed best PCFG is: "); // tc.transformTree(tree2).pennPrint(); // System.out.println("True Best Parse:"); // tree.pennPrint(); // tc.transformTree(tree).pennPrint(); pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); Tree tree4b = null; if (op.doDep) { comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw); tree4b = tree4; tree4 = debinarizer.transformTree(tree4); if (op.nodePrune) { NodePruner np = new NodePruner(parser, debinarizer); tree4 = np.prune(tree4); } // tree4.pennPrint(); comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } // pcfgTE.evaluate(tree2, tree); pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw); pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw); if (op.doDep) { comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw); comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw); } System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0)); // tc.transformTree(tree2).pennPrint(); tree2.pennPrint(pw); if (op.doDep) { System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0)); // tc.transformTree(tree4).pennPrint(pw); tree4.pennPrint(pw); } System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0)); /* if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) { System.out.println("SCORE INVERSION"); parser.validateBinarizedTree(binaryTree,0); } */ tree.pennPrint(pw); } // end if doPCFG if (op.testOptions.evalb) { if (op.doPCFG && op.doDep) { EvalbFormatWriter.writeEVALBline( tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4)); } else if (op.doPCFG) { EvalbFormatWriter.writeEVALBline( tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2)); } else if (op.doDep) { EvalbFormatWriter.writeEVALBline( tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db)); } } } // end for each tree in test treebank if (op.testOptions.evalb) { EvalbFormatWriter.closeEVALBfiles(); } // op.testOptions.display(); if (op.doPCFG) { pcfgPE.display(false, pw); System.out.println("Grammar size: " + stateIndex.size()); pcfgCB.display(false, pw); if (op.doDep) { comboPE.display(false, pw); } pcfgTE.display(false, pw); pcfgTEnoPunct.display(false, pw); if (op.doDep) { comboTE.display(false, pw); comboTEnoPunct.display(false, pw); } } if (op.doDep) { depTE.display(false, pw); depDE.display(false, pw); } if (op.doPCFG && op.doDep) { comboDE.display(false, pw); } // pcfgPE.printGoodBad(); }
/** * Adds dependencies to list depList. These are in terms of the original tag set not the reduced * (projected) tag set. */ protected static EndHead treeToDependencyHelper( Tree tree, List<IntDependency> depList, int loc, Index<String> wordIndex, Index<String> tagIndex) { // try { // PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out,"GB18030"),true); // tree.pennPrint(pw); // } // catch (UnsupportedEncodingException e) {} if (tree.isLeaf() || tree.isPreTerminal()) { EndHead tempEndHead = new EndHead(); tempEndHead.head = loc; tempEndHead.end = loc + 1; return tempEndHead; } Tree[] kids = tree.children(); if (kids.length == 1) { return treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex); } EndHead tempEndHead = treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex); int lHead = tempEndHead.head; int split = tempEndHead.end; tempEndHead = treeToDependencyHelper(kids[1], depList, tempEndHead.end, wordIndex, tagIndex); int end = tempEndHead.end; int rHead = tempEndHead.head; String hTag = ((HasTag) tree.label()).tag(); String lTag = ((HasTag) kids[0].label()).tag(); String rTag = ((HasTag) kids[1].label()).tag(); String hWord = ((HasWord) tree.label()).word(); String lWord = ((HasWord) kids[0].label()).word(); String rWord = ((HasWord) kids[1].label()).word(); boolean leftHeaded = hWord.equals(lWord); String aTag = (leftHeaded ? rTag : lTag); String aWord = (leftHeaded ? rWord : lWord); int hT = tagIndex.indexOf(hTag); int aT = tagIndex.indexOf(aTag); int hW = (wordIndex.contains(hWord) ? wordIndex.indexOf(hWord) : wordIndex.indexOf(Lexicon.UNKNOWN_WORD)); int aW = (wordIndex.contains(aWord) ? wordIndex.indexOf(aWord) : wordIndex.indexOf(Lexicon.UNKNOWN_WORD)); int head = (leftHeaded ? lHead : rHead); int arg = (leftHeaded ? rHead : lHead); IntDependency dependency = new IntDependency( hW, hT, aW, aT, leftHeaded, (leftHeaded ? split - head - 1 : head - split)); depList.add(dependency); IntDependency stopL = new IntDependency( aW, aT, STOP_WORD_INT, STOP_TAG_INT, false, (leftHeaded ? arg - split : arg - loc)); depList.add(stopL); IntDependency stopR = new IntDependency( aW, aT, STOP_WORD_INT, STOP_TAG_INT, true, (leftHeaded ? end - arg - 1 : split - arg - 1)); depList.add(stopR); // System.out.println("Adding: "+dependency+" at "+tree.label()); tempEndHead.head = head; return tempEndHead; }
@Override public Collection<L> labels() { return labelIndex.objectsList(); }
static { binaryIndex = new HashIndex<>(); binaryIndex.add(POS_LABEL); binaryIndex.add(NEG_LABEL); posIndex = binaryIndex.indexOf(POS_LABEL); }
public static <L, F> OneVsAllClassifier<L, F> train( ClassifierFactory<String, F, Classifier<String, F>> classifierFactory, GeneralDataset<L, F> dataset) { Index<L> labelIndex = dataset.labelIndex(); return train(classifierFactory, dataset, labelIndex.objectsList()); }