/**
   * Gets the current settings of IBk.
   *
   * @return an array of strings suitable for passing to setOptions()
   */
  public String[] getOptions() {

    String[] options = new String[11];
    int current = 0;
    options[current++] = "-K";
    options[current++] = "" + getKNN();
    options[current++] = "-W";
    options[current++] = "" + m_WindowSize;
    if (getCrossValidate()) {
      options[current++] = "-X";
    }
    if (getMeanSquared()) {
      options[current++] = "-E";
    }
    if (m_DistanceWeighting == WEIGHT_INVERSE) {
      options[current++] = "-I";
    } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {
      options[current++] = "-F";
    }

    options[current++] = "-A";
    options[current++] =
        m_NNSearch.getClass().getName() + " " + Utils.joinOptions(m_NNSearch.getOptions());

    while (current < options.length) {
      options[current++] = "";
    }

    return options;
  }
 /**
  * Returns an enumeration of the additional measure names produced by the neighbour search
  * algorithm, plus the chosen K in case cross-validation is enabled.
  *
  * @return an enumeration of the measure names
  */
 public Enumeration enumerateMeasures() {
   if (m_CrossValidate) {
     Enumeration enm = m_NNSearch.enumerateMeasures();
     Vector measures = new Vector();
     while (enm.hasMoreElements()) measures.add(enm.nextElement());
     measures.add("measureKNN");
     return measures.elements();
   } else {
     return m_NNSearch.enumerateMeasures();
   }
 }
示例#3
0
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    if (!(m_Classifier instanceof WeightedInstancesHandler)) {
      throw new IllegalArgumentException("Classifier must be a " + "WeightedInstancesHandler!");
    }

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    // only class? -> build ZeroR model
    if (instances.numAttributes() == 1) {
      System.err.println(
          "Cannot build model (only class attribute present in data!), "
              + "using ZeroR model instead!");
      m_ZeroR = new weka.classifiers.rules.ZeroR();
      m_ZeroR.buildClassifier(instances);
      return;
    } else {
      m_ZeroR = null;
    }

    m_Train = new Instances(instances, 0, instances.numInstances());

    m_NNSearch.setInstances(m_Train);
  }
示例#4
0
  /**
   * Adds the supplied instance to the training set.
   *
   * @param instance the instance to add
   * @throws Exception if instance could not be incorporated successfully
   */
  public void updateClassifier(Instance instance) throws Exception {

    if (m_Train == null) {
      throw new Exception("No training instance structure set!");
    } else if (m_Train.equalHeaders(instance.dataset()) == false) {
      throw new Exception(
          "Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset()));
    }
    if (!instance.classIsMissing()) {
      m_NNSearch.update(instance);
      m_Train.add(instance);
    }
  }
  /**
   * Adds the supplied instance to the training set.
   *
   * @param instance the instance to add
   * @throws Exception if instance could not be incorporated successfully
   */
  public void updateClassifier(Instance instance) throws Exception {

    if (m_Train.equalHeaders(instance.dataset()) == false) {
      throw new Exception("Incompatible instance types");
    }
    if (instance.classIsMissing()) {
      return;
    }

    m_Train.add(instance);
    m_NNSearch.update(instance);
    m_kNNValid = false;
    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
      boolean deletedInstance = false;
      while (m_Train.numInstances() > m_WindowSize) {
        m_Train.delete(0);
        deletedInstance = true;
      }
      // rebuild datastructure KDTree currently can't delete
      if (deletedInstance == true) m_NNSearch.setInstances(m_Train);
    }
  }
示例#6
0
  /**
   * Gets the current settings of the classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  public String[] getOptions() {

    String[] superOptions = super.getOptions();
    String[] options = new String[superOptions.length + 6];

    int current = 0;

    options[current++] = "-U";
    options[current++] = "" + getWeightingKernel();
    if ((getKNN() == 0) && m_UseAllK) {
      options[current++] = "-K";
      options[current++] = "-1";
    } else {
      options[current++] = "-K";
      options[current++] = "" + getKNN();
    }
    options[current++] = "-A";
    options[current++] =
        m_NNSearch.getClass().getName() + " " + Utils.joinOptions(m_NNSearch.getOptions());

    System.arraycopy(superOptions, 0, options, current, superOptions.length);

    return options;
  }
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_NumClasses = instances.numClasses();
    m_ClassType = instances.classAttribute().type();
    m_Train = new Instances(instances, 0, instances.numInstances());

    // Throw away initial instances until within the specified window size
    if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) {
      m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize);
    }

    m_NumAttributesUsed = 0.0;
    for (int i = 0; i < m_Train.numAttributes(); i++) {
      if ((i != m_Train.classIndex())
          && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) {
        m_NumAttributesUsed += 1.0;
      }
    }

    m_NNSearch.setInstances(m_Train);

    // Invalidate any currently cross-validation selected k
    m_kNNValid = false;

    m_defaultModel = new ZeroR();
    m_defaultModel.buildClassifier(instances);
    m_defaultModel.setOptions(getOptions());
    // System.out.println("hello world");

  }
示例#8
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return preedicted class probability distribution
   * @throws Exception if distribution can't be computed successfully
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    // default model?
    if (m_ZeroR != null) {
      return m_ZeroR.distributionForInstance(instance);
    }

    if (m_Train.numInstances() == 0) {
      throw new Exception("No training instances!");
    }

    m_NNSearch.addInstanceInfo(instance);

    int k = m_Train.numInstances();
    if ((!m_UseAllK && (m_kNN < k)) /*&&
       !(m_WeightKernel==INVERSE ||
         m_WeightKernel==GAUSS)*/) {
      k = m_kNN;
    }

    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k);
    double distances[] = m_NNSearch.getDistances();

    if (m_Debug) {
      System.out.println("Test Instance: " + instance);
      System.out.println(
          "For "
              + k
              + " kept "
              + neighbours.numInstances()
              + " out of "
              + m_Train.numInstances()
              + " instances.");
    }

    // IF LinearNN has skipped so much that <k neighbours are remaining.
    if (k > distances.length) k = distances.length;

    if (m_Debug) {
      System.out.println("Instance Distances");
      for (int i = 0; i < distances.length; i++) {
        System.out.println("" + distances[i]);
      }
    }

    // Determine the bandwidth
    double bandwidth = distances[k - 1];

    // Check for bandwidth zero
    if (bandwidth <= 0) {
      // if the kth distance is zero than give all instances the same weight
      for (int i = 0; i < distances.length; i++) distances[i] = 1;
    } else {
      // Rescale the distances by the bandwidth
      for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth;
    }

    // Pass the distances through a weighting kernel
    for (int i = 0; i < distances.length; i++) {
      switch (m_WeightKernel) {
        case LINEAR:
          distances[i] = 1.0001 - distances[i];
          break;
        case EPANECHNIKOV:
          distances[i] = 3 / 4D * (1.0001 - distances[i] * distances[i]);
          break;
        case TRICUBE:
          distances[i] = Math.pow((1.0001 - Math.pow(distances[i], 3)), 3);
          break;
        case CONSTANT:
          // System.err.println("using constant kernel");
          distances[i] = 1;
          break;
        case INVERSE:
          distances[i] = 1.0 / (1.0 + distances[i]);
          break;
        case GAUSS:
          distances[i] = Math.exp(-distances[i] * distances[i]);
          break;
      }
    }

    if (m_Debug) {
      System.out.println("Instance Weights");
      for (int i = 0; i < distances.length; i++) {
        System.out.println("" + distances[i]);
      }
    }

    // Set the weights on the training data
    double sumOfWeights = 0, newSumOfWeights = 0;
    for (int i = 0; i < distances.length; i++) {
      double weight = distances[i];
      Instance inst = (Instance) neighbours.instance(i);
      sumOfWeights += inst.weight();
      newSumOfWeights += inst.weight() * weight;
      inst.setWeight(inst.weight() * weight);
      // weightedTrain.add(newInst);
    }

    // Rescale weights
    for (int i = 0; i < neighbours.numInstances(); i++) {
      Instance inst = neighbours.instance(i);
      inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights);
    }

    // Create a weighted classifier
    m_Classifier.buildClassifier(neighbours);

    if (m_Debug) {
      System.out.println("Classifying test instance: " + instance);
      System.out.println("Built base classifier:\n" + m_Classifier.toString());
    }

    // Return the classifier's predictions
    return m_Classifier.distributionForInstance(instance);
  }
示例#9
0
 /**
  * Returns the value of the named measure from the neighbour search algorithm.
  *
  * @param additionalMeasureName the name of the measure to query for its value
  * @return the value of the named measure
  * @throws IllegalArgumentException if the named measure is not supported
  */
 public double getMeasure(String additionalMeasureName) {
   return m_NNSearch.getMeasure(additionalMeasureName);
 }
示例#10
0
 /**
  * Returns an enumeration of the additional measure names produced by the neighbour search
  * algorithm.
  *
  * @return an enumeration of the measure names
  */
 public Enumeration enumerateMeasures() {
   return m_NNSearch.enumerateMeasures();
 }
示例#11
0
  /**
   * Select the best value for k by hold-one-out cross-validation. If the class attribute is
   * nominal, classification error is minimised. If the class attribute is numeric, mean absolute
   * error is minimised
   */
  protected void crossValidate() {

    try {
      if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree)
        throw new Exception(
            "CoverTree doesn't support hold-one-out "
                + "cross-validation. Use some other NN "
                + "method.");

      double[] performanceStats = new double[m_kNNUpper];
      double[] performanceStatsSq = new double[m_kNNUpper];

      for (int i = 0; i < m_kNNUpper; i++) {
        performanceStats[i] = 0;
        performanceStatsSq[i] = 0;
      }

      m_kNN = m_kNNUpper;
      Instance instance;
      Instances neighbours;
      double[] origDistances, convertedDistances;
      for (int i = 0; i < m_Train.numInstances(); i++) {
        if (m_Debug && (i % 50 == 0)) {
          System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r");
        }
        instance = m_Train.instance(i);
        neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
        origDistances = m_NNSearch.getDistances();

        for (int j = m_kNNUpper - 1; j >= 0; j--) {
          // Update the performance stats
          convertedDistances = new double[origDistances.length];
          System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length);
          double[] distribution = makeDistribution(neighbours, convertedDistances);
          double thisPrediction = Utils.maxIndex(distribution);
          if (m_Train.classAttribute().isNumeric()) {
            thisPrediction = distribution[0];
            double err = thisPrediction - instance.classValue();
            performanceStatsSq[j] += err * err; // Squared error
            performanceStats[j] += Math.abs(err); // Absolute error
          } else {
            if (thisPrediction != instance.classValue()) {
              performanceStats[j]++; // Classification error
            }
          }
          if (j >= 1) {
            neighbours = pruneToK(neighbours, convertedDistances, j);
          }
        }
      }

      // Display the results of the cross-validation
      for (int i = 0; i < m_kNNUpper; i++) {
        if (m_Debug) {
          System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors ");
        }
        if (m_Train.classAttribute().isNumeric()) {
          if (m_Debug) {
            if (m_MeanSquared) {
              System.err.println(
                  "(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances()));
            } else {
              System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances());
            }
          }
        } else {
          if (m_Debug) {
            System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances());
          }
        }
      }

      // Check through the performance stats and select the best
      // k value (or the lowest k if more than one best)
      double[] searchStats = performanceStats;
      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {
        searchStats = performanceStatsSq;
      }
      double bestPerformance = Double.NaN;
      int bestK = 1;
      for (int i = 0; i < m_kNNUpper; i++) {
        if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) {
          bestPerformance = searchStats[i];
          bestK = i + 1;
        }
      }
      m_kNN = bestK;
      if (m_Debug) {
        System.err.println("Selected k = " + bestK);
      }

      m_kNNValid = true;
    } catch (Exception ex) {
      throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage());
    }
  }
示例#12
0
 /**
  * Returns the value of the named measure from the neighbour search algorithm, plus the chosen K
  * in case cross-validation is enabled.
  *
  * @param additionalMeasureName the name of the measure to query for its value
  * @return the value of the named measure
  * @throws IllegalArgumentException if the named measure is not supported
  */
 public double getMeasure(String additionalMeasureName) {
   if (additionalMeasureName.equals("measureKNN")) return m_kNN;
   else return m_NNSearch.getMeasure(additionalMeasureName);
 }
示例#13
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if an error occurred during the prediction
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    NaiveBayes nb = new NaiveBayes();

    // System.out.println("number of instances		"+m_Train.numInstances());

    if (m_Train.numInstances() == 0) {
      // throw new Exception("No training instances!");
      return m_defaultModel.distributionForInstance(instance);
    }
    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {
      m_kNNValid = false;
      boolean deletedInstance = false;
      while (m_Train.numInstances() > m_WindowSize) {
        m_Train.delete(0);
      }
      // rebuild datastructure KDTree currently can't delete
      if (deletedInstance == true) m_NNSearch.setInstances(m_Train);
    }

    // Select k by cross validation
    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {
      crossValidate();
    }

    m_NNSearch.addInstanceInfo(instance);
    m_kNN = 1000;
    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);
    double[] distances = m_NNSearch.getDistances();

    // System.out.println("--------------classify instance--------- ");
    // System.out.println("neighbours.numInstances"+neighbours.numInstances());
    // System.out.println("distances.length"+distances.length);
    // System.out.println("--------------classify instance--------- ");

    /*	for (int k = 0; k < distances.length; k++) {
    	//System.out.println("-------");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    	//System.out.println("instance of "+k+"	"+neighbours.instance(k));
    	//distances[k] = distances[k]+0.1;
    	//System.out.println("------- after add 0.1");
    	//System.out.println("distance of "+k+"	"+distances[k]);
    }
    */
    Instances instances = new Instances(m_Train);
    // int attrnum = instances.numAttributes();
    instances.deleteWithMissingClass();

    Instances newm_Train = new Instances(instances, 0, instances.numInstances());

    for (int k = 0; k < neighbours.numInstances(); k++) {
      // System.out.println("-------");
      // Instance in = new Instance();
      Instance insk = neighbours.instance(k);
      // System.out.println("instance "+k+"	"+neighbours.instance(k));
      // System.out.println("-------");
      double dis = distances[k] + 0.1;
      // System.out.println("dis		"+dis);
      dis = (1 / dis) * 10;
      // System.out.println("1/dis		"+dis);
      int weightnum = (int) dis;
      // System.out.println("weightnum		"+weightnum);

      for (int s = 0; s < weightnum; s++) {

        newm_Train.add(insk);
      }
    }

    // System.out.println("number of instances		"+newm_Train.numInstances());

    /*  for (int k = 0; k < newm_Train.numInstances(); k++) {
    		System.out.println("-------");
    		System.out.println("instance "+k+"	"+newm_Train.instance(k));
    		System.out.println("-------");
    	}

    /*
    	for (int k = 0; k < distances.length; k++) {
    		System.out.println("-------");
    		System.out.println("distance of "+k+"	"+distances[k]);
    		System.out.println("-------");
    	}*/

    nb.buildClassifier(newm_Train);
    double[] dis = nb.distributionForInstance(instance);
    // double[] distribution = makeDistribution(neighbours, distances);
    return dis;
    // return distribution;
  }