/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.distributionForInstance(instance); } if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } m_NNSearch.addInstanceInfo(instance); int k = m_Train.numInstances(); if ((!m_UseAllK && (m_kNN < k)) /*&& !(m_WeightKernel==INVERSE || m_WeightKernel==GAUSS)*/) { k = m_kNN; } Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k); double distances[] = m_NNSearch.getDistances(); if (m_Debug) { System.out.println("Test Instance: " + instance); System.out.println( "For " + k + " kept " + neighbours.numInstances() + " out of " + m_Train.numInstances() + " instances."); } // IF LinearNN has skipped so much that <k neighbours are remaining. if (k > distances.length) k = distances.length; if (m_Debug) { System.out.println("Instance Distances"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Determine the bandwidth double bandwidth = distances[k - 1]; // Check for bandwidth zero if (bandwidth <= 0) { // if the kth distance is zero than give all instances the same weight for (int i = 0; i < distances.length; i++) distances[i] = 1; } else { // Rescale the distances by the bandwidth for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth; } // Pass the distances through a weighting kernel for (int i = 0; i < distances.length; i++) { switch (m_WeightKernel) { case LINEAR: distances[i] = 1.0001 - distances[i]; break; case EPANECHNIKOV: distances[i] = 3 / 4D * (1.0001 - distances[i] * distances[i]); break; case TRICUBE: distances[i] = Math.pow((1.0001 - Math.pow(distances[i], 3)), 3); break; case CONSTANT: // System.err.println("using constant kernel"); distances[i] = 1; break; case INVERSE: distances[i] = 1.0 / (1.0 + distances[i]); break; case GAUSS: distances[i] = Math.exp(-distances[i] * distances[i]); break; } } if (m_Debug) { System.out.println("Instance Weights"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Set the weights on the training data double sumOfWeights = 0, newSumOfWeights = 0; for (int i = 0; i < distances.length; i++) { double weight = distances[i]; Instance inst = (Instance) neighbours.instance(i); sumOfWeights += inst.weight(); newSumOfWeights += inst.weight() * weight; inst.setWeight(inst.weight() * weight); // weightedTrain.add(newInst); } // Rescale weights for (int i = 0; i < neighbours.numInstances(); i++) { Instance inst = neighbours.instance(i); inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights); } // Create a weighted classifier m_Classifier.buildClassifier(neighbours); if (m_Debug) { System.out.println("Classifying test instance: " + instance); System.out.println("Built base classifier:\n" + m_Classifier.toString()); } // Return the classifier's predictions return m_Classifier.distributionForInstance(instance); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { NaiveBayes nb = new NaiveBayes(); // System.out.println("number of instances "+m_Train.numInstances()); if (m_Train.numInstances() == 0) { // throw new Exception("No training instances!"); return m_defaultModel.distributionForInstance(instance); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; boolean deletedInstance = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } // rebuild datastructure KDTree currently can't delete if (deletedInstance == true) m_NNSearch.setInstances(m_Train); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { crossValidate(); } m_NNSearch.addInstanceInfo(instance); m_kNN = 1000; Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); double[] distances = m_NNSearch.getDistances(); // System.out.println("--------------classify instance--------- "); // System.out.println("neighbours.numInstances"+neighbours.numInstances()); // System.out.println("distances.length"+distances.length); // System.out.println("--------------classify instance--------- "); /* for (int k = 0; k < distances.length; k++) { //System.out.println("-------"); //System.out.println("distance of "+k+" "+distances[k]); //System.out.println("instance of "+k+" "+neighbours.instance(k)); //distances[k] = distances[k]+0.1; //System.out.println("------- after add 0.1"); //System.out.println("distance of "+k+" "+distances[k]); } */ Instances instances = new Instances(m_Train); // int attrnum = instances.numAttributes(); instances.deleteWithMissingClass(); Instances newm_Train = new Instances(instances, 0, instances.numInstances()); for (int k = 0; k < neighbours.numInstances(); k++) { // System.out.println("-------"); // Instance in = new Instance(); Instance insk = neighbours.instance(k); // System.out.println("instance "+k+" "+neighbours.instance(k)); // System.out.println("-------"); double dis = distances[k] + 0.1; // System.out.println("dis "+dis); dis = (1 / dis) * 10; // System.out.println("1/dis "+dis); int weightnum = (int) dis; // System.out.println("weightnum "+weightnum); for (int s = 0; s < weightnum; s++) { newm_Train.add(insk); } } // System.out.println("number of instances "+newm_Train.numInstances()); /* for (int k = 0; k < newm_Train.numInstances(); k++) { System.out.println("-------"); System.out.println("instance "+k+" "+newm_Train.instance(k)); System.out.println("-------"); } /* for (int k = 0; k < distances.length; k++) { System.out.println("-------"); System.out.println("distance of "+k+" "+distances[k]); System.out.println("-------"); }*/ nb.buildClassifier(newm_Train); double[] dis = nb.distributionForInstance(instance); // double[] distribution = makeDistribution(neighbours, distances); return dis; // return distribution; }