public static void f(String[] args) throws IOException, ClassifierException { Vector<LabeledSample> samples = readDataFile(new File(args[0])); printSamples(samples); LogisticRegressionClassifier classifier = new LogisticRegressionClassifier(LogisticRegressionClassifier.DEFAULT_LEARNING_RATE, 1.0); classifier.train(samples); System.out.println("Trained with " + classifier.getNumberOfIterations() + " iterations."); System.out.println(classifier.getFeatures()); int all = 0; int good = 0; for (LabeledSample sample : samples) { double classification = classifier.classify(sample.getFeatures()); if (((classification < 0.5) && (sample.getLabel() == false)) || ((classification >= 0.5) && (sample.getLabel() == true))) good++; all++; } System.out.println("correct: " + good + ". All=" + all); System.out.println("accuracy: " + String.format("%4.4f", ((double) good) / ((double) all))); System.out.println(StringUtil.generateStringOfCharacter('-', 50)); System.out.println(classifier.descriptionOfTraining()); System.out.println(StringUtil.generateStringOfCharacter('-', 50)); Set<Integer> dontChange = new LinkedHashSet<Integer>(); dontChange.add(2); classifier.setToZeroNegativeParametersBut(dontChange); System.out.println(classifier.descriptionOfTraining()); }
public static void printSamples(Vector<LabeledSample> samples) { for (LabeledSample sample : samples) { System.out.print(sample.getLabel() + " "); for (Integer index : sample.getFeatures().keySet()) { System.out.print( index.toString() + ":" + String.format("%3.3f", sample.getFeatures().get(index)) + ", "); } System.out.println(); } }