Esempio n. 1
0
  public static void demonstrateSerialization() throws IOException, ClassNotFoundException {
    System.out.println("Demonstrating working with a serialized classifier");
    ColumnDataClassifier cdc = new ColumnDataClassifier("examples/cheese2007.prop");
    Classifier<String, String> cl =
        cdc.makeClassifier(cdc.readTrainingExamples("examples/cheeseDisease.train"));

    // Exhibit serialization and deserialization working. Serialized to bytes in memory for
    // simplicity
    System.out.println();
    System.out.println();
    ByteArrayOutputStream baos = new ByteArrayOutputStream();
    ObjectOutputStream oos = new ObjectOutputStream(baos);
    oos.writeObject(cl);
    oos.close();
    byte[] object = baos.toByteArray();
    ByteArrayInputStream bais = new ByteArrayInputStream(object);
    ObjectInputStream ois = new ObjectInputStream(bais);
    LinearClassifier<String, String> lc = ErasureUtils.uncheckedCast(ois.readObject());
    ois.close();
    ColumnDataClassifier cdc2 = new ColumnDataClassifier("examples/cheese2007.prop");

    // We compare the output of the deserialized classifier lc versus the original one cl
    // For both we use a ColumnDataClassifier to convert text lines to examples
    for (String line : ObjectBank.getLineIterator("examples/cheeseDisease.test", "utf-8")) {
      Datum<String, String> d = cdc.makeDatumFromLine(line);
      Datum<String, String> d2 = cdc2.makeDatumFromLine(line);
      System.out.println(line + "  =origi=>  " + cl.classOf(d));
      System.out.println(line + "  =deser=>  " + lc.classOf(d2));
    }
  }
Esempio n. 2
0
  public static void main(String[] args) throws Exception {
    ColumnDataClassifier cdc = new ColumnDataClassifier("examples/cheese2007.prop");
    Classifier<String, String> cl =
        cdc.makeClassifier(cdc.readTrainingExamples("examples/cheeseDisease.train"));
    for (String line : ObjectBank.getLineIterator("examples/cheeseDisease.test", "utf-8")) {
      // instead of the method in the line below, if you have the individual elements
      // already you can use cdc.makeDatumFromStrings(String[])
      Datum<String, String> d = cdc.makeDatumFromLine(line);
      System.out.println(line + "  ==>  " + cl.classOf(d));
    }

    demonstrateSerialization();
  }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {

    List<L> guesses = new ArrayList<L>();
    List<L> labels = new ArrayList<L>();

    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      guesses.add(guess);
    }

    int[] labelsArr = data.getLabelsArray();
    labelIndex = data.labelIndex;
    for (int i = 0; i < data.size(); i++) {
      labels.add(labelIndex.get(labelsArr[i]));
    }

    labelIndex = new HashIndex<L>();
    labelIndex.addAll(data.labelIndex().objectsList());
    labelIndex.addAll(classifier.labels());

    int numClasses = labelIndex.size();
    tpCount = new int[numClasses];
    fpCount = new int[numClasses];
    fnCount = new int[numClasses];

    negIndex = labelIndex.indexOf(negLabel);

    for (int i = 0; i < guesses.size(); ++i) {
      L guess = guesses.get(i);
      int guessIndex = labelIndex.indexOf(guess);
      L label = labels.get(i);
      int trueIndex = labelIndex.indexOf(label);

      if (guessIndex == trueIndex) {
        if (guessIndex != negIndex) {
          tpCount[guessIndex]++;
        }
      } else {
        if (guessIndex != negIndex) {
          fpCount[guessIndex]++;
        }
        if (trueIndex != negIndex) {
          fnCount[trueIndex]++;
        }
      }
    }

    return getFMeasure();
  }
  public <F> double score(Classifier<L, F> classifier, GeneralDataset<L, F> data) {
    labelIndex = new HashIndex<L>();
    labelIndex.addAll(classifier.labels());
    labelIndex.addAll(data.labelIndex.objectsList());
    clearCounts();
    int[] labelsArr = data.getLabelsArray();
    for (int i = 0; i < data.size(); i++) {
      Datum<L, F> d = data.getRVFDatum(i);
      L guess = classifier.classOf(d);
      addGuess(guess, labelIndex.get(labelsArr[i]));
    }
    finalizeCounts();

    return getFMeasure();
  }