public static void main(String[] args) throws Exception { // Parse command line flags and arguments Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args); // Set up default parameters and settings String basePath = "."; boolean verbose = false; // Update defaults using command line specifications // The path to the assignment data if (argMap.containsKey("-path")) { basePath = argMap.get("-path"); } System.out.println("Using base path: " + basePath); // Whether or not to print the individual errors. if (argMap.containsKey("-verbose")) { verbose = true; } // Read in data System.out.print("Loading training sentences..."); List<TaggedSentence> trainTaggedSentences = readTaggedSentences(basePath + "/en-wsj-train.pos", true); Set<String> trainingVocabulary = extractVocabulary(trainTaggedSentences); System.out.println("done."); System.out.print("Loading in-domain dev sentences..."); List<TaggedSentence> devInTaggedSentences = readTaggedSentences(basePath + "/en-wsj-dev.pos", true); System.out.println("done."); System.out.print("Loading out-of-domain dev sentences..."); List<TaggedSentence> devOutTaggedSentences = readTaggedSentences(basePath + "/en-web-weblogs-dev.pos", true); System.out.println("done."); System.out.print("Loading out-of-domain blind test sentences..."); List<TaggedSentence> testSentences = readTaggedSentences(basePath + "/en-web-test.blind", false); System.out.println("done."); // Construct tagger components // TODO : improve on the MostFrequentTagScorer LocalTrigramScorer localTrigramScorer = new MostFrequentTagScorer(false); // TODO : improve on the GreedyDecoder TrellisDecoder<State> trellisDecoder = new GreedyDecoder<State>(); // Train tagger POSTagger posTagger = new POSTagger(localTrigramScorer, trellisDecoder); posTagger.train(trainTaggedSentences); // Optionally tune hyperparameters on dev data posTagger.validate(devInTaggedSentences); // Test tagger System.out.println("Evaluating on in-domain data:."); evaluateTagger(posTagger, devInTaggedSentences, trainingVocabulary, verbose); System.out.println("Evaluating on out-of-domain data:."); evaluateTagger(posTagger, devOutTaggedSentences, trainingVocabulary, verbose); labelTestSet(posTagger, testSentences, basePath + "/en-web-test.tagged"); }
public static void main(String[] args) throws IOException { // Parse command line flags and arguments final Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args); // Set up default parameters and settings String basePath = "."; String model = "baseline"; boolean verbose = false; // Update defaults using command line specifications // The path to the assignment data if (argMap.containsKey("-path")) { basePath = argMap.get("-path"); } System.out.println("Using base path: " + basePath); // A string descriptor of the model to use if (argMap.containsKey("-model")) { model = argMap.get("-model"); } System.out.println("Using model: " + model); // Whether or not to print the individual speech errors. if (argMap.containsKey("-verbose")) { verbose = true; } if (argMap.containsKey("-quiet")) { verbose = false; } // Read in all the assignment data final String trainingSentencesFile = "/treebank-sentences-spoken-train.txt"; final String speechNBestListsPath = "/wsj_n_bst"; final Collection<List<String>> trainingSentenceCollection = SentenceCollection.Reader.readSentenceCollection(basePath + trainingSentencesFile); final Set<String> trainingVocabulary = extractVocabulary(trainingSentenceCollection); final List<SpeechNBestList> speechNBestLists = SpeechNBestList.Reader.readSpeechNBestLists( basePath + speechNBestListsPath, trainingVocabulary); // String validationSentencesFile = // "/treebank-sentences-spoken-validate.txt"; // Collection<List<String>> validationSentenceCollection = // SentenceCollection.Reader.readSentenceCollection(basePath + // validationSentencesFile); // String testSentencesFile = "/treebank-sentences-spoken-test.txt"; // Collection<List<String>> testSentenceCollection = // SentenceCollection.Reader.readSentenceCollection(basePath + // testSentencesFile); // Build the language model LanguageModel languageModel = null; if (model.equalsIgnoreCase("baseline")) { languageModel = new EmpiricalUnigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("sri")) { languageModel = new SriLanguageModel(argMap.get("-sri")); } else if (model.equalsIgnoreCase("bigram")) { languageModel = new EmpiricalBigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("trigram")) { languageModel = new EmpiricalTrigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("katz-bigram")) { languageModel = new KatzBigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("katz-bigram-pp")) { languageModel = new KatzPPBigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("katz-trigram")) { throw new IllegalStateException( "Katz trigram model not fully implemented -- remove exception and uncomment next line if implemented"); // languageModel = new KatzTrigramLanguageModel( // trainingSentenceCollection); } else { throw new RuntimeException("Unknown model descriptor: " + model); } // Evaluate the language model // final double wsjPerplexity = calculatePerplexity(languageModel, // testSentenceCollection); final double hubPerplexity = calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists)); // System.out.println("WSJ Perplexity: " + wsjPerplexity); System.out.println("HUB Perplexity: " + hubPerplexity); System.out.println("WER Baselines:"); System.out.println(" Best Path: " + calculateWordErrorRateLowerBound(speechNBestLists)); System.out.println(" Worst Path: " + calculateWordErrorRateUpperBound(speechNBestLists)); System.out.println(" Avg Path: " + calculateWordErrorRateRandomChoice(speechNBestLists)); final double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose); System.out.println("HUB Word Error Rate: " + wordErrorRate); System.out.println("Generated Sentences:"); for (int i = 0; i < 10; i++) System.out.println(" " + languageModel.generateSentence()); }
public static void main(String[] args) throws IOException { // Parse command line flags and arguments Map<String,String> argMap = CommandLineUtils.simpleCommandLineParser(args); // Set up default parameters and settings String basePath = "."; int maxTrainingSentences = 0; int maxIterations = 20; boolean verbose = false; boolean initialize = false; String dataset = "mini"; String model = "baseline"; // Update defaults using command line specifications if (argMap.containsKey("-path")) { basePath = argMap.get("-path"); System.out.println("Using base path: "+basePath); } if (argMap.containsKey("-sentences")) { maxTrainingSentences = Integer.parseInt(argMap.get("-sentences")); System.out.println("Using an additional "+maxTrainingSentences+" training sentences."); } if (argMap.containsKey("-data")) { dataset = argMap.get("-data"); System.out.println("Running with data: "+dataset); } else { System.out.println("No data set specified. Use -data [miniTest, validate]."); } if (argMap.containsKey("-model")) { model = argMap.get("-model"); System.out.println("Running with model: "+model); } else { System.out.println("No model specified. Use -model modelname."); } if (argMap.containsKey("-verbose")) { verbose = true; } if (argMap.containsKey("-iterations")) { maxIterations = Integer.parseInt(argMap.get("-iterations")); } if (argMap.containsKey("-initialize")) { initialize = true; } // Read appropriate training and testing sets. List<SentencePair> trainingSentencePairs = new ArrayList<SentencePair>(); if (! (dataset.equals("miniTest") || dataset.equals("mini")) && maxTrainingSentences > 0) trainingSentencePairs = readSentencePairs(basePath+"/training", maxTrainingSentences); List<SentencePair> testSentencePairs = new ArrayList<SentencePair>(); Map<Integer,Alignment> testAlignments = new HashMap<Integer, Alignment>(); if (dataset.equalsIgnoreCase("validate")) { testSentencePairs = readSentencePairs(basePath+"/trial", Integer.MAX_VALUE); testAlignments = readAlignments(basePath+"/trial/trial.wa"); } else if (dataset.equals("miniTest") || dataset.equals("mini")) { testSentencePairs = readSentencePairs(basePath+"/mini", Integer.MAX_VALUE); testAlignments = readAlignments(basePath+"/mini/mini.wa"); } else { throw new RuntimeException("Bad data set mode: "+ dataset+", use validate or miniTest."); } trainingSentencePairs.addAll(testSentencePairs); // Build model WordAligner wordAligner = null; if (model.equalsIgnoreCase("baseline")) { wordAligner = new BaselineWordAligner(); } // TODO : build other alignment models else if (model.equalsIgnoreCase("heuristic")) { wordAligner = new HeuristicWordAligner(trainingSentencePairs); } else if (model.equalsIgnoreCase("dice")) { wordAligner = new DiceWordAligner(trainingSentencePairs); } else if (model.equalsIgnoreCase("ibm1") || model.equalsIgnoreCase("ibmModel1")) { wordAligner = new IBMmodel1WordAligner(trainingSentencePairs, maxIterations, initialize); } else if (model.equalsIgnoreCase("ibm2") || model.equalsIgnoreCase("ibmModel2")) { wordAligner = new IBMmodel2WordAligner(trainingSentencePairs, maxIterations, initialize); } // Test model test(wordAligner, testSentencePairs, testAlignments, verbose); // Generate file for submission //can comment out if not ready for submission testSentencePairs = readSentencePairs(basePath+"/test", Integer.MAX_VALUE); predict(wordAligner, testSentencePairs, basePath+"/"+model+".out"); }
public static void main(String[] args) throws Exception { // Parse command line flags and arguments. Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args); // Read commandline parameters. String embeddingPath = ""; if (!argMap.containsKey("-embeddings")) { System.out.println("-embeddings flag required."); System.exit(0); } else { embeddingPath = argMap.get("-embeddings"); } String wordSimPath = ""; if (!argMap.containsKey("-wordsim")) { System.out.println("-wordsim flag required."); System.exit(0); } else { wordSimPath = argMap.get("-wordsim"); } // Read in the labeled similarities and generate the target vocabulary. System.out.println("Loading wordsim353 ..."); List<Pair<Pair<String, String>, Float>> wordSimPairs = readWordSimPairs(wordSimPath); Set<String> targetVocab = getWordSimVocab(wordSimPath); // It is likely that you will want to generate your embeddings // elsewhere. But this supports the option to generate the embeddings // and evaluate them in a single loop. HashMap<String, float[]> embeddings; if (argMap.containsKey("-trainandeval")) { // Get some training data. String dataPath = ""; if (!argMap.containsKey("-trainingdata")) { System.out.println("-trainingdata flag required with -trainandeval"); System.exit(0); } else { dataPath = argMap.get("-trainingdata"); } // Since this simple approach does not do dimensionality reduction // on the co-occurrence vectors, we instead control the size of the // vectors by only counting co-occurrence with core WordNet senses. String wordNetPath = ""; if (!argMap.containsKey("-wordnetdata")) { System.out.println("-wordnetdata flag required with -trainandeval"); System.exit(0); } else { wordNetPath = argMap.get("-wordnetdata"); } // HashMap<String, Integer> contentWordVocab = getWordNetVocab(wordNetPath); System.out.println("Training embeddings on " + dataPath + " ..."); // embeddings = getEmbeddings(dataPath, contentWordVocab, targetVocab); int kSamples = 5; int dimensions = 100; int contextSize = 2; WordSim skipgram = new WordSim(dataPath, kSamples, dimensions, contextSize); embeddings = skipgram.getEmbeddings(targetVocab); // Keep only the words that are needed. System.out.println("Writing embeddings to " + embeddingPath + " ..."); // embeddings = reduceVocabulary(embeddings, targetVocab); // writeEmbeddings(embeddings, embeddingPath, contentVocab.size()); writeEmbeddings(embeddings, embeddingPath, dimensions); } else { // Read in embeddings. System.out.println("Loading embeddings ..."); embeddings = readEmbeddings(embeddingPath); // Keep only the words that are needed. System.out.println( "Writing reduced vocabulary embeddings to " + embeddingPath + ".reduced ..."); embeddings = reduceVocabulary(embeddings, targetVocab); writeEmbeddings( embeddings, embeddingPath + ".reduced", embeddings.values().iterator().next().length); } reduceVocabulary(embeddings, targetVocab); double score = spearmansScore(wordSimPairs, embeddings); System.out.println("Score is " + score); }
public static void main(String[] args) throws IOException { // Parse command line flags and arguments Map<String, String> argMap = CommandLineUtils.simpleCommandLineParser(args); // Set up default parameters and settings String basePath = "."; String model = "baseline"; boolean verbose = false; // Update defaults using command line specifications // The path to the assignment data if (argMap.containsKey("-path")) { basePath = argMap.get("-path"); } System.out.println("Using base path: " + basePath); // A string descriptor of the model to use if (argMap.containsKey("-model")) { model = argMap.get("-model"); } System.out.println("Using model: " + model); // Whether or not to print the individual speech errors. if (argMap.containsKey("-verbose")) { verbose = true; } if (argMap.containsKey("-quiet")) { verbose = false; } // Read in all the assignment data String trainingSentencesFile = "/treebank-sentences-spoken-train.txt"; String validationSentencesFile = "/treebank-sentences-spoken-validate.txt"; String testSentencesFile = "/treebank-sentences-spoken-test.txt"; String speechNBestListsPath = "/wsj_n_bst"; Collection<List<String>> trainingSentenceCollection = SentenceCollection.Reader.readSentenceCollection(basePath + trainingSentencesFile); Collection<List<String>> validationSentenceCollection = SentenceCollection.Reader.readSentenceCollection(basePath + validationSentencesFile); Collection<List<String>> testSentenceCollection = SentenceCollection.Reader.readSentenceCollection(basePath + testSentencesFile); Set trainingVocabulary = extractVocabulary(trainingSentenceCollection); List<SpeechNBestList> speechNBestLists = SpeechNBestList.Reader.readSpeechNBestLists( basePath + speechNBestListsPath, trainingVocabulary); // Build the language model LanguageModel languageModel = null; if (model.equalsIgnoreCase("baseline")) { languageModel = new EmpiricalUnigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("bigram")) { languageModel = new EmpiricalBigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("trigram")) { languageModel = new EmpiricalTrigramLanguageModel(trainingSentenceCollection); } else if (model.equalsIgnoreCase("KN")) { languageModel = new KNBigramLanguageModel(trainingSentenceCollection); } else { throw new RuntimeException("Unknown model descriptor: " + model); } boolean tunning = true; if (tunning == true) { // Evaluate the language model File f = new File("3-gram result2.txt"); BufferedWriter output = new BufferedWriter(new FileWriter(f)); double min = 1; double minR1 = 1; double minR2 = 1; double minR3 = 1; double minR4 = 1; double minR5 = 1; double minK = 1; double minPerp = 1000000; double k = 0.1; double step = 0.1; double r1 = 0; double r2 = 0; double r3 = 0; double r4 = 0; double r5 = 0; for (r1 = 0; r1 < 1; r1 += 0.2) { for (r2 = 0; (r1 + r2) < 1; r2 += 0.1) { r3 = 1 - (r1 + r2); for (r4 = 0; r4 < 1; r4 += 0.1) { r5 = 1 - r4; EmpiricalTrigramLanguageModel.r1 = r1; EmpiricalTrigramLanguageModel.r2 = r2; EmpiricalTrigramLanguageModel.r3 = r3; EmpiricalTrigramLanguageModel.r4 = r4; EmpiricalTrigramLanguageModel.r5 = r5; double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose); double wsjPerplexity = calculatePerplexity(languageModel, validationSentenceCollection); double hubPerplexity = calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists)); if (minPerp > wsjPerplexity) { minPerp = wsjPerplexity; // minK = k; minR1 = r1; minR2 = r2; minR3 = r3; minR4 = r4; minR5 = r5; } if (min > wordErrorRate) min = wordErrorRate; System.out.println( "r1:" + r1 + "\t" + "r2:" + r2 + "\t" + "r3:" + r3 + "\t" + "r4:" + r4 + "\t" + "r5:" + r5 + "\t"); System.out.println("HUB Word Error Rate: " + wordErrorRate); System.out.println("Min Error Rate till now: " + min); // System.out.println("minK=" + minK); System.out.println("WSJ Perplexity: " + wsjPerplexity); System.out.println("HUB Perplexity: " + hubPerplexity); System.out.println(); BigDecimal big_r1 = new BigDecimal(r1); BigDecimal big_r2 = new BigDecimal(r2); BigDecimal big_r3 = new BigDecimal(r3); BigDecimal big_r4 = new BigDecimal(r4); BigDecimal big_r5 = new BigDecimal(r5); BigDecimal big_wsjPerplexity = new BigDecimal(wsjPerplexity); BigDecimal big_hubPerplexity = new BigDecimal(hubPerplexity); BigDecimal big_wordErrorRate = new BigDecimal(wordErrorRate); big_r1 = big_r1.setScale(1, BigDecimal.ROUND_HALF_UP); big_r2 = big_r2.setScale(1, BigDecimal.ROUND_HALF_UP); big_r3 = big_r3.setScale(1, BigDecimal.ROUND_HALF_UP); big_r4 = big_r4.setScale(1, BigDecimal.ROUND_HALF_UP); big_r5 = big_r5.setScale(1, BigDecimal.ROUND_HALF_UP); big_wsjPerplexity = big_wsjPerplexity.setScale(2, BigDecimal.ROUND_HALF_UP); big_hubPerplexity = big_hubPerplexity.setScale(2, BigDecimal.ROUND_HALF_UP); big_wordErrorRate = big_wordErrorRate.setScale(4, BigDecimal.ROUND_HALF_UP); output.write( big_r1 + "\t\t\t" + big_r2 + "\t\t\t" + big_r3 + "\t\t\t" + big_r4 + "\t\t\t" + big_r5 + "\t\t\t" + big_wsjPerplexity + "\t\t\t" + big_hubPerplexity + "\t\t\t" + big_wordErrorRate); output.write("\n"); } } } output.write("\n"); output.write("min WER:" + min + "\n"); output.write("min Perp:" + minPerp + "\n"); output.write( "minR1:" + "\t\t\t" + minR1 + "\t\t\t" + "minR2:" + "\t\t\t" + minR2 + "\t\t\t" + "minR3:" + "\t\t\t" + minR3 + "\t\t\t" + "minR4:" + "\t\t\t" + minR4 + "\t\t\t" + "minR5:" + "\t\t\t" + minR5 + "\n"); output.close(); } else { EmpiricalTrigramLanguageModel.k = 0.1; EmpiricalTrigramLanguageModel.r1 = 0.7; EmpiricalTrigramLanguageModel.r2 = 0.2; EmpiricalTrigramLanguageModel.r3 = 0.1; double wordErrorRate = calculateWordErrorRate(languageModel, speechNBestLists, verbose); double wsjPerplexity = calculatePerplexity(languageModel, testSentenceCollection); double hubPerplexity = calculatePerplexity(languageModel, extractCorrectSentenceList(speechNBestLists)); System.out.println("HUB Word Error Rate: " + wordErrorRate); System.out.println("WSJ Perplexity: " + wsjPerplexity); System.out.println("HUB Perplexity: " + hubPerplexity); } }