/** * Select the best value for k by hold-one-out cross-validation. If the class attribute is * nominal, classification error is minimised. If the class attribute is numeric, mean absolute * error is minimised */ protected void crossValidate() { try { if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree) throw new Exception( "CoverTree doesn't support hold-one-out " + "cross-validation. Use some other NN " + "method."); double[] performanceStats = new double[m_kNNUpper]; double[] performanceStatsSq = new double[m_kNNUpper]; for (int i = 0; i < m_kNNUpper; i++) { performanceStats[i] = 0; performanceStatsSq[i] = 0; } m_kNN = m_kNNUpper; Instance instance; Instances neighbours; double[] origDistances, convertedDistances; for (int i = 0; i < m_Train.numInstances(); i++) { if (m_Debug && (i % 50 == 0)) { System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r"); } instance = m_Train.instance(i); neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); origDistances = m_NNSearch.getDistances(); for (int j = m_kNNUpper - 1; j >= 0; j--) { // Update the performance stats convertedDistances = new double[origDistances.length]; System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length); double[] distribution = makeDistribution(neighbours, convertedDistances); double thisPrediction = Utils.maxIndex(distribution); if (m_Train.classAttribute().isNumeric()) { thisPrediction = distribution[0]; double err = thisPrediction - instance.classValue(); performanceStatsSq[j] += err * err; // Squared error performanceStats[j] += Math.abs(err); // Absolute error } else { if (thisPrediction != instance.classValue()) { performanceStats[j]++; // Classification error } } if (j >= 1) { neighbours = pruneToK(neighbours, convertedDistances, j); } } } // Display the results of the cross-validation for (int i = 0; i < m_kNNUpper; i++) { if (m_Debug) { System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors "); } if (m_Train.classAttribute().isNumeric()) { if (m_Debug) { if (m_MeanSquared) { System.err.println( "(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances())); } else { System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances()); } } } else { if (m_Debug) { System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances()); } } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double[] searchStats = performanceStats; if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { searchStats = performanceStatsSq; } double bestPerformance = Double.NaN; int bestK = 1; for (int i = 0; i < m_kNNUpper; i++) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i + 1; } } m_kNN = bestK; if (m_Debug) { System.err.println("Selected k = " + bestK); } m_kNNValid = true; } catch (Exception ex) { throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage()); } }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.distributionForInstance(instance); } if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } m_NNSearch.addInstanceInfo(instance); int k = m_Train.numInstances(); if ((!m_UseAllK && (m_kNN < k)) /*&& !(m_WeightKernel==INVERSE || m_WeightKernel==GAUSS)*/) { k = m_kNN; } Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k); double distances[] = m_NNSearch.getDistances(); if (m_Debug) { System.out.println("Test Instance: " + instance); System.out.println( "For " + k + " kept " + neighbours.numInstances() + " out of " + m_Train.numInstances() + " instances."); } // IF LinearNN has skipped so much that <k neighbours are remaining. if (k > distances.length) k = distances.length; if (m_Debug) { System.out.println("Instance Distances"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Determine the bandwidth double bandwidth = distances[k - 1]; // Check for bandwidth zero if (bandwidth <= 0) { // if the kth distance is zero than give all instances the same weight for (int i = 0; i < distances.length; i++) distances[i] = 1; } else { // Rescale the distances by the bandwidth for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth; } // Pass the distances through a weighting kernel for (int i = 0; i < distances.length; i++) { switch (m_WeightKernel) { case LINEAR: distances[i] = 1.0001 - distances[i]; break; case EPANECHNIKOV: distances[i] = 3 / 4D * (1.0001 - distances[i] * distances[i]); break; case TRICUBE: distances[i] = Math.pow((1.0001 - Math.pow(distances[i], 3)), 3); break; case CONSTANT: // System.err.println("using constant kernel"); distances[i] = 1; break; case INVERSE: distances[i] = 1.0 / (1.0 + distances[i]); break; case GAUSS: distances[i] = Math.exp(-distances[i] * distances[i]); break; } } if (m_Debug) { System.out.println("Instance Weights"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Set the weights on the training data double sumOfWeights = 0, newSumOfWeights = 0; for (int i = 0; i < distances.length; i++) { double weight = distances[i]; Instance inst = (Instance) neighbours.instance(i); sumOfWeights += inst.weight(); newSumOfWeights += inst.weight() * weight; inst.setWeight(inst.weight() * weight); // weightedTrain.add(newInst); } // Rescale weights for (int i = 0; i < neighbours.numInstances(); i++) { Instance inst = neighbours.instance(i); inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights); } // Create a weighted classifier m_Classifier.buildClassifier(neighbours); if (m_Debug) { System.out.println("Classifying test instance: " + instance); System.out.println("Built base classifier:\n" + m_Classifier.toString()); } // Return the classifier's predictions return m_Classifier.distributionForInstance(instance); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { NaiveBayes nb = new NaiveBayes(); // System.out.println("number of instances "+m_Train.numInstances()); if (m_Train.numInstances() == 0) { // throw new Exception("No training instances!"); return m_defaultModel.distributionForInstance(instance); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; boolean deletedInstance = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } // rebuild datastructure KDTree currently can't delete if (deletedInstance == true) m_NNSearch.setInstances(m_Train); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { crossValidate(); } m_NNSearch.addInstanceInfo(instance); m_kNN = 1000; Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); double[] distances = m_NNSearch.getDistances(); // System.out.println("--------------classify instance--------- "); // System.out.println("neighbours.numInstances"+neighbours.numInstances()); // System.out.println("distances.length"+distances.length); // System.out.println("--------------classify instance--------- "); /* for (int k = 0; k < distances.length; k++) { //System.out.println("-------"); //System.out.println("distance of "+k+" "+distances[k]); //System.out.println("instance of "+k+" "+neighbours.instance(k)); //distances[k] = distances[k]+0.1; //System.out.println("------- after add 0.1"); //System.out.println("distance of "+k+" "+distances[k]); } */ Instances instances = new Instances(m_Train); // int attrnum = instances.numAttributes(); instances.deleteWithMissingClass(); Instances newm_Train = new Instances(instances, 0, instances.numInstances()); for (int k = 0; k < neighbours.numInstances(); k++) { // System.out.println("-------"); // Instance in = new Instance(); Instance insk = neighbours.instance(k); // System.out.println("instance "+k+" "+neighbours.instance(k)); // System.out.println("-------"); double dis = distances[k] + 0.1; // System.out.println("dis "+dis); dis = (1 / dis) * 10; // System.out.println("1/dis "+dis); int weightnum = (int) dis; // System.out.println("weightnum "+weightnum); for (int s = 0; s < weightnum; s++) { newm_Train.add(insk); } } // System.out.println("number of instances "+newm_Train.numInstances()); /* for (int k = 0; k < newm_Train.numInstances(); k++) { System.out.println("-------"); System.out.println("instance "+k+" "+newm_Train.instance(k)); System.out.println("-------"); } /* for (int k = 0; k < distances.length; k++) { System.out.println("-------"); System.out.println("distance of "+k+" "+distances[k]); System.out.println("-------"); }*/ nb.buildClassifier(newm_Train); double[] dis = nb.distributionForInstance(instance); // double[] distribution = makeDistribution(neighbours, distances); return dis; // return distribution; }