コード例 #1
0
  public static void testCV(DataFileReader ip) {
    try {
      Instances dataSet = ip.loadDataFromFile();
      MinMaxNormalizer minMaxNorm = new MinMaxNormalizer();
      ConfusionMatrix averageConfusionMatrix = new ConfusionMatrix();
      int crossValidation = 10;
      for (int i = 1; i <= crossValidation; i++) {
        Instances train = dataSet.getTrainingForCrossValidation(crossValidation, i);
        Instances test = dataSet.getTestForCrossValidation();
        test = test.normalize(minMaxNorm);
        train = train.normalize(minMaxNorm);
        KNearestNeighbor kNNClassifier = new KNearestNeighbor(train);
        kNNClassifier.kNNClassify(test);

        ConfusionMatrix cm = new ConfusionMatrix();
        for (int j = 0; j < test.getDataSetSize(); j++) {
          Instance t = test.getInstance(j);
          if (t.isCorrectClassified()) {
            // true
            if (t.getClassValue() == 1) {
              // positive
              cm.incrementTruePositive();
            } else {
              // negative
              cm.incrementTrueNegative();
            }
          } else {
            // false
            if (t.getClassValue() == 1) {
              // positive
              cm.incrementFalsePositive();
            } else {
              // negative
              cm.incrementFalseNegative();
            }
          }
        }
        averageConfusionMatrix.addToTruePositive(cm.getTruePositive());
        averageConfusionMatrix.addToFalsePositive(cm.getFalsePositive());
        averageConfusionMatrix.addToTrueNegative(cm.getTrueNegative());
        averageConfusionMatrix.addToFalseNegative(cm.getFalseNegative());
      }
      System.out.format("\t Average Accuracy %f \n", averageConfusionMatrix.getAccuracy());
      System.out.format("\t Average Precision %f \n", averageConfusionMatrix.getPrecision());
      System.out.format("\t Average Recall %f \n", averageConfusionMatrix.getRecall());
      System.out.format("\t Average F-Measure %f \n", averageConfusionMatrix.getFmeasure());

    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }
コード例 #2
0
  public static void testD3(DataFileReader ip) {
    String stemFileName = "dataset3";
    String delimiter = "\t";
    try {
      Instances train = ip.loadDataFromFile();
      MinMaxNormalizer minMaxNorm = new MinMaxNormalizer();
      Instances test =
          new TestFileReader(stemFileName, delimiter).loadDataFromFile(train.getHeader());
      test = test.normalize(minMaxNorm);
      train = train.normalize(minMaxNorm);
      KNearestNeighbor kNNClassifier = new KNearestNeighbor(train);
      kNNClassifier.kNNClassify(test);

      ConfusionMatrix cm = new ConfusionMatrix();
      for (int j = 0; j < test.getDataSetSize(); j++) {
        Instance t = test.getInstance(j);
        if (t.isCorrectClassified()) {
          // true
          if (t.getClassValue() == 1) {
            // positive
            cm.incrementTruePositive();
          } else {
            // negative
            cm.incrementTrueNegative();
          }
        } else {
          // false
          if (t.getClassValue() == 1) {
            // positive
            cm.incrementFalsePositive();
          } else {
            // negative
            cm.incrementFalseNegative();
          }
        }
      }
      System.out.format("\t Accuracy %f \n", cm.getAccuracy());
      System.out.format("\t Precision %f \n", cm.getPrecision());
      System.out.format("\t Recall %f \n", cm.getRecall());
      System.out.format("\t F-Measure %f \n", cm.getFmeasure());
    } catch (IOException e) {
      // TODO Auto-generated catch block
      e.printStackTrace();
    }
  }