public static void main(String[] args) { // Create a training set List<Datum<String, String>> trainingData = new ArrayList<>(); trainingData.add(makeStopLights(GREEN, RED)); trainingData.add(makeStopLights(GREEN, RED)); trainingData.add(makeStopLights(GREEN, RED)); trainingData.add(makeStopLights(RED, GREEN)); trainingData.add(makeStopLights(RED, GREEN)); trainingData.add(makeStopLights(RED, GREEN)); trainingData.add(makeStopLights(RED, RED)); // Create a test set Datum<String, String> workingLights = makeStopLights(GREEN, RED); Datum<String, String> brokenLights = makeStopLights(RED, RED); // Build a classifier factory LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>(); factory.useConjugateGradientAscent(); // Turn on per-iteration convergence updates factory.setVerbose(true); // Small amount of smoothing factory.setSigma(10.0); // Build a classifier LinearClassifier<String, String> classifier = factory.trainClassifier(trainingData); // Check out the learned weights classifier.dump(); // Test the classifier System.out.println("Working instance got: " + classifier.classOf(workingLights)); classifier.justificationOf(workingLights); System.out.println("Broken instance got: " + classifier.classOf(brokenLights)); classifier.justificationOf(brokenLights); }
/** * 57% of time spent in LogConditionalObjectiveFunction.calculateCLBatch() 22% spent in * constructing datums (expensive) * * <p>Single threaded, 4100 ms Multi threaded, 600 ms * * <p>With same data, seed 42, 52 ms With reordered accesses for cacheing, 38 ms Down to 73% of * the time * * <p>with 8 cpus, a 6.8x speedup -- basically the same as with RVFDatum */ public static void benchmarkLogisticRegression() { Dataset<String, String> data = new Dataset<>(); for (int i = 0; i < 10000; i++) { Random r = new Random(42); Set<String> features = new HashSet<>(); boolean cl = r.nextBoolean(); for (int j = 0; j < 1000; j++) { if (cl && i % 2 == 0) { if (r.nextDouble() > 0.3) { features.add("f:" + j + ":true"); } else { features.add("f:" + j + ":false"); } } else { if (r.nextDouble() > 0.3) { features.add("f:" + j + ":false"); } else { features.add("f:" + j + ":false"); } } } data.add(new BasicDatum<String, String>(features, "target:" + cl)); } LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>(); long msStart = System.currentTimeMillis(); factory.trainClassifier(data); long delay = System.currentTimeMillis() - msStart; System.out.println("Training took " + delay + " ms"); }
/** * 67% of time spent in LogConditionalObjectiveFunction.rvfcalculate() 29% of time spent in * dataset construction (11% in RVFDataset.addFeatures(), 7% rvf incrementCount(), 11% rest) * * <p>Single threaded, 4700 ms Multi threaded, 700 ms * * <p>With same data, seed 42, 245 ms With reordered accesses for cacheing, 195 ms Down to 80% of * the time, not huge but a win nonetheless * * <p>with 8 cpus, a 6.7x speedup -- almost, but not quite linear, pretty good */ public static void benchmarkRVFLogisticRegression() { RVFDataset<String, String> data = new RVFDataset<>(); for (int i = 0; i < 10000; i++) { Random r = new Random(42); Counter<String> features = new ClassicCounter<>(); boolean cl = r.nextBoolean(); for (int j = 0; j < 1000; j++) { double value; if (cl && i % 2 == 0) { value = (r.nextDouble() * 2.0) - 0.6; } else { value = (r.nextDouble() * 2.0) - 1.4; } features.incrementCount("f" + j, value); } data.add(new RVFDatum<>(features, "target:" + cl)); } LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>(); long msStart = System.currentTimeMillis(); factory.trainClassifier(data); long delay = System.currentTimeMillis() - msStart; System.out.println("Training took " + delay + " ms"); }
public static void benchmarkSGD() { Dataset<String, String> data = new Dataset<>(); for (int i = 0; i < 10000; i++) { Random r = new Random(42); Set<String> features = new HashSet<>(); boolean cl = r.nextBoolean(); for (int j = 0; j < 1000; j++) { if (cl && i % 2 == 0) { if (r.nextDouble() > 0.3) { features.add("f:" + j + ":true"); } else { features.add("f:" + j + ":false"); } } else { if (r.nextDouble() > 0.3) { features.add("f:" + j + ":false"); } else { features.add("f:" + j + ":false"); } } } data.add(new BasicDatum<String, String>(features, "target:" + cl)); } LinearClassifierFactory<String, String> factory = new LinearClassifierFactory<>(); factory.setMinimizerCreator( new Factory<Minimizer<DiffFunction>>() { @Override public Minimizer<DiffFunction> create() { return new SGDMinimizer<DiffFunction>(0.1, 100, 0, 1000); } }); long msStart = System.currentTimeMillis(); factory.trainClassifier(data); long delay = System.currentTimeMillis() - msStart; System.out.println("Training took " + delay + " ms"); }
public void train(Collection<Pair<Document, List<Entity>>> trainingData) { startTrack("Training"); // --Variables RVFDataset<Boolean, Feature> dataset = new RVFDataset<Boolean, Feature>(); LinearClassifierFactory<Boolean, Feature> fact = new LinearClassifierFactory<Boolean, Feature>(); // --Feature Extraction startTrack("Feature Extraction"); for (Pair<Document, List<Entity>> datum : trainingData) { // (document variables) Document doc = datum.getFirst(); List<Entity> goldClusters = datum.getSecond(); List<Mention> mentions = doc.getMentions(); Map<Mention, Entity> goldEntities = Entity.mentionToEntityMap(goldClusters); startTrack("Document " + doc.id); // (for each mention...) for (int i = 0; i < mentions.size(); i++) { // (get the mention and its cluster) Mention onPrix = mentions.get(i); Entity source = goldEntities.get(onPrix); if (source == null) { throw new IllegalArgumentException("Mention has no gold entity: " + onPrix); } // (for each previous mention...) int oldSize = dataset.size(); for (int j = i - 1; j >= 0; j--) { // (get previous mention and its cluster) Mention cand = mentions.get(j); Entity target = goldEntities.get(cand); if (target == null) { throw new IllegalArgumentException("Mention has no gold entity: " + cand); } // (extract features) Counter<Feature> feats = extractor.extractFeatures(Pair.make(onPrix, cand.markCoreferent(target))); // (add datum) dataset.add(new RVFDatum<Boolean, Feature>(feats, target == source)); // (stop if if (target == source) { break; } } // logf("Mention %s (%d datums)", onPrix.toString(), dataset.size() - oldSize); } endTrack("Document " + doc.id); } endTrack("Feature Extraction"); // --Train Classifier startTrack("Minimizer"); this.classifier = fact.trainClassifier(dataset); endTrack("Minimizer"); // --Dump Weights startTrack("Features"); // (get labels to print) Set<Boolean> labels = new HashSet<Boolean>(); labels.add(true); // (print features) for (Triple<Feature, Boolean, Double> featureInfo : this.classifier.getTopFeatures(labels, 0.0, true, 100, true)) { Feature feature = featureInfo.first(); Boolean label = featureInfo.second(); Double magnitude = featureInfo.third(); // log(FORCE,new DecimalFormat("0.000").format(magnitude) + " [" + label + "] " + feature); } end_Track("Features"); endTrack("Training"); }