/**
   * Select the best value for k by hold-one-out cross-validation. If the class attribute is
   * nominal, classification error is minimised. If the class attribute is numeric, mean absolute
   * error is minimised
   */
  protected void crossValidate() {

    try {
      if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree)
        throw new Exception(
            "CoverTree doesn't support hold-one-out "
                + "cross-validation. Use some other NN "
                + "method.");

      double[] performanceStats = new double[m_kNNUpper];
      double[] performanceStatsSq = new double[m_kNNUpper];

      for (int i = 0; i < m_kNNUpper; i++) {
        performanceStats[i] = 0;
        performanceStatsSq[i] = 0;
      }

      m_kNN = m_kNNUpper;
      Instance instance;
      Instances neighbours;
      double[] origDistances, convertedDistances;
      for (int i = 0; i < m_Train.numInstances(); i++) {
        if (m_Debug && (i % 50 == 0)) {
          System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r");
        }
        instance = m_Train.instance(i);
        neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
        origDistances = m_NNSearch.getDistances();

        for (int j = m_kNNUpper - 1; j >= 0; j--) {
          // Update the performance stats
          convertedDistances = new double[origDistances.length];
          System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length);
          double[] distribution = makeDistribution(neighbours, convertedDistances);
          double thisPrediction = Utils.maxIndex(distribution);
          if (m_Train.classAttribute().isNumeric()) {
            thisPrediction = distribution[0];
            double err = thisPrediction - instance.classValue();
            performanceStatsSq[j] += err * err; // Squared error
            performanceStats[j] += Math.abs(err); // Absolute error
          } else {
            if (thisPrediction != instance.classValue()) {
              performanceStats[j]++; // Classification error
            }
          }
          if (j >= 1) {
            neighbours = pruneToK(neighbours, convertedDistances, j);
          }
        }
      }

      // Display the results of the cross-validation
      for (int i = 0; i < m_kNNUpper; i++) {
        if (m_Debug) {
          System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors ");
        }
        if (m_Train.classAttribute().isNumeric()) {
          if (m_Debug) {
            if (m_MeanSquared) {
              System.err.println(
                  "(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances()));
            } else {
              System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances());
            }
          }
        } else {
          if (m_Debug) {
            System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances());
          }
        }
      }

      // Check through the performance stats and select the best
      // k value (or the lowest k if more than one best)
      double[] searchStats = performanceStats;
      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
        searchStats = performanceStatsSq;
      }
      double bestPerformance = Double.NaN;
      int bestK = 1;
      for (int i = 0; i < m_kNNUpper; i++) {
        if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) {
          bestPerformance = searchStats[i];
          bestK = i + 1;
        }
      }
      m_kNN = bestK;
      if (m_Debug) {
        System.err.println("Selected k = " + bestK);
      }

      m_kNNValid = true;
    } catch (Exception ex) {
      throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage());
    }
  }
Exemple #2
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return preedicted class probability distribution
   * @throws Exception if distribution can't be computed successfully
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    // default model?
    if (m_ZeroR != null) {
      return m_ZeroR.distributionForInstance(instance);
    }

    if (m_Train.numInstances() == 0) {
      throw new Exception("No training instances!");
    }

    m_NNSearch.addInstanceInfo(instance);

    int k = m_Train.numInstances();
    if ((!m_UseAllK && (m_kNN < k)) /*&&
       !(m_WeightKernel==INVERSE ||
         m_WeightKernel==GAUSS)*/) {
      k = m_kNN;
    }

    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k);
    double distances[] = m_NNSearch.getDistances();

    if (m_Debug) {
      System.out.println("Test Instance: " + instance);
      System.out.println(
          "For "
              + k
              + " kept "
              + neighbours.numInstances()
              + " out of "
              + m_Train.numInstances()
              + " instances.");
    }

    // IF LinearNN has skipped so much that <k neighbours are remaining.
    if (k > distances.length) k = distances.length;

    if (m_Debug) {
      System.out.println("Instance Distances");
      for (int i = 0; i < distances.length; i++) {
        System.out.println("" + distances[i]);
      }
    }

    // Determine the bandwidth
    double bandwidth = distances[k - 1];

    // Check for bandwidth zero
    if (bandwidth <= 0) {
      // if the kth distance is zero than give all instances the same weight
      for (int i = 0; i < distances.length; i++) distances[i] = 1;
    } else {
      // Rescale the distances by the bandwidth
      for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth;
    }

    // Pass the distances through a weighting kernel
    for (int i = 0; i < distances.length; i++) {
      switch (m_WeightKernel) {
        case LINEAR:
          distances[i] = 1.0001 - distances[i];
          break;
        case EPANECHNIKOV:
          distances[i] = 3 / 4D * (1.0001 - distances[i] * distances[i]);
          break;
        case TRICUBE:
          distances[i] = Math.pow((1.0001 - Math.pow(distances[i], 3)), 3);
          break;
        case CONSTANT:
          // System.err.println("using constant kernel");
          distances[i] = 1;
          break;
        case INVERSE:
          distances[i] = 1.0 / (1.0 + distances[i]);
          break;
        case GAUSS:
          distances[i] = Math.exp(-distances[i] * distances[i]);
          break;
      }
    }

    if (m_Debug) {
      System.out.println("Instance Weights");
      for (int i = 0; i < distances.length; i++) {
        System.out.println("" + distances[i]);
      }
    }

    // Set the weights on the training data
    double sumOfWeights = 0, newSumOfWeights = 0;
    for (int i = 0; i < distances.length; i++) {
      double weight = distances[i];
      Instance inst = (Instance) neighbours.instance(i);
      sumOfWeights += inst.weight();
      newSumOfWeights += inst.weight() * weight;
      inst.setWeight(inst.weight() * weight);
      // weightedTrain.add(newInst);
    }

    // Rescale weights
    for (int i = 0; i < neighbours.numInstances(); i++) {
      Instance inst = neighbours.instance(i);
      inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights);
    }

    // Create a weighted classifier
    m_Classifier.buildClassifier(neighbours);

    if (m_Debug) {
      System.out.println("Classifying test instance: " + instance);
      System.out.println("Built base classifier:\n" + m_Classifier.toString());
    }

    // Return the classifier's predictions
    return m_Classifier.distributionForInstance(instance);
  }
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if an error occurred during the prediction
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    NaiveBayes nb = new NaiveBayes();

    // System.out.println("number of instances		"+m_Train.numInstances());

    if (m_Train.numInstances() == 0) {
      // throw new Exception("No training instances!");
      return m_defaultModel.distributionForInstance(instance);
    }
    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
      m_kNNValid = false;
      boolean deletedInstance = false;
      while (m_Train.numInstances() > m_WindowSize) {
        m_Train.delete(0);
      }
      // rebuild datastructure KDTree currently can't delete
      if (deletedInstance == true) m_NNSearch.setInstances(m_Train);
    }

    // Select k by cross validation
    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
      crossValidate();
    }

    m_NNSearch.addInstanceInfo(instance);
    m_kNN = 1000;
    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
    double[] distances = m_NNSearch.getDistances();

    // System.out.println("--------------classify instance--------- ");
    // System.out.println("neighbours.numInstances"+neighbours.numInstances());
    // System.out.println("distances.length"+distances.length);
    // System.out.println("--------------classify instance--------- ");

    /*	for (int k = 0; k < distances.length; k++) {
    	//System.out.println("-------");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    	//System.out.println("instance of "+k+"	"+neighbours.instance(k));
    	//distances[k] = distances[k]+0.1;
    	//System.out.println("------- after add 0.1");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    }
    */
    Instances instances = new Instances(m_Train);
    // int attrnum = instances.numAttributes();
    instances.deleteWithMissingClass();

    Instances newm_Train = new Instances(instances, 0, instances.numInstances());

    for (int k = 0; k < neighbours.numInstances(); k++) {
      // System.out.println("-------");
      // Instance in = new Instance();
      Instance insk = neighbours.instance(k);
      // System.out.println("instance "+k+"	"+neighbours.instance(k));
      // System.out.println("-------");
      double dis = distances[k] + 0.1;
      // System.out.println("dis		"+dis);
      dis = (1 / dis) * 10;
      // System.out.println("1/dis		"+dis);
      int weightnum = (int) dis;
      // System.out.println("weightnum		"+weightnum);

      for (int s = 0; s < weightnum; s++) {

        newm_Train.add(insk);
      }
    }

    // System.out.println("number of instances		"+newm_Train.numInstances());

    /*  for (int k = 0; k < newm_Train.numInstances(); k++) {
    		System.out.println("-------");
    		System.out.println("instance "+k+"	"+newm_Train.instance(k));
    		System.out.println("-------");
    	}

    /*
    	for (int k = 0; k < distances.length; k++) {
    		System.out.println("-------");
    		System.out.println("distance of "+k+"	"+distances[k]);
    		System.out.println("-------");
    	}*/

    nb.buildClassifier(newm_Train);
    double[] dis = nb.distributionForInstance(instance);
    // double[] distribution = makeDistribution(neighbours, distances);
    return dis;
    // return distribution;
  }