@Test
  public void testLuceneEncoding() throws Exception {
    LuceneTextValueEncoder enc = new LuceneTextValueEncoder("text");
    enc.setAnalyzer(new WhitespaceAnalyzer(Version.LUCENE_43));
    Vector v1 = new DenseVector(200);
    enc.addToVector("test1 and more", v1);
    enc.flush(1, v1);

    // should be the same as text test above, since we are splitting on whitespace
    // should set 6 distinct locations to 1
    assertEquals(6.0, v1.norm(1), 0);
    assertEquals(1.0, v1.maxValue(), 0);

    v1 = new DenseVector(200);
    enc.addToVector("", v1);
    enc.flush(1, v1);
    assertEquals(0.0, v1.norm(1), 0);
    assertEquals(0.0, v1.maxValue(), 0);

    v1 = new DenseVector(200);
    StringBuilder builder = new StringBuilder(5000);
    for (int i = 0;
        i < 1000;
        i++) { // lucene's internal buffer length request is 4096, so let's make sure we can handle
      // larger size
      builder.append("token_").append(i).append(' ');
    }
    enc.addToVector(builder.toString(), v1);
    enc.flush(1, v1);
    // System.out.println(v1);
    assertEquals(2000.0, v1.norm(1), 0);
    assertEquals(19.0, v1.maxValue(), 0);
  }
  @Test
  public void testAddToVector() {
    TextValueEncoder enc = new TextValueEncoder("text");
    Vector v1 = new DenseVector(200);
    enc.addToVector("test1 and more", v1);
    enc.flush(1, v1);
    // should set 6 distinct locations to 1
    assertEquals(6.0, v1.norm(1), 0);
    assertEquals(1.0, v1.maxValue(), 0);

    // now some fancy weighting
    StaticWordValueEncoder w = new StaticWordValueEncoder("text");
    w.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5));
    enc.setWordEncoder(w);

    // should set 6 locations to something
    Vector v2 = new DenseVector(200);
    enc.addToVector("test1 and more", v2);
    enc.flush(1, v2);

    // this should set the same 6 locations to the same values
    Vector v3 = new DenseVector(200);
    w.addToVector("test1", v3);
    w.addToVector("and", v3);
    w.addToVector("more", v3);
    assertEquals(0, v3.minus(v2).norm(1), 0);

    // moreover, the locations set in the unweighted case should be the same as in the weighted case
    assertEquals(v3.zSum(), v3.dot(v1), 0);
  }
Esempio n. 3
0
  static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (!parseArgs(args)) {
      return;
    }
    AdaptiveLogisticModelParameters lmp =
        AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    csv.setIdName(idColumn);

    AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();

    State<Wrapper, CrossFoldLearner> best = lr.getBest();
    if (best == null) {
      output.println("AdaptiveLogisticRegression has not be trained probably.");
      return;
    }
    CrossFoldLearner learner = best.getPayload().getLearner();

    BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
    BufferedWriter out =
        new BufferedWriter(
            new OutputStreamWriter(new FileOutputStream(outputFile), Charsets.UTF_8));

    out.write(idColumn + ",target,score");
    out.newLine();

    String line = in.readLine();
    csv.firstLine(line);
    line = in.readLine();
    Map<String, Double> results = new HashMap<String, Double>();
    int k = 0;
    while (line != null) {
      Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
      csv.processLine(line, v, false);
      Vector scores = learner.classifyFull(v);
      results.clear();
      if (maxScoreOnly) {
        results.put(csv.getTargetLabel(scores.maxValueIndex()), scores.maxValue());
      } else {
        for (int i = 0; i < scores.size(); i++) {
          results.put(csv.getTargetLabel(i), scores.get(i));
        }
      }

      for (Map.Entry<String, Double> entry : results.entrySet()) {
        out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
        out.newLine();
      }
      k++;
      if (k % 100 == 0) {
        output.println(k + " records processed");
      }
      line = in.readLine();
    }
    out.flush();
    out.close();
    output.println(k + " records processed totally.");
  }
  @Test
  public void testAddToVector() {
    FeatureVectorEncoder enc = new ContinuousValueEncoder("foo");
    Vector v1 = new DenseVector(20);
    enc.addToVector("-123", v1);
    Assert.assertEquals(-123, v1.minValue(), 0);
    Assert.assertEquals(0, v1.maxValue(), 0);
    Assert.assertEquals(123, v1.norm(1), 0);

    v1 = new DenseVector(20);
    enc.addToVector("123", v1);
    Assert.assertEquals(123, v1.maxValue(), 0);
    Assert.assertEquals(0, v1.minValue(), 0);
    Assert.assertEquals(123, v1.norm(1), 0);

    Vector v2 = new DenseVector(20);
    enc.setProbes(2);
    enc.addToVector("123", v2);
    Assert.assertEquals(123, v2.maxValue(), 0);
    Assert.assertEquals(2 * 123, v2.norm(1), 0);

    v1 = v2.minus(v1);
    Assert.assertEquals(123, v1.maxValue(), 0);
    Assert.assertEquals(123, v1.norm(1), 0);

    Vector v3 = new DenseVector(20);
    enc.setProbes(2);
    enc.addToVector("100", v3);
    v1 = v2.minus(v3);
    Assert.assertEquals(23, v1.maxValue(), 0);
    Assert.assertEquals(2 * 23, v1.norm(1), 0);

    enc.addToVector("7", v1);
    Assert.assertEquals(30, v1.maxValue(), 0);
    Assert.assertEquals(2 * 30, v1.norm(1), 0);
    Assert.assertEquals(30, v1.get(10), 0);
    Assert.assertEquals(30, v1.get(18), 0);

    try {
      enc.addToVector("foobar", v1);
      Assert.fail("Should have noticed bad numeric format");
    } catch (NumberFormatException e) {
      Assert.assertEquals("For input string: \"foobar\"", e.getMessage());
    }
  }