コード例 #1
0
 public static void main(String[] args) {
   // Create a training set
   List<Datum<String, String>> trainingData = new ArrayList<>();
   trainingData.add(makeStopLights(GREEN, RED));
   trainingData.add(makeStopLights(GREEN, RED));
   trainingData.add(makeStopLights(GREEN, RED));
   trainingData.add(makeStopLights(RED, GREEN));
   trainingData.add(makeStopLights(RED, GREEN));
   trainingData.add(makeStopLights(RED, GREEN));
   trainingData.add(makeStopLights(RED, RED));
   // Create a test set
   Datum<String, String> workingLights = makeStopLights(GREEN, RED);
   Datum<String, String> brokenLights = makeStopLights(RED, RED);
   // Build a classifier factory
   LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>();
   factory.useConjugateGradientAscent();
   // Turn on per-iteration convergence updates
   factory.setVerbose(true);
   // Small amount of smoothing
   factory.setSigma(10.0);
   // Build a classifier
   LinearClassifier<String, String> classifier = factory.trainClassifier(trainingData);
   // Check out the learned weights
   classifier.dump();
   // Test the classifier
   System.out.println("Working instance got: " + classifier.classOf(workingLights));
   classifier.justificationOf(workingLights);
   System.out.println("Broken instance got: " + classifier.classOf(brokenLights));
   classifier.justificationOf(brokenLights);
 }
コード例 #2
0
ファイル: Benchmarks.java プロジェクト: jayantam/CoreNLP
  /**
   * 57% of time spent in LogConditionalObjectiveFunction.calculateCLBatch() 22% spent in
   * constructing datums (expensive)
   *
   * <p>Single threaded, 4100 ms Multi threaded, 600 ms
   *
   * <p>With same data, seed 42, 52 ms With reordered accesses for cacheing, 38 ms Down to 73% of
   * the time
   *
   * <p>with 8 cpus, a 6.8x speedup -- basically the same as with RVFDatum
   */
  public static void benchmarkLogisticRegression() {
    Dataset<String, String> data = new Dataset<>();
    for (int i = 0; i < 10000; i++) {
      Random r = new Random(42);
      Set<String> features = new HashSet<>();

      boolean cl = r.nextBoolean();

      for (int j = 0; j < 1000; j++) {
        if (cl && i % 2 == 0) {
          if (r.nextDouble() > 0.3) {
            features.add("f:" + j + ":true");
          } else {
            features.add("f:" + j + ":false");
          }
        } else {
          if (r.nextDouble() > 0.3) {
            features.add("f:" + j + ":false");
          } else {
            features.add("f:" + j + ":false");
          }
        }
      }

      data.add(new BasicDatum<String, String>(features, "target:" + cl));
    }

    LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>();

    long msStart = System.currentTimeMillis();
    factory.trainClassifier(data);
    long delay = System.currentTimeMillis() - msStart;
    System.out.println("Training took " + delay + " ms");
  }
コード例 #3
0
ファイル: Benchmarks.java プロジェクト: jayantam/CoreNLP
  /**
   * 67% of time spent in LogConditionalObjectiveFunction.rvfcalculate() 29% of time spent in
   * dataset construction (11% in RVFDataset.addFeatures(), 7% rvf incrementCount(), 11% rest)
   *
   * <p>Single threaded, 4700 ms Multi threaded, 700 ms
   *
   * <p>With same data, seed 42, 245 ms With reordered accesses for cacheing, 195 ms Down to 80% of
   * the time, not huge but a win nonetheless
   *
   * <p>with 8 cpus, a 6.7x speedup -- almost, but not quite linear, pretty good
   */
  public static void benchmarkRVFLogisticRegression() {
    RVFDataset<String, String> data = new RVFDataset<>();
    for (int i = 0; i < 10000; i++) {
      Random r = new Random(42);
      Counter<String> features = new ClassicCounter<>();

      boolean cl = r.nextBoolean();

      for (int j = 0; j < 1000; j++) {
        double value;
        if (cl && i % 2 == 0) {
          value = (r.nextDouble() * 2.0) - 0.6;
        } else {
          value = (r.nextDouble() * 2.0) - 1.4;
        }
        features.incrementCount("f" + j, value);
      }

      data.add(new RVFDatum<>(features, "target:" + cl));
    }

    LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>();

    long msStart = System.currentTimeMillis();
    factory.trainClassifier(data);
    long delay = System.currentTimeMillis() - msStart;
    System.out.println("Training took " + delay + " ms");
  }
コード例 #4
0
ファイル: Benchmarks.java プロジェクト: jayantam/CoreNLP
  public static void benchmarkSGD() {
    Dataset<String, String> data = new Dataset<>();
    for (int i = 0; i < 10000; i++) {
      Random r = new Random(42);
      Set<String> features = new HashSet<>();

      boolean cl = r.nextBoolean();

      for (int j = 0; j < 1000; j++) {
        if (cl && i % 2 == 0) {
          if (r.nextDouble() > 0.3) {
            features.add("f:" + j + ":true");
          } else {
            features.add("f:" + j + ":false");
          }
        } else {
          if (r.nextDouble() > 0.3) {
            features.add("f:" + j + ":false");
          } else {
            features.add("f:" + j + ":false");
          }
        }
      }

      data.add(new BasicDatum<String, String>(features, "target:" + cl));
    }

    LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>();
    factory.setMinimizerCreator(
        new Factory<Minimizer<DiffFunction>>() {
          @Override
          public Minimizer<DiffFunction> create() {
            return new SGDMinimizer<DiffFunction>(0.1, 100, 0, 1000);
          }
        });

    long msStart = System.currentTimeMillis();
    factory.trainClassifier(data);
    long delay = System.currentTimeMillis() - msStart;
    System.out.println("Training took " + delay + " ms");
  }
コード例 #5
0
 public void train(Collection<Pair<Document, List<Entity>>> trainingData) {
   startTrack("Training");
   // --Variables
   RVFDataset<Boolean, Feature> dataset = new RVFDataset<Boolean, Feature>();
   LinearClassifierFactory<Boolean, Feature> fact =
       new LinearClassifierFactory<Boolean, Feature>();
   // --Feature Extraction
   startTrack("Feature Extraction");
   for (Pair<Document, List<Entity>> datum : trainingData) {
     // (document variables)
     Document doc = datum.getFirst();
     List<Entity> goldClusters = datum.getSecond();
     List<Mention> mentions = doc.getMentions();
     Map<Mention, Entity> goldEntities = Entity.mentionToEntityMap(goldClusters);
     startTrack("Document " + doc.id);
     // (for each mention...)
     for (int i = 0; i < mentions.size(); i++) {
       // (get the mention and its cluster)
       Mention onPrix = mentions.get(i);
       Entity source = goldEntities.get(onPrix);
       if (source == null) {
         throw new IllegalArgumentException("Mention has no gold entity: " + onPrix);
       }
       // (for each previous mention...)
       int oldSize = dataset.size();
       for (int j = i - 1; j >= 0; j--) {
         // (get previous mention and its cluster)
         Mention cand = mentions.get(j);
         Entity target = goldEntities.get(cand);
         if (target == null) {
           throw new IllegalArgumentException("Mention has no gold entity: " + cand);
         }
         // (extract features)
         Counter<Feature> feats =
             extractor.extractFeatures(Pair.make(onPrix, cand.markCoreferent(target)));
         // (add datum)
         dataset.add(new RVFDatum<Boolean, Feature>(feats, target == source));
         // (stop if
         if (target == source) {
           break;
         }
       }
       // logf("Mention %s (%d datums)", onPrix.toString(), dataset.size() - oldSize);
     }
     endTrack("Document " + doc.id);
   }
   endTrack("Feature Extraction");
   // --Train Classifier
   startTrack("Minimizer");
   this.classifier = fact.trainClassifier(dataset);
   endTrack("Minimizer");
   // --Dump Weights
   startTrack("Features");
   // (get labels to print)
   Set<Boolean> labels = new HashSet<Boolean>();
   labels.add(true);
   // (print features)
   for (Triple<Feature, Boolean, Double> featureInfo :
       this.classifier.getTopFeatures(labels, 0.0, true, 100, true)) {
     Feature feature = featureInfo.first();
     Boolean label = featureInfo.second();
     Double magnitude = featureInfo.third();
     // log(FORCE,new DecimalFormat("0.000").format(magnitude) + " [" + label + "] " + feature);
   }
   end_Track("Features");
   endTrack("Training");
 }