@Override
 public Vector select(Vector probabilities) {
   int maxValueIndex = probabilities.maxValueIndex();
   Vector weights = new SequentialAccessSparseVector(probabilities.size());
   weights.set(maxValueIndex, 1.0);
   return weights;
 }
示例#2
0
  static void mainToOutput(String[] args, PrintWriter output) throws Exception {
    if (!parseArgs(args)) {
      return;
    }
    AdaptiveLogisticModelParameters lmp =
        AdaptiveLogisticModelParameters.loadFromFile(new File(modelFile));

    CsvRecordFactory csv = lmp.getCsvRecordFactory();
    csv.setIdName(idColumn);

    AdaptiveLogisticRegression lr = lmp.createAdaptiveLogisticRegression();

    State<Wrapper, CrossFoldLearner> best = lr.getBest();
    if (best == null) {
      output.println("AdaptiveLogisticRegression has not be trained probably.");
      return;
    }
    CrossFoldLearner learner = best.getPayload().getLearner();

    BufferedReader in = TrainAdaptiveLogistic.open(inputFile);
    BufferedWriter out =
        new BufferedWriter(
            new OutputStreamWriter(new FileOutputStream(outputFile), Charsets.UTF_8));

    out.write(idColumn + ",target,score");
    out.newLine();

    String line = in.readLine();
    csv.firstLine(line);
    line = in.readLine();
    Map<String, Double> results = new HashMap<String, Double>();
    int k = 0;
    while (line != null) {
      Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
      csv.processLine(line, v, false);
      Vector scores = learner.classifyFull(v);
      results.clear();
      if (maxScoreOnly) {
        results.put(csv.getTargetLabel(scores.maxValueIndex()), scores.maxValue());
      } else {
        for (int i = 0; i < scores.size(); i++) {
          results.put(csv.getTargetLabel(i), scores.get(i));
        }
      }

      for (Map.Entry<String, Double> entry : results.entrySet()) {
        out.write(csv.getIdString(line) + ',' + entry.getKey() + ',' + entry.getValue());
        out.newLine();
      }
      k++;
      if (k % 100 == 0) {
        output.println(k + " records processed");
      }
      line = in.readLine();
    }
    out.flush();
    out.close();
    output.println(k + " records processed totally.");
  }
  public void processVectors(
      List<Pair<Integer, NamedVector>> vectors, boolean train, SequenceFile.Writer writer)
      throws IOException {

    BytesWritable recordKey = new BytesWritable("".getBytes());
    Text value = new Text();

    if (train) Collections.shuffle(vectors);

    double mu = -1.0;
    double ll = -1.0;
    int actual = -1;
    for (Pair<Integer, NamedVector> pair : vectors) {
      NamedVector v = pair.getValue();
      if (train) {
        actual = pair.getKey();
        mu = Math.min(this.k + 1, 200);
        ll = learningAlgorithm.logLikelihood(actual, v);
        this.averageLL = this.averageLL + (ll - this.averageLL) / mu;
      }

      Vector p = new DenseVector(LABELS);
      learningAlgorithm.classifyFull(p, v);
      int estimated = p.maxValueIndex();
      this.counts[estimated]++;

      if (writer != null) {
        value.set(
            String.format(
                "%s%c%d%c01%f",
                v.getName(),
                SEQUENCE_FILE_FIELD_SEPARATOR,
                estimated,
                SEQUENCE_FILE_FIELD_SEPARATOR,
                p.get(estimated)));
        writer.append(recordKey, value);
      }

      if (train) {
        int correct = (estimated == actual ? 1 : 0);
        this.averageCorrect = this.averageCorrect + (correct - this.averageCorrect) / mu;

        learningAlgorithm.train(actual, v);
        learningAlgorithm.close();
      }
      this.k++;
      int bump = this.BUMPS[(int) Math.floor(this.step) % this.BUMPS.length];
      int scale = (int) Math.pow(10, Math.floor(this.step / this.BUMPS.length));
      if (this.k % Math.min(MAX_STEP, bump * scale) == 0) {
        this.step += 0.25;
        if (train)
          System.out.printf(
              "%10d %10.3f %10.3f %10.2f %d\n",
              this.k, ll, this.averageLL, this.averageCorrect * 100, estimated);
        else System.out.printf("%c%10d, per label: %s\n", CR, this.k, Arrays.toString(this.counts));
      }
    }
  }