public static void demonstrateSerialization() throws IOException, ClassNotFoundException { System.out.println("Demonstrating working with a serialized classifier"); ColumnDataClassifier cdc = new ColumnDataClassifier("examples/cheese2007.prop"); Classifier<String, String> cl = cdc.makeClassifier(cdc.readTrainingExamples("examples/cheeseDisease.train")); // Exhibit serialization and deserialization working. Serialized to bytes in memory for // simplicity System.out.println(); System.out.println(); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos); oos.writeObject(cl); oos.close(); byte[] object = baos.toByteArray(); ByteArrayInputStream bais = new ByteArrayInputStream(object); ObjectInputStream ois = new ObjectInputStream(bais); LinearClassifier<String, String> lc = ErasureUtils.uncheckedCast(ois.readObject()); ois.close(); ColumnDataClassifier cdc2 = new ColumnDataClassifier("examples/cheese2007.prop"); // We compare the output of the deserialized classifier lc versus the original one cl // For both we use a ColumnDataClassifier to convert text lines to examples for (String line : ObjectBank.getLineIterator("examples/cheeseDisease.test", "utf-8")) { Datum<String, String> d = cdc.makeDatumFromLine(line); Datum<String, String> d2 = cdc2.makeDatumFromLine(line); System.out.println(line + " =origi=> " + cl.classOf(d)); System.out.println(line + " =deser=> " + lc.classOf(d2)); } }
public static void main(String[] args) throws Exception { ColumnDataClassifier cdc = new ColumnDataClassifier("examples/cheese2007.prop"); Classifier<String, String> cl = cdc.makeClassifier(cdc.readTrainingExamples("examples/cheeseDisease.train")); for (String line : ObjectBank.getLineIterator("examples/cheeseDisease.test", "utf-8")) { // instead of the method in the line below, if you have the individual elements // already you can use cdc.makeDatumFromStrings(String[]) Datum<String, String> d = cdc.makeDatumFromLine(line); System.out.println(line + " ==> " + cl.classOf(d)); } demonstrateSerialization(); }
public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) { List<L> guesses = new ArrayList<L>(); List<L> labels = new ArrayList<L>(); for (int i = 0; i < data.size(); i++) { Datum<L, F> d = data.getRVFDatum(i); L guess = classifier.classOf(d); guesses.add(guess); } int[] labelsArr = data.getLabelsArray(); labelIndex = data.labelIndex; for (int i = 0; i < data.size(); i++) { labels.add(labelIndex.get(labelsArr[i])); } labelIndex = new HashIndex<L>(); labelIndex.addAll(data.labelIndex().objectsList()); labelIndex.addAll(classifier.labels()); int numClasses = labelIndex.size(); tpCount = new int[numClasses]; fpCount = new int[numClasses]; fnCount = new int[numClasses]; negIndex = labelIndex.indexOf(negLabel); for (int i = 0; i < guesses.size(); ++i) { L guess = guesses.get(i); int guessIndex = labelIndex.indexOf(guess); L label = labels.get(i); int trueIndex = labelIndex.indexOf(label); if (guessIndex == trueIndex) { if (guessIndex != negIndex) { tpCount[guessIndex]++; } } else { if (guessIndex != negIndex) { fpCount[guessIndex]++; } if (trueIndex != negIndex) { fnCount[trueIndex]++; } } } return getFMeasure(); }
public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) { labelIndex = new HashIndex<L>(); labelIndex.addAll(classifier.labels()); labelIndex.addAll(data.labelIndex.objectsList()); clearCounts(); int[] labelsArr = data.getLabelsArray(); for (int i = 0; i < data.size(); i++) { Datum<L, F> d = data.getRVFDatum(i); L guess = classifier.classOf(d); addGuess(guess, labelIndex.get(labelsArr[i])); } finalizeCounts(); return getFMeasure(); }