/**
  * Tests a classifier on a data set
  *
  * @param cls the classifier to test
  * @param data the data set to test on
  * @return the performance for each class
  */
 public static Map<Object, PerformanceMeasure> testDataset(Classifier cls, Dataset data) {
   Map<Object, PerformanceMeasure> out = new HashMap<Object, PerformanceMeasure>();
   for (Object o : data.classes()) {
     out.put(o, new PerformanceMeasure());
   }
   for (Instance instance : data) {
     Object prediction = cls.classify(instance);
     if (instance.classValue().equals(prediction)) { // prediction
       // ==class
       for (Object o : out.keySet()) {
         if (o.equals(instance.classValue())) {
           out.get(o).tp++;
         } else {
           out.get(o).tn++;
         }
       }
     } else { // prediction != class
       for (Object o : out.keySet()) {
         /* prediction is positive class */
         if (prediction.equals(o)) {
           out.get(o).fp++;
         }
         /* instance is positive class */
         else if (o.equals(instance.classValue())) {
           out.get(o).fn++;
         }
         /* none is positive class */
         else {
           out.get(o).tn++;
         }
       }
     }
   }
   return out;
 }
コード例 #2
0
 @Override
 public double measure(Instance x, Instance y) {
   if (x.noAttributes() != y.noAttributes())
     throw new RuntimeException("Both instances should contain the same number of values.");
   double totalMax = 0.0;
   for (int i = 0; i < x.noAttributes(); i++) {
     totalMax = Math.max(totalMax, Math.abs(y.value(i) - x.value(i)));
   }
   return totalMax;
 }
 /** Shows how to construct a SparseInstance. */
 public static void main(String[] args) {
   /*
    * Here we will create an instance with 10 attributes, but will only set
    * the attributes with index 1,3 and 7 with a value.
    */
   /* Create instance with 10 attributes */
   Instance instance = new SparseInstance(10);
   /* Set the values for particular attributes */
   instance.put(1, 1.0);
   instance.put(3, 2.0);
   instance.put(7, 4.0);
 }
  /**
   * Convert instance attribute values to double array values
   *
   * @param instnc Instance to convert
   * @return double[]
   */
  private double[] convertInstanceToDoubleArray(Instance instnc) {
    Iterator attributeIterator = instnc.iterator();

    double[] item = new double[instnc.noAttributes()];
    int index = 0;

    while (attributeIterator.hasNext()) {
      Double attrValue = (Double) attributeIterator.next();
      item[index] = attrValue.doubleValue();
      index++;
    }

    return item;
  }
 public double calculateDistance(Instance x, Instance y) {
   if (x.noAttributes() != x.noAttributes()) {
     throw new RuntimeException("Both instances should contain the same number of values.");
   }
   double sum = 0;
   for (int i = 0; i < x.noAttributes(); i++) {
     // ignore missing values
     if (!Double.isNaN(y.value(i)) && !Double.isNaN(x.value(i)))
       sum += (y.value(i) - x.value(i)) * (y.value(i) - x.value(i));
   }
   return Math.sqrt(sum);
 }
コード例 #6
0
ファイル: JavaMlUtil.java プロジェクト: valiki/odzrecog
 public static int[] instanceToArray(Instance instance) {
   Collection<Double> values = instance.values();
   Iterator<Double> iterator = values.iterator();
   int j = 0;
   int[] result = new int[values.size()];
   while (iterator.hasNext()) {
     Double next = iterator.next();
     result[j] = (int) next.doubleValue();
     j++;
   }
   return result;
 }
 public void filter(Instance inst) {
   for (int i = 0; i < inst.noAttributes(); i++) {
     if (Double.isNaN(inst.value(i)) || Double.isInfinite(inst.value(i))) inst.put(i, d);
   }
 }
コード例 #8
0
  /** Shows the default usage of the KNN algorithm. */
  public static void main(String[] args) throws Exception {

    /* Load a data set */
    Dataset data = FileHandler.loadDataset(new File(DATASET), 4, ",");
    /*
     * Contruct a KNN classifier that uses 5 neighbors to make a decision.
     */
    Classifier knn = new KNearestNeighbors(5);
    knn.buildClassifier(data);

    Classifier kdtKnn = new KDtreeKNN(5);
    kdtKnn.buildClassifier(data);

    /*
     * Load a data set for evaluation, this can be a different one, but we
     * will use the same one.
     */
    Dataset dataForClassification = FileHandler.loadDataset(new File(DATASET), 4, ",");
    /* Counters for correct and wrong predictions. */
    int correct = 0, wrong = 0;
    /* Classify all instances and check with the correct class values */
    for (Instance inst : dataForClassification) {
      Object predictedClassValue = knn.classify(inst);
      Object realClassValue = inst.classValue();
      // System.out.println("predicted=" + predictedClassValue.toString() + "
      // real="+realClassValue.toString());
      if (predictedClassValue.equals(realClassValue)) {
        correct++;
      } else {
        wrong++;
      }
    }
    System.out.println("Correct predictions  " + correct);
    System.out.println("Wrong predictions " + wrong);

    /* Performance
     *
     */
    System.out.println("Performance ...");
    Map<Object, PerformanceMeasure> pm = EvaluateDataset.testDataset(knn, dataForClassification);
    printPerfMeasure(pm);
    /*
     * Cross validation
     */
    System.out.println("Cross validation ...");
    /* Construct new cross validation instance with the KNN classifier, */
    CrossValidation cv = new CrossValidation(knn);
    /* 5-fold CV with fixed random generator */
    Map<Object, PerformanceMeasure> p0 = cv.crossValidation(data, 5, new Random(1));
    Map<Object, PerformanceMeasure> p1 = cv.crossValidation(data, 5, new Random(1));
    Map<Object, PerformanceMeasure> p2 = cv.crossValidation(data, 5, new Random(25));
    printPerfMeasure(p0);
    printPerfMeasure(p1);
    printPerfMeasure(p2);

    /*
     * Create Weka classifier
     */
    System.out.println("Weka classifier ...");
    SMO smo = new SMO();
    /* Wrap Weka classifier in bridge */
    Classifier javamlsmo = new WekaClassifier(smo);
    /* Initialize cross-validation */
    CrossValidation wekaCV = new CrossValidation(javamlsmo);
    /* Perform cross-validation */
    Map<Object, PerformanceMeasure> wekaPm = wekaCV.crossValidation(data);
    /* Output results
     * see http://en.wikipedia.org/wiki/Precision_and_recall
     */
    System.out.println("see http://en.wikipedia.org/wiki/Precision_and_recall" + wekaPm);
    printPerfMeasure(wekaPm);

    /*
     * Feature scoring
     */
    System.out.println("Feature scoring ");
    GainRatio ga = new GainRatio();
    /* Apply the algorithm to the data set */
    ga.build(data);
    /* Print out the score of each attribute */
    for (int i = 0; i < ga.noAttributes(); i++) {
      System.out.println("Attribute[" + i + "] relevance" + ga.score(i));
    }
  }