Esempio n. 1
0
  static double calculatePerplexity(
      LanguageModel languageModel, Collection<List<String>> sentenceCollection) {
    double logProbability = 0.0;
    double numSymbols = 0.0;
    double oneProbability = 0;
    for (List<String> sentence : sentenceCollection) {
      oneProbability = languageModel.getSentenceProbability(sentence);
      if (!(oneProbability > 0)) {
        oneProbability = languageModel.getSentenceProbability(sentence);
      }
      logProbability += Math.log(oneProbability) / Math.log(2.0);
      numSymbols += sentence.size();
    }
    double avgLogProbability = logProbability / numSymbols;

    double perplexity = Math.pow(0.5, avgLogProbability);
    return perplexity;
  }
 private static void displayHypothesis(
     String prefix,
     List<String> guess,
     SpeechNBestList speechNBestList,
     LanguageModel languageModel) {
   final double acoustic = speechNBestList.getAcousticScore(guess) / 16.0;
   final double language = Math.log(languageModel.getSentenceProbability(guess));
   System.out.println(
       prefix
           + "\tAM: "
           + nf.format(acoustic)
           + "\tLM: "
           + nf.format(language)
           + "\tTotal: "
           + nf.format(acoustic + language)
           + "\t"
           + guess);
 }
Esempio n. 3
0
 static double calculateWordErrorRate(
     LanguageModel languageModel, List<SpeechNBestList> speechNBestLists, boolean verbose) {
   double totalDistance = 0.0;
   double totalWords = 0.0;
   EditDistance editDistance = new EditDistance();
   for (SpeechNBestList speechNBestList : speechNBestLists) {
     List<String> correctSentence = speechNBestList.getCorrectSentence();
     List<String> bestGuess = null;
     double bestScore = Double.NEGATIVE_INFINITY;
     double numWithBestScores = 0.0;
     double distanceForBestScores = 0.0;
     for (List<String> guess : speechNBestList.getNBestSentences()) {
       double score =
           Math.log(languageModel.getSentenceProbability(guess))
               + (speechNBestList.getAcousticScore(guess) / 16.0);
       double distance = editDistance.getDistance(correctSentence, guess);
       if (score == bestScore) {
         numWithBestScores += 1.0;
         distanceForBestScores += distance;
       }
       if (score > bestScore || bestGuess == null) {
         bestScore = score;
         bestGuess = guess;
         distanceForBestScores = distance;
         numWithBestScores = 1.0;
       }
     }
     // double distance = editDistance.getDistance(correctSentence,
     // bestGuess);
     totalDistance += distanceForBestScores / numWithBestScores;
     totalWords += correctSentence.size();
     if (verbose) {
       if (distanceForBestScores > 0.0) {
         System.out.println();
         displayHypothesis("GUESS:", bestGuess, speechNBestList, languageModel);
         displayHypothesis("GOLD: ", correctSentence, speechNBestList, languageModel);
         // System.out.println("GOLD:  "+correctSentence);
       }
     }
   }
   return totalDistance / totalWords;
 }
  public static void main(String[] args) throws IOException {
    // Parse command line flags and arguments
    final Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args);

    // Set up default parameters and settings
    String basePath = ".";
    String model = "baseline";
    boolean verbose = false;

    // Update defaults using command line specifications

    // The path to the assignment data
    if (argMap.containsKey("-path")) {
      basePath = argMap.get("-path");
    }
    System.out.println("Using base path: " + basePath);

    // A string descriptor of the model to use
    if (argMap.containsKey("-model")) {
      model = argMap.get("-model");
    }
    System.out.println("Using model: " + model);

    // Whether or not to print the individual speech errors.
    if (argMap.containsKey("-verbose")) {
      verbose = true;
    }
    if (argMap.containsKey("-quiet")) {
      verbose = false;
    }

    // Read in all the assignment data
    final String trainingSentencesFile = "/treebank-sentences-spoken-train.txt";
    final String speechNBestListsPath = "/wsj_n_bst";
    final Collection<List<String>> trainingSentenceCollection =
        SentenceCollection.Reader.readSentenceCollection(basePath + trainingSentencesFile);
    final Set<String> trainingVocabulary = extractVocabulary(trainingSentenceCollection);
    final List<SpeechNBestList> speechNBestLists =
        SpeechNBestList.Reader.readSpeechNBestLists(
            basePath + speechNBestListsPath, trainingVocabulary);

    // String validationSentencesFile =
    // "/treebank-sentences-spoken-validate.txt";
    // Collection<List<String>> validationSentenceCollection =
    // SentenceCollection.Reader.readSentenceCollection(basePath +
    // validationSentencesFile);

    // String testSentencesFile = "/treebank-sentences-spoken-test.txt";
    // Collection<List<String>> testSentenceCollection =
    // SentenceCollection.Reader.readSentenceCollection(basePath +
    // testSentencesFile);

    // Build the language model
    LanguageModel languageModel = null;
    if (model.equalsIgnoreCase("baseline")) {
      languageModel = new EmpiricalUnigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("sri")) {
      languageModel = new SriLanguageModel(argMap.get("-sri"));
    } else if (model.equalsIgnoreCase("bigram")) {
      languageModel = new EmpiricalBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("trigram")) {
      languageModel = new EmpiricalTrigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("katz-bigram")) {
      languageModel = new KatzBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("katz-bigram-pp")) {
      languageModel = new KatzPPBigramLanguageModel(trainingSentenceCollection);
    } else if (model.equalsIgnoreCase("katz-trigram")) {
      throw new IllegalStateException(
          "Katz trigram model not fully implemented -- remove exception and uncomment next line if implemented");
      // languageModel = new KatzTrigramLanguageModel(
      // trainingSentenceCollection);
    } else {
      throw new RuntimeException("Unknown model descriptor: " + model);
    }

    // Evaluate the language model
    // final double wsjPerplexity = calculatePerplexity(languageModel,
    // testSentenceCollection);
    final double hubPerplexity =
        calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists));
    // System.out.println("WSJ Perplexity: " + wsjPerplexity);
    System.out.println("HUB Perplexity:  " + hubPerplexity);
    System.out.println("WER Baselines:");
    System.out.println("  Best Path:  " + calculateWordErrorRateLowerBound(speechNBestLists));
    System.out.println("  Worst Path: " + calculateWordErrorRateUpperBound(speechNBestLists));
    System.out.println("  Avg Path:   " + calculateWordErrorRateRandomChoice(speechNBestLists));
    final double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose);
    System.out.println("HUB Word Error Rate: " + wordErrorRate);
    System.out.println("Generated Sentences:");
    for (int i = 0; i < 10; i++) System.out.println(" " + languageModel.generateSentence());
  }