/**
   * determines the "K" for the neighbors from the training set, initializes the labels of the test
   * set to "missing" and generates the neighbors for all instances in the test set
   *
   * @throws Exception if initialization fails
   */
  protected void initialize() throws Exception {
    int i;
    double timeStart;
    double timeEnd;
    Instances trainingNew;
    Instances testNew;

    // determine K
    if (getVerbose()) System.out.println("\nOriginal KNN = " + m_KNN);
    ((IBk) m_Classifier).setKNN(m_KNN);
    ((IBk) m_Classifier).setCrossValidate(true);
    m_Classifier.buildClassifier(m_TrainsetNew);
    m_Classifier.toString(); // necessary to crossvalidate IBk!
    ((IBk) m_Classifier).setCrossValidate(false);
    m_KNNdetermined = ((IBk) m_Classifier).getKNN();
    if (getVerbose()) System.out.println("Determined KNN = " + m_KNNdetermined);

    // set class labels in test set to "missing"
    for (i = 0; i < m_TestsetNew.numInstances(); i++) m_TestsetNew.instance(i).setClassMissing();

    // copy data
    trainingNew = new Instances(m_TrainsetNew);
    testNew = new Instances(m_TestsetNew);

    // filter data
    m_Missing.setInputFormat(trainingNew);
    trainingNew = Filter.useFilter(trainingNew, m_Missing);
    testNew = Filter.useFilter(testNew, m_Missing);

    // create the list of neighbors for the instances in the test set
    m_NeighborsTestset = new Neighbors[m_TestsetNew.numInstances()];
    timeStart = System.currentTimeMillis();
    for (i = 0; i < testNew.numInstances(); i++) {
      m_NeighborsTestset[i] =
          new Neighbors(
              testNew.instance(i), m_TestsetNew.instance(i), m_KNNdetermined, trainingNew, testNew);
      m_NeighborsTestset[i].setVerbose(getVerbose() || getDebug());
      m_NeighborsTestset[i].setUseNaiveSearch(getUseNaiveSearch());
      m_NeighborsTestset[i].find();
    }
    timeEnd = System.currentTimeMillis();

    if (getVerbose())
      System.out.println(
          "Time for finding neighbors: " + Utils.doubleToString((timeEnd - timeStart) / 1000.0, 3));
  }
Example #2
0
  /**
   * Generates the classifier.
   *
   * @param data set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    m_theInstances = new Instances(data);
    m_theInstances.deleteWithMissingClass();

    m_rr = new Random(1);

    if (m_theInstances.classAttribute().isNominal()) { // 	 Set up class priors
      m_classPriorCounts = new double[data.classAttribute().numValues()];
      Arrays.fill(m_classPriorCounts, 1.0);
      for (int i = 0; i < data.numInstances(); i++) {
        Instance curr = data.instance(i);
        m_classPriorCounts[(int) curr.classValue()] += curr.weight();
      }
      m_classPriors = m_classPriorCounts.clone();
      Utils.normalize(m_classPriors);
    }

    setUpEvaluator();

    if (m_theInstances.classAttribute().isNumeric()) {
      m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
      m_classIsNominal = false;

      // use binned discretisation if the class is numeric
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setBins(10);
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setInvertSelection(true);

      // Discretize all attributes EXCEPT the class
      String rangeList = "";
      rangeList += (m_theInstances.classIndex() + 1);
      // System.out.println("The class col: "+m_theInstances.classIndex());

      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform)
          .setAttributeIndices(rangeList);
    } else {
      m_disTransform = new weka.filters.supervised.attribute.Discretize();
      ((weka.filters.supervised.attribute.Discretize) m_disTransform).setUseBetterEncoding(true);
      m_classIsNominal = true;
    }

    m_disTransform.setInputFormat(m_theInstances);
    m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);

    m_numAttributes = m_theInstances.numAttributes();
    m_numInstances = m_theInstances.numInstances();
    m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());

    // Perform the search
    int[] selected = m_search.search(m_evaluator, m_theInstances);

    m_decisionFeatures = new int[selected.length + 1];
    System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length);
    m_decisionFeatures[m_decisionFeatures.length - 1] = m_theInstances.classIndex();

    // reduce instances to selected features
    m_delTransform = new Remove();
    m_delTransform.setInvertSelection(true);

    // set features to keep
    m_delTransform.setAttributeIndicesArray(m_decisionFeatures);
    m_delTransform.setInputFormat(m_theInstances);
    m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform);

    // reset the number of attributes
    m_numAttributes = m_dtInstances.numAttributes();

    // create hash table
    m_entries = new Hashtable((int) (m_dtInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (int i = 0; i < m_numInstances; i++) {
      Instance inst = m_dtInstances.instance(i);
      insertIntoTable(inst, null);
    }

    // Replace the global table majority with nearest neighbour?
    if (m_useIBk) {
      m_ibk = new IBk();
      m_ibk.buildClassifier(m_theInstances);
    }

    // Save memory
    if (m_saveMemory) {
      m_theInstances = new Instances(m_theInstances, 0);
      m_dtInstances = new Instances(m_dtInstances, 0);
    }
    m_evaluation = null;
  }