/**
  * The examples are assumed to be a list of RFVDatum. The datums are assumed to contain the zeroes
  * as well.
  */
 @Override
 @Deprecated
 public NaiveBayesClassifier<L, F> trainClassifier(List<RVFDatum<L, F>> examples) {
   RVFDatum<L, F> d0 = examples.get(0);
   int numFeatures = d0.asFeatures().size();
   int[][] data = new int[examples.size()][numFeatures];
   int[] labels = new int[examples.size()];
   labelIndex = new HashIndex<L>();
   featureIndex = new HashIndex<F>();
   for (int d = 0; d < examples.size(); d++) {
     RVFDatum<L, F> datum = examples.get(d);
     Counter<F> c = datum.asFeaturesCounter();
     for (F feature : c.keySet()) {
       if (featureIndex.add(feature)) {
         int fNo = featureIndex.indexOf(feature);
         int value = (int) c.getCount(feature);
         data[d][fNo] = value;
       }
     }
     labelIndex.add(datum.label());
     labels[d] = labelIndex.indexOf(datum.label());
   }
   int numClasses = labelIndex.size();
   return trainClassifier(data, labels, numFeatures, numClasses, labelIndex, featureIndex);
 }
 private int add(AmbiguityClass a) {
   if (classes.contains(a)) {
     return classes.indexOf(a);
   }
   classes.add(a);
   return classes.indexOf(a);
 }
예제 #3
0
  /**
   * Evaluates how many words (= terminals) in a collection of trees are covered by the lexicon.
   * First arg is the collection of trees; second through fourth args get the results. Currently
   * unused; this probably only works if train and test at same time so tags and words variables are
   * initialized.
   */
  public double evaluateCoverage(
      Collection<Tree> trees,
      Set<String> missingWords,
      Set<String> missingTags,
      Set<IntTaggedWord> missingTW) {

    List<IntTaggedWord> iTW1 = new ArrayList<IntTaggedWord>();
    for (Tree t : trees) {
      iTW1.addAll(treeToEvents(t));
    }

    int total = 0;
    int unseen = 0;

    for (IntTaggedWord itw : iTW1) {
      total++;
      if (!words.contains(new IntTaggedWord(itw.word(), nullTag))) {
        missingWords.add(wordIndex.get(itw.word()));
      }
      if (!tags.contains(new IntTaggedWord(nullWord, itw.tag()))) {
        missingTags.add(tagIndex.get(itw.tag()));
      }
      // if (!rules.contains(itw)) {
      if (seenCounter.getCount(itw) == 0.0) {
        unseen++;
        missingTW.add(itw);
      }
    }
    return (double) unseen / total;
  }
 private NaiveBayesClassifier<L, F> trainClassifier(
     int[][] data,
     int[] labels,
     int numFeatures,
     int numClasses,
     Index<L> labelIndex,
     Index<F> featureIndex) {
   Set<L> labelSet = Generics.newHashSet();
   NBWeights nbWeights = trainWeights(data, labels, numFeatures, numClasses);
   Counter<L> priors = new ClassicCounter<L>();
   double[] pr = nbWeights.priors;
   for (int i = 0; i < pr.length; i++) {
     priors.incrementCount(labelIndex.get(i), pr[i]);
     labelSet.add(labelIndex.get(i));
   }
   Counter<Pair<Pair<L, F>, Number>> weightsCounter =
       new ClassicCounter<Pair<Pair<L, F>, Number>>();
   double[][][] wts = nbWeights.weights;
   for (int c = 0; c < numClasses; c++) {
     L label = labelIndex.get(c);
     for (int f = 0; f < numFeatures; f++) {
       F feature = featureIndex.get(f);
       Pair<L, F> p = new Pair<L, F>(label, feature);
       for (int val = 0; val < wts[c][f].length; val++) {
         Pair<Pair<L, F>, Number> key = new Pair<Pair<L, F>, Number>(p, Integer.valueOf(val));
         weightsCounter.incrementCount(key, wts[c][f][val]);
       }
     }
   }
   return new NaiveBayesClassifier<L, F>(weightsCounter, priors, labelSet);
 }
예제 #5
0
 /**
  * Provides some testing and opportunities for exploration of the probabilities of a BaseLexicon.
  * What's here currently probably only works for the English Penn Treeebank, as it uses default
  * constructors. Of the words given to test on, the first is treated as sentence initial, and the
  * rest as not sentence initial.
  *
  * @param args The command line arguments: java BaseLexicon treebankPath fileRange
  *     unknownWordModel words*
  */
 public static void main(String[] args) {
   if (args.length < 3) {
     System.err.println("java BaseLexicon treebankPath fileRange unknownWordModel words*");
     return;
   }
   System.out.print("Training BaseLexicon from " + args[0] + ' ' + args[1] + " ... ");
   Treebank tb = new DiskTreebank();
   tb.loadPath(args[0], new NumberRangesFileFilter(args[1], true));
   // TODO: change this interface so the lexicon creates its own indices?
   Index<String> wordIndex = new HashIndex<String>();
   Index<String> tagIndex = new HashIndex<String>();
   BaseLexicon lex = new BaseLexicon(wordIndex, tagIndex);
   lex.getUnknownWordModel().setUnknownLevel(Integer.parseInt(args[2]));
   lex.train(tb);
   System.out.println("done.");
   System.out.println();
   NumberFormat nf = NumberFormat.getNumberInstance();
   nf.setMaximumFractionDigits(4);
   List<String> impos = new ArrayList<String>();
   for (int i = 3; i < args.length; i++) {
     if (lex.isKnown(args[i])) {
       System.out.println(
           args[i] + " is a known word.  Log probabilities [log P(w|t)] for its taggings are:");
       for (Iterator<IntTaggedWord> it =
               lex.ruleIteratorByWord(wordIndex.indexOf(args[i], true), i - 3, null);
           it.hasNext(); ) {
         IntTaggedWord iTW = it.next();
         System.out.println(
             StringUtils.pad(iTW, 24) + nf.format(lex.score(iTW, i - 3, wordIndex.get(iTW.word))));
       }
     } else {
       String sig = lex.getUnknownWordModel().getSignature(args[i], i - 3);
       System.out.println(
           args[i]
               + " is an unknown word.  Signature with uwm "
               + lex.getUnknownWordModel().getUnknownLevel()
               + ((i == 3) ? " init" : "non-init")
               + " is: "
               + sig);
       impos.clear();
       List<String> lis = new ArrayList<String>(tagIndex.objectsList());
       Collections.sort(lis);
       for (String tStr : lis) {
         IntTaggedWord iTW = new IntTaggedWord(args[i], tStr, wordIndex, tagIndex);
         double score = lex.score(iTW, 1, args[i]);
         if (score == Float.NEGATIVE_INFINITY) {
           impos.add(tStr);
         } else {
           System.out.println(StringUtils.pad(iTW, 24) + nf.format(score));
         }
       }
       if (impos.size() > 0) {
         System.out.println(args[i] + " impossible tags: " + impos);
       }
     }
     System.out.println();
   }
 }
 public LinearClassifier createLinearClassifier(double[] weights) {
   double[][] weights2D;
   if (objective != null) {
     weights2D = objective.to2D(weights);
   } else {
     weights2D = ArrayUtils.to2D(weights, featureIndex.size(), labelIndex.size());
   }
   return new LinearClassifier<L, F>(weights2D, featureIndex, labelIndex);
 }
예제 #7
0
 public Index<IntPair> createIndex() {
   Index<IntPair> index = new HashIndex<>();
   for (int x = 0; x < px.length; x++) {
     int numberY = numY(x);
     for (int y = 0; y < numberY; y++) {
       index.add(new IntPair(x, y));
     }
   }
   return index;
 }
예제 #8
0
 private void populateTagsToBaseTags(TreebankLanguagePack tlp) {
   int total = tagIndex.size();
   tagsToBaseTags = new int[total];
   for (int i = 0; i < total; i++) {
     String tag = tagIndex.get(i);
     String baseTag = tlp.basicCategory(tag);
     int j = tagIndex.indexOf(baseTag, true);
     tagsToBaseTags[i] = j;
   }
 }
예제 #9
0
 private short tagProject(short tag) {
   if (smoothTPIndex == null) {
     smoothTPIndex = new HashIndex<String>(tagIndex);
   }
   if (tag < 0) {
     return tag;
   } else {
     String tagStr = smoothTPIndex.get(tag);
     String binStr = TP_PREFIX + smoothTP.project(tagStr);
     return (short) smoothTPIndex.indexOf(binStr, true);
   }
 }
예제 #10
0
 /**
  * Generate the possible taggings for a word at a sentence position. This may either be based on a
  * strict lexicon or an expanded generous set of possible taggings.
  *
  * <p><i>Implementation note:</i> Expanded sets of possible taggings are calculated dynamically at
  * runtime, so as to reduce the memory used by the lexicon (a space/time tradeoff).
  *
  * @param word The word (as an int)
  * @param loc Its index in the sentence (usually only relevant for unknown words)
  * @return A list of possible taggings
  */
 public Iterator<IntTaggedWord> ruleIteratorByWord(int word, int loc, String featureSpec) {
   // if (rulesWithWord == null) { // tested in isKnown already
   // initRulesWithWord();
   // }
   List<IntTaggedWord> wordTaggings;
   if (isKnown(word)) {
     if (!flexiTag) {
       // Strict lexical tagging for seen items
       wordTaggings = rulesWithWord[word];
     } else {
       /* Allow all tags with same basicCategory */
       /* Allow all scored taggings, unless very common */
       IntTaggedWord iW = new IntTaggedWord(word, nullTag);
       if (seenCounter.getCount(iW) > smoothInUnknownsThreshold) {
         return rulesWithWord[word].iterator();
       } else {
         // give it flexible tagging not just lexicon
         wordTaggings = new ArrayList<IntTaggedWord>(40);
         for (IntTaggedWord iTW2 : tags) {
           IntTaggedWord iTW = new IntTaggedWord(word, iTW2.tag);
           if (score(iTW, loc, wordIndex.get(word)) > Float.NEGATIVE_INFINITY) {
             wordTaggings.add(iTW);
           }
         }
       }
     }
   } else {
     // we copy list so we can insert correct word in each item
     wordTaggings = new ArrayList<IntTaggedWord>(40);
     for (IntTaggedWord iTW : rulesWithWord[wordIndex.indexOf(UNKNOWN_WORD)]) {
       wordTaggings.add(new IntTaggedWord(word, iTW.tag));
     }
   }
   if (DEBUG_LEXICON) {
     EncodingPrintWriter.err.println(
         "Lexicon: "
             + wordIndex.get(word)
             + " ("
             + (isKnown(word) ? "known" : "unknown")
             + ", loc="
             + loc
             + ", n="
             + (isKnown(word) ? word : wordIndex.indexOf(UNKNOWN_WORD))
             + ") "
             + (flexiTag ? "flexi" : "lexicon")
             + " taggings: "
             + wordTaggings,
         "UTF-8");
   }
   return wordTaggings.iterator();
 }
예제 #11
0
 /**
  * This records how likely it is for a word with one tag to also have another tag. This won't work
  * after serialization/deserialization, but that is how it is currently called....
  */
 void buildPT_T() {
   int numTags = tagIndex.size();
   m_TT = new double[numTags][numTags];
   m_T = new double[numTags];
   double[] tmp = new double[numTags];
   for (IntTaggedWord word : words) {
     double tot = 0.0;
     for (int t = 0; t < numTags; t++) {
       IntTaggedWord iTW = new IntTaggedWord(word.word, t);
       tmp[t] = seenCounter.getCount(iTW);
       tot += tmp[t];
     }
     if (tot < 10) {
       continue;
     }
     for (int t = 0; t < numTags; t++) {
       for (int t2 = 0; t2 < numTags; t2++) {
         if (tmp[t2] > 0.0) {
           double c = tmp[t] / tot;
           m_T[t] += c;
           m_TT[t2][t] += c;
         }
       }
     }
   }
 }
 /** Returns the current precision: <tt>tp/(tp+fp)</tt>. Returns 1.0 if tp and fp are both 0. */
 public Triple<Double, Integer, Integer> getPrecisionInfo(L label) {
   int i = labelIndex.indexOf(label);
   if (tpCount[i] == 0 && fpCount[i] == 0) {
     return new Triple<Double, Integer, Integer>(1.0, tpCount[i], fpCount[i]);
   }
   return new Triple<Double, Integer, Integer>(
       (((double) tpCount[i]) / (tpCount[i] + fpCount[i])), tpCount[i], fpCount[i]);
 }
예제 #13
0
 public SimpleSequence(int[] intElements, Index<T> index) {
   elements = new Object[intElements.length];
   for (int i = 0; i < intElements.length; i++) {
     elements[i] = index.get(intElements[i]);
   }
   start = 0;
   end = intElements.length;
 }
  /**
   * Given the path to a file representing the text based serialization of a Linear Classifier,
   * reconstitutes and returns that LinearClassifier.
   *
   * <p>TODO: Leverage Index
   */
  public static LinearClassifier<String, String> loadFromFilename(String file) {
    try {
      BufferedReader in = IOUtils.readerFromString(file);

      // Format: read indices first, weights, then thresholds
      Index<String> labelIndex = HashIndex.loadFromReader(in);
      Index<String> featureIndex = HashIndex.loadFromReader(in);
      double[][] weights = new double[featureIndex.size()][labelIndex.size()];
      int currLine = 1;
      String line = in.readLine();
      while (line != null && line.length() > 0) {
        String[] tuples = line.split(LinearClassifier.TEXT_SERIALIZATION_DELIMITER);
        if (tuples.length != 3) {
          throw new Exception(
              "Error: incorrect number of tokens in weight specifier, line="
                  + currLine
                  + " in file "
                  + file);
        }
        currLine++;
        int feature = Integer.valueOf(tuples[0]);
        int label = Integer.valueOf(tuples[1]);
        double value = Double.valueOf(tuples[2]);
        weights[feature][label] = value;
        line = in.readLine();
      }

      // First line in thresholds is the number of thresholds
      int numThresholds = Integer.valueOf(in.readLine());
      double[] thresholds = new double[numThresholds];
      int curr = 0;
      while ((line = in.readLine()) != null) {
        double tval = Double.valueOf(line.trim());
        thresholds[curr++] = tval;
      }
      in.close();
      LinearClassifier<String, String> classifier =
          new LinearClassifier<String, String>(weights, featureIndex, labelIndex);
      return classifier;
    } catch (Exception e) {
      System.err.println("Error in LinearClassifierFactory, loading from file=" + file);
      e.printStackTrace();
      return null;
    }
  }
  @Override
  public void finishTraining() {
    lex.finishTraining();

    int numTags = tagIndex.size();
    POSes = new HashSet<String>(tagIndex.objectsList());
    initialPOSDist = Distribution.laplaceSmoothedDistribution(initial, numTags, 0.5);
    markovPOSDists = new HashMap<String, Distribution>();
    Set entries = ruleCounter.lowestLevelCounterEntrySet();
    for (Iterator iter = entries.iterator(); iter.hasNext(); ) {
      Map.Entry entry = (Map.Entry) iter.next();
      //      Map.Entry<List<String>, Counter> entry = (Map.Entry<List<String>, Counter>)
      // iter.next();
      Distribution d =
          Distribution.laplaceSmoothedDistribution((ClassicCounter) entry.getValue(), numTags, 0.5);
      markovPOSDists.put(((List<String>) entry.getKey()).get(0), d);
    }
  }
 public Triple<Double, Integer, Integer> getPrecisionInfo() {
   int tp = 0, fp = 0;
   for (int i = 0; i < labelIndex.size(); i++) {
     if (i == negIndex) {
       continue;
     }
     tp += tpCount[i];
     fp += fpCount[i];
   }
   return new Triple<Double, Integer, Integer>((((double) tp) / (tp + fp)), tp, fp);
 }
 public Triple<Double, Integer, Integer> getRecallInfo() {
   int tp = 0, fn = 0;
   for (int i = 0; i < labelIndex.size(); i++) {
     if (i == negIndex) {
       continue;
     }
     tp += tpCount[i];
     fn += fnCount[i];
   }
   return new Triple<Double, Integer, Integer>((((double) tp) / (tp + fn)), tp, fn);
 }
예제 #18
0
 @Override
 public String toString() {
   StringBuilder s = new StringBuilder();
   s.append(index.toString());
   s.append(' ');
   if (openFixed) {
     s.append(" OPEN:").append(getOpenTags());
   } else {
     s.append(" open:").append(getOpenTags()).append(" CLOSED:").append(closed);
   }
   return s.toString();
 }
예제 #19
0
 public static <L, F> OneVsAllClassifier<L, F> train(
     ClassifierFactory<String, F, Classifier<String, F>> classifierFactory,
     GeneralDataset<L, F> dataset,
     Collection<L> trainLabels) {
   Index<L> labelIndex = dataset.labelIndex();
   Index<F> featureIndex = dataset.featureIndex();
   Map<L, Classifier<String, F>> classifiers = Generics.newHashMap();
   for (L label : trainLabels) {
     int i = labelIndex.indexOf(label);
     logger.info("Training " + label + " = " + i + ", posIndex = " + posIndex);
     // Create training data for training this classifier
     Map<L, String> posLabelMap = new ArrayMap<>();
     posLabelMap.put(label, POS_LABEL);
     GeneralDataset<String, F> binaryDataset =
         dataset.mapDataset(dataset, binaryIndex, posLabelMap, NEG_LABEL);
     Classifier<String, F> binaryClassifier = classifierFactory.trainClassifier(binaryDataset);
     classifiers.put(label, binaryClassifier);
   }
   OneVsAllClassifier<L, F> classifier =
       new OneVsAllClassifier<>(featureIndex, labelIndex, classifiers);
   return classifier;
 }
예제 #20
0
  protected void read(DataInputStream file) {
    try {
      int size = file.readInt();
      index = new HashIndex<String>();
      for (int i = 0; i < size; i++) {
        String tag = file.readUTF();
        boolean inClosed = file.readBoolean();
        index.add(tag);

        if (inClosed) closed.add(tag);
      }
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
예제 #21
0
 protected void save(DataOutputStream file, Map<String, Set<String>> tagTokens) {
   try {
     file.writeInt(index.size());
     for (String item : index) {
       file.writeUTF(item);
       if (learnClosedTags) {
         if (tagTokens.get(item).size() < closedTagThreshold) {
           markClosed(item);
         }
       }
       file.writeBoolean(isClosed(item));
     }
   } catch (IOException e) {
     throw new RuntimeIOException(e);
   }
 }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {

    List<L> guesses = new ArrayList<L>();
    List<L> labels = new ArrayList<L>();

    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      guesses.add(guess);
    }

    int[] labelsArr = data.getLabelsArray();
    labelIndex = data.labelIndex;
    for (int i = 0; i < data.size(); i++) {
      labels.add(labelIndex.get(labelsArr[i]));
    }

    labelIndex = new HashIndex<L>();
    labelIndex.addAll(data.labelIndex().objectsList());
    labelIndex.addAll(classifier.labels());

    int numClasses = labelIndex.size();
    tpCount = new int[numClasses];
    fpCount = new int[numClasses];
    fnCount = new int[numClasses];

    negIndex = labelIndex.indexOf(negLabel);

    for (int i = 0; i < guesses.size(); ++i) {
      L guess = guesses.get(i);
      int guessIndex = labelIndex.indexOf(guess);
      L label = labels.get(i);
      int trueIndex = labelIndex.indexOf(label);

      if (guessIndex == trueIndex) {
        if (guessIndex != negIndex) {
          tpCount[guessIndex]++;
        }
      } else {
        if (guessIndex != negIndex) {
          fpCount[guessIndex]++;
        }
        if (trueIndex != negIndex) {
          fnCount[trueIndex]++;
        }
      }
    }

    return getFMeasure();
  }
  @Override
  public void train(List<TaggedWord> sentence) {
    lex.train(sentence, 1.0);

    String last = null;
    for (TaggedWord tagLabel : sentence) {
      String tag = tagLabel.tag();
      tagIndex.add(tag);
      if (last == null) {
        initial.incrementCount(tag);
      } else {
        ruleCounter.incrementCount2D(last, tag);
      }
      last = tag;
    }
  }
 @Override
 public DependencyGrammar formResult() {
   wordIndex.indexOf(Lexicon.UNKNOWN_WORD, true);
   MLEDependencyGrammar dg =
       new MLEDependencyGrammar(
           tlpParams,
           directional,
           useDistance,
           useCoarseDistance,
           basicCategoryTagsInDependencyGrammar,
           op,
           wordIndex,
           tagIndex);
   for (IntDependency dependency : dependencyCounter.keySet()) {
     dg.addRule(dependency, dependencyCounter.getCount(dependency));
   }
   return dg;
 }
  public Classifier<L, F> trainClassifier(Iterable<Datum<L, F>> dataIterable) {
    Minimizer<DiffFunction> minimizer = getMinimizer();
    Index<F> featureIndex = Generics.newIndex();
    Index<L> labelIndex = Generics.newIndex();
    for (Datum<L, F> d : dataIterable) {
      labelIndex.add(d.label());
      featureIndex.addAll(d.asFeatures()); // If there are duplicates, it doesn't add them again.
    }
    System.err.println(
        String.format(
            "Training linear classifier with %d features and %d labels",
            featureIndex.size(), labelIndex.size()));

    LogConditionalObjectiveFunction<L, F> objective =
        new LogConditionalObjectiveFunction<L, F>(dataIterable, logPrior, featureIndex, labelIndex);
    objective.setPrior(new LogPrior(LogPrior.LogPriorType.QUADRATIC));

    double[] initial = objective.initial();
    double[] weights = minimizer.minimize(objective, TOL, initial);

    LinearClassifier<L, F> classifier =
        new LinearClassifier<L, F>(objective.to2D(weights), featureIndex, labelIndex);
    return classifier;
  }
예제 #26
0
  public static void main(String[] args) {
    Options op = new Options(new EnglishTreebankParserParams());
    // op.tlpParams may be changed to something else later, so don't use it till
    // after options are parsed.

    System.out.println(StringUtils.toInvocationString("FactoredParser", args));

    String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj";
    int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219;
    String serializeFile = null;

    int i = 0;
    while (i < args.length && args[i].startsWith("-")) {
      if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) {
        path = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) {
        trainLow = Integer.parseInt(args[i + 1]);
        trainHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) {
        testLow = Integer.parseInt(args[i + 1]);
        testHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) {
        serializeFile = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) {
        try {
          op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance();
        } catch (ClassNotFoundException e) {
          System.err.println("Class not found: " + args[i + 1]);
          throw new RuntimeException(e);
        } catch (InstantiationException e) {
          System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString());
          throw new RuntimeException(e);
        } catch (IllegalAccessException e) {
          System.err.println("illegal access" + e);
          throw new RuntimeException(e);
        }
        i += 2;
      } else if (args[i].equals("-encoding")) {
        // sets encoding for TreebankLangParserParams
        op.tlpParams.setInputEncoding(args[i + 1]);
        op.tlpParams.setOutputEncoding(args[i + 1]);
        i += 2;
      } else {
        i = op.setOptionOrWarn(args, i);
      }
    }
    // System.out.println(tlpParams.getClass());
    TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack();

    op.trainOptions.sisterSplitters =
        new HashSet<String>(Arrays.asList(op.tlpParams.sisterSplitters()));
    //    BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams);
    PrintWriter pw = op.tlpParams.pw();

    op.testOptions.display();
    op.trainOptions.display();
    op.display();
    op.tlpParams.display();

    // setup tree transforms
    Treebank trainTreebank = op.tlpParams.memoryTreebank();
    MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank();
    // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank();
    // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/";
    // blippTreebank.loadPath(blippPath, "", true);

    Timing.startTime();
    System.err.print("Reading trees...");
    testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
    if (op.testOptions.increasingLength) {
      Collections.sort(testTreebank, new TreeLengthComparator());
    }

    trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true));
    Timing.tick("done.");

    System.err.print("Binarizing trees...");
    TreeAnnotatorAndBinarizer binarizer;
    if (!op.trainOptions.leftToRight) {
      binarizer =
          new TreeAnnotatorAndBinarizer(
              op.tlpParams, op.forceCNF, !op.trainOptions.outsideFactor(), true, op);
    } else {
      binarizer =
          new TreeAnnotatorAndBinarizer(
              op.tlpParams.headFinder(),
              new LeftHeadFinder(),
              op.tlpParams,
              op.forceCNF,
              !op.trainOptions.outsideFactor(),
              true,
              op);
    }

    CollinsPuncTransformer collinsPuncTransformer = null;
    if (op.trainOptions.collinsPunc) {
      collinsPuncTransformer = new CollinsPuncTransformer(tlp);
    }
    TreeTransformer debinarizer = new Debinarizer(op.forceCNF);
    List<Tree> binaryTrainTrees = new ArrayList<Tree>();

    if (op.trainOptions.selectiveSplit) {
      op.trainOptions.splitters =
          ParentAnnotationStats.getSplitCategories(
              trainTreebank,
              op.trainOptions.tagSelectiveSplit,
              0,
              op.trainOptions.selectiveSplitCutOff,
              op.trainOptions.tagSelectiveSplitCutOff,
              op.tlpParams.treebankLanguagePack());
      if (op.trainOptions.deleteSplitters != null) {
        List<String> deleted = new ArrayList<String>();
        for (String del : op.trainOptions.deleteSplitters) {
          String baseDel = tlp.basicCategory(del);
          boolean checkBasic = del.equals(baseDel);
          for (Iterator<String> it = op.trainOptions.splitters.iterator(); it.hasNext(); ) {
            String elem = it.next();
            String baseElem = tlp.basicCategory(elem);
            boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del);
            if (delStr) {
              it.remove();
              deleted.add(elem);
            }
          }
        }
        System.err.println("Removed from vertical splitters: " + deleted);
      }
    }
    if (op.trainOptions.selectivePostSplit) {
      TreeTransformer myTransformer =
          new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams, op);
      Treebank annotatedTB = trainTreebank.transform(myTransformer);
      op.trainOptions.postSplitters =
          ParentAnnotationStats.getSplitCategories(
              annotatedTB,
              true,
              0,
              op.trainOptions.selectivePostSplitCutOff,
              op.trainOptions.tagSelectivePostSplitCutOff,
              op.tlpParams.treebankLanguagePack());
    }

    if (op.trainOptions.hSelSplit) {
      binarizer.setDoSelectiveSplit(false);
      for (Tree tree : trainTreebank) {
        if (op.trainOptions.collinsPunc) {
          tree = collinsPuncTransformer.transformTree(tree);
        }
        // tree.pennPrint(tlpParams.pw());
        tree = binarizer.transformTree(tree);
        // binaryTrainTrees.add(tree);
      }
      binarizer.setDoSelectiveSplit(true);
    }
    for (Tree tree : trainTreebank) {
      if (op.trainOptions.collinsPunc) {
        tree = collinsPuncTransformer.transformTree(tree);
      }
      tree = binarizer.transformTree(tree);
      binaryTrainTrees.add(tree);
    }
    if (op.testOptions.verbose) {
      binarizer.dumpStats();
    }

    List<Tree> binaryTestTrees = new ArrayList<Tree>();
    for (Tree tree : testTreebank) {
      if (op.trainOptions.collinsPunc) {
        tree = collinsPuncTransformer.transformTree(tree);
      }
      tree = binarizer.transformTree(tree);
      binaryTestTrees.add(tree);
    }
    Timing.tick("done."); // binarization
    BinaryGrammar bg = null;
    UnaryGrammar ug = null;
    DependencyGrammar dg = null;
    // DependencyGrammar dgBLIPP = null;
    Lexicon lex = null;
    Index<String> stateIndex = new HashIndex<String>();

    // extract grammars
    Extractor<Pair<UnaryGrammar, BinaryGrammar>> bgExtractor =
        new BinaryGrammarExtractor(op, stateIndex);
    // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor();
    // Extractor lexExtractor = new LexiconExtractor();

    // Extractor dgExtractor = new DependencyMemGrammarExtractor();

    if (op.doPCFG) {
      System.err.print("Extracting PCFG...");
      Pair<UnaryGrammar, BinaryGrammar> bgug = null;
      if (op.trainOptions.cheatPCFG) {
        List<Tree> allTrees = new ArrayList<Tree>(binaryTrainTrees);
        allTrees.addAll(binaryTestTrees);
        bgug = bgExtractor.extract(allTrees);
      } else {
        bgug = bgExtractor.extract(binaryTrainTrees);
      }
      bg = bgug.second;
      bg.splitRules();
      ug = bgug.first;
      ug.purgeRules();
      Timing.tick("done.");
    }
    System.err.print("Extracting Lexicon...");
    Index<String> wordIndex = new HashIndex<String>();
    Index<String> tagIndex = new HashIndex<String>();
    lex = op.tlpParams.lex(op, wordIndex, tagIndex);
    lex.train(binaryTrainTrees);
    Timing.tick("done.");

    if (op.doDep) {
      System.err.print("Extracting Dependencies...");
      binaryTrainTrees.clear();
      Extractor<DependencyGrammar> dgExtractor =
          new MLEDependencyGrammarExtractor(op, wordIndex, tagIndex);
      // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new
      // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new
      // TransformTreeDependency(tlpParams,true));

      // DependencyGrammar dg1 = dgExtractor.extract(trainTreebank.iterator(), new
      // TransformTreeDependency(op.tlpParams, true));
      // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new
      // TransformTreeDependency(tlpParams));

      // dg = (DependencyGrammar) dgExtractor.extract(new
      // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new
      // TransformTreeDependency(tlpParams));
      // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2);
      dg =
          dgExtractor.extract(
              binaryTrainTrees); // uses information whether the words are known or not, discards
      // unknown words
      Timing.tick("done.");
      // System.out.print("Extracting Unknown Word Model...");
      // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees);
      // Timing.tick("done.");
      System.out.print("Tuning Dependency Model...");
      dg.tune(binaryTestTrees);
      // System.out.println("TUNE DEPS: "+tuneDeps);
      Timing.tick("done.");
    }

    BinaryGrammar boundBG = bg;
    UnaryGrammar boundUG = ug;

    GrammarProjection gp = new NullGrammarProjection(bg, ug);

    // serialization
    if (serializeFile != null) {
      System.err.print("Serializing parser...");
      LexicalizedParser.saveParserDataToSerialized(
          new ParserData(lex, bg, ug, dg, stateIndex, wordIndex, tagIndex, op), serializeFile);
      Timing.tick("done.");
    }

    // test: pcfg-parse and output

    ExhaustivePCFGParser parser = null;
    if (op.doPCFG) {
      parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op, stateIndex, wordIndex, tagIndex);
    }

    ExhaustiveDependencyParser dparser =
        ((op.doDep && !op.testOptions.useFastFactored)
            ? new ExhaustiveDependencyParser(dg, lex, op, wordIndex, tagIndex)
            : null);

    Scorer scorer =
        (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp, op), dparser) : null);
    // Scorer scorer = parser;
    BiLexPCFGParser bparser = null;
    if (op.doPCFG && op.doDep) {
      bparser =
          (op.testOptions.useN5)
              ? new BiLexPCFGParser.N5BiLexPCFGParser(
                  scorer, parser, dparser, bg, ug, dg, lex, op, gp, stateIndex, wordIndex, tagIndex)
              : new BiLexPCFGParser(
                  scorer,
                  parser,
                  dparser,
                  bg,
                  ug,
                  dg,
                  lex,
                  op,
                  gp,
                  stateIndex,
                  wordIndex,
                  tagIndex);
    }

    Evalb pcfgPE = new Evalb("pcfg  PE", true);
    Evalb comboPE = new Evalb("combo PE", true);
    AbstractEval pcfgCB = new Evalb.CBEval("pcfg  CB", true);

    AbstractEval pcfgTE = new TaggingEval("pcfg  TE");
    AbstractEval comboTE = new TaggingEval("combo TE");
    AbstractEval pcfgTEnoPunct = new TaggingEval("pcfg nopunct TE");
    AbstractEval comboTEnoPunct = new TaggingEval("combo nopunct TE");
    AbstractEval depTE = new TaggingEval("depnd TE");

    AbstractEval depDE =
        new UnlabeledAttachmentEval("depnd DE", true, null, tlp.punctuationWordRejectFilter());
    AbstractEval comboDE =
        new UnlabeledAttachmentEval("combo DE", true, null, tlp.punctuationWordRejectFilter());

    if (op.testOptions.evalb) {
      EvalbFormatWriter.initEVALBfiles(op.tlpParams);
    }

    // int[] countByLength = new int[op.testOptions.maxLength+1];

    // Use a reflection ruse, so one can run this without needing the
    // tagger.  Using a function rather than a MaxentTagger means we
    // can distribute a version of the parser that doesn't include the
    // entire tagger.
    Function<List<? extends HasWord>, ArrayList<TaggedWord>> tagger = null;
    if (op.testOptions.preTag) {
      try {
        Class[] argsClass = {String.class};
        Object[] arguments = new Object[] {op.testOptions.taggerSerializedFile};
        tagger =
            (Function<List<? extends HasWord>, ArrayList<TaggedWord>>)
                Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger")
                    .getConstructor(argsClass)
                    .newInstance(arguments);
      } catch (Exception e) {
        System.err.println(e);
        System.err.println("Warning: No pretagging of sentences will be done.");
      }
    }

    for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) {
      Tree tree = testTreebank.get(tNum);
      int testTreeLen = tree.yield().size();
      if (testTreeLen > op.testOptions.maxLength) {
        continue;
      }
      Tree binaryTree = binaryTestTrees.get(tNum);
      // countByLength[testTreeLen]++;
      System.out.println("-------------------------------------");
      System.out.println("Number: " + (tNum + 1));
      System.out.println("Length: " + testTreeLen);

      // tree.pennPrint(pw);
      // System.out.println("XXXX The binary tree is");
      // binaryTree.pennPrint(pw);
      // System.out.println("Here are the tags in the lexicon:");
      // System.out.println(lex.showTags());
      // System.out.println("Here's the tagnumberer:");
      // System.out.println(Numberer.getGlobalNumberer("tags").toString());

      long timeMil1 = System.currentTimeMillis();
      Timing.tick("Starting parse.");
      if (op.doPCFG) {
        // System.err.println(op.testOptions.forceTags);
        if (op.testOptions.forceTags) {
          if (tagger != null) {
            // System.out.println("Using a tagger to set tags");
            // System.out.println("Tagged sentence as: " +
            // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false));
            parser.parse(addLast(tagger.apply(cutLast(wordify(binaryTree.yield())))));
          } else {
            // System.out.println("Forcing tags to match input.");
            parser.parse(cleanTags(binaryTree.taggedYield(), tlp));
          }
        } else {
          // System.out.println("XXXX Parsing " + binaryTree.yield());
          parser.parse(binaryTree.yieldHasWord());
        }
        // Timing.tick("Done with pcfg phase.");
      }
      if (op.doDep) {
        dparser.parse(binaryTree.yieldHasWord());
        // Timing.tick("Done with dependency phase.");
      }
      boolean bothPassed = false;
      if (op.doPCFG && op.doDep) {
        bothPassed = bparser.parse(binaryTree.yieldHasWord());
        // Timing.tick("Done with combination phase.");
      }
      long timeMil2 = System.currentTimeMillis();
      long elapsed = timeMil2 - timeMil1;
      System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec.");
      // System.out.println("PCFG Best Parse:");
      Tree tree2b = null;
      Tree tree2 = null;
      // System.out.println("Got full best parse...");
      if (op.doPCFG) {
        tree2b = parser.getBestParse();
        tree2 = debinarizer.transformTree(tree2b);
      }
      // System.out.println("Debinarized parse...");
      // tree2.pennPrint();
      // System.out.println("DepG Best Parse:");
      Tree tree3 = null;
      Tree tree3db = null;
      if (op.doDep) {
        tree3 = dparser.getBestParse();
        // was: but wrong Tree tree3db = debinarizer.transformTree(tree2);
        tree3db = debinarizer.transformTree(tree3);
        tree3.pennPrint(pw);
      }
      // tree.pennPrint();
      // ((Tree)binaryTrainTrees.get(tNum)).pennPrint();
      // System.out.println("Combo Best Parse:");
      Tree tree4 = null;
      if (op.doPCFG && op.doDep) {
        try {
          tree4 = bparser.getBestParse();
          if (tree4 == null) {
            tree4 = tree2b;
          }
        } catch (NullPointerException e) {
          System.err.println("Blocked, using PCFG parse!");
          tree4 = tree2b;
        }
      }
      if (op.doPCFG && !bothPassed) {
        tree4 = tree2b;
      }
      // tree4.pennPrint();
      if (op.doDep) {
        depDE.evaluate(tree3, binaryTree, pw);
        depTE.evaluate(tree3db, tree, pw);
      }
      TreeTransformer tc = op.tlpParams.collinizer();
      TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb();
      if (op.doPCFG) {
        // System.out.println("XXXX Best PCFG was: ");
        // tree2.pennPrint();
        // System.out.println("XXXX Transformed best PCFG is: ");
        // tc.transformTree(tree2).pennPrint();
        // System.out.println("True Best Parse:");
        // tree.pennPrint();
        // tc.transformTree(tree).pennPrint();
        pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
        pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
        Tree tree4b = null;
        if (op.doDep) {
          comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw);
          tree4b = tree4;
          tree4 = debinarizer.transformTree(tree4);
          if (op.nodePrune) {
            NodePruner np = new NodePruner(parser, debinarizer);
            tree4 = np.prune(tree4);
          }
          // tree4.pennPrint();
          comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
        }
        // pcfgTE.evaluate(tree2, tree);
        pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw);
        pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);

        if (op.doDep) {
          comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw);
          comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
        }
        System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0));

        // tc.transformTree(tree2).pennPrint();
        tree2.pennPrint(pw);

        if (op.doDep) {
          System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0));
          // tc.transformTree(tree4).pennPrint(pw);
          tree4.pennPrint(pw);
        }
        System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0));
        /*
        if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) {
          System.out.println("SCORE INVERSION");
          parser.validateBinarizedTree(binaryTree,0);
        }
        */
        tree.pennPrint(pw);
      } // end if doPCFG

      if (op.testOptions.evalb) {
        if (op.doPCFG && op.doDep) {
          EvalbFormatWriter.writeEVALBline(
              tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4));
        } else if (op.doPCFG) {
          EvalbFormatWriter.writeEVALBline(
              tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2));
        } else if (op.doDep) {
          EvalbFormatWriter.writeEVALBline(
              tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db));
        }
      }
    } // end for each tree in test treebank

    if (op.testOptions.evalb) {
      EvalbFormatWriter.closeEVALBfiles();
    }

    // op.testOptions.display();
    if (op.doPCFG) {
      pcfgPE.display(false, pw);
      System.out.println("Grammar size: " + stateIndex.size());
      pcfgCB.display(false, pw);
      if (op.doDep) {
        comboPE.display(false, pw);
      }
      pcfgTE.display(false, pw);
      pcfgTEnoPunct.display(false, pw);
      if (op.doDep) {
        comboTE.display(false, pw);
        comboTEnoPunct.display(false, pw);
      }
    }
    if (op.doDep) {
      depTE.display(false, pw);
      depDE.display(false, pw);
    }
    if (op.doPCFG && op.doDep) {
      comboDE.display(false, pw);
    }
    // pcfgPE.printGoodBad();
  }
예제 #27
0
  /**
   * Adds dependencies to list depList. These are in terms of the original tag set not the reduced
   * (projected) tag set.
   */
  protected static EndHead treeToDependencyHelper(
      Tree tree,
      List<IntDependency> depList,
      int loc,
      Index<String> wordIndex,
      Index<String> tagIndex) {
    //       try {
    // 	PrintWriter pw = new PrintWriter(new OutputStreamWriter(System.out,"GB18030"),true);
    // 	tree.pennPrint(pw);
    //       }
    //       catch (UnsupportedEncodingException e) {}

    if (tree.isLeaf() || tree.isPreTerminal()) {
      EndHead tempEndHead = new EndHead();
      tempEndHead.head = loc;
      tempEndHead.end = loc + 1;
      return tempEndHead;
    }
    Tree[] kids = tree.children();
    if (kids.length == 1) {
      return treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex);
    }
    EndHead tempEndHead = treeToDependencyHelper(kids[0], depList, loc, wordIndex, tagIndex);
    int lHead = tempEndHead.head;
    int split = tempEndHead.end;
    tempEndHead = treeToDependencyHelper(kids[1], depList, tempEndHead.end, wordIndex, tagIndex);
    int end = tempEndHead.end;
    int rHead = tempEndHead.head;
    String hTag = ((HasTag) tree.label()).tag();
    String lTag = ((HasTag) kids[0].label()).tag();
    String rTag = ((HasTag) kids[1].label()).tag();
    String hWord = ((HasWord) tree.label()).word();
    String lWord = ((HasWord) kids[0].label()).word();
    String rWord = ((HasWord) kids[1].label()).word();
    boolean leftHeaded = hWord.equals(lWord);
    String aTag = (leftHeaded ? rTag : lTag);
    String aWord = (leftHeaded ? rWord : lWord);
    int hT = tagIndex.indexOf(hTag);
    int aT = tagIndex.indexOf(aTag);
    int hW =
        (wordIndex.contains(hWord)
            ? wordIndex.indexOf(hWord)
            : wordIndex.indexOf(Lexicon.UNKNOWN_WORD));
    int aW =
        (wordIndex.contains(aWord)
            ? wordIndex.indexOf(aWord)
            : wordIndex.indexOf(Lexicon.UNKNOWN_WORD));
    int head = (leftHeaded ? lHead : rHead);
    int arg = (leftHeaded ? rHead : lHead);
    IntDependency dependency =
        new IntDependency(
            hW, hT, aW, aT, leftHeaded, (leftHeaded ? split - head - 1 : head - split));
    depList.add(dependency);
    IntDependency stopL =
        new IntDependency(
            aW, aT, STOP_WORD_INT, STOP_TAG_INT, false, (leftHeaded ? arg - split : arg - loc));
    depList.add(stopL);
    IntDependency stopR =
        new IntDependency(
            aW,
            aT,
            STOP_WORD_INT,
            STOP_TAG_INT,
            true,
            (leftHeaded ? end - arg - 1 : split - arg - 1));
    depList.add(stopR);
    // System.out.println("Adding: "+dependency+" at "+tree.label());
    tempEndHead.head = head;
    return tempEndHead;
  }
예제 #28
0
 @Override
 public Collection<L> labels() {
   return labelIndex.objectsList();
 }
예제 #29
0
 static {
   binaryIndex = new HashIndex<>();
   binaryIndex.add(POS_LABEL);
   binaryIndex.add(NEG_LABEL);
   posIndex = binaryIndex.indexOf(POS_LABEL);
 }
예제 #30
0
 public static <L, F> OneVsAllClassifier<L, F> train(
     ClassifierFactory<String, F, Classifier<String, F>> classifierFactory,
     GeneralDataset<L, F> dataset) {
   Index<L> labelIndex = dataset.labelIndex();
   return train(classifierFactory, dataset, labelIndex.objectsList());
 }