/** * Gets the current settings of IBk. * * @return an array of strings suitable for passing to setOptions() */ public String[] getOptions() { String[] options = new String[11]; int current = 0; options[current++] = "-K"; options[current++] = "" + getKNN(); options[current++] = "-W"; options[current++] = "" + m_WindowSize; if (getCrossValidate()) { options[current++] = "-X"; } if (getMeanSquared()) { options[current++] = "-E"; } if (m_DistanceWeighting == WEIGHT_INVERSE) { options[current++] = "-I"; } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) { options[current++] = "-F"; } options[current++] = "-A"; options[current++] = m_NNSearch.getClass().getName() + " " + Utils.joinOptions(m_NNSearch.getOptions()); while (current < options.length) { options[current++] = ""; } return options; }
/** * Returns an enumeration of the additional measure names produced by the neighbour search * algorithm, plus the chosen K in case cross-validation is enabled. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { if (m_CrossValidate) { Enumeration enm = m_NNSearch.enumerateMeasures(); Vector measures = new Vector(); while (enm.hasMoreElements()) measures.add(enm.nextElement()); measures.add("measureKNN"); return measures.elements(); } else { return m_NNSearch.enumerateMeasures(); } }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (!(m_Classifier instanceof WeightedInstancesHandler)) { throw new IllegalArgumentException("Classifier must be a " + "WeightedInstancesHandler!"); } // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); // only class? -> build ZeroR model if (instances.numAttributes() == 1) { System.err.println( "Cannot build model (only class attribute present in data!), " + "using ZeroR model instead!"); m_ZeroR = new weka.classifiers.rules.ZeroR(); m_ZeroR.buildClassifier(instances); return; } else { m_ZeroR = null; } m_Train = new Instances(instances, 0, instances.numInstances()); m_NNSearch.setInstances(m_Train); }
/** * Adds the supplied instance to the training set. * * @param instance the instance to add * @throws Exception if instance could not be incorporated successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train == null) { throw new Exception("No training instance structure set!"); } else if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception( "Incompatible instance types\n" + m_Train.equalHeadersMsg(instance.dataset())); } if (!instance.classIsMissing()) { m_NNSearch.update(instance); m_Train.add(instance); } }
/** * Adds the supplied instance to the training set. * * @param instance the instance to add * @throws Exception if instance could not be incorporated successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (instance.classIsMissing()) { return; } m_Train.add(instance); m_NNSearch.update(instance); m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { boolean deletedInstance = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); deletedInstance = true; } // rebuild datastructure KDTree currently can't delete if (deletedInstance == true) m_NNSearch.setInstances(m_Train); } }
/** * Gets the current settings of the classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { String[] superOptions = super.getOptions(); String[] options = new String[superOptions.length + 6]; int current = 0; options[current++] = "-U"; options[current++] = "" + getWeightingKernel(); if ((getKNN() == 0) && m_UseAllK) { options[current++] = "-K"; options[current++] = "-1"; } else { options[current++] = "-K"; options[current++] = "" + getKNN(); } options[current++] = "-A"; options[current++] = m_NNSearch.getClass().getName() + " " + Utils.joinOptions(m_NNSearch.getOptions()); System.arraycopy(superOptions, 0, options, current, superOptions.length); return options; }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class instances = new Instances(instances); instances.deleteWithMissingClass(); m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); m_Train = new Instances(instances, 0, instances.numInstances()); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances() - m_WindowSize, m_WindowSize); } m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } m_NNSearch.setInstances(m_Train); // Invalidate any currently cross-validation selected k m_kNNValid = false; m_defaultModel = new ZeroR(); m_defaultModel.buildClassifier(instances); m_defaultModel.setOptions(getOptions()); // System.out.println("hello world"); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return preedicted class probability distribution * @throws Exception if distribution can't be computed successfully */ public double[] distributionForInstance(Instance instance) throws Exception { // default model? if (m_ZeroR != null) { return m_ZeroR.distributionForInstance(instance); } if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } m_NNSearch.addInstanceInfo(instance); int k = m_Train.numInstances(); if ((!m_UseAllK && (m_kNN < k)) /*&& !(m_WeightKernel==INVERSE || m_WeightKernel==GAUSS)*/) { k = m_kNN; } Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k); double distances[] = m_NNSearch.getDistances(); if (m_Debug) { System.out.println("Test Instance: " + instance); System.out.println( "For " + k + " kept " + neighbours.numInstances() + " out of " + m_Train.numInstances() + " instances."); } // IF LinearNN has skipped so much that <k neighbours are remaining. if (k > distances.length) k = distances.length; if (m_Debug) { System.out.println("Instance Distances"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Determine the bandwidth double bandwidth = distances[k - 1]; // Check for bandwidth zero if (bandwidth <= 0) { // if the kth distance is zero than give all instances the same weight for (int i = 0; i < distances.length; i++) distances[i] = 1; } else { // Rescale the distances by the bandwidth for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth; } // Pass the distances through a weighting kernel for (int i = 0; i < distances.length; i++) { switch (m_WeightKernel) { case LINEAR: distances[i] = 1.0001 - distances[i]; break; case EPANECHNIKOV: distances[i] = 3 / 4D * (1.0001 - distances[i] * distances[i]); break; case TRICUBE: distances[i] = Math.pow((1.0001 - Math.pow(distances[i], 3)), 3); break; case CONSTANT: // System.err.println("using constant kernel"); distances[i] = 1; break; case INVERSE: distances[i] = 1.0 / (1.0 + distances[i]); break; case GAUSS: distances[i] = Math.exp(-distances[i] * distances[i]); break; } } if (m_Debug) { System.out.println("Instance Weights"); for (int i = 0; i < distances.length; i++) { System.out.println("" + distances[i]); } } // Set the weights on the training data double sumOfWeights = 0, newSumOfWeights = 0; for (int i = 0; i < distances.length; i++) { double weight = distances[i]; Instance inst = (Instance) neighbours.instance(i); sumOfWeights += inst.weight(); newSumOfWeights += inst.weight() * weight; inst.setWeight(inst.weight() * weight); // weightedTrain.add(newInst); } // Rescale weights for (int i = 0; i < neighbours.numInstances(); i++) { Instance inst = neighbours.instance(i); inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights); } // Create a weighted classifier m_Classifier.buildClassifier(neighbours); if (m_Debug) { System.out.println("Classifying test instance: " + instance); System.out.println("Built base classifier:\n" + m_Classifier.toString()); } // Return the classifier's predictions return m_Classifier.distributionForInstance(instance); }
/** * Returns the value of the named measure from the neighbour search algorithm. * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { return m_NNSearch.getMeasure(additionalMeasureName); }
/** * Returns an enumeration of the additional measure names produced by the neighbour search * algorithm. * * @return an enumeration of the measure names */ public Enumeration enumerateMeasures() { return m_NNSearch.enumerateMeasures(); }
/** * Select the best value for k by hold-one-out cross-validation. If the class attribute is * nominal, classification error is minimised. If the class attribute is numeric, mean absolute * error is minimised */ protected void crossValidate() { try { if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree) throw new Exception( "CoverTree doesn't support hold-one-out " + "cross-validation. Use some other NN " + "method."); double[] performanceStats = new double[m_kNNUpper]; double[] performanceStatsSq = new double[m_kNNUpper]; for (int i = 0; i < m_kNNUpper; i++) { performanceStats[i] = 0; performanceStatsSq[i] = 0; } m_kNN = m_kNNUpper; Instance instance; Instances neighbours; double[] origDistances, convertedDistances; for (int i = 0; i < m_Train.numInstances(); i++) { if (m_Debug && (i % 50 == 0)) { System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r"); } instance = m_Train.instance(i); neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); origDistances = m_NNSearch.getDistances(); for (int j = m_kNNUpper - 1; j >= 0; j--) { // Update the performance stats convertedDistances = new double[origDistances.length]; System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length); double[] distribution = makeDistribution(neighbours, convertedDistances); double thisPrediction = Utils.maxIndex(distribution); if (m_Train.classAttribute().isNumeric()) { thisPrediction = distribution[0]; double err = thisPrediction - instance.classValue(); performanceStatsSq[j] += err * err; // Squared error performanceStats[j] += Math.abs(err); // Absolute error } else { if (thisPrediction != instance.classValue()) { performanceStats[j]++; // Classification error } } if (j >= 1) { neighbours = pruneToK(neighbours, convertedDistances, j); } } } // Display the results of the cross-validation for (int i = 0; i < m_kNNUpper; i++) { if (m_Debug) { System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors "); } if (m_Train.classAttribute().isNumeric()) { if (m_Debug) { if (m_MeanSquared) { System.err.println( "(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances())); } else { System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances()); } } } else { if (m_Debug) { System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances()); } } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double[] searchStats = performanceStats; if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { searchStats = performanceStatsSq; } double bestPerformance = Double.NaN; int bestK = 1; for (int i = 0; i < m_kNNUpper; i++) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i + 1; } } m_kNN = bestK; if (m_Debug) { System.err.println("Selected k = " + bestK); } m_kNNValid = true; } catch (Exception ex) { throw new Error("Couldn't optimize by cross-validation: " + ex.getMessage()); } }
/** * Returns the value of the named measure from the neighbour search algorithm, plus the chosen K * in case cross-validation is enabled. * * @param additionalMeasureName the name of the measure to query for its value * @return the value of the named measure * @throws IllegalArgumentException if the named measure is not supported */ public double getMeasure(String additionalMeasureName) { if (additionalMeasureName.equals("measureKNN")) return m_kNN; else return m_NNSearch.getMeasure(additionalMeasureName); }
/** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double[] distributionForInstance(Instance instance) throws Exception { NaiveBayes nb = new NaiveBayes(); // System.out.println("number of instances "+m_Train.numInstances()); if (m_Train.numInstances() == 0) { // throw new Exception("No training instances!"); return m_defaultModel.distributionForInstance(instance); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; boolean deletedInstance = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } // rebuild datastructure KDTree currently can't delete if (deletedInstance == true) m_NNSearch.setInstances(m_Train); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { crossValidate(); } m_NNSearch.addInstanceInfo(instance); m_kNN = 1000; Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); double[] distances = m_NNSearch.getDistances(); // System.out.println("--------------classify instance--------- "); // System.out.println("neighbours.numInstances"+neighbours.numInstances()); // System.out.println("distances.length"+distances.length); // System.out.println("--------------classify instance--------- "); /* for (int k = 0; k < distances.length; k++) { //System.out.println("-------"); //System.out.println("distance of "+k+" "+distances[k]); //System.out.println("instance of "+k+" "+neighbours.instance(k)); //distances[k] = distances[k]+0.1; //System.out.println("------- after add 0.1"); //System.out.println("distance of "+k+" "+distances[k]); } */ Instances instances = new Instances(m_Train); // int attrnum = instances.numAttributes(); instances.deleteWithMissingClass(); Instances newm_Train = new Instances(instances, 0, instances.numInstances()); for (int k = 0; k < neighbours.numInstances(); k++) { // System.out.println("-------"); // Instance in = new Instance(); Instance insk = neighbours.instance(k); // System.out.println("instance "+k+" "+neighbours.instance(k)); // System.out.println("-------"); double dis = distances[k] + 0.1; // System.out.println("dis "+dis); dis = (1 / dis) * 10; // System.out.println("1/dis "+dis); int weightnum = (int) dis; // System.out.println("weightnum "+weightnum); for (int s = 0; s < weightnum; s++) { newm_Train.add(insk); } } // System.out.println("number of instances "+newm_Train.numInstances()); /* for (int k = 0; k < newm_Train.numInstances(); k++) { System.out.println("-------"); System.out.println("instance "+k+" "+newm_Train.instance(k)); System.out.println("-------"); } /* for (int k = 0; k < distances.length; k++) { System.out.println("-------"); System.out.println("distance of "+k+" "+distances[k]); System.out.println("-------"); }*/ nb.buildClassifier(newm_Train); double[] dis = nb.distributionForInstance(instance); // double[] distribution = makeDistribution(neighbours, distances); return dis; // return distribution; }