예제 #1
0
 private String getBestTag(String word) {
   double bestScore = Double.NEGATIVE_INFINITY;
   String bestTag = null;
   for (String tag : lexicon.getAllTags()) {
     double score = lexicon.scoreTagging(word, tag);
     if (bestTag == null || score > bestScore) {
       bestScore = score;
       bestTag = tag;
     }
   }
   return bestTag;
 }
예제 #2
0
  /**
   * @param previousGrammar
   * @param previousLexicon
   * @param grammar
   * @param lexicon
   * @param trainStateSetTrees
   * @return
   */
  public static double doOneEStep(
      Grammar previousGrammar,
      Lexicon previousLexicon,
      Grammar grammar,
      Lexicon lexicon,
      StateSetTreeList trainStateSetTrees,
      boolean updateOnlyLexicon,
      int unkThreshold) {
    boolean secondHalf = false;
    ArrayParser parser = new ArrayParser(previousGrammar, previousLexicon);
    double trainingLikelihood = 0;
    int n = 0;
    int nTrees = trainStateSetTrees.size();
    for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
      secondHalf = (n++ > nTrees / 2.0);
      boolean noSmoothing = true, debugOutput = false;
      parser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput); // E Step
      double ll = stateSetTree.getLabel().getIScore(0);
      ll =
          Math.log(ll)
              + (100 * stateSetTree.getLabel().getIScale()); // System.out.println(stateSetTree);
      if ((Double.isInfinite(ll) || Double.isNaN(ll))) {
        if (VERBOSE) {
          System.out.println("Training sentence " + n + " is given " + ll + " log likelihood!");
          System.out.println(
              "Root iScore "
                  + stateSetTree.getLabel().getIScore(0)
                  + " scale "
                  + stateSetTree.getLabel().getIScale());
        }
      } else {
        lexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf, noSmoothing, unkThreshold);
        if (!updateOnlyLexicon) grammar.tallyStateSetTree(stateSetTree, previousGrammar); // E Step
        trainingLikelihood += ll; // there are for some reason some sentences that are unparsable
      }
    }
    lexicon.tieRareWordStats(unkThreshold);

    // SSIE
    ((SophisticatedLexicon) lexicon).overwriteWithMaxent();

    return trainingLikelihood;
  }
예제 #3
0
  public static void main(String[] args) {
    OptionParser optParser = new OptionParser(Options.class);
    Options opts = (Options) optParser.parse(args, true);
    // provide feedback on command-line arguments
    System.out.println("Calling with " + optParser.getPassedInOptions());

    String path = opts.path;
    //    int lang = opts.lang;
    System.out.println("Loading trees from " + path + " and using language " + opts.treebank);

    double trainingFractionToKeep = opts.trainingFractionToKeep;

    int maxSentenceLength = opts.maxSentenceLength;
    System.out.println("Will remove sentences with more than " + maxSentenceLength + " words.");

    HORIZONTAL_MARKOVIZATION = opts.horizontalMarkovization;
    VERTICAL_MARKOVIZATION = opts.verticalMarkovization;
    System.out.println(
        "Using horizontal="
            + HORIZONTAL_MARKOVIZATION
            + " and vertical="
            + VERTICAL_MARKOVIZATION
            + " markovization.");

    Binarization binarization = opts.binarization;
    System.out.println(
        "Using " + binarization.name() + " binarization."); // and "+annotateString+".");

    double randomness = opts.randomization;
    System.out.println("Using a randomness value of " + randomness);

    String outFileName = opts.outFileName;
    if (outFileName == null) {
      System.out.println("Output File name is required.");
      System.exit(-1);
    } else System.out.println("Using grammar output file " + outFileName + ".");

    VERBOSE = opts.verbose;
    RANDOM = new Random(opts.randSeed);
    System.out.println("Random number generator seeded at " + opts.randSeed + ".");

    boolean manualAnnotation = false;
    boolean baseline = opts.baseline;
    boolean noSplit = opts.noSplit;
    int numSplitTimes = opts.numSplits;
    if (baseline) numSplitTimes = 0;
    String splitGrammarFile = opts.inFile;
    int allowedDroppingIters = opts.di;

    int maxIterations = opts.splitMaxIterations;
    int minIterations = opts.splitMinIterations;
    if (minIterations > 0)
      System.out.println("I will do at least " + minIterations + " iterations.");

    double[] smoothParams = {opts.smoothingParameter1, opts.smoothingParameter2};
    System.out.println("Using smoothing parameters " + smoothParams[0] + " and " + smoothParams[1]);

    boolean allowMoreSubstatesThanCounts = false;
    boolean findClosedUnaryPaths = opts.findClosedUnaryPaths;

    Corpus corpus =
        new Corpus(
            path,
            opts.treebank,
            trainingFractionToKeep,
            false,
            opts.skipSection,
            opts.skipBilingual);
    List<Tree<String>> trainTrees =
        Corpus.binarizeAndFilterTrees(
            corpus.getTrainTrees(),
            VERTICAL_MARKOVIZATION,
            HORIZONTAL_MARKOVIZATION,
            maxSentenceLength,
            binarization,
            manualAnnotation,
            VERBOSE);
    List<Tree<String>> validationTrees =
        Corpus.binarizeAndFilterTrees(
            corpus.getValidationTrees(),
            VERTICAL_MARKOVIZATION,
            HORIZONTAL_MARKOVIZATION,
            maxSentenceLength,
            binarization,
            manualAnnotation,
            VERBOSE);
    Numberer tagNumberer = Numberer.getGlobalNumberer("tags");

    //		for (Tree<String> t : trainTrees){
    //				System.out.println(t);
    //			}

    if (opts.trainOnDevSet) {
      System.out.println("Adding devSet to training data.");
      trainTrees.addAll(validationTrees);
    }

    if (opts.lowercase) {
      System.out.println("Lowercasing the treebank.");
      Corpus.lowercaseWords(trainTrees);
      Corpus.lowercaseWords(validationTrees);
    }

    int nTrees = trainTrees.size();

    System.out.println("There are " + nTrees + " trees in the training set.");

    double filter = opts.filter;
    if (filter > 0)
      System.out.println(
          "Will remove rules with prob under "
              + filter
              + ".\nEven though only unlikely rules are pruned the training LL is not guaranteed to increase in every round anymore "
              + "(especially when we are close to converging)."
              + "\nFurthermore it increases the variance because 'good' rules can be pruned away in early stages.");

    short nSubstates = opts.nSubStates;
    short[] numSubStatesArray =
        initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates);
    if (baseline) {
      short one = 1;
      Arrays.fill(numSubStatesArray, one);
      System.out.println("Training just the baseline grammar (1 substate for all states)");
      randomness = 0.0f;
    }

    if (VERBOSE) {
      for (int i = 0; i < numSubStatesArray.length; i++) {
        System.out.println("Tag " + (String) tagNumberer.object(i) + " " + i);
      }
    }

    System.out.println("There are " + numSubStatesArray.length + " observed categories.");

    // initialize lexicon and grammar
    Lexicon lexicon = null, maxLexicon = null, previousLexicon = null;
    Grammar grammar = null, maxGrammar = null, previousGrammar = null;
    double maxLikelihood = Double.NEGATIVE_INFINITY;

    //    String smootherStr = opts.smooth;
    //    Smoother lexiconSmoother = null;
    //    Smoother grammarSmoother = null;
    //    if (splitGrammarFile!=null){
    //    	lexiconSmoother = maxLexicon.smoother;
    //    	grammarSmoother = maxGrammar.smoother;
    //    	System.out.println("Using smoother from input grammar.");
    //    }
    //    else if (smootherStr.equals("NoSmoothing"))
    //    	lexiconSmoother = grammarSmoother = new NoSmoothing();
    //    else if (smootherStr.equals("SmoothAcrossParentBits")) {
    //    	lexiconSmoother = grammarSmoother  = new SmoothAcrossParentBits(grammarSmoothing,
    // maxGrammar.splitTrees);
    //    }
    //    else
    //    	throw new Error("I didn't understand the type of smoother '"+smootherStr+"'");
    //    System.out.println("Using smoother "+smootherStr);

    // EM: iterate until the validation likelihood drops for four consecutive
    // iterations
    int iter = 0;
    int droppingIter = 0;

    //  If we are splitting, we load the old grammar and start off by splitting.
    int startSplit = 0;
    if (splitGrammarFile != null) {
      System.out.println("Loading old grammar from " + splitGrammarFile);
      startSplit = 1; // we've already trained the grammar
      ParserData pData = ParserData.Load(splitGrammarFile);
      maxGrammar = pData.gr;
      maxLexicon = pData.lex;
      numSubStatesArray = maxGrammar.numSubStates;
      previousGrammar = grammar = maxGrammar;
      previousLexicon = lexicon = maxLexicon;
      Numberer.setNumberers(pData.getNumbs());
      tagNumberer = Numberer.getGlobalNumberer("tags");
      System.out.println("Loading old grammar complete.");
      if (noSplit) {
        System.out.println("Will NOT split the loaded grammar.");
        startSplit = 0;
      }
    }

    double mergingPercentage = opts.mergingPercentage;
    boolean separateMergingThreshold = opts.separateMergingThreshold;
    if (mergingPercentage > 0) {
      System.out.println(
          "Will merge " + (int) (mergingPercentage * 100) + "% of the splits in each round.");
      System.out.println(
          "The threshold for merging lexical and phrasal categories will be set separately: "
              + separateMergingThreshold);
    }

    StateSetTreeList trainStateSetTrees =
        new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
    StateSetTreeList validationStateSetTrees =
        new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer); // deletePC);

    // get rid of the old trees
    trainTrees = null;
    validationTrees = null;
    corpus = null;
    System.gc();

    if (opts.simpleLexicon) {
      System.out.println(
          "Replacing words which have been seen less than 5 times with their signature.");
      Corpus.replaceRareWords(
          trainStateSetTrees, new SimpleLexicon(numSubStatesArray, -1), opts.rare);
    }

    // If we're training without loading a split grammar, then we run once without splitting.
    if (splitGrammarFile == null) {
      grammar =
          new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter);
      Lexicon tmp_lexicon =
          (opts.simpleLexicon)
              ? new SimpleLexicon(
                  numSubStatesArray,
                  -1,
                  smoothParams,
                  new NoSmoothing(),
                  filter,
                  trainStateSetTrees)
              : new SophisticatedLexicon(
                  numSubStatesArray,
                  SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                  smoothParams,
                  new NoSmoothing(),
                  filter);
      int n = 0;
      boolean secondHalf = false;
      for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
        secondHalf = (n++ > nTrees / 2.0);
        tmp_lexicon.trainTree(stateSetTree, randomness, null, secondHalf, false, opts.rare);
      }
      lexicon =
          (opts.simpleLexicon)
              ? new SimpleLexicon(
                  numSubStatesArray,
                  -1,
                  smoothParams,
                  new NoSmoothing(),
                  filter,
                  trainStateSetTrees)
              : new SophisticatedLexicon(
                  numSubStatesArray,
                  SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                  smoothParams,
                  new NoSmoothing(),
                  filter);
      for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
        secondHalf = (n++ > nTrees / 2.0);
        lexicon.trainTree(stateSetTree, randomness, tmp_lexicon, secondHalf, false, opts.rare);
        grammar.tallyUninitializedStateSetTree(stateSetTree);
      }
      lexicon.tieRareWordStats(opts.rare);
      lexicon.optimize();

      // SSIE
      ((SophisticatedLexicon) lexicon).overwriteWithMaxent();

      grammar.optimize(randomness);
      // System.out.println(grammar);
      previousGrammar = maxGrammar = grammar; // needed for baseline - when there is no EM loop
      previousLexicon = maxLexicon = lexicon;
    }

    // the main loop: split and train the grammar
    for (int splitIndex = startSplit; splitIndex < numSplitTimes * 3; splitIndex++) {

      // now do either a merge or a split and the end a smooth
      // on odd iterations merge, on even iterations split
      String opString = "";
      if (splitIndex % 3 == 2) { // (splitIndex==numSplitTimes*2){
        if (opts.smooth.equals("NoSmoothing")) continue;
        System.out.println("Setting smoother for grammar and lexicon.");
        Smoother grSmoother = new SmoothAcrossParentBits(0.01, maxGrammar.splitTrees);
        Smoother lexSmoother = new SmoothAcrossParentBits(0.1, maxGrammar.splitTrees);
        //        Smoother grSmoother = new SmoothAcrossParentSubstate(0.01);
        //        Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1);
        maxGrammar.setSmoother(grSmoother);
        maxLexicon.setSmoother(lexSmoother);
        minIterations = maxIterations = opts.smoothMaxIterations;
        opString = "smoothing";
      } else if (splitIndex % 3 == 0) {
        // the case where we split
        if (opts.noSplit) continue;
        System.out.println(
            "Before splitting, we have a total of " + maxGrammar.totalSubStates() + " substates.");
        CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer, trainStateSetTrees);
        int[] counts = corpusStatistics.getSymbolCounts();

        maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts, 0);
        maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts, 0);
        Smoother grSmoother = new NoSmoothing();
        Smoother lexSmoother = new NoSmoothing();
        maxGrammar.setSmoother(grSmoother);
        maxLexicon.setSmoother(lexSmoother);
        System.out.println(
            "After splitting, we have a total of " + maxGrammar.totalSubStates() + " substates.");
        System.out.println(
            "Rule probabilities are NOT normalized in the split, therefore the training LL is not guaranteed to improve between iteration 0 and 1!");
        opString = "splitting";
        maxIterations = opts.splitMaxIterations;
        minIterations = opts.splitMinIterations;
      } else {
        if (mergingPercentage == 0) continue;
        // the case where we merge
        double[][] mergeWeights =
            GrammarMerger.computeMergeWeights(maxGrammar, maxLexicon, trainStateSetTrees);
        double[][][] deltas =
            GrammarMerger.computeDeltas(maxGrammar, maxLexicon, mergeWeights, trainStateSetTrees);
        boolean[][][] mergeThesePairs =
            GrammarMerger.determineMergePairs(
                deltas, separateMergingThreshold, mergingPercentage, maxGrammar);

        grammar = GrammarMerger.doTheMerges(maxGrammar, maxLexicon, mergeThesePairs, mergeWeights);
        short[] newNumSubStatesArray = grammar.numSubStates;
        trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, newNumSubStatesArray, false);
        validationStateSetTrees =
            new StateSetTreeList(validationStateSetTrees, newNumSubStatesArray, false);

        // retrain lexicon to finish the lexicon merge (updates the unknown words model)...
        lexicon =
            (opts.simpleLexicon)
                ? new SimpleLexicon(
                    newNumSubStatesArray,
                    -1,
                    smoothParams,
                    maxLexicon.getSmoother(),
                    filter,
                    trainStateSetTrees)
                : new SophisticatedLexicon(
                    newNumSubStatesArray,
                    SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                    maxLexicon.getSmoothingParams(),
                    maxLexicon.getSmoother(),
                    maxLexicon.getPruningThreshold());
        boolean updateOnlyLexicon = true;
        double trainingLikelihood =
            GrammarTrainer.doOneEStep(
                grammar,
                maxLexicon,
                null,
                lexicon,
                trainStateSetTrees,
                updateOnlyLexicon,
                opts.rare);
        //    		System.out.println("The training LL is "+trainingLikelihood);
        lexicon
            .optimize(); // Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION);
                         // // M Step

        GrammarMerger.printMergingStatistics(maxGrammar, grammar);
        opString = "merging";
        maxGrammar = grammar;
        maxLexicon = lexicon;
        maxIterations = opts.mergeMaxIterations;
        minIterations = opts.mergeMinIterations;
      }
      // update the substate dependent objects
      previousGrammar = grammar = maxGrammar;
      previousLexicon = lexicon = maxLexicon;
      droppingIter = 0;
      numSubStatesArray = grammar.numSubStates;
      trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, numSubStatesArray, false);
      validationStateSetTrees =
          new StateSetTreeList(validationStateSetTrees, numSubStatesArray, false);
      maxLikelihood = calculateLogLikelihood(maxGrammar, maxLexicon, validationStateSetTrees);
      System.out.println(
          "After "
              + opString
              + " in the "
              + (splitIndex / 3 + 1)
              + "th round, we get a validation likelihood of "
              + maxLikelihood);
      iter = 0;

      // the inner loop: train the grammar via EM until validation likelihood reliably drops
      do {
        iter += 1;
        System.out.println("Beginning iteration " + (iter - 1) + ":");

        // 1) Compute the validation likelihood of the previous iteration
        System.out.print("Calculating validation likelihood...");
        double validationLikelihood =
            calculateLogLikelihood(
                previousGrammar,
                previousLexicon,
                validationStateSetTrees); // The validation LL of previousGrammar/previousLexicon
        System.out.println("done: " + validationLikelihood);

        // 2) Perform the E step while computing the training likelihood of the previous iteration
        System.out.print("Calculating training likelihood...");
        grammar =
            new Grammar(
                grammar.numSubStates,
                grammar.findClosedPaths,
                grammar.smoother,
                grammar,
                grammar.threshold);
        lexicon =
            (opts.simpleLexicon)
                ? new SimpleLexicon(
                    grammar.numSubStates,
                    -1,
                    smoothParams,
                    lexicon.getSmoother(),
                    filter,
                    trainStateSetTrees)
                : new SophisticatedLexicon(
                    grammar.numSubStates,
                    SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                    lexicon.getSmoothingParams(),
                    lexicon.getSmoother(),
                    lexicon.getPruningThreshold());
        boolean updateOnlyLexicon = false;
        double trainingLikelihood =
            doOneEStep(
                previousGrammar,
                previousLexicon,
                grammar,
                lexicon,
                trainStateSetTrees,
                updateOnlyLexicon,
                opts.rare); // The training LL of previousGrammar/previousLexicon
        System.out.println("done: " + trainingLikelihood);

        // 3) Perform the M-Step
        lexicon.optimize(); // M Step
        grammar.optimize(0); // M Step

        // 4) Check whether previousGrammar/previousLexicon was in fact better than the best
        if (iter < minIterations || validationLikelihood >= maxLikelihood) {
          maxLikelihood = validationLikelihood;
          maxGrammar = previousGrammar;
          maxLexicon = previousLexicon;
          droppingIter = 0;
        } else {
          droppingIter++;
        }

        // 5) advance the 'pointers'
        previousGrammar = grammar;
        previousLexicon = lexicon;
      } while ((droppingIter < allowedDroppingIters) && (!baseline) && (iter < maxIterations));

      // Dump a grammar file to disk from time to time
      ParserData pData =
          new ParserData(
              maxLexicon,
              maxGrammar,
              null,
              Numberer.getNumberers(),
              numSubStatesArray,
              VERTICAL_MARKOVIZATION,
              HORIZONTAL_MARKOVIZATION,
              binarization);
      String outTmpName = outFileName + "_" + (splitIndex / 3 + 1) + "_" + opString + ".gr";
      System.out.println("Saving grammar to " + outTmpName + ".");
      if (pData.Save(outTmpName)) System.out.println("Saving successful.");
      else System.out.println("Saving failed!");
      pData = null;
    }

    // The last grammar/lexicon has not yet been evaluated. Even though the validation likelihood
    // has been dropping in the past few iteration, there is still a chance that the last one was in
    // fact the best so just in case we evaluate it.
    System.out.print("Calculating last validation likelihood...");
    double validationLikelihood = calculateLogLikelihood(grammar, lexicon, validationStateSetTrees);
    System.out.println(
        "done.\n  Iteration "
            + iter
            + " (final) gives validation likelihood "
            + validationLikelihood);
    if (validationLikelihood > maxLikelihood) {
      maxLikelihood = validationLikelihood;
      maxGrammar = previousGrammar;
      maxLexicon = previousLexicon;
    }

    ParserData pData =
        new ParserData(
            maxLexicon,
            maxGrammar,
            null,
            Numberer.getNumberers(),
            numSubStatesArray,
            VERTICAL_MARKOVIZATION,
            HORIZONTAL_MARKOVIZATION,
            binarization);

    System.out.println("Saving grammar to " + outFileName + ".");
    System.out.println("It gives a validation data log likelihood of: " + maxLikelihood);
    if (pData.Save(outFileName)) System.out.println("Saving successful.");
    else System.out.println("Saving failed!");

    System.exit(0);
  }
예제 #4
0
    public Tree<String> getBestParse(List<String> sentence) {
      // TODO: implement this method
      int n = sentence.size();

      // System.out.println("getBestParse: n=" + n);

      List<List<Map<Object, Double>>> scores = new ArrayList<List<Map<Object, Double>>>(n + 1);
      for (int i = 0; i < n + 1; i++) {
        List<Map<Object, Double>> row = new ArrayList<Map<Object, Double>>(n + 1);
        for (int j = 0; j < n + 1; j++) {
          row.add(new HashMap<Object, Double>());
        }
        scores.add(row);
      }
      List<List<Map<Object, Triplet<Integer, Object, Object>>>> backs =
          new ArrayList<List<Map<Object, Triplet<Integer, Object, Object>>>>(n + 1);
      for (int i = 0; i < n + 1; i++) {
        List<Map<Object, Triplet<Integer, Object, Object>>> row =
            new ArrayList<Map<Object, Triplet<Integer, Object, Object>>>(n + 1);
        for (int j = 0; j < n + 1; j++) {
          row.add(new HashMap<Object, Triplet<Integer, Object, Object>>());
        }
        backs.add(row);
      }

      /*
      System.out.println("scores=" + scores.size() + "x" + scores.get(0).size());
      System.out.println("backs=" + backs.size() + "x" + backs.get(0).size());
      printChart(scores, backs, "scores");
      */
      // First the Lexicon

      for (int i = 0; i < n; i++) {
        String word = sentence.get(i);
        for (String tag : lexicon.getAllTags()) {
          UnaryRule A = new UnaryRule(tag, word);
          A.setScore(Math.log(lexicon.scoreTagging(word, tag)));
          scores.get(i).get(i + 1).put(A, A.getScore());
          backs.get(i).get(i + 1).put(A, null);
        }

        // System.out.println("Starting unaries: i=" + i + ",n=" + n );
        // Handle unaries
        boolean added = true;
        while (added) {
          added = false;
          Map<Object, Double> A_scores = scores.get(i).get(i + 1);
          // Don't modify the dict we are iterating
          List<Object> A_keys = copyKeys(A_scores);
          // for (int j = 0; j < 5 && j < A_keys.size(); j++) {
          //	System.out.print("," + j + "=" + A_scores.get(A_keys.get(j)));
          // }

          for (Object oB : A_keys) {
            UnaryRule B = (UnaryRule) oB;
            for (UnaryRule A : grammar.getUnaryRulesByChild(B.getParent())) {
              double prob = Math.log(A.getScore()) + A_scores.get(B);
              if (prob > -1000.0) {

                if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                  // System.out.print(" *A=" + A + ", B=" + B);
                  // System.out.print(",  prob=" +  prob);
                  // System.out.println(",  A_scores.get(A)=" +  A_scores.get(A));
                  A_scores.put(A, prob);
                  backs.get(i).get(i + 1).put(A, new Triplet<Integer, Object, Object>(-1, B, null));
                  added = true;
                }
                // System.out.println(", added=" + added);
              }
            }
          }
          // System.out.println(", A_scores=" + A_scores.size() + ", added=" + added);
        }
      }

      // printChart(scores, backs, "scores with Lexicon");

      // Do higher layers
      // Naming is based on rules: A -> B,C

      long startTime = new Date().getTime();
      for (int span = 2; span < n + 1; span++) {

        for (int begin = 0; begin < n + 1 - span; begin++) {
          int end = begin + span;

          Map<Object, Double> A_scores = scores.get(begin).get(end);
          Map<Object, Triplet<Integer, Object, Object>> A_backs = backs.get(begin).get(end);

          for (int split = begin + 1; split < end; split++) {

            Map<Object, Double> B_scores = scores.get(begin).get(split);
            Map<Object, Double> C_scores = scores.get(split).get(end);

            List<Object> B_list = new ArrayList<Object>(B_scores.keySet());
            List<Object> C_list = new ArrayList<Object>(C_scores.keySet());

            // This is a key optimization. !@#$
            // It avoids a B_list.size() x C_list.size() search in the for (Object B : B_list) loop
            Map<String, List<Object>> C_map = new HashMap<String, List<Object>>();
            for (Object C : C_list) {
              String parent = getParent(C);
              if (!C_map.containsKey(parent)) {
                C_map.put(parent, new ArrayList<Object>());
              }
              C_map.get(parent).add(C);
            }

            for (Object B : B_list) {
              for (BinaryRule A : grammar.getBinaryRulesByLeftChild(getParent(B))) {
                if (C_map.containsKey(A.getRightChild())) {
                  for (Object C : C_map.get(A.getRightChild())) {
                    // We now have A which has B as left child and C as right child
                    double prob = Math.log(A.getScore()) + B_scores.get(B) + C_scores.get(C);
                    if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                      A_scores.put(A, prob);
                      A_backs.put(A, new Triplet<Integer, Object, Object>(split, B, C));
                    }
                  }
                }
              }
            }
          }

          // Handle unaries: A -> B
          boolean added = true;
          while (added) {
            added = false;
            // Don't modify the dict we are iterating
            List<Object> A_keys = copyKeys(A_scores);
            for (Object oB : A_keys) {
              for (UnaryRule A : grammar.getUnaryRulesByChild(getParent(oB))) {
                double prob = Math.log(A.getScore()) + A_scores.get(oB);
                if (!A_scores.containsKey(A) || prob > A_scores.get(A)) {
                  A_scores.put(A, prob);
                  A_backs.put(A, new Triplet<Integer, Object, Object>(-1, oB, null));
                  added = true;
                }
              }
            }
          }
        }
      }

      // printChart(scores, backs, "scores with Lexicon and Grammar");

      Map<Object, Double> topOfChart = scores.get(0).get(n);

      System.out.println("topOfChart: " + topOfChart.size());
      /*
      for (Object o: topOfChart.keySet()) {
          System.out.println("o=" + o + ", score=" + topOfChart.getCount(o));
      }
      */

      // All parses have "ROOT" at top of tree
      Object bestKey = null;
      Object secondBestKey = null;
      double bestScore = Double.NEGATIVE_INFINITY;
      double secondBestScore = Double.NEGATIVE_INFINITY;
      for (Object key : topOfChart.keySet()) {
        double score = topOfChart.get(key);
        if (score >= secondBestScore || secondBestKey == null) {
          secondBestKey = key;
          secondBestScore = score;
        }
        if ("ROOT".equals(getParent(key)) && (score >= bestScore || bestKey == null)) {
          bestKey = key;
          bestScore = score;
        }
      }

      if (bestKey == null) {
        bestKey = secondBestKey;
        System.out.println("secondBestKey=" + secondBestKey);
      }
      if (bestKey == null) {
        for (Object key : topOfChart.keySet()) {
          System.out.println("val=" + topOfChart.get(key) + ", key=" + key);
        }
      }
      System.out.println("bestKey=" + bestKey + ", log(prob)=" + topOfChart.get(bestKey));

      Tree<String> result = makeTree(backs, 0, n, bestKey);
      if (!"ROOT".equals(result.getLabel())) {
        List<Tree<String>> children = new ArrayList<Tree<String>>();
        children.add(result);
        result = new Tree<String>("ROOT", children); // !@#$
      }

      /*
      System.out.println("==================================================");
      System.out.println(result);
      System.out.println("====================^^^^^^========================");
      */
      return TreeAnnotations.unAnnotateTree(result);
    }
예제 #5
0
  public static void main(String[] args) {
    Options op = new Options(new EnglishTreebankParserParams());
    // op.tlpParams may be changed to something else later, so don't use it till
    // after options are parsed.

    System.out.println("Currently " + new Date());
    System.out.print("Invoked with arguments:");
    for (String arg : args) {
      System.out.print(" " + arg);
    }
    System.out.println();

    String path = "/u/nlp/stuff/corpora/Treebank3/parsed/mrg/wsj";
    int trainLow = 200, trainHigh = 2199, testLow = 2200, testHigh = 2219;
    String serializeFile = null;

    int i = 0;
    while (i < args.length && args[i].startsWith("-")) {
      if (args[i].equalsIgnoreCase("-path") && (i + 1 < args.length)) {
        path = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-train") && (i + 2 < args.length)) {
        trainLow = Integer.parseInt(args[i + 1]);
        trainHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-test") && (i + 2 < args.length)) {
        testLow = Integer.parseInt(args[i + 1]);
        testHigh = Integer.parseInt(args[i + 2]);
        i += 3;
      } else if (args[i].equalsIgnoreCase("-serialize") && (i + 1 < args.length)) {
        serializeFile = args[i + 1];
        i += 2;
      } else if (args[i].equalsIgnoreCase("-tLPP") && (i + 1 < args.length)) {
        try {
          op.tlpParams = (TreebankLangParserParams) Class.forName(args[i + 1]).newInstance();
        } catch (ClassNotFoundException e) {
          System.err.println("Class not found: " + args[i + 1]);
        } catch (InstantiationException e) {
          System.err.println("Couldn't instantiate: " + args[i + 1] + ": " + e.toString());
        } catch (IllegalAccessException e) {
          System.err.println("illegal access" + e);
        }
        i += 2;
      } else if (args[i].equals("-encoding")) {
        // sets encoding for TreebankLangParserParams
        op.tlpParams.setInputEncoding(args[i + 1]);
        op.tlpParams.setOutputEncoding(args[i + 1]);
        i += 2;
      } else {
        i = op.setOptionOrWarn(args, i);
      }
    }
    // System.out.println(tlpParams.getClass());
    TreebankLanguagePack tlp = op.tlpParams.treebankLanguagePack();

    Train.sisterSplitters = new HashSet(Arrays.asList(op.tlpParams.sisterSplitters()));
    //    BinarizerFactory.TreeAnnotator.setTreebankLang(tlpParams);
    PrintWriter pw = op.tlpParams.pw();

    Test.display();
    Train.display();
    op.display();
    op.tlpParams.display();

    // setup tree transforms
    Treebank trainTreebank = op.tlpParams.memoryTreebank();
    MemoryTreebank testTreebank = op.tlpParams.testMemoryTreebank();
    // Treebank blippTreebank = ((EnglishTreebankParserParams) tlpParams).diskTreebank();
    // String blippPath = "/afs/ir.stanford.edu/data/linguistic-data/BLLIP-WSJ/";
    // blippTreebank.loadPath(blippPath, "", true);

    Timing.startTime();
    System.err.print("Reading trees...");
    testTreebank.loadPath(path, new NumberRangeFileFilter(testLow, testHigh, true));
    if (Test.increasingLength) {
      Collections.sort(testTreebank, new TreeLengthComparator());
    }

    trainTreebank.loadPath(path, new NumberRangeFileFilter(trainLow, trainHigh, true));
    Timing.tick("done.");
    System.err.print("Binarizing trees...");
    TreeAnnotatorAndBinarizer binarizer = null;
    if (!Train.leftToRight) {
      binarizer =
          new TreeAnnotatorAndBinarizer(op.tlpParams, op.forceCNF, !Train.outsideFactor(), true);
    } else {
      binarizer =
          new TreeAnnotatorAndBinarizer(
              op.tlpParams.headFinder(),
              new LeftHeadFinder(),
              op.tlpParams,
              op.forceCNF,
              !Train.outsideFactor(),
              true);
    }
    CollinsPuncTransformer collinsPuncTransformer = null;
    if (Train.collinsPunc) {
      collinsPuncTransformer = new CollinsPuncTransformer(tlp);
    }
    TreeTransformer debinarizer = new Debinarizer(op.forceCNF);
    List<Tree> binaryTrainTrees = new ArrayList<Tree>();

    if (Train.selectiveSplit) {
      Train.splitters =
          ParentAnnotationStats.getSplitCategories(
              trainTreebank,
              Train.tagSelectiveSplit,
              0,
              Train.selectiveSplitCutOff,
              Train.tagSelectiveSplitCutOff,
              op.tlpParams.treebankLanguagePack());
      if (Train.deleteSplitters != null) {
        List<String> deleted = new ArrayList<String>();
        for (String del : Train.deleteSplitters) {
          String baseDel = tlp.basicCategory(del);
          boolean checkBasic = del.equals(baseDel);
          for (Iterator<String> it = Train.splitters.iterator(); it.hasNext(); ) {
            String elem = it.next();
            String baseElem = tlp.basicCategory(elem);
            boolean delStr = checkBasic && baseElem.equals(baseDel) || elem.equals(del);
            if (delStr) {
              it.remove();
              deleted.add(elem);
            }
          }
        }
        System.err.println("Removed from vertical splitters: " + deleted);
      }
    }
    if (Train.selectivePostSplit) {
      TreeTransformer myTransformer = new TreeAnnotator(op.tlpParams.headFinder(), op.tlpParams);
      Treebank annotatedTB = trainTreebank.transform(myTransformer);
      Train.postSplitters =
          ParentAnnotationStats.getSplitCategories(
              annotatedTB,
              true,
              0,
              Train.selectivePostSplitCutOff,
              Train.tagSelectivePostSplitCutOff,
              op.tlpParams.treebankLanguagePack());
    }

    if (Train.hSelSplit) {
      binarizer.setDoSelectiveSplit(false);
      for (Tree tree : trainTreebank) {
        if (Train.collinsPunc) {
          tree = collinsPuncTransformer.transformTree(tree);
        }
        // tree.pennPrint(tlpParams.pw());
        tree = binarizer.transformTree(tree);
        // binaryTrainTrees.add(tree);
      }
      binarizer.setDoSelectiveSplit(true);
    }
    for (Tree tree : trainTreebank) {
      if (Train.collinsPunc) {
        tree = collinsPuncTransformer.transformTree(tree);
      }
      tree = binarizer.transformTree(tree);
      binaryTrainTrees.add(tree);
    }
    if (Test.verbose) {
      binarizer.dumpStats();
    }

    List<Tree> binaryTestTrees = new ArrayList<Tree>();
    for (Tree tree : testTreebank) {
      if (Train.collinsPunc) {
        tree = collinsPuncTransformer.transformTree(tree);
      }
      tree = binarizer.transformTree(tree);
      binaryTestTrees.add(tree);
    }
    Timing.tick("done."); // binarization
    BinaryGrammar bg = null;
    UnaryGrammar ug = null;
    DependencyGrammar dg = null;
    // DependencyGrammar dgBLIPP = null;
    Lexicon lex = null;
    // extract grammars
    Extractor bgExtractor = new BinaryGrammarExtractor();
    // Extractor bgExtractor = new SmoothedBinaryGrammarExtractor();//new BinaryGrammarExtractor();
    // Extractor lexExtractor = new LexiconExtractor();

    // Extractor dgExtractor = new DependencyMemGrammarExtractor();

    Extractor dgExtractor = new MLEDependencyGrammarExtractor(op);
    if (op.doPCFG) {
      System.err.print("Extracting PCFG...");
      Pair bgug = null;
      if (Train.cheatPCFG) {
        List allTrees = new ArrayList(binaryTrainTrees);
        allTrees.addAll(binaryTestTrees);
        bgug = (Pair) bgExtractor.extract(allTrees);
      } else {
        bgug = (Pair) bgExtractor.extract(binaryTrainTrees);
      }
      bg = (BinaryGrammar) bgug.second;
      bg.splitRules();
      ug = (UnaryGrammar) bgug.first;
      ug.purgeRules();
      Timing.tick("done.");
    }
    System.err.print("Extracting Lexicon...");
    lex = op.tlpParams.lex(op.lexOptions);
    lex.train(binaryTrainTrees);
    Timing.tick("done.");

    if (op.doDep) {
      System.err.print("Extracting Dependencies...");
      binaryTrainTrees.clear();
      // dgBLIPP = (DependencyGrammar) dgExtractor.extract(new
      // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new
      // TransformTreeDependency(tlpParams,true));

      DependencyGrammar dg1 =
          (DependencyGrammar)
              dgExtractor.extract(
                  trainTreebank.iterator(), new TransformTreeDependency(op.tlpParams, true));
      // dgBLIPP=(DependencyGrammar)dgExtractor.extract(blippTreebank.iterator(),new
      // TransformTreeDependency(tlpParams));

      // dg = (DependencyGrammar) dgExtractor.extract(new
      // ConcatenationIterator(trainTreebank.iterator(),blippTreebank.iterator()),new
      // TransformTreeDependency(tlpParams));
      // dg=new DependencyGrammarCombination(dg1,dgBLIPP,2);
      // dg = (DependencyGrammar) dgExtractor.extract(binaryTrainTrees); //uses information whether
      // the words are known or not, discards unknown words
      Timing.tick("done.");
      // System.out.print("Extracting Unknown Word Model...");
      // UnknownWordModel uwm = (UnknownWordModel)uwmExtractor.extract(binaryTrainTrees);
      // Timing.tick("done.");
      System.out.print("Tuning Dependency Model...");
      dg.tune(binaryTestTrees);
      // System.out.println("TUNE DEPS: "+tuneDeps);
      Timing.tick("done.");
    }

    BinaryGrammar boundBG = bg;
    UnaryGrammar boundUG = ug;

    GrammarProjection gp = new NullGrammarProjection(bg, ug);

    // serialization
    if (serializeFile != null) {
      System.err.print("Serializing parser...");
      LexicalizedParser.saveParserDataToSerialized(
          new ParserData(lex, bg, ug, dg, Numberer.getNumberers(), op), serializeFile);
      Timing.tick("done.");
    }

    // test: pcfg-parse and output

    ExhaustivePCFGParser parser = null;
    if (op.doPCFG) {
      parser = new ExhaustivePCFGParser(boundBG, boundUG, lex, op);
    }

    ExhaustiveDependencyParser dparser =
        ((op.doDep && !Test.useFastFactored) ? new ExhaustiveDependencyParser(dg, lex, op) : null);

    Scorer scorer = (op.doPCFG ? new TwinScorer(new ProjectionScorer(parser, gp), dparser) : null);
    // Scorer scorer = parser;
    BiLexPCFGParser bparser = null;
    if (op.doPCFG && op.doDep) {
      bparser =
          (Test.useN5)
              ? new BiLexPCFGParser.N5BiLexPCFGParser(
                  scorer, parser, dparser, bg, ug, dg, lex, op, gp)
              : new BiLexPCFGParser(scorer, parser, dparser, bg, ug, dg, lex, op, gp);
    }

    LabeledConstituentEval pcfgPE = new LabeledConstituentEval("pcfg  PE", true, tlp);
    LabeledConstituentEval comboPE = new LabeledConstituentEval("combo PE", true, tlp);
    AbstractEval pcfgCB = new LabeledConstituentEval.CBEval("pcfg  CB", true, tlp);

    AbstractEval pcfgTE = new AbstractEval.TaggingEval("pcfg  TE");
    AbstractEval comboTE = new AbstractEval.TaggingEval("combo TE");
    AbstractEval pcfgTEnoPunct = new AbstractEval.TaggingEval("pcfg nopunct TE");
    AbstractEval comboTEnoPunct = new AbstractEval.TaggingEval("combo nopunct TE");
    AbstractEval depTE = new AbstractEval.TaggingEval("depnd TE");

    AbstractEval depDE =
        new AbstractEval.DependencyEval("depnd DE", true, tlp.punctuationWordAcceptFilter());
    AbstractEval comboDE =
        new AbstractEval.DependencyEval("combo DE", true, tlp.punctuationWordAcceptFilter());

    if (Test.evalb) {
      EvalB.initEVALBfiles(op.tlpParams);
    }

    // int[] countByLength = new int[Test.maxLength+1];

    // use a reflection ruse, so one can run this without needing the tagger
    // edu.stanford.nlp.process.SentenceTagger tagger = (Test.preTag ? new
    // edu.stanford.nlp.process.SentenceTagger("/u/nlp/data/tagger.params/wsj0-21.holder") : null);
    SentenceProcessor tagger = null;
    if (Test.preTag) {
      try {
        Class[] argsClass = new Class[] {String.class};
        Object[] arguments =
            new Object[] {"/u/nlp/data/pos-tagger/wsj3t0-18-bidirectional/train-wsj-0-18.holder"};
        tagger =
            (SentenceProcessor)
                Class.forName("edu.stanford.nlp.tagger.maxent.MaxentTagger")
                    .getConstructor(argsClass)
                    .newInstance(arguments);
      } catch (Exception e) {
        System.err.println(e);
        System.err.println("Warning: No pretagging of sentences will be done.");
      }
    }

    for (int tNum = 0, ttSize = testTreebank.size(); tNum < ttSize; tNum++) {
      Tree tree = testTreebank.get(tNum);
      int testTreeLen = tree.yield().size();
      if (testTreeLen > Test.maxLength) {
        continue;
      }
      Tree binaryTree = binaryTestTrees.get(tNum);
      // countByLength[testTreeLen]++;
      System.out.println("-------------------------------------");
      System.out.println("Number: " + (tNum + 1));
      System.out.println("Length: " + testTreeLen);

      // tree.pennPrint(pw);
      // System.out.println("XXXX The binary tree is");
      // binaryTree.pennPrint(pw);
      // System.out.println("Here are the tags in the lexicon:");
      // System.out.println(lex.showTags());
      // System.out.println("Here's the tagnumberer:");
      // System.out.println(Numberer.getGlobalNumberer("tags").toString());

      long timeMil1 = System.currentTimeMillis();
      Timing.tick("Starting parse.");
      if (op.doPCFG) {
        // System.err.println(Test.forceTags);
        if (Test.forceTags) {
          if (tagger != null) {
            // System.out.println("Using a tagger to set tags");
            // System.out.println("Tagged sentence as: " +
            // tagger.processSentence(cutLast(wordify(binaryTree.yield()))).toString(false));
            parser.parse(addLast(tagger.processSentence(cutLast(wordify(binaryTree.yield())))));
          } else {
            // System.out.println("Forcing tags to match input.");
            parser.parse(cleanTags(binaryTree.taggedYield(), tlp));
          }
        } else {
          // System.out.println("XXXX Parsing " + binaryTree.yield());
          parser.parse(binaryTree.yield());
        }
        // Timing.tick("Done with pcfg phase.");
      }
      if (op.doDep) {
        dparser.parse(binaryTree.yield());
        // Timing.tick("Done with dependency phase.");
      }
      boolean bothPassed = false;
      if (op.doPCFG && op.doDep) {
        bothPassed = bparser.parse(binaryTree.yield());
        // Timing.tick("Done with combination phase.");
      }
      long timeMil2 = System.currentTimeMillis();
      long elapsed = timeMil2 - timeMil1;
      System.err.println("Time: " + ((int) (elapsed / 100)) / 10.00 + " sec.");
      // System.out.println("PCFG Best Parse:");
      Tree tree2b = null;
      Tree tree2 = null;
      // System.out.println("Got full best parse...");
      if (op.doPCFG) {
        tree2b = parser.getBestParse();
        tree2 = debinarizer.transformTree(tree2b);
      }
      // System.out.println("Debinarized parse...");
      // tree2.pennPrint();
      // System.out.println("DepG Best Parse:");
      Tree tree3 = null;
      Tree tree3db = null;
      if (op.doDep) {
        tree3 = dparser.getBestParse();
        // was: but wrong Tree tree3db = debinarizer.transformTree(tree2);
        tree3db = debinarizer.transformTree(tree3);
        tree3.pennPrint(pw);
      }
      // tree.pennPrint();
      // ((Tree)binaryTrainTrees.get(tNum)).pennPrint();
      // System.out.println("Combo Best Parse:");
      Tree tree4 = null;
      if (op.doPCFG && op.doDep) {
        try {
          tree4 = bparser.getBestParse();
          if (tree4 == null) {
            tree4 = tree2b;
          }
        } catch (NullPointerException e) {
          System.err.println("Blocked, using PCFG parse!");
          tree4 = tree2b;
        }
      }
      if (op.doPCFG && !bothPassed) {
        tree4 = tree2b;
      }
      // tree4.pennPrint();
      if (op.doDep) {
        depDE.evaluate(tree3, binaryTree, pw);
        depTE.evaluate(tree3db, tree, pw);
      }
      TreeTransformer tc = op.tlpParams.collinizer();
      TreeTransformer tcEvalb = op.tlpParams.collinizerEvalb();
      Tree tree4b = null;
      if (op.doPCFG) {
        // System.out.println("XXXX Best PCFG was: ");
        // tree2.pennPrint();
        // System.out.println("XXXX Transformed best PCFG is: ");
        // tc.transformTree(tree2).pennPrint();
        // System.out.println("True Best Parse:");
        // tree.pennPrint();
        // tc.transformTree(tree).pennPrint();
        pcfgPE.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
        pcfgCB.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);
        if (op.doDep) {
          comboDE.evaluate((bothPassed ? tree4 : tree3), binaryTree, pw);
          tree4b = tree4;
          tree4 = debinarizer.transformTree(tree4);
          if (op.nodePrune) {
            NodePruner np = new NodePruner(parser, debinarizer);
            tree4 = np.prune(tree4);
          }
          // tree4.pennPrint();
          comboPE.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
        }
        // pcfgTE.evaluate(tree2, tree);
        pcfgTE.evaluate(tcEvalb.transformTree(tree2), tcEvalb.transformTree(tree), pw);
        pcfgTEnoPunct.evaluate(tc.transformTree(tree2), tc.transformTree(tree), pw);

        if (op.doDep) {
          comboTE.evaluate(tcEvalb.transformTree(tree4), tcEvalb.transformTree(tree), pw);
          comboTEnoPunct.evaluate(tc.transformTree(tree4), tc.transformTree(tree), pw);
        }
        System.out.println("PCFG only: " + parser.scoreBinarizedTree(tree2b, 0));

        // tc.transformTree(tree2).pennPrint();
        tree2.pennPrint(pw);

        if (op.doDep) {
          System.out.println("Combo: " + parser.scoreBinarizedTree(tree4b, 0));
          // tc.transformTree(tree4).pennPrint(pw);
          tree4.pennPrint(pw);
        }
        System.out.println("Correct:" + parser.scoreBinarizedTree(binaryTree, 0));
        /*
        if (parser.scoreBinarizedTree(tree2b,true) < parser.scoreBinarizedTree(binaryTree,true)) {
          System.out.println("SCORE INVERSION");
          parser.validateBinarizedTree(binaryTree,0);
        }
        */
        tree.pennPrint(pw);
      } // end if doPCFG

      if (Test.evalb) {
        if (op.doPCFG && op.doDep) {
          EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree4));
        } else if (op.doPCFG) {
          EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree2));
        } else if (op.doDep) {
          EvalB.writeEVALBline(tcEvalb.transformTree(tree), tcEvalb.transformTree(tree3db));
        }
      }
    } // end for each tree in test treebank

    if (Test.evalb) {
      EvalB.closeEVALBfiles();
    }

    // Test.display();
    if (op.doPCFG) {
      pcfgPE.display(false, pw);
      System.out.println("Grammar size: " + Numberer.getGlobalNumberer("states").total());
      pcfgCB.display(false, pw);
      if (op.doDep) {
        comboPE.display(false, pw);
      }
      pcfgTE.display(false, pw);
      pcfgTEnoPunct.display(false, pw);
      if (op.doDep) {
        comboTE.display(false, pw);
        comboTEnoPunct.display(false, pw);
      }
    }
    if (op.doDep) {
      depTE.display(false, pw);
      depDE.display(false, pw);
    }
    if (op.doPCFG && op.doDep) {
      comboDE.display(false, pw);
    }
    // pcfgPE.printGoodBad();
  }
예제 #6
0
  public void stocGradTrain(Parser parser, boolean testEachRound) {

    int numUpdates = 0;

    List<LexEntry> fixedEntries = new LinkedList<LexEntry>();
    fixedEntries.addAll(parser.returnLex().getLexicon());

    // add all sentential lexical entries.
    for (int l = 0; l < trainData.size(); l++) {
      parser.addLexEntries(trainData.getDataSet(l).makeSentEntries());
    }
    parser.setGlobals();

    DataSet data = null;
    // for each pass over the data
    for (int j = 0; j < EPOCHS; j++) {
      System.out.println("Training, iteration " + j);
      int total = 0, correct = 0, wrong = 0, looCorrect = 0, looWrong = 0;
      for (int l = 0; l < trainData.size(); l++) {

        // the variables to hold the current training example
        String words = null;
        Exp sem = null;

        data = trainData.getDataSet(l);
        if (verbose) System.out.println("---------------------");
        String filename = trainData.getFilename(l);
        if (verbose) System.out.println("DataSet: " + filename);
        if (verbose) System.out.println("---------------------");

        // loop through the training examples
        // try to create lexical entries for each training example
        for (int i = 0; i < data.size(); i++) {
          // print running stats
          if (verbose) {
            if (total != 0) {
              double r = (double) correct / total;
              double p = (double) correct / (correct + wrong);
              System.out.print(i + ": =========== r:" + r + " p:" + p);
              System.out.println(" (epoch:" + j + " file:" + l + " " + filename + ")");
            } else System.out.println(i + ": ===========");
          }

          // get the training example
          words = data.sent(i);
          sem = data.sem(i);
          if (verbose) {
            System.out.println(words);
            System.out.println(sem);
          }

          List<String> tokens = Parser.tokenize(words);

          if (tokens.size() > maxSentLen) continue;
          total++;

          String mes = null;
          boolean hasCorrect = false;

          // first, get all possible lexical entries from
          // a manipulation of the best parse.
          List<LexEntry> lex = makeLexEntriesChart(words, sem, parser);

          if (verbose) {
            System.out.println("Adding:");
            for (LexEntry le : lex) {
              System.out.println(le + " : " + LexiconFeatSet.initialWeight(le));
            }
          }

          parser.addLexEntries(lex);

          if (verbose) System.out.println("Lex Size: " + parser.returnLex().size());

          // first parse to see if we are currently correct
          if (verbose) mes = "First";
          parser.parseTimed(words, null, mes);

          Chart firstChart = parser.getChart();
          Exp best = parser.bestSem();

          // this just collates and outputs the training
          // accuracy.
          if (sem.equals(best)) {
            // System.out.println(parser.bestParses().get(0));
            if (verbose) {
              System.out.println("CORRECT:" + best);
              lex = parser.getMaxLexEntriesFor(sem);
              System.out.println("Using:");
              printLex(lex);
              if (lex.size() == 0) {
                System.out.println("ERROR: empty lex");
              }
            }
            correct++;
          } else {
            if (verbose) {
              System.out.println("WRONG: " + best);
              lex = parser.getMaxLexEntriesFor(best);
              System.out.println("Using:");
              printLex(lex);
              if (best != null && lex.size() == 0) {
                System.out.println("ERROR: empty lex");
              }
            }
            wrong++;
          }

          // compute first half of parameter update:
          // subtract the expectation of parameters
          // under the distribution that is conditioned
          // on the sentence alone.
          double norm = firstChart.computeNorm();
          HashVector update = new HashVector();
          HashVector firstfeats = null, secondfeats = null;
          if (norm != 0.0) {
            firstfeats = firstChart.computeExpFeatVals();
            firstfeats.divideBy(norm);
            firstfeats.dropSmallEntries();
            firstfeats.addTimesInto(-1.0, update);
          } else continue;
          firstChart = null;

          if (verbose) mes = "Second";
          parser.parseTimed(words, sem, mes);
          hasCorrect = parser.hasParseFor(sem);

          // compute second half of parameter update:
          // add the expectation of parameters
          // under the distribution that is conditioned
          // on the sentence and correct logical form.
          if (!hasCorrect) continue;
          Chart secondChart = parser.getChart();
          double secnorm = secondChart.computeNorm(sem);
          if (norm != 0.0) {
            secondfeats = secondChart.computeExpFeatVals(sem);
            secondfeats.divideBy(secnorm);
            secondfeats.dropSmallEntries();
            secondfeats.addTimesInto(1.0, update);
            lex = parser.getMaxLexEntriesFor(sem);
            data.setBestLex(i, lex);
            if (verbose) {
              System.out.println("Best LexEntries:");
              printLex(lex);
              if (lex.size() == 0) {
                System.out.println("ERROR: empty lex");
              }
            }
          } else continue;

          // now do the update
          double scale = alpha_0 / (1.0 + c * numUpdates);
          if (verbose) System.out.println("Scale: " + scale);
          update.multiplyBy(scale);
          update.dropSmallEntries();

          numUpdates++;
          if (verbose) {
            System.out.println("Update:");
            System.out.println(update);
          }
          if (!update.isBad()) {
            if (!update.valuesInRange(-100, 100)) {
              System.out.println("WARNING: large update");
              System.out.println("first feats: " + firstfeats);
              System.out.println("second feats: " + secondfeats);
            }
            parser.updateParams(update);
          } else {
            System.out.println(
                "ERROR: Bad Update: " + update + " -- norm: " + norm + " -- feats: ");
            parser.getParams().printValues(update);
            System.out.println();
          }
        } // end for each training example
      } // end for each data set

      double r = (double) correct / total;

      // we can prune the lexical items that were not used
      // in a max scoring parse.
      if (pruneLex) {
        Lexicon cur = new Lexicon();
        cur.addLexEntries(fixedEntries);
        cur.addLexEntries(data.getBestLex());
        parser.setLexicon(cur);
      }

      if (testEachRound) {
        System.out.println("Testing");
        test(parser, false);
      }
    } // end epochs loop
  }
예제 #7
0
    public Tree<String> getBestParseOld(List<String> sentence) {
      // TODO: This implements the CKY algorithm

      CounterMap<String, String> parseScores = new CounterMap<String, String>();

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores.setCount(index + " " + (index + span), tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now
      HashMap<String, Triplet<Integer, String, String>> backHash =
          new HashMap<
              String, Triplet<Integer, String, String>>(); // hashmap to store back propation

      // System.out.println("Lexicons found");
      Boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores.getCounter(index + " " + (index + span));
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob =
                  rule.getScore() * parseScores.getCount(index + " " + (index + span), entry);
              if (prob > parseScores.getCount(index + " " + (index + span), rule.parent)) {
                parseScores.setCount(index + " " + (index + span), rule.parent, prob);
                backHash.put(
                    index + " " + (index + span) + " " + rule.parent,
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores.getCounter(begin + " " + split);
            Counter<String> countRight = parseScores.getCounter(split + " " + end);
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores.getCount(begin + " " + split, ruleRight.leftChild)
                        * parseScores.getCount(split + " " + end, ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores.getCount(begin + " " + end, ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores.setCount(begin + " " + end, ruleRight.getParent(), prob);
                  backHash.put(
                      begin + " " + end + " " + ruleRight.parent,
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores.getCounter(begin + " " + end);
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores.getCount(begin + " " + (end), entry);
                if (prob > parseScores.getCount(begin + " " + (end), rule.parent)) {
                  parseScores.setCount(begin + " " + (end), rule.parent, prob);

                  backHash.put(
                      begin + " " + (end) + " " + rule.parent,
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());
      String parent = parseScores.getCounter(0 + " " + sentence.size()).argMax();
      if (parent == null) {
        System.out.println(parseScores.getCounter(0 + " " + sentence.size()).toString());
        System.out.println("THIS IS WEIRD");
      }
      parent = "ROOT";
      parseTree = getParseTreeOld(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }
예제 #8
0
    public Tree<String> getBestParse(List<String> sentence) {
      // This implements the CKY algorithm
      int nEntries = sentence.size();

      // hashmap to store back rules
      HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>> backHash =
          new HashMap<Triplet<Integer, Integer, String>, Triplet<Integer, String, String>>();

      // more efficient access with arrays, but must cast each time :(
      @SuppressWarnings("unchecked")
      Counter<String>[][] parseScores = (Counter<String>[][]) (new Counter[nEntries][nEntries]);

      for (int i = 0; i < nEntries; i++) {
        for (int j = 0; j < nEntries; j++) {
          parseScores[i][j] = new Counter<String>();
        }
      }

      System.out.println(sentence.toString());
      // First deal with the lexicons
      int index = 0;
      int span = 1; // All spans are 1 at the lexicon level
      for (String word : sentence) {
        for (String tag : lexicon.getAllTags()) {
          double score = lexicon.scoreTagging(word, tag);
          if (score >= 0.0) { // This lexicon may generate this word
            // We use a counter map in order to store the scores for this sentence parse.
            parseScores[index][index + span - 1].setCount(tag, score);
          }
        }
        index = index + 1;
      }

      // handle unary rules now

      // System.out.println("Lexicons found");
      boolean added = true;

      while (added) {
        added = false;
        for (index = 0; index < sentence.size(); index++) {
          // For each index+ span pair, get the counter.
          Counter<String> count = parseScores[index][index + span - 1];
          PriorityQueue<String> countAsPQ = count.asPriorityQueue();
          while (countAsPQ.hasNext()) {
            String entry = countAsPQ.next();
            // System.out.println("I am fine here!!");
            List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
            for (UnaryRule rule : unaryRules) {
              // These are the unary rules which might give rise to the above preterminal
              double prob = rule.getScore() * parseScores[index][index + span - 1].getCount(entry);
              if (prob > parseScores[index][index + span - 1].getCount(rule.parent)) {
                parseScores[index][index + span - 1].setCount(rule.parent, prob);
                backHash.put(
                    new Triplet<Integer, Integer, String>(index, index + span, rule.parent),
                    new Triplet<Integer, String, String>(-1, entry, null));
                added = true;
              }
            }
          }
        }
      }
      // System.out.println("Lexicon unaries dealt with");

      // Now work with the grammar to produce higher level probabilities
      for (span = 2; span <= sentence.size(); span++) {
        for (int begin = 0; begin <= (sentence.size() - span); begin++) {
          int end = begin + span;
          for (int split = begin + 1; split <= end - 1; split++) {
            Counter<String> countLeft = parseScores[begin][split - 1];
            Counter<String> countRight = parseScores[split][end - 1];
            // List<BinaryRule> leftRules= new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> leftMap = new HashMap<Integer, BinaryRule>();
            // List<BinaryRule> rightRules=new ArrayList<BinaryRule>();
            HashMap<Integer, BinaryRule> rightMap = new HashMap<Integer, BinaryRule>();

            for (String entry : countLeft.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByLeftChild(entry)) {
                if (!leftMap.containsKey(rule.hashCode())) {
                  leftMap.put(rule.hashCode(), rule);
                }
              }
            }

            for (String entry : countRight.keySet()) {
              for (BinaryRule rule : grammar.getBinaryRulesByRightChild(entry)) {
                if (!rightMap.containsKey(rule.hashCode())) {
                  rightMap.put(rule.hashCode(), rule);
                }
              }
            }

            // System.out.println("About to enter the rules loops");
            for (Integer ruleHash : leftMap.keySet()) {
              if (rightMap.containsKey(ruleHash)) {
                BinaryRule ruleRight = rightMap.get(ruleHash);
                double prob =
                    ruleRight.getScore()
                        * parseScores[begin][split - 1].getCount(ruleRight.leftChild)
                        * parseScores[split][end - 1].getCount(ruleRight.rightChild);
                // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                if (prob > parseScores[begin][end - 1].getCount(ruleRight.parent)) {
                  // System.out.println(begin+" "+ end +" "+ ruleRight.parent+ " "+ prob);
                  // System.out.println("parentrule :"+ ruleRight.getParent());
                  parseScores[begin][end - 1].setCount(ruleRight.getParent(), prob);
                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, ruleRight.parent),
                      new Triplet<Integer, String, String>(
                          split, ruleRight.leftChild, ruleRight.rightChild));
                }
              }
            }

            // System.out.println("Exited rules loop");

          }
          // System.out.println("Grammar found for " + begin + " "+ end);
          // Now handle unary rules
          added = true;
          while (added) {
            added = false;
            Counter<String> count = parseScores[begin][end - 1];
            PriorityQueue<String> countAsPriorityQueue = count.asPriorityQueue();
            while (countAsPriorityQueue.hasNext()) {
              String entry = countAsPriorityQueue.next();
              List<UnaryRule> unaryRules = grammar.getUnaryRulesByChild(entry);
              for (UnaryRule rule : unaryRules) {
                double prob = rule.getScore() * parseScores[begin][end - 1].getCount(entry);
                if (prob > parseScores[begin][end - 1].getCount(rule.parent)) {
                  parseScores[begin][end - 1].setCount(rule.parent, prob);

                  backHash.put(
                      new Triplet<Integer, Integer, String>(begin, end, rule.parent),
                      new Triplet<Integer, String, String>(-1, entry, null));
                  added = true;
                }
              }
            }
          }

          // System.out.println("Unaries dealt for " + begin + " "+ end);

        }
      }

      // Create and return the parse tree
      Tree<String> parseTree = new Tree<String>("null");
      // System.out.println(parseScores.getCounter(0+" "+sentence.size()).toString());

      // Pick the argmax
      String parent = parseScores[0][nEntries - 1].argMax();

      // Or pick root. This second one is preferred since sentences are meant to have ROOT as their
      // root node.
      parent = "ROOT";
      parseTree = getParseTree(sentence, backHash, 0, sentence.size(), parent);
      // System.out.println("PARSE SCORES");
      //	System.out.println(parseScores.toString());
      // System.out.println("BACK HASH");
      // System.out.println(backHash.toString());
      //	parseTree = addRoot(parseTree);
      // System.out.println(parseTree.toString());
      // return parseTree;
      return TreeAnnotations.unAnnotateTree(parseTree);
    }