Exemplo n.º 1
0
  /**
   * @param previousGrammar
   * @param previousLexicon
   * @param grammar
   * @param lexicon
   * @param trainStateSetTrees
   * @return
   */
  public static double doOneEStep(
      Grammar previousGrammar,
      Lexicon previousLexicon,
      Grammar grammar,
      Lexicon lexicon,
      StateSetTreeList trainStateSetTrees,
      boolean updateOnlyLexicon,
      int unkThreshold) {
    boolean secondHalf = false;
    ArrayParser parser = new ArrayParser(previousGrammar, previousLexicon);
    double trainingLikelihood = 0;
    int n = 0;
    int nTrees = trainStateSetTrees.size();
    for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
      secondHalf = (n++ > nTrees / 2.0);
      boolean noSmoothing = true, debugOutput = false;
      parser.doInsideOutsideScores(stateSetTree, noSmoothing, debugOutput); // E Step
      double ll = stateSetTree.getLabel().getIScore(0);
      ll =
          Math.log(ll)
              + (100 * stateSetTree.getLabel().getIScale()); // System.out.println(stateSetTree);
      if ((Double.isInfinite(ll) || Double.isNaN(ll))) {
        if (VERBOSE) {
          System.out.println("Training sentence " + n + " is given " + ll + " log likelihood!");
          System.out.println(
              "Root iScore "
                  + stateSetTree.getLabel().getIScore(0)
                  + " scale "
                  + stateSetTree.getLabel().getIScale());
        }
      } else {
        lexicon.trainTree(stateSetTree, -1, previousLexicon, secondHalf, noSmoothing, unkThreshold);
        if (!updateOnlyLexicon) grammar.tallyStateSetTree(stateSetTree, previousGrammar); // E Step
        trainingLikelihood += ll; // there are for some reason some sentences that are unparsable
      }
    }
    lexicon.tieRareWordStats(unkThreshold);

    // SSIE
    ((SophisticatedLexicon) lexicon).overwriteWithMaxent();

    return trainingLikelihood;
  }
Exemplo n.º 2
0
  public static void main(String[] args) {
    OptionParser optParser = new OptionParser(Options.class);
    Options opts = (Options) optParser.parse(args, true);
    // provide feedback on command-line arguments
    System.out.println("Calling with " + optParser.getPassedInOptions());

    String path = opts.path;
    //    int lang = opts.lang;
    System.out.println("Loading trees from " + path + " and using language " + opts.treebank);

    double trainingFractionToKeep = opts.trainingFractionToKeep;

    int maxSentenceLength = opts.maxSentenceLength;
    System.out.println("Will remove sentences with more than " + maxSentenceLength + " words.");

    HORIZONTAL_MARKOVIZATION = opts.horizontalMarkovization;
    VERTICAL_MARKOVIZATION = opts.verticalMarkovization;
    System.out.println(
        "Using horizontal="
            + HORIZONTAL_MARKOVIZATION
            + " and vertical="
            + VERTICAL_MARKOVIZATION
            + " markovization.");

    Binarization binarization = opts.binarization;
    System.out.println(
        "Using " + binarization.name() + " binarization."); // and "+annotateString+".");

    double randomness = opts.randomization;
    System.out.println("Using a randomness value of " + randomness);

    String outFileName = opts.outFileName;
    if (outFileName == null) {
      System.out.println("Output File name is required.");
      System.exit(-1);
    } else System.out.println("Using grammar output file " + outFileName + ".");

    VERBOSE = opts.verbose;
    RANDOM = new Random(opts.randSeed);
    System.out.println("Random number generator seeded at " + opts.randSeed + ".");

    boolean manualAnnotation = false;
    boolean baseline = opts.baseline;
    boolean noSplit = opts.noSplit;
    int numSplitTimes = opts.numSplits;
    if (baseline) numSplitTimes = 0;
    String splitGrammarFile = opts.inFile;
    int allowedDroppingIters = opts.di;

    int maxIterations = opts.splitMaxIterations;
    int minIterations = opts.splitMinIterations;
    if (minIterations > 0)
      System.out.println("I will do at least " + minIterations + " iterations.");

    double[] smoothParams = {opts.smoothingParameter1, opts.smoothingParameter2};
    System.out.println("Using smoothing parameters " + smoothParams[0] + " and " + smoothParams[1]);

    boolean allowMoreSubstatesThanCounts = false;
    boolean findClosedUnaryPaths = opts.findClosedUnaryPaths;

    Corpus corpus =
        new Corpus(
            path,
            opts.treebank,
            trainingFractionToKeep,
            false,
            opts.skipSection,
            opts.skipBilingual);
    List<Tree<String>> trainTrees =
        Corpus.binarizeAndFilterTrees(
            corpus.getTrainTrees(),
            VERTICAL_MARKOVIZATION,
            HORIZONTAL_MARKOVIZATION,
            maxSentenceLength,
            binarization,
            manualAnnotation,
            VERBOSE);
    List<Tree<String>> validationTrees =
        Corpus.binarizeAndFilterTrees(
            corpus.getValidationTrees(),
            VERTICAL_MARKOVIZATION,
            HORIZONTAL_MARKOVIZATION,
            maxSentenceLength,
            binarization,
            manualAnnotation,
            VERBOSE);
    Numberer tagNumberer = Numberer.getGlobalNumberer("tags");

    //		for (Tree<String> t : trainTrees){
    //				System.out.println(t);
    //			}

    if (opts.trainOnDevSet) {
      System.out.println("Adding devSet to training data.");
      trainTrees.addAll(validationTrees);
    }

    if (opts.lowercase) {
      System.out.println("Lowercasing the treebank.");
      Corpus.lowercaseWords(trainTrees);
      Corpus.lowercaseWords(validationTrees);
    }

    int nTrees = trainTrees.size();

    System.out.println("There are " + nTrees + " trees in the training set.");

    double filter = opts.filter;
    if (filter > 0)
      System.out.println(
          "Will remove rules with prob under "
              + filter
              + ".\nEven though only unlikely rules are pruned the training LL is not guaranteed to increase in every round anymore "
              + "(especially when we are close to converging)."
              + "\nFurthermore it increases the variance because 'good' rules can be pruned away in early stages.");

    short nSubstates = opts.nSubStates;
    short[] numSubStatesArray =
        initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates);
    if (baseline) {
      short one = 1;
      Arrays.fill(numSubStatesArray, one);
      System.out.println("Training just the baseline grammar (1 substate for all states)");
      randomness = 0.0f;
    }

    if (VERBOSE) {
      for (int i = 0; i < numSubStatesArray.length; i++) {
        System.out.println("Tag " + (String) tagNumberer.object(i) + " " + i);
      }
    }

    System.out.println("There are " + numSubStatesArray.length + " observed categories.");

    // initialize lexicon and grammar
    Lexicon lexicon = null, maxLexicon = null, previousLexicon = null;
    Grammar grammar = null, maxGrammar = null, previousGrammar = null;
    double maxLikelihood = Double.NEGATIVE_INFINITY;

    //    String smootherStr = opts.smooth;
    //    Smoother lexiconSmoother = null;
    //    Smoother grammarSmoother = null;
    //    if (splitGrammarFile!=null){
    //    	lexiconSmoother = maxLexicon.smoother;
    //    	grammarSmoother = maxGrammar.smoother;
    //    	System.out.println("Using smoother from input grammar.");
    //    }
    //    else if (smootherStr.equals("NoSmoothing"))
    //    	lexiconSmoother = grammarSmoother = new NoSmoothing();
    //    else if (smootherStr.equals("SmoothAcrossParentBits")) {
    //    	lexiconSmoother = grammarSmoother  = new SmoothAcrossParentBits(grammarSmoothing,
    // maxGrammar.splitTrees);
    //    }
    //    else
    //    	throw new Error("I didn't understand the type of smoother '"+smootherStr+"'");
    //    System.out.println("Using smoother "+smootherStr);

    // EM: iterate until the validation likelihood drops for four consecutive
    // iterations
    int iter = 0;
    int droppingIter = 0;

    //  If we are splitting, we load the old grammar and start off by splitting.
    int startSplit = 0;
    if (splitGrammarFile != null) {
      System.out.println("Loading old grammar from " + splitGrammarFile);
      startSplit = 1; // we've already trained the grammar
      ParserData pData = ParserData.Load(splitGrammarFile);
      maxGrammar = pData.gr;
      maxLexicon = pData.lex;
      numSubStatesArray = maxGrammar.numSubStates;
      previousGrammar = grammar = maxGrammar;
      previousLexicon = lexicon = maxLexicon;
      Numberer.setNumberers(pData.getNumbs());
      tagNumberer = Numberer.getGlobalNumberer("tags");
      System.out.println("Loading old grammar complete.");
      if (noSplit) {
        System.out.println("Will NOT split the loaded grammar.");
        startSplit = 0;
      }
    }

    double mergingPercentage = opts.mergingPercentage;
    boolean separateMergingThreshold = opts.separateMergingThreshold;
    if (mergingPercentage > 0) {
      System.out.println(
          "Will merge " + (int) (mergingPercentage * 100) + "% of the splits in each round.");
      System.out.println(
          "The threshold for merging lexical and phrasal categories will be set separately: "
              + separateMergingThreshold);
    }

    StateSetTreeList trainStateSetTrees =
        new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
    StateSetTreeList validationStateSetTrees =
        new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer); // deletePC);

    // get rid of the old trees
    trainTrees = null;
    validationTrees = null;
    corpus = null;
    System.gc();

    if (opts.simpleLexicon) {
      System.out.println(
          "Replacing words which have been seen less than 5 times with their signature.");
      Corpus.replaceRareWords(
          trainStateSetTrees, new SimpleLexicon(numSubStatesArray, -1), opts.rare);
    }

    // If we're training without loading a split grammar, then we run once without splitting.
    if (splitGrammarFile == null) {
      grammar =
          new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter);
      Lexicon tmp_lexicon =
          (opts.simpleLexicon)
              ? new SimpleLexicon(
                  numSubStatesArray,
                  -1,
                  smoothParams,
                  new NoSmoothing(),
                  filter,
                  trainStateSetTrees)
              : new SophisticatedLexicon(
                  numSubStatesArray,
                  SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                  smoothParams,
                  new NoSmoothing(),
                  filter);
      int n = 0;
      boolean secondHalf = false;
      for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
        secondHalf = (n++ > nTrees / 2.0);
        tmp_lexicon.trainTree(stateSetTree, randomness, null, secondHalf, false, opts.rare);
      }
      lexicon =
          (opts.simpleLexicon)
              ? new SimpleLexicon(
                  numSubStatesArray,
                  -1,
                  smoothParams,
                  new NoSmoothing(),
                  filter,
                  trainStateSetTrees)
              : new SophisticatedLexicon(
                  numSubStatesArray,
                  SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                  smoothParams,
                  new NoSmoothing(),
                  filter);
      for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
        secondHalf = (n++ > nTrees / 2.0);
        lexicon.trainTree(stateSetTree, randomness, tmp_lexicon, secondHalf, false, opts.rare);
        grammar.tallyUninitializedStateSetTree(stateSetTree);
      }
      lexicon.tieRareWordStats(opts.rare);
      lexicon.optimize();

      // SSIE
      ((SophisticatedLexicon) lexicon).overwriteWithMaxent();

      grammar.optimize(randomness);
      // System.out.println(grammar);
      previousGrammar = maxGrammar = grammar; // needed for baseline - when there is no EM loop
      previousLexicon = maxLexicon = lexicon;
    }

    // the main loop: split and train the grammar
    for (int splitIndex = startSplit; splitIndex < numSplitTimes * 3; splitIndex++) {

      // now do either a merge or a split and the end a smooth
      // on odd iterations merge, on even iterations split
      String opString = "";
      if (splitIndex % 3 == 2) { // (splitIndex==numSplitTimes*2){
        if (opts.smooth.equals("NoSmoothing")) continue;
        System.out.println("Setting smoother for grammar and lexicon.");
        Smoother grSmoother = new SmoothAcrossParentBits(0.01, maxGrammar.splitTrees);
        Smoother lexSmoother = new SmoothAcrossParentBits(0.1, maxGrammar.splitTrees);
        //        Smoother grSmoother = new SmoothAcrossParentSubstate(0.01);
        //        Smoother lexSmoother = new SmoothAcrossParentSubstate(0.1);
        maxGrammar.setSmoother(grSmoother);
        maxLexicon.setSmoother(lexSmoother);
        minIterations = maxIterations = opts.smoothMaxIterations;
        opString = "smoothing";
      } else if (splitIndex % 3 == 0) {
        // the case where we split
        if (opts.noSplit) continue;
        System.out.println(
            "Before splitting, we have a total of " + maxGrammar.totalSubStates() + " substates.");
        CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer, trainStateSetTrees);
        int[] counts = corpusStatistics.getSymbolCounts();

        maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts, 0);
        maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts, 0);
        Smoother grSmoother = new NoSmoothing();
        Smoother lexSmoother = new NoSmoothing();
        maxGrammar.setSmoother(grSmoother);
        maxLexicon.setSmoother(lexSmoother);
        System.out.println(
            "After splitting, we have a total of " + maxGrammar.totalSubStates() + " substates.");
        System.out.println(
            "Rule probabilities are NOT normalized in the split, therefore the training LL is not guaranteed to improve between iteration 0 and 1!");
        opString = "splitting";
        maxIterations = opts.splitMaxIterations;
        minIterations = opts.splitMinIterations;
      } else {
        if (mergingPercentage == 0) continue;
        // the case where we merge
        double[][] mergeWeights =
            GrammarMerger.computeMergeWeights(maxGrammar, maxLexicon, trainStateSetTrees);
        double[][][] deltas =
            GrammarMerger.computeDeltas(maxGrammar, maxLexicon, mergeWeights, trainStateSetTrees);
        boolean[][][] mergeThesePairs =
            GrammarMerger.determineMergePairs(
                deltas, separateMergingThreshold, mergingPercentage, maxGrammar);

        grammar = GrammarMerger.doTheMerges(maxGrammar, maxLexicon, mergeThesePairs, mergeWeights);
        short[] newNumSubStatesArray = grammar.numSubStates;
        trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, newNumSubStatesArray, false);
        validationStateSetTrees =
            new StateSetTreeList(validationStateSetTrees, newNumSubStatesArray, false);

        // retrain lexicon to finish the lexicon merge (updates the unknown words model)...
        lexicon =
            (opts.simpleLexicon)
                ? new SimpleLexicon(
                    newNumSubStatesArray,
                    -1,
                    smoothParams,
                    maxLexicon.getSmoother(),
                    filter,
                    trainStateSetTrees)
                : new SophisticatedLexicon(
                    newNumSubStatesArray,
                    SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                    maxLexicon.getSmoothingParams(),
                    maxLexicon.getSmoother(),
                    maxLexicon.getPruningThreshold());
        boolean updateOnlyLexicon = true;
        double trainingLikelihood =
            GrammarTrainer.doOneEStep(
                grammar,
                maxLexicon,
                null,
                lexicon,
                trainStateSetTrees,
                updateOnlyLexicon,
                opts.rare);
        //    		System.out.println("The training LL is "+trainingLikelihood);
        lexicon
            .optimize(); // Grammar.RandomInitializationType.INITIALIZE_WITH_SMALL_RANDOMIZATION);
                         // // M Step

        GrammarMerger.printMergingStatistics(maxGrammar, grammar);
        opString = "merging";
        maxGrammar = grammar;
        maxLexicon = lexicon;
        maxIterations = opts.mergeMaxIterations;
        minIterations = opts.mergeMinIterations;
      }
      // update the substate dependent objects
      previousGrammar = grammar = maxGrammar;
      previousLexicon = lexicon = maxLexicon;
      droppingIter = 0;
      numSubStatesArray = grammar.numSubStates;
      trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, numSubStatesArray, false);
      validationStateSetTrees =
          new StateSetTreeList(validationStateSetTrees, numSubStatesArray, false);
      maxLikelihood = calculateLogLikelihood(maxGrammar, maxLexicon, validationStateSetTrees);
      System.out.println(
          "After "
              + opString
              + " in the "
              + (splitIndex / 3 + 1)
              + "th round, we get a validation likelihood of "
              + maxLikelihood);
      iter = 0;

      // the inner loop: train the grammar via EM until validation likelihood reliably drops
      do {
        iter += 1;
        System.out.println("Beginning iteration " + (iter - 1) + ":");

        // 1) Compute the validation likelihood of the previous iteration
        System.out.print("Calculating validation likelihood...");
        double validationLikelihood =
            calculateLogLikelihood(
                previousGrammar,
                previousLexicon,
                validationStateSetTrees); // The validation LL of previousGrammar/previousLexicon
        System.out.println("done: " + validationLikelihood);

        // 2) Perform the E step while computing the training likelihood of the previous iteration
        System.out.print("Calculating training likelihood...");
        grammar =
            new Grammar(
                grammar.numSubStates,
                grammar.findClosedPaths,
                grammar.smoother,
                grammar,
                grammar.threshold);
        lexicon =
            (opts.simpleLexicon)
                ? new SimpleLexicon(
                    grammar.numSubStates,
                    -1,
                    smoothParams,
                    lexicon.getSmoother(),
                    filter,
                    trainStateSetTrees)
                : new SophisticatedLexicon(
                    grammar.numSubStates,
                    SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF,
                    lexicon.getSmoothingParams(),
                    lexicon.getSmoother(),
                    lexicon.getPruningThreshold());
        boolean updateOnlyLexicon = false;
        double trainingLikelihood =
            doOneEStep(
                previousGrammar,
                previousLexicon,
                grammar,
                lexicon,
                trainStateSetTrees,
                updateOnlyLexicon,
                opts.rare); // The training LL of previousGrammar/previousLexicon
        System.out.println("done: " + trainingLikelihood);

        // 3) Perform the M-Step
        lexicon.optimize(); // M Step
        grammar.optimize(0); // M Step

        // 4) Check whether previousGrammar/previousLexicon was in fact better than the best
        if (iter < minIterations || validationLikelihood >= maxLikelihood) {
          maxLikelihood = validationLikelihood;
          maxGrammar = previousGrammar;
          maxLexicon = previousLexicon;
          droppingIter = 0;
        } else {
          droppingIter++;
        }

        // 5) advance the 'pointers'
        previousGrammar = grammar;
        previousLexicon = lexicon;
      } while ((droppingIter < allowedDroppingIters) && (!baseline) && (iter < maxIterations));

      // Dump a grammar file to disk from time to time
      ParserData pData =
          new ParserData(
              maxLexicon,
              maxGrammar,
              null,
              Numberer.getNumberers(),
              numSubStatesArray,
              VERTICAL_MARKOVIZATION,
              HORIZONTAL_MARKOVIZATION,
              binarization);
      String outTmpName = outFileName + "_" + (splitIndex / 3 + 1) + "_" + opString + ".gr";
      System.out.println("Saving grammar to " + outTmpName + ".");
      if (pData.Save(outTmpName)) System.out.println("Saving successful.");
      else System.out.println("Saving failed!");
      pData = null;
    }

    // The last grammar/lexicon has not yet been evaluated. Even though the validation likelihood
    // has been dropping in the past few iteration, there is still a chance that the last one was in
    // fact the best so just in case we evaluate it.
    System.out.print("Calculating last validation likelihood...");
    double validationLikelihood = calculateLogLikelihood(grammar, lexicon, validationStateSetTrees);
    System.out.println(
        "done.\n  Iteration "
            + iter
            + " (final) gives validation likelihood "
            + validationLikelihood);
    if (validationLikelihood > maxLikelihood) {
      maxLikelihood = validationLikelihood;
      maxGrammar = previousGrammar;
      maxLexicon = previousLexicon;
    }

    ParserData pData =
        new ParserData(
            maxLexicon,
            maxGrammar,
            null,
            Numberer.getNumberers(),
            numSubStatesArray,
            VERTICAL_MARKOVIZATION,
            HORIZONTAL_MARKOVIZATION,
            binarization);

    System.out.println("Saving grammar to " + outFileName + ".");
    System.out.println("It gives a validation data log likelihood of: " + maxLikelihood);
    if (pData.Save(outFileName)) System.out.println("Saving successful.");
    else System.out.println("Saving failed!");

    System.exit(0);
  }