@Test public void testLuceneEncoding() throws Exception { LuceneTextValueEncoder enc = new LuceneTextValueEncoder("text"); enc.setAnalyzer(new WhitespaceAnalyzer(Version.LUCENE_43)); Vector v1 = new DenseVector(200); enc.addToVector("test1 and more", v1); enc.flush(1, v1); // should be the same as text test above, since we are splitting on whitespace // should set 6 distinct locations to 1 assertEquals(6.0, v1.norm(1), 0); assertEquals(1.0, v1.maxValue(), 0); v1 = new DenseVector(200); enc.addToVector("", v1); enc.flush(1, v1); assertEquals(0.0, v1.norm(1), 0); assertEquals(0.0, v1.maxValue(), 0); v1 = new DenseVector(200); StringBuilder builder = new StringBuilder(5000); for (int i = 0; i < 1000; i++) { // lucene's internal buffer length request is 4096, so let's make sure we can handle // larger size builder.append("token_").append(i).append(' '); } enc.addToVector(builder.toString(), v1); enc.flush(1, v1); // System.out.println(v1); assertEquals(2000.0, v1.norm(1), 0); assertEquals(19.0, v1.maxValue(), 0); }
@Test public void testAddToVector() { TextValueEncoder enc = new TextValueEncoder("text"); Vector v1 = new DenseVector(200); enc.addToVector("test1 and more", v1); enc.flush(1, v1); // should set 6 distinct locations to 1 assertEquals(6.0, v1.norm(1), 0); assertEquals(1.0, v1.maxValue(), 0); // now some fancy weighting StaticWordValueEncoder w = new StaticWordValueEncoder("text"); w.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5)); enc.setWordEncoder(w); // should set 6 locations to something Vector v2 = new DenseVector(200); enc.addToVector("test1 and more", v2); enc.flush(1, v2); // this should set the same 6 locations to the same values Vector v3 = new DenseVector(200); w.addToVector("test1", v3); w.addToVector("and", v3); w.addToVector("more", v3); assertEquals(0, v3.minus(v2).norm(1), 0); // moreover, the locations set in the unweighted case should be the same as in the weighted case assertEquals(v3.zSum(), v3.dot(v1), 0); }
static void mainToOutput(String[] args, PrintWriter output) throws Exception { if (!parseArgs(args)) { return; } AdaptiveLogisticModelParameters lmp = AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile)); CsvRecordFactory csv = lmp.getCsvRecordFactory(); csv.setIdName(idColumn); AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression(); State<Wrapper, CrossFoldLearner> best = lr.getBest(); if (best == null) { output.println("AdaptiveLogisticRegression has not be trained probably."); return; } CrossFoldLearner learner = best.getPayload().getLearner(); BufferedReader in = TrainAdaptiveLogistic.open(inputFile); BufferedWriter out = new BufferedWriter( new OutputStreamWriter(new FileOutputStream(outputFile), Charsets.UTF_8)); out.write(idColumn + ",target,score"); out.newLine(); String line = in.readLine(); csv.firstLine(line); line = in.readLine(); Map<String, Double> results = new HashMap<String, Double>(); int k = 0; while (line != null) { Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures()); csv.processLine(line, v, false); Vector scores = learner.classifyFull(v); results.clear(); if (maxScoreOnly) { results.put(csv.getTargetLabel(scores.maxValueIndex()), scores.maxValue()); } else { for (int i = 0; i < scores.size(); i++) { results.put(csv.getTargetLabel(i), scores.get(i)); } } for (Map.Entry<String, Double> entry : results.entrySet()) { out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue()); out.newLine(); } k++; if (k % 100 == 0) { output.println(k + " records processed"); } line = in.readLine(); } out.flush(); out.close(); output.println(k + " records processed totally."); }
@Test public void testAddToVector() { FeatureVectorEncoder enc = new ContinuousValueEncoder("foo"); Vector v1 = new DenseVector(20); enc.addToVector("-123", v1); Assert.assertEquals(-123, v1.minValue(), 0); Assert.assertEquals(0, v1.maxValue(), 0); Assert.assertEquals(123, v1.norm(1), 0); v1 = new DenseVector(20); enc.addToVector("123", v1); Assert.assertEquals(123, v1.maxValue(), 0); Assert.assertEquals(0, v1.minValue(), 0); Assert.assertEquals(123, v1.norm(1), 0); Vector v2 = new DenseVector(20); enc.setProbes(2); enc.addToVector("123", v2); Assert.assertEquals(123, v2.maxValue(), 0); Assert.assertEquals(2 * 123, v2.norm(1), 0); v1 = v2.minus(v1); Assert.assertEquals(123, v1.maxValue(), 0); Assert.assertEquals(123, v1.norm(1), 0); Vector v3 = new DenseVector(20); enc.setProbes(2); enc.addToVector("100", v3); v1 = v2.minus(v3); Assert.assertEquals(23, v1.maxValue(), 0); Assert.assertEquals(2 * 23, v1.norm(1), 0); enc.addToVector("7", v1); Assert.assertEquals(30, v1.maxValue(), 0); Assert.assertEquals(2 * 30, v1.norm(1), 0); Assert.assertEquals(30, v1.get(10), 0); Assert.assertEquals(30, v1.get(18), 0); try { enc.addToVector("foobar", v1); Assert.fail("Should have noticed bad numeric format"); } catch (NumberFormatException e) { Assert.assertEquals("For input string: \"foobar\"", e.getMessage()); } }