/** * determines the "K" for the neighbors from the training set, initializes the labels of the test * set to "missing" and generates the neighbors for all instances in the test set * * @throws Exception if initialization fails */ protected void initialize() throws Exception { int i; double timeStart; double timeEnd; Instances trainingNew; Instances testNew; // determine K if (getVerbose()) System.out.println("\nOriginal KNN = " + m_KNN); ((IBk) m_Classifier).setKNN(m_KNN); ((IBk) m_Classifier).setCrossValidate(true); m_Classifier.buildClassifier(m_TrainsetNew); m_Classifier.toString(); // necessary to crossvalidate IBk! ((IBk) m_Classifier).setCrossValidate(false); m_KNNdetermined = ((IBk) m_Classifier).getKNN(); if (getVerbose()) System.out.println("Determined KNN = " + m_KNNdetermined); // set class labels in test set to "missing" for (i = 0; i < m_TestsetNew.numInstances(); i++) m_TestsetNew.instance(i).setClassMissing(); // copy data trainingNew = new Instances(m_TrainsetNew); testNew = new Instances(m_TestsetNew); // filter data m_Missing.setInputFormat(trainingNew); trainingNew = Filter.useFilter(trainingNew, m_Missing); testNew = Filter.useFilter(testNew, m_Missing); // create the list of neighbors for the instances in the test set m_NeighborsTestset = new Neighbors[m_TestsetNew.numInstances()]; timeStart = System.currentTimeMillis(); for (i = 0; i < testNew.numInstances(); i++) { m_NeighborsTestset[i] = new Neighbors( testNew.instance(i), m_TestsetNew.instance(i), m_KNNdetermined, trainingNew, testNew); m_NeighborsTestset[i].setVerbose(getVerbose() || getDebug()); m_NeighborsTestset[i].setUseNaiveSearch(getUseNaiveSearch()); m_NeighborsTestset[i].find(); } timeEnd = System.currentTimeMillis(); if (getVerbose()) System.out.println( "Time for finding neighbors: " + Utils.doubleToString((timeEnd - timeStart) / 1000.0, 3)); }
/** * Generates the classifier. * * @param data set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances data) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(data); // remove instances with missing class m_theInstances = new Instances(data); m_theInstances.deleteWithMissingClass(); m_rr = new Random(1); if (m_theInstances.classAttribute().isNominal()) { // Set up class priors m_classPriorCounts = new double[data.classAttribute().numValues()]; Arrays.fill(m_classPriorCounts, 1.0); for (int i = 0; i < data.numInstances(); i++) { Instance curr = data.instance(i); m_classPriorCounts[(int) curr.classValue()] += curr.weight(); } m_classPriors = m_classPriorCounts.clone(); Utils.normalize(m_classPriors); } setUpEvaluator(); if (m_theInstances.classAttribute().isNumeric()) { m_disTransform = new weka.filters.unsupervised.attribute.Discretize(); m_classIsNominal = false; // use binned discretisation if the class is numeric ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setBins(10); ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setInvertSelection(true); // Discretize all attributes EXCEPT the class String rangeList = ""; rangeList += (m_theInstances.classIndex() + 1); // System.out.println("The class col: "+m_theInstances.classIndex()); ((weka.filters.unsupervised.attribute.Discretize) m_disTransform) .setAttributeIndices(rangeList); } else { m_disTransform = new weka.filters.supervised.attribute.Discretize(); ((weka.filters.supervised.attribute.Discretize) m_disTransform).setUseBetterEncoding(true); m_classIsNominal = true; } m_disTransform.setInputFormat(m_theInstances); m_theInstances = Filter.useFilter(m_theInstances, m_disTransform); m_numAttributes = m_theInstances.numAttributes(); m_numInstances = m_theInstances.numInstances(); m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute()); // Perform the search int[] selected = m_search.search(m_evaluator, m_theInstances); m_decisionFeatures = new int[selected.length + 1]; System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length); m_decisionFeatures[m_decisionFeatures.length - 1] = m_theInstances.classIndex(); // reduce instances to selected features m_delTransform = new Remove(); m_delTransform.setInvertSelection(true); // set features to keep m_delTransform.setAttributeIndicesArray(m_decisionFeatures); m_delTransform.setInputFormat(m_theInstances); m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform); // reset the number of attributes m_numAttributes = m_dtInstances.numAttributes(); // create hash table m_entries = new Hashtable((int) (m_dtInstances.numInstances() * 1.5)); // insert instances into the hash table for (int i = 0; i < m_numInstances; i++) { Instance inst = m_dtInstances.instance(i); insertIntoTable(inst, null); } // Replace the global table majority with nearest neighbour? if (m_useIBk) { m_ibk = new IBk(); m_ibk.buildClassifier(m_theInstances); } // Save memory if (m_saveMemory) { m_theInstances = new Instances(m_theInstances, 0); m_dtInstances = new Instances(m_dtInstances, 0); } m_evaluation = null; }