Exemplo n.º 1
0
  public static void main(String[] args) throws Exception {
    // Parse command line flags and arguments
    Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Set up default parameters and settings
    String basePath = ".";
    boolean verbose = false;

    // Update defaults using command line specifications

    // The path to the assignment data
    if (argMap.containsKey("-path")) {
      basePath = argMap.get("-path");
    }
    System.out.println("Using base path: " + basePath);

    // Whether or not to print the individual errors.
    if (argMap.containsKey("-verbose")) {
      verbose = true;
    }

    // Read in data
    System.out.print("Loading training sentences...");
    List<TaggedSentence> trainTaggedSentences =
        readTaggedSentences(basePath + "/en-wsj-train.pos", true);
    Set<String> trainingVocabulary = extractVocabulary(trainTaggedSentences);
    System.out.println("done.");
    System.out.print("Loading in-domain dev sentences...");
    List<TaggedSentence> devInTaggedSentences =
        readTaggedSentences(basePath + "/en-wsj-dev.pos", true);
    System.out.println("done.");
    System.out.print("Loading out-of-domain dev sentences...");
    List<TaggedSentence> devOutTaggedSentences =
        readTaggedSentences(basePath + "/en-web-weblogs-dev.pos", true);
    System.out.println("done.");
    System.out.print("Loading out-of-domain blind test sentences...");
    List<TaggedSentence> testSentences =
        readTaggedSentences(basePath + "/en-web-test.blind", false);
    System.out.println("done.");

    // Construct tagger components
    // TODO : improve on the MostFrequentTagScorer
    LocalTrigramScorer localTrigramScorer = new MostFrequentTagScorer(false);
    // TODO : improve on the GreedyDecoder
    TrellisDecoder<State> trellisDecoder = new GreedyDecoder<State>();

    // Train tagger
    POSTagger posTagger = new POSTagger(localTrigramScorer, trellisDecoder);
    posTagger.train(trainTaggedSentences);

    // Optionally tune hyperparameters on dev data
    posTagger.validate(devInTaggedSentences);

    // Test tagger
    System.out.println("Evaluating on in-domain data:.");
    evaluateTagger(posTagger, devInTaggedSentences, trainingVocabulary, verbose);
    System.out.println("Evaluating on out-of-domain data:.");
    evaluateTagger(posTagger, devOutTaggedSentences, trainingVocabulary, verbose);
    labelTestSet(posTagger, testSentences, basePath + "/en-web-test.tagged");
  }
  public static void main(String[] args) throws IOException {
    // Parse command line flags and arguments
    final Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Set up default parameters and settings
    String basePath = ".";
    String model = "baseline";
    boolean verbose = false;

    // Update defaults using command line specifications

    // The path to the assignment data
    if (argMap.containsKey("-path")) {
      basePath = argMap.get("-path");
    }
    System.out.println("Using base path: " + basePath);

    // A string descriptor of the model to use
    if (argMap.containsKey("-model")) {
      model = argMap.get("-model");
    }
    System.out.println("Using model: " + model);

    // Whether or not to print the individual speech errors.
    if (argMap.containsKey("-verbose")) {
      verbose = true;
    }
    if (argMap.containsKey("-quiet")) {
      verbose = false;
    }

    // Read in all the assignment data
    final String trainingSentencesFile = "/treebank-sentences-spoken-train.txt";
    final String speechNBestListsPath = "/wsj_n_bst";
    final Collection<List<String>> trainingSentenceCollection =
        SentenceCollection.Reader.readSentenceCollection(basePath + trainingSentencesFile);
    final Set<String> trainingVocabulary = extractVocabulary(trainingSentenceCollection);
    final List<SpeechNBestList> speechNBestLists =
        SpeechNBestList.Reader.readSpeechNBestLists(
            basePath + speechNBestListsPath, trainingVocabulary);

    // String validationSentencesFile =
    // "/treebank-sentences-spoken-validate.txt";
    // Collection<List<String>> validationSentenceCollection =
    // SentenceCollection.Reader.readSentenceCollection(basePath +
    // validationSentencesFile);

    // String testSentencesFile = "/treebank-sentences-spoken-test.txt";
    // Collection<List<String>> testSentenceCollection =
    // SentenceCollection.Reader.readSentenceCollection(basePath +
    // testSentencesFile);

    // Build the language model
    LanguageModel languageModel = null;
    if (model.equalsIgnoreCase("baseline")) {
      languageModel = new EmpiricalUnigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("sri")) {
      languageModel = new SriLanguageModel(argMap.get("-sri"));
    } else if (model.equalsIgnoreCase("bigram")) {
      languageModel = new EmpiricalBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("trigram")) {
      languageModel = new EmpiricalTrigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("katz-bigram")) {
      languageModel = new KatzBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("katz-bigram-pp")) {
      languageModel = new KatzPPBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("katz-trigram")) {
      throw new IllegalStateException(
          "Katz trigram model not fully implemented -- remove exception and uncomment next line if implemented");
      // languageModel = new KatzTrigramLanguageModel(
      // trainingSentenceCollection);
    } else {
      throw new RuntimeException("Unknown model descriptor: " + model);
    }

    // Evaluate the language model
    // final double wsjPerplexity = calculatePerplexity(languageModel,
    // testSentenceCollection);
    final double hubPerplexity =
        calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists));
    // System.out.println("WSJ Perplexity: " + wsjPerplexity);
    System.out.println("HUB Perplexity:  " + hubPerplexity);
    System.out.println("WER Baselines:");
    System.out.println("  Best Path:  " + calculateWordErrorRateLowerBound(speechNBestLists));
    System.out.println("  Worst Path: " + calculateWordErrorRateUpperBound(speechNBestLists));
    System.out.println("  Avg Path:   " + calculateWordErrorRateRandomChoice(speechNBestLists));
    final double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose);
    System.out.println("HUB Word Error Rate: " + wordErrorRate);
    System.out.println("Generated Sentences:");
    for (int i = 0; i < 10; i++) System.out.println(" " + languageModel.generateSentence());
  }
  public static void main(String[] args) throws IOException {
    // Parse command line flags and arguments
    Map<String,String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Set up default parameters and settings
    String basePath = ".";
    int maxTrainingSentences = 0;
    int maxIterations = 20;
    boolean verbose = false;
    boolean initialize = false;
    String dataset = "mini";
    String model = "baseline";

    // Update defaults using command line specifications
    if (argMap.containsKey("-path")) {
      basePath = argMap.get("-path");
      System.out.println("Using base path: "+basePath);
    }
    if (argMap.containsKey("-sentences")) {
      maxTrainingSentences = Integer.parseInt(argMap.get("-sentences"));
      System.out.println("Using an additional "+maxTrainingSentences+" training sentences.");
    }
    if (argMap.containsKey("-data")) {
      dataset = argMap.get("-data");
      System.out.println("Running with data: "+dataset);
    } else {
      System.out.println("No data set specified.  Use -data [miniTest, validate].");
    }
    if (argMap.containsKey("-model")) {
      model = argMap.get("-model");
      System.out.println("Running with model: "+model);
    } else {
      System.out.println("No model specified.  Use -model modelname.");
    }
    if (argMap.containsKey("-verbose")) {
      verbose = true;
    }
    if (argMap.containsKey("-iterations")) {
    	maxIterations = Integer.parseInt(argMap.get("-iterations"));
    }
    if (argMap.containsKey("-initialize")) {
    	initialize = true;
    }

    // Read appropriate training and testing sets.
    List<SentencePair> trainingSentencePairs = new ArrayList<SentencePair>();
    if (! (dataset.equals("miniTest") || dataset.equals("mini")) && maxTrainingSentences > 0)
      trainingSentencePairs = readSentencePairs(basePath+"/training", maxTrainingSentences);
    List<SentencePair> testSentencePairs = new ArrayList<SentencePair>();
    Map<Integer,Alignment> testAlignments = new HashMap<Integer, Alignment>();
    if (dataset.equalsIgnoreCase("validate")) {
      testSentencePairs = readSentencePairs(basePath+"/trial", Integer.MAX_VALUE);
      testAlignments = readAlignments(basePath+"/trial/trial.wa");
    } else if (dataset.equals("miniTest") || dataset.equals("mini")) {
      testSentencePairs = readSentencePairs(basePath+"/mini", Integer.MAX_VALUE);
      testAlignments = readAlignments(basePath+"/mini/mini.wa");
    } else {
      throw new RuntimeException("Bad data set mode: "+ dataset+", use validate or miniTest.");
    }
    trainingSentencePairs.addAll(testSentencePairs);

    // Build model
    WordAligner wordAligner = null;
    if (model.equalsIgnoreCase("baseline")) {
      wordAligner = new BaselineWordAligner();
    }
    // TODO : build other alignment models
    else if (model.equalsIgnoreCase("heuristic")) {
    	wordAligner = new HeuristicWordAligner(trainingSentencePairs);
    }
    else if (model.equalsIgnoreCase("dice")) {
    	wordAligner = new DiceWordAligner(trainingSentencePairs);
    }
    else if (model.equalsIgnoreCase("ibm1") || model.equalsIgnoreCase("ibmModel1")) {
    	wordAligner = new IBMmodel1WordAligner(trainingSentencePairs, maxIterations, initialize);
    }
    else if (model.equalsIgnoreCase("ibm2") || model.equalsIgnoreCase("ibmModel2")) {
    	wordAligner = new IBMmodel2WordAligner(trainingSentencePairs, maxIterations, initialize);
    }

    // Test model
    test(wordAligner, testSentencePairs, testAlignments, verbose);
    
    // Generate file for submission //can comment out if not ready for submission
    testSentencePairs = readSentencePairs(basePath+"/test", Integer.MAX_VALUE);
    predict(wordAligner, testSentencePairs, basePath+"/"+model+".out");
  }
Exemplo n.º 4
0
  public static void main(String[] args) throws Exception {
    // Parse command line flags and arguments.
    Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Read commandline parameters.
    String embeddingPath = "";
    if (!argMap.containsKey("-embeddings")) {
      System.out.println("-embeddings flag required.");
      System.exit(0);
    } else {
      embeddingPath = argMap.get("-embeddings");
    }

    String wordSimPath = "";
    if (!argMap.containsKey("-wordsim")) {
      System.out.println("-wordsim flag required.");
      System.exit(0);
    } else {
      wordSimPath = argMap.get("-wordsim");
    }

    // Read in the labeled similarities and generate the target vocabulary.
    System.out.println("Loading wordsim353 ...");
    List<Pair<Pair<String, String>, Float>> wordSimPairs = readWordSimPairs(wordSimPath);
    Set<String> targetVocab = getWordSimVocab(wordSimPath);

    // It is likely that you will want to generate your embeddings
    // elsewhere. But this supports the option to generate the embeddings
    // and evaluate them in a single loop.
    HashMap<String, float[]> embeddings;
    if (argMap.containsKey("-trainandeval")) {
      // Get some training data.
      String dataPath = "";
      if (!argMap.containsKey("-trainingdata")) {
        System.out.println("-trainingdata flag required with -trainandeval");
        System.exit(0);
      } else {
        dataPath = argMap.get("-trainingdata");
      }

      // Since this simple approach does not do dimensionality reduction
      // on the co-occurrence vectors, we instead control the size of the
      // vectors by only counting co-occurrence with core WordNet senses.
      String wordNetPath = "";
      if (!argMap.containsKey("-wordnetdata")) {
        System.out.println("-wordnetdata flag required with -trainandeval");
        System.exit(0);
      } else {
        wordNetPath = argMap.get("-wordnetdata");
      }
      // HashMap<String, Integer> contentWordVocab = getWordNetVocab(wordNetPath);

      System.out.println("Training embeddings on " + dataPath + " ...");
      // embeddings = getEmbeddings(dataPath, contentWordVocab, targetVocab);
      int kSamples = 5;
      int dimensions = 100;
      int contextSize = 2;

      WordSim skipgram = new WordSim(dataPath, kSamples, dimensions, contextSize);
      embeddings = skipgram.getEmbeddings(targetVocab);

      // Keep only the words that are needed.
      System.out.println("Writing embeddings to " + embeddingPath + " ...");
      // embeddings = reduceVocabulary(embeddings, targetVocab);
      // writeEmbeddings(embeddings, embeddingPath, contentVocab.size());
      writeEmbeddings(embeddings, embeddingPath, dimensions);
    } else {
      // Read in embeddings.
      System.out.println("Loading embeddings ...");
      embeddings = readEmbeddings(embeddingPath);

      // Keep only the words that are needed.
      System.out.println(
          "Writing reduced vocabulary embeddings to " + embeddingPath + ".reduced ...");
      embeddings = reduceVocabulary(embeddings, targetVocab);
      writeEmbeddings(
          embeddings, embeddingPath + ".reduced", embeddings.values().iterator().next().length);
    }

    reduceVocabulary(embeddings, targetVocab);

    double score = spearmansScore(wordSimPairs, embeddings);
    System.out.println("Score is " + score);
  }
Exemplo n.º 5
0
  public static void main(String[] args) throws IOException {
    // Parse command line flags and arguments
    Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Set up default parameters and settings
    String basePath = ".";
    String model = "baseline";
    boolean verbose = false;

    // Update defaults using command line specifications

    // The path to the assignment data
    if (argMap.containsKey("-path")) {
      basePath = argMap.get("-path");
    }
    System.out.println("Using base path: " + basePath);

    // A string descriptor of the model to use
    if (argMap.containsKey("-model")) {
      model = argMap.get("-model");
    }
    System.out.println("Using model: " + model);

    // Whether or not to print the individual speech errors.
    if (argMap.containsKey("-verbose")) {
      verbose = true;
    }
    if (argMap.containsKey("-quiet")) {
      verbose = false;
    }

    // Read in all the assignment data
    String trainingSentencesFile = "/treebank-sentences-spoken-train.txt";
    String validationSentencesFile = "/treebank-sentences-spoken-validate.txt";
    String testSentencesFile = "/treebank-sentences-spoken-test.txt";
    String speechNBestListsPath = "/wsj_n_bst";
    Collection<List<String>> trainingSentenceCollection =
        SentenceCollection.Reader.readSentenceCollection(basePath + trainingSentencesFile);
    Collection<List<String>> validationSentenceCollection =
        SentenceCollection.Reader.readSentenceCollection(basePath + validationSentencesFile);
    Collection<List<String>> testSentenceCollection =
        SentenceCollection.Reader.readSentenceCollection(basePath + testSentencesFile);
    Set trainingVocabulary = extractVocabulary(trainingSentenceCollection);
    List<SpeechNBestList> speechNBestLists =
        SpeechNBestList.Reader.readSpeechNBestLists(
            basePath + speechNBestListsPath, trainingVocabulary);

    // Build the language model
    LanguageModel languageModel = null;
    if (model.equalsIgnoreCase("baseline")) {
      languageModel = new EmpiricalUnigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("bigram")) {
      languageModel = new EmpiricalBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("trigram")) {
      languageModel = new EmpiricalTrigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("KN")) {
      languageModel = new KNBigramLanguageModel(trainingSentenceCollection);
    } else {
      throw new RuntimeException("Unknown model descriptor: " + model);
    }

    boolean tunning = true;
    if (tunning == true) {
      // Evaluate the language model
      File f = new File("3-gram result2.txt");
      BufferedWriter output = new BufferedWriter(new FileWriter(f));
      double min = 1;
      double minR1 = 1;
      double minR2 = 1;
      double minR3 = 1;
      double minR4 = 1;
      double minR5 = 1;
      double minK = 1;
      double minPerp = 1000000;
      double k = 0.1;
      double step = 0.1;
      double r1 = 0;
      double r2 = 0;
      double r3 = 0;
      double r4 = 0;
      double r5 = 0;
      for (r1 = 0; r1 < 1; r1 += 0.2) {
        for (r2 = 0; (r1 + r2) < 1; r2 += 0.1) {
          r3 = 1 - (r1 + r2);
          for (r4 = 0; r4 < 1; r4 += 0.1) {
            r5 = 1 - r4;
            EmpiricalTrigramLanguageModel.r1 = r1;
            EmpiricalTrigramLanguageModel.r2 = r2;
            EmpiricalTrigramLanguageModel.r3 = r3;
            EmpiricalTrigramLanguageModel.r4 = r4;
            EmpiricalTrigramLanguageModel.r5 = r5;

            double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose);
            double wsjPerplexity = calculatePerplexity(languageModel, validationSentenceCollection);
            double hubPerplexity =
                calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists));
            if (minPerp > wsjPerplexity) {
              minPerp = wsjPerplexity;
              // minK = k;
              minR1 = r1;
              minR2 = r2;
              minR3 = r3;
              minR4 = r4;
              minR5 = r5;
            }
            if (min > wordErrorRate) min = wordErrorRate;
            System.out.println(
                "r1:" + r1 + "\t" + "r2:" + r2 + "\t" + "r3:" + r3 + "\t" + "r4:" + r4 + "\t"
                    + "r5:" + r5 + "\t");
            System.out.println("HUB Word Error Rate: " + wordErrorRate);
            System.out.println("Min Error Rate till now: " + min);
            // System.out.println("minK=" + minK);

            System.out.println("WSJ Perplexity:  " + wsjPerplexity);
            System.out.println("HUB Perplexity:  " + hubPerplexity);
            System.out.println();

            BigDecimal big_r1 = new BigDecimal(r1);
            BigDecimal big_r2 = new BigDecimal(r2);
            BigDecimal big_r3 = new BigDecimal(r3);
            BigDecimal big_r4 = new BigDecimal(r4);
            BigDecimal big_r5 = new BigDecimal(r5);
            BigDecimal big_wsjPerplexity = new BigDecimal(wsjPerplexity);
            BigDecimal big_hubPerplexity = new BigDecimal(hubPerplexity);
            BigDecimal big_wordErrorRate = new BigDecimal(wordErrorRate);

            big_r1 = big_r1.setScale(1, BigDecimal.ROUND_HALF_UP);
            big_r2 = big_r2.setScale(1, BigDecimal.ROUND_HALF_UP);
            big_r3 = big_r3.setScale(1, BigDecimal.ROUND_HALF_UP);
            big_r4 = big_r4.setScale(1, BigDecimal.ROUND_HALF_UP);
            big_r5 = big_r5.setScale(1, BigDecimal.ROUND_HALF_UP);
            big_wsjPerplexity = big_wsjPerplexity.setScale(2, BigDecimal.ROUND_HALF_UP);
            big_hubPerplexity = big_hubPerplexity.setScale(2, BigDecimal.ROUND_HALF_UP);
            big_wordErrorRate = big_wordErrorRate.setScale(4, BigDecimal.ROUND_HALF_UP);

            output.write(
                big_r1
                    + "\t\t\t"
                    + big_r2
                    + "\t\t\t"
                    + big_r3
                    + "\t\t\t"
                    + big_r4
                    + "\t\t\t"
                    + big_r5
                    + "\t\t\t"
                    + big_wsjPerplexity
                    + "\t\t\t"
                    + big_hubPerplexity
                    + "\t\t\t"
                    + big_wordErrorRate);
            output.write("\n");
          }
        }
      }
      output.write("\n");
      output.write("min WER:" + min + "\n");
      output.write("min Perp:" + minPerp + "\n");
      output.write(
          "minR1:" + "\t\t\t" + minR1 + "\t\t\t" + "minR2:" + "\t\t\t" + minR2 + "\t\t\t" + "minR3:"
              + "\t\t\t" + minR3 + "\t\t\t" + "minR4:" + "\t\t\t" + minR4 + "\t\t\t" + "minR5:"
              + "\t\t\t" + minR5 + "\n");
      output.close();
    } else {
      EmpiricalTrigramLanguageModel.k = 0.1;
      EmpiricalTrigramLanguageModel.r1 = 0.7;
      EmpiricalTrigramLanguageModel.r2 = 0.2;
      EmpiricalTrigramLanguageModel.r3 = 0.1;
      double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose);
      double wsjPerplexity = calculatePerplexity(languageModel, testSentenceCollection);
      double hubPerplexity =
          calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists));
      System.out.println("HUB Word Error Rate: " + wordErrorRate);
      System.out.println("WSJ Perplexity:  " + wsjPerplexity);
      System.out.println("HUB Perplexity:  " + hubPerplexity);
    }
  }