public void testWordFrequency() { WordPredictor p = new WordPredictor(); p.train("a big brown bear".split(" ")); p.train("a big brown bench".split(" ")); p.train("a big yellow banana".split(" ")); { List<Prediction> pns = p.predictWord("big"); assertEquals(1, pns.size()); assertPrediction("brown", 100d, pns.get(0)); } p.train("a big yellow duck".split(" ")); { List<Prediction> pns = p.predictWord("big"); assertEquals(2, pns.size()); assertPrediction("brown", 50d, pns.get(0)); assertPrediction("yellow", 50d, pns.get(1)); } p.train("a big yellow daisy".split(" ")); { List<Prediction> pns = p.predictWord("big"); assertEquals(2, pns.size()); assertPrediction("yellow", (3d / 5) * 100, pns.get(0)); assertPrediction("brown", (2d / 5) * 100, pns.get(1)); } }
public void testWordPredictor() { WordPredictor p = new WordPredictor(); { // Should be empty before training assertEquals(Collections.emptyList(), p.predictWord("a")); } p.train("a big brown bear".split(" ")); p.train("a big brown bench".split(" ")); { List<Prediction> pns = p.predictWord("a"); assertEquals(1, pns.size()); assertPrediction("big", 100d, pns.get(0)); } { List<Prediction> pns = p.predictWord("big"); assertEquals(1, pns.size()); assertPrediction("brown", 100d, pns.get(0)); } { // Word must occur more than once to be considered interesting. assertEquals(Collections.emptyList(), p.predictWord("brown")); assertEquals(Collections.emptyList(), p.predictWord("bear")); assertEquals(Collections.emptyList(), p.predictWord("foo")); } }
public void testReset() { WordPredictor p = new WordPredictor(); p.train("a big brown bath".split(" ")); p.train("a big brown bath".split(" ")); { List<Prediction> pns = p.predictWord("a"); assertEquals(1, pns.size()); assertPrediction("big", 100d, pns.get(0)); } p.reset(); { assertEquals(Collections.emptyList(), p.predictWord("a")); } }
public void testEmptyInput() { WordPredictor p = new WordPredictor(); { try { p.train(null); fail("should throw NPE on null input"); } catch (NullPointerException e) { // OK } } p.train("a big brown bear".split(" ")); p.train("a big brown bench".split(" ")); { // Empty string should give empty predictions even after training assertEquals(Collections.emptyList(), p.predictWord("")); } { try { p.predictWord(null); fail("Should throw NPE on null input"); } catch (NullPointerException e) { // OK } } }
// Test main method public static void main(String[] args) { // Train a model on the first bit of Moby-Dick WordPredictor wp = new WordPredictor(); System.out.println("bad1 = " + wp.getBest("the")); wp.train("moby_start.txt"); System.out.println("training words = " + wp.getTrainingCount()); // Try and crash things on bad input System.out.println("bad2 = " + wp.getBest("the")); wp.train("thisfiledoesnotexist.txt"); System.out.println("training words = " + wp.getTrainingCount() + "\n"); String[] words = {"the", "me", "zebra", "ishmael", "savage"}; for (String s : words) System.out.println("count, " + s + " = " + wp.getWordCount(s)); System.out.println(); wp.train("moby_end.txt"); // Check the counts again after training on the end of the book for (String s : words) System.out.println("count, " + s + " = " + wp.getWordCount(s)); System.out.println(); // Get the object ready to start looking things up wp.build(); // Do some prefix lookups String[] test = {"a", "ab", "b", "be", "t", "th", "archang"}; for (String prefix : test) System.out.println(prefix + " -> " + wp.getBest(prefix).getList().get(0).getWord()); System.out.println("training words = " + wp.getTrainingCount() + "\n"); // Add two individual words to the training data wp.trainWord("beefeater"); wp.trainWord("BEEFEATER!"); wp.trainWord("Pneumonoultramicroscopicsilicovolcanoconiosis"); // The change should have no effect for prefix lookup until we build() System.out.println("before, b -> " + wp.getBest("b").getList().get(0).getWord()); System.out.println("before, pn -> " + wp.getBest("pn")); wp.build(); System.out.println("after, b -> " + wp.getBest("b").getList().get(0).getWord()); System.out.println("after, pn -> " + wp.getBest("pn").getList().get(0).getWord()); System.out.println("training words = " + wp.getTrainingCount() + "\n"); // Test out training on a big file, timing the training as well Stats stats1 = new Stats(); wp.train("mobydick.txt"); wp.build(); for (String prefix : test) System.out.println(prefix + " -> " + wp.getBest(prefix).getList().get(0).getWord()); System.out.println("training words = " + wp.getTrainingCount()); System.out.println(stats1); // Test lookup using random prefixes between 1-6 characters System.out.println("\nRandom load test:"); Stats stats2 = new Stats(); final String VALID = "abcdefghijklmnopqrstuvwxyz'"; final long TEST_NUM = 10000000; long hits = 0; for (long i = 0; i < TEST_NUM; i++) { String prefix = ""; for (int j = 0; j <= (int) (Math.random() * 6); j++) prefix += VALID.charAt((int) (Math.random() * VALID.length())); // Word word = wp.getBest(prefix).getList().get(0); // if (word != null) // hits++; } System.out.println(stats2); System.out.println("Hit % = " + ((double) hits / TEST_NUM * 100.0)); }
public void testMinFrequency() { WordPredictor p = new WordPredictor(); for (int i = 0; i < 2; i++) { p.train("a big brown bear".split(" ")); p.train("a big brown bench".split(" ")); p.train("a big brown bazooka".split(" ")); p.train("a big brown bazinga".split(" ")); p.train("a big brown balloon".split(" ")); p.train("a big brown boulder".split(" ")); p.train("a big brown blanket".split(" ")); p.train("a big brown balcony".split(" ")); p.train("a big brown binder".split(" ")); p.train("a big brown book".split(" ")); } { List<Prediction> pns = p.predictWord("brown"); assertEquals(10, pns.size()); assertPrediction("balcony", 10d, pns.get(0)); assertPrediction("boulder", 10d, pns.get(9)); } // Add an eleventh distinct word after "brown", making none of the // predictions higher than the minimum threshold for being considered // interesting. p.train("a big brown bath".split(" ")); p.train("a big brown bath".split(" ")); { assertEquals(Collections.emptyList(), p.predictWord("brown")); } { List<Prediction> pns = p.predictWord("a"); assertEquals(1, pns.size()); assertPrediction("big", 100d, pns.get(0)); } }
public static void main(String[] args) { // GUI Variables int width = 1024; int height = 256; int pauseTime = 0; Font typeFont = new Font("Consolas", 1, 24); Font sentenceFont = new Font("Consolas", 0, 16); boolean multipleWords = false; // Strings for typing word, prediction, and sentence. StringBuilder typeString = new StringBuilder(); String prediction; StringBuilder sentence = new StringBuilder(); // Set up window StdDraw.setCanvasSize(width, height); StdDraw.setXscale(0.0, (double) width); StdDraw.setYscale(0.0, (double) height); // Train on files. WordPredictor predictor; if (args.length > 0) { try { int number = Integer.parseInt(args[args.length - 1]); if (number > 0 && number <= 9) { predictor = new WordPredictor(number); multipleWords = true; } else { predictor = new WordPredictor(); } } catch (NumberFormatException e) { predictor = new WordPredictor(); predictor.train(args[args.length - 1]); } for (int i = 0; i < args.length - 1; i++) { predictor.train(args[i]); } predictor.build(); } else { predictor = new WordPredictor(); } // Main loop for GUI and typing. while (true) { // Clear screen and reset variables. StdDraw.clear(); ArrayList<Word> best = null; if (predictor.getBest(typeString.toString()) != null) { best = predictor.getBest(typeString.toString()).getList(); } // If next key has been typed then act. if (StdDraw.hasNextKeyTyped()) { char key = StdDraw.nextKeyTyped(); // If key is a-z or apostrophe add it to string. if ((key >= 'a' && key <= 'z') || key == '\'') { typeString.append(key); } else if (key == ' ') { // If space move word to sentence, train into dict and reset if (typeString.length() > 0) { sentence.append(typeString + " "); predictor.trainWord(typeString.toString()); predictor.build(); typeString = null; typeString = new StringBuilder(); } } else if (key == '\b') { // Remove last letter if possible. if (typeString.length() > 0) { typeString.deleteCharAt(typeString.length() - 1); } } else if (key == '\n') { if (best != null) { // If enter was pressed and there is a prediction // Use the prediction and reset current word. sentence.append(best.get(0).getWord() + " "); typeString = null; typeString = new StringBuilder(); predictor.trainWord(best.get(0).getWord()); predictor.build(); } } else if (multipleWords) { // If there are multiple words available listen for numbers // These will allow for adding from suggestions. if (key >= '1' && key <= '9' && best != null) { int predNum = Integer.parseInt(key + ""); if (best.size() >= predNum) { sentence.append(best.get(predNum - 1).getWord() + " "); typeString = null; typeString = new StringBuilder(); predictor.trainWord(best.get(predNum - 1).getWord()); predictor.build(); } } } } // Start redrawing onto screen. // Typed word set up and print. StdDraw.setPenColor(Color.BLACK); StdDraw.setFont(typeFont); StdDraw.text(width / 2, height / 4 * 3, typeString.toString()); // Show prediction if it exists. if (best != null) { StdDraw.setPenColor(Color.BLUE); for (int i = 0; i < best.size(); i++) { StdDraw.text( width * ((i + .5) / best.size()), height / 2 + typeFont.getSize() / 2, best.get(i).getWord()); if (multipleWords) { StdDraw.text( width * ((i + .5) / best.size()), height / 2 - typeFont.getSize() / 2, "" + (i + 1)); } } StdDraw.setPenColor(Color.BLACK); } // Print out the sentence. StdDraw.setFont(sentenceFont); StdDraw.textRight(width - 20, height / 4, sentence.toString()); // Pause the loop based on time set at beginning. Set to 0. StdDraw.show(pauseTime); } }