/** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String[] getOptions() { String[] superOptions; if (m_InitOptions != null) { try { ((OptionHandler) m_Classifier).setOptions((String[]) m_InitOptions.clone()); superOptions = super.getOptions(); ((OptionHandler) m_Classifier).setOptions((String[]) m_BestClassifierOptions.clone()); } catch (Exception e) { throw new RuntimeException( "CVParameterSelection: could not set options " + "in getOptions()."); } } else { superOptions = super.getOptions(); } String[] options = new String[superOptions.length + m_CVParams.size() * 2 + 2]; int current = 0; for (int i = 0; i < m_CVParams.size(); i++) { options[current++] = "-P"; options[current++] = "" + getCVParameter(i); } options[current++] = "-X"; options[current++] = "" + getNumFolds(); System.arraycopy(superOptions, 0, options, current, superOptions.length); return options; }
public weka.core.Instances toWekaInstances() { // attributes FastVector wattrs = new FastVector(); Iterator itr = attributes.iterator(); while (itr.hasNext()) { Attribute attr = (Attribute) itr.next(); wattrs.addElement(attr.toWekaAttribute()); } // data instances weka.core.Instances winsts = new weka.core.Instances(name, wattrs, instances.size()); itr = instances.iterator(); while (itr.hasNext()) { Instance inst = (Instance) itr.next(); Iterator itrval = inst.getValues().iterator(); Iterator itrmis = inst.getMissing().iterator(); double[] vals = new double[wattrs.size()]; for (int i = 0; i < wattrs.size(); i++) { double val = (Double) itrval.next(); if ((Boolean) itrmis.next()) { vals[i] = weka.core.Instance.missingValue(); } else { vals[i] = val; } } weka.core.Instance winst = new weka.core.Instance(1, vals); winst.setDataset(winsts); winsts.add(winst); } winsts.setClassIndex(this.class_index); return winsts; }
/** * Compute the minimal data description length of the ruleset if the rule in the given position is * NOT deleted.<br> * The min_data_DL_if_n_deleted = data_DL_if_n_deleted - potential * * @param index the index of the rule in question * @param expFPRate expected FP/(FP+FN), used in dataDL calculation * @param checkErr whether check if error rate >= 0.5 * @return the minDataDL */ public double minDataDLIfExists(int index, double expFPRate, boolean checkErr) { // System.out.println("!!!Enter with: "); double[] rulesetStat = new double[6]; // Stats of ruleset if rule exists for (int j = 0; j < m_SimpleStats.size(); j++) { // Covered stats are cumulative rulesetStat[0] += ((double[]) m_SimpleStats.elementAt(j))[0]; rulesetStat[2] += ((double[]) m_SimpleStats.elementAt(j))[2]; rulesetStat[4] += ((double[]) m_SimpleStats.elementAt(j))[4]; if (j == m_SimpleStats.size() - 1) { // Last rule rulesetStat[1] = ((double[]) m_SimpleStats.elementAt(j))[1]; rulesetStat[3] = ((double[]) m_SimpleStats.elementAt(j))[3]; rulesetStat[5] = ((double[]) m_SimpleStats.elementAt(j))[5]; } } // Potential double potential = 0; for (int k = index + 1; k < m_SimpleStats.size(); k++) { double[] ruleStat = (double[]) getSimpleStats(k); double ifDeleted = potential(k, expFPRate, rulesetStat, ruleStat, checkErr); if (!Double.isNaN(ifDeleted)) potential += ifDeleted; } // Data DL of the ruleset without the rule // Note that ruleset stats has already been updated to reflect deletion // if any potential double dataDLWith = dataDL(expFPRate, rulesetStat[0], rulesetStat[1], rulesetStat[4], rulesetStat[5]); // System.out.println("!!!with: "+dataDLWith + " |potential: "+ // potential); return (dataDLWith - potential); }
/** * Create the options array to pass to the classifier. The parameter values and positions are * taken from m_ClassifierOptions and m_CVParams. * * @return the options array */ protected String[] createOptions() { String[] options = new String[m_ClassifierOptions.length + 2 * m_CVParams.size()]; int start = 0, end = options.length; // Add the cross-validation parameters and their values for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter) m_CVParams.elementAt(i); double paramValue = cvParam.m_ParamValue; if (cvParam.m_RoundParam) { // paramValue = (double)((int) (paramValue + 0.5)); paramValue = Math.rint(paramValue); } if (cvParam.m_AddAtEnd) { options[--end] = "" + Utils.doubleToString(paramValue, 4); options[--end] = "-" + cvParam.m_ParamChar; } else { options[start++] = "-" + cvParam.m_ParamChar; options[start++] = "" + Utils.doubleToString(paramValue, 4); } } // Add the static parameters System.arraycopy(m_ClassifierOptions, 0, options, start, m_ClassifierOptions.length); return options; }
/** * Method that finds all large itemsets for the given set of instances. * * @param the instances to be used * @exception Exception if an attribute is numeric */ private void findLargeItemSets(int index) throws Exception { FastVector kMinusOneSets, kSets = new FastVector(); Hashtable hashtable; int i = 0; // Find large itemsets // of length 1 if (index == 1) { kSets = ItemSet.singletons(m_instances); ItemSet.upDateCounters(kSets, m_instances); kSets = ItemSet.deleteItemSets(kSets, m_premiseCount, Integer.MAX_VALUE); if (kSets.size() == 0) return; m_Ls.addElement(kSets); } // of length > 1 if (index > 1) { if (m_Ls.size() > 0) kSets = (FastVector) m_Ls.lastElement(); m_Ls.removeAllElements(); i = index - 2; kMinusOneSets = kSets; kSets = ItemSet.mergeAllItemSets(kMinusOneSets, i, m_instances.numInstances()); hashtable = ItemSet.getHashtable(kMinusOneSets, kMinusOneSets.size()); m_hashtables.addElement(hashtable); kSets = ItemSet.pruneItemSets(kSets, hashtable); ItemSet.upDateCounters(kSets, m_instances); kSets = ItemSet.deleteItemSets(kSets, m_premiseCount, Integer.MAX_VALUE); if (kSets.size() == 0) return; m_Ls.addElement(kSets); } }
/** * Gets the scheme paramter with the given index. * * @param index the index for the parameter * @return the scheme parameter */ public String getCVParameter(int index) { if (m_CVParams.size() <= index) { return ""; } return ((CVParameter) m_CVParams.elementAt(index)).toString(); }
/** * Get the class distribution predicted by the rule in given position * * @param index the position index of the rule * @return the class distributions */ public double[] getDistributions(int index) { if ((m_Distributions != null) && (index < m_Distributions.size())) return (double[]) m_Distributions.elementAt(index); return null; }
/** * Remove the last rule in the ruleset as well as it's stats. It might be useful when the last * rule was added for testing purpose and then the test failed */ public void removeLast() { int last = m_Ruleset.size() - 1; m_Ruleset.removeElementAt(last); m_Filtered.removeElementAt(last); m_SimpleStats.removeElementAt(last); if (m_Distributions != null) m_Distributions.removeElementAt(last); }
/** * Get the data after filtering the given rule * * @param index the index of the rule * @return the data covered and uncovered by the rule */ public Instances[] getFiltered(int index) { if ((m_Filtered != null) && (index < m_Filtered.size())) return (Instances[]) m_Filtered.elementAt(index); return null; }
/** * Handles the various button clicking type activities. * * @param e a value of type 'ActionEvent' */ public void actionPerformed(ActionEvent e) { if (e.getSource() == m_ConfigureBut) { selectProperty(); } else if (e.getSource() == m_StatusBox) { // notify any listeners for (int i = 0; i < m_Listeners.size(); i++) { ActionListener temp = ((ActionListener) m_Listeners.elementAt(i)); temp.actionPerformed( new ActionEvent(this, ActionEvent.ACTION_PERFORMED, "Editor status change")); } // Toggles whether the custom property is used if (m_StatusBox.getSelectedIndex() == 0) { m_Exp.setUsePropertyIterator(false); m_ConfigureBut.setEnabled(false); m_ArrayEditor.setEnabled(false); m_ArrayEditor.setValue(null); validate(); } else { if (m_Exp.getPropertyArray() == null) { selectProperty(); } if (m_Exp.getPropertyArray() == null) { m_StatusBox.setSelectedIndex(0); } else { m_Exp.setUsePropertyIterator(true); m_ConfigureBut.setEnabled(true); m_ArrayEditor.setEnabled(true); } validate(); } } }
/** * Set the labels for nominal attribute creation. * * @param labelList a comma separated list of labels * @throws IllegalArgumentException if the labelList was invalid */ public void setNominalLabels(String labelList) { FastVector labels = new FastVector(10); // Split the labelList up into the vector int commaLoc; while ((commaLoc = labelList.indexOf(',')) >= 0) { String label = labelList.substring(0, commaLoc).trim(); if (!label.equals("")) { labels.addElement(label); } else { throw new IllegalArgumentException( "Invalid label list at " + labelList.substring(commaLoc)); } labelList = labelList.substring(commaLoc + 1); } String label = labelList.trim(); if (!label.equals("")) { labels.addElement(label); } // If everything is OK, make the type change m_Labels = labels; if (labels.size() == 0) { m_AttributeType = Attribute.NUMERIC; } else { m_AttributeType = Attribute.NOMINAL; } }
/** * Finds the best parameter combination. (recursive for each parameter being optimised). * * @param depth the index of the parameter to be optimised at this level * @param trainData the data the search is based on * @param random a random number generator * @throws Exception if an error occurs */ protected void findParamsByCrossValidation(int depth, Instances trainData, Random random) throws Exception { if (depth < m_CVParams.size()) { CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth); double upper; switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: upper = m_NumAttributes; break; case 2: upper = m_TrainFoldSize; break; default: upper = cvParam.m_Upper; break; } double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1); for (cvParam.m_ParamValue = cvParam.m_Lower; cvParam.m_ParamValue <= upper; cvParam.m_ParamValue += increment) { findParamsByCrossValidation(depth + 1, trainData, random); } } else { Evaluation evaluation = new Evaluation(trainData); // Set the classifier options String[] options = createOptions(); if (m_Debug) { System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":"); for (int i = 0; i < options.length; i++) { System.err.print(" " + options[i]); } System.err.println(""); } ((OptionHandler) m_Classifier).setOptions(options); for (int j = 0; j < m_NumFolds; j++) { // We want to randomize the data the same way for every // learning scheme. Instances train = trainData.trainCV(m_NumFolds, j, new Random(1)); Instances test = trainData.testCV(m_NumFolds, j); m_Classifier.buildClassifier(train); evaluation.setPriors(train); evaluation.evaluateModel(m_Classifier, test); } double error = evaluation.errorRate(); if (m_Debug) { System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4)); } if ((m_BestPerformance == -99) || (error < m_BestPerformance)) { m_BestPerformance = error; m_BestClassifierOptions = createOptions(); } } }
/** * Adds the prediction intervals as additional attributes at the end. Since classifiers can * returns varying number of intervals per instance, the dataset is filled with missing values for * non-existing intervals. */ protected void addPredictionIntervals() { int maxNum; int num; int i; int n; FastVector preds; FastVector atts; Instances data; Instance inst; Instance newInst; double[] values; double[][] predInt; // determine the maximum number of intervals maxNum = 0; preds = m_Evaluation.predictions(); for (i = 0; i < preds.size(); i++) { num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length; if (num > maxNum) maxNum = num; } // create new header atts = new FastVector(); for (i = 0; i < m_PlotInstances.numAttributes(); i++) atts.addElement(m_PlotInstances.attribute(i)); for (i = 0; i < maxNum; i++) { atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-lowerBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-upperBoundary")); atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-width")); } data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances()); data.setClassIndex(m_PlotInstances.classIndex()); // update data for (i = 0; i < m_PlotInstances.numInstances(); i++) { inst = m_PlotInstances.instance(i); // copy old values values = new double[data.numAttributes()]; System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes()); // add interval data predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals(); for (n = 0; n < maxNum; n++) { if (n < predInt.length) { values[m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0]; values[m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1]; values[m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0]; } else { values[m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue(); values[m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue(); } } // create new Instance newInst = new DenseInstance(inst.weight(), values); data.add(newInst); } m_PlotInstances = data; }
/** * Calculates the performance stats for the default class and return results as a set of * Instances. The structure of these Instances is as follows: * * <p> * * <ul> * <li><b>True Positives </b> * <li><b>False Negatives</b> * <li><b>False Positives</b> * <li><b>True Negatives</b> * <li><b>False Positive Rate</b> * <li><b>True Positive Rate</b> * <li><b>Precision</b> * <li><b>Recall</b> * <li><b>Fallout</b> * <li><b>Threshold</b> contains the probability threshold that gives rise to the previous * performance values. * </ul> * * <p>For the definitions of these measures, see TwoClassStats * * <p> * * @see TwoClassStats * @param predictions the predictions to base the curve on * @return datapoints as a set of instances, null if no predictions have been made. */ public Instances getCurve(FastVector predictions) { if (predictions.size() == 0) { return null; } return getCurve( predictions, ((NominalPrediction) predictions.elementAt(0)).distribution().length - 1); }
/** * Try to reduce the DL of the ruleset by testing removing the rules one by one in reverse order * and update all the stats * * @param expFPRate expected FP/(FP+FN), used in dataDL calculation * @param checkErr whether check if error rate >= 0.5 */ public void reduceDL(double expFPRate, boolean checkErr) { boolean needUpdate = false; double[] rulesetStat = new double[6]; for (int j = 0; j < m_SimpleStats.size(); j++) { // Covered stats are cumulative rulesetStat[0] += ((double[]) m_SimpleStats.elementAt(j))[0]; rulesetStat[2] += ((double[]) m_SimpleStats.elementAt(j))[2]; rulesetStat[4] += ((double[]) m_SimpleStats.elementAt(j))[4]; if (j == m_SimpleStats.size() - 1) { // Last rule rulesetStat[1] = ((double[]) m_SimpleStats.elementAt(j))[1]; rulesetStat[3] = ((double[]) m_SimpleStats.elementAt(j))[3]; rulesetStat[5] = ((double[]) m_SimpleStats.elementAt(j))[5]; } } // Potential for (int k = m_SimpleStats.size() - 1; k >= 0; k--) { double[] ruleStat = (double[]) m_SimpleStats.elementAt(k); // rulesetStat updated double ifDeleted = potential(k, expFPRate, rulesetStat, ruleStat, checkErr); if (!Double.isNaN(ifDeleted)) { /*System.err.println("!!!deleted ("+k+"): save "+ifDeleted +" | "+rulesetStat[0] +" | "+rulesetStat[1] +" | "+rulesetStat[4] +" | "+rulesetStat[5]); */ if (k == (m_SimpleStats.size() - 1)) removeLast(); else { m_Ruleset.removeElementAt(k); needUpdate = true; } } } if (needUpdate) { m_Filtered = null; m_SimpleStats = null; countData(); } }
/** * Adds the statistics encapsulated in the supplied Evaluation object into this one. Does not * perform any checks for compatibility between the supplied Evaluation object and this one. * * @param evaluation the evaluation object to aggregate */ public void aggregate(Evaluation evaluation) { m_Incorrect += evaluation.incorrect(); m_Correct += evaluation.correct(); m_Unclassified += evaluation.unclassified(); m_MissingClass += evaluation.m_MissingClass; m_WithClass += evaluation.m_WithClass; if (evaluation.m_ConfusionMatrix != null) { double[][] newMatrix = evaluation.confusionMatrix(); if (newMatrix != null) { for (int i = 0; i < m_ConfusionMatrix.length; i++) { for (int j = 0; j < m_ConfusionMatrix[i].length; j++) { m_ConfusionMatrix[i][j] += newMatrix[i][j]; } } } } double[] newClassPriors = evaluation.m_ClassPriors; if (newClassPriors != null) { for (int i = 0; i < this.m_ClassPriors.length; i++) { m_ClassPriors[i] = newClassPriors[i]; } } m_ClassPriorsSum = evaluation.m_ClassPriorsSum; m_TotalCost += evaluation.totalCost(); m_SumErr += evaluation.m_SumErr; m_SumAbsErr += evaluation.m_SumAbsErr; m_SumSqrErr += evaluation.m_SumSqrErr; m_SumClass += evaluation.m_SumClass; m_SumSqrClass += evaluation.m_SumSqrClass; m_SumPredicted += evaluation.m_SumPredicted; m_SumSqrPredicted += evaluation.m_SumSqrPredicted; m_SumClassPredicted += evaluation.m_SumClassPredicted; m_SumPriorAbsErr += evaluation.m_SumPriorAbsErr; m_SumPriorSqrErr += evaluation.m_SumPriorSqrErr; m_SumKBInfo += evaluation.m_SumKBInfo; double[] newMarginCounts = evaluation.m_MarginCounts; if (newMarginCounts != null) { for (int i = 0; i < m_MarginCounts.length; i++) { m_MarginCounts[i] += newMarginCounts[i]; } } m_SumPriorEntropy += evaluation.m_SumPriorEntropy; m_SumSchemeEntropy += evaluation.m_SumSchemeEntropy; m_TotalSizeOfRegions += evaluation.m_TotalSizeOfRegions; m_TotalCoverage += evaluation.m_TotalCoverage; FastVector predsToAdd = evaluation.m_Predictions; if (predsToAdd != null) { if (m_Predictions == null) { m_Predictions = new FastVector(); } for (int i = 0; i < predsToAdd.size(); i++) { m_Predictions.addElement(predsToAdd.elementAt(i)); } } }
/** * @param predictions the predictions to use * @param classIndex the class index * @return the probabilities */ private double[] getProbabilities(FastVector predictions, int classIndex) { // sort by predicted probability of the desired class. double[] probs = new double[predictions.size()]; for (int i = 0; i < probs.length; i++) { NominalPrediction pred = (NominalPrediction) predictions.elementAt(i); probs[i] = pred.distribution()[classIndex]; } return probs; }
/** * Generates the classifier. * * @param instances set of instances serving as training data * @throws Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { // can classifier handle the data? getCapabilities().testWithFail(instances); // remove instances with missing class Instances trainData = new Instances(instances); trainData.deleteWithMissingClass(); if (!(m_Classifier instanceof OptionHandler)) { throw new IllegalArgumentException("Base classifier should be OptionHandler."); } m_InitOptions = ((OptionHandler) m_Classifier).getOptions(); m_BestPerformance = -99; m_NumAttributes = trainData.numAttributes(); Random random = new Random(m_Seed); trainData.randomize(random); m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances(); // Check whether there are any parameters to optimize if (m_CVParams.size() == 0) { m_Classifier.buildClassifier(trainData); m_BestClassifierOptions = m_InitOptions; return; } if (trainData.classAttribute().isNominal()) { trainData.stratify(m_NumFolds); } m_BestClassifierOptions = null; // Set up m_ClassifierOptions -- take getOptions() and remove // those being optimised. m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions(); for (int i = 0; i < m_CVParams.size(); i++) { Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions); } findParamsByCrossValidation(0, trainData, random); String[] options = (String[]) m_BestClassifierOptions.clone(); ((OptionHandler) m_Classifier).setOptions(options); m_Classifier.buildClassifier(trainData); }
/** * Parses a given list of options. Valid options are: * * <p>-D <br> * Turn on debugging output. * * <p>-S seed <br> * Random number seed (default 1). * * <p>-B classifierstring <br> * Classifierstring should contain the full class name of a scheme included for selection followed * by options to the classifier (required, option should be used once for each classifier). * * <p>-X num_folds <br> * Use cross validation error as the basis for classifier selection. (default 0, is to use error * on the training data instead) * * <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { setDebug(Utils.getFlag('D', options)); String numFoldsString = Utils.getOption('X', options); if (numFoldsString.length() != 0) { setNumFolds(Integer.parseInt(numFoldsString)); } else { setNumFolds(0); } String randomString = Utils.getOption('S', options); if (randomString.length() != 0) { setSeed(Integer.parseInt(randomString)); } else { setSeed(1); } // Iterate through the schemes FastVector classifiers = new FastVector(); while (true) { String classifierString = Utils.getOption('B', options); if (classifierString.length() == 0) { break; } String[] classifierSpec = Utils.splitOptions(classifierString); if (classifierSpec.length == 0) { throw new Exception("Invalid classifier specification string"); } String classifierName = classifierSpec[0]; classifierSpec[0] = ""; classifiers.addElement(Classifier.forName(classifierName, classifierSpec)); } if (classifiers.size() <= 1) { throw new Exception("At least two classifiers must be specified" + " with the -B option."); } else { Classifier[] classifiersArray = new Classifier[classifiers.size()]; for (int i = 0; i < classifiersArray.length; i++) { classifiersArray[i] = (Classifier) classifiers.elementAt(i); } setClassifiers(classifiersArray); } }
public void testEvaluationMode() throws Exception { int cind = 0; for (int i = 0; i < ThresholdSelector.TAGS_EVAL.length; i++) { ((ThresholdSelector) m_Classifier) .setEvaluationMode( new SelectedTag(ThresholdSelector.TAGS_EVAL[i].getID(), ThresholdSelector.TAGS_EVAL)); m_Instances.setClassIndex(1); FastVector result = useClassifier(); assertTrue(result.size() != 0); } }
/** * Get the list of labels for nominal attribute creation. * * @return the list of labels for nominal attribute creation */ public String getNominalLabels() { String labelList = ""; for (int i = 0; i < m_Labels.size(); i++) { if (i == 0) { labelList = (String) m_Labels.elementAt(i); } else { labelList += "," + (String) m_Labels.elementAt(i); } } return labelList; }
public void testDesignatedClass() throws Exception { int cind = 0; for (int i = 0; i < ThresholdSelector.TAGS_OPTIMIZE.length; i++) { ((ThresholdSelector) m_Classifier) .setDesignatedClass( new SelectedTag( ThresholdSelector.TAGS_OPTIMIZE[i].getID(), ThresholdSelector.TAGS_OPTIMIZE)); m_Instances.setClassIndex(1); FastVector result = useClassifier(); assertTrue(result.size() != 0); } }
/** Scales numeric class predictions into shape sizes for plotting in the visualize panel. */ protected void scaleNumericPredictions() { double maxErr; double minErr; double err; int i; Double errd; double temp; maxErr = Double.NEGATIVE_INFINITY; minErr = Double.POSITIVE_INFINITY; // find min/max errors for (i = 0; i < m_PlotSizes.size(); i++) { errd = (Double) m_PlotSizes.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (err < minErr) minErr = err; if (err > maxErr) maxErr = err; } } // scale errors for (i = 0; i < m_PlotSizes.size(); i++) { errd = (Double) m_PlotSizes.elementAt(i); if (errd != null) { err = Math.abs(errd.doubleValue()); if (maxErr - minErr > 0) { temp = (((err - minErr) / (maxErr - minErr)) * (m_MaximumPlotSizeNumeric - m_MinimumPlotSizeNumeric + 1)); m_PlotSizes.setElementAt(new Integer((int) temp) + m_MinimumPlotSizeNumeric, i); } else { m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i); } } else { m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i); } } }
public void testRangeNone() throws Exception { int cind = 0; ((ThresholdSelector) m_Classifier) .setDesignatedClass( new SelectedTag(ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE)); ((ThresholdSelector) m_Classifier) .setRangeCorrection( new SelectedTag(ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE)); FastVector result = null; m_Instances.setClassIndex(1); result = useClassifier(); assertTrue(result.size() != 0); double minp = 0; double maxp = 0; for (int i = 0; i < result.size(); i++) { NominalPrediction p = (NominalPrediction) result.elementAt(i); double prob = p.distribution()[cind]; if ((i == 0) || (prob < minp)) minp = prob; if ((i == 0) || (prob > maxp)) maxp = prob; } assertTrue("Upper limit shouldn't increase", maxp <= 1.0); assertTrue("Lower limit shouldn'd decrease", minp >= 0.25); }
public void testNumXValFolds() throws Exception { try { ((ThresholdSelector) m_Classifier).setNumXValFolds(0); fail("Expected IllegalArgumentException"); } catch (IllegalArgumentException e) { // OK } int cind = 0; for (int i = 2; i < 20; i += 2) { ((ThresholdSelector) m_Classifier).setNumXValFolds(i); m_Instances.setClassIndex(1); FastVector result = useClassifier(); assertTrue(result.size() != 0); } }
/** * Add a rule to the ruleset and update the stats * * @param lastRule the rule to be added */ public void addAndUpdate(Rule lastRule) { if (m_Ruleset == null) m_Ruleset = new FastVector(); m_Ruleset.addElement(lastRule); Instances data = (m_Filtered == null) ? m_Data : ((Instances[]) m_Filtered.lastElement())[1]; double[] stats = new double[6]; double[] classCounts = new double[m_Data.classAttribute().numValues()]; Instances[] filtered = computeSimpleStats(m_Ruleset.size() - 1, data, stats, classCounts); if (m_Filtered == null) m_Filtered = new FastVector(); m_Filtered.addElement(filtered); if (m_SimpleStats == null) m_SimpleStats = new FastVector(); m_SimpleStats.addElement(stats); if (m_Distributions == null) m_Distributions = new FastVector(); m_Distributions.addElement(classCounts); }
/** * Filter the data according to the ruleset and compute the basic stats: coverage/uncoverage, * true/false positive/negatives of each rule */ public void countData() { if ((m_Filtered != null) || (m_Ruleset == null) || (m_Data == null)) return; int size = m_Ruleset.size(); m_Filtered = new FastVector(size); m_SimpleStats = new FastVector(size); m_Distributions = new FastVector(size); Instances data = new Instances(m_Data); for (int i = 0; i < size; i++) { double[] stats = new double[6]; // 6 statistics parameters double[] classCounts = new double[m_Data.classAttribute().numValues()]; Instances[] filtered = computeSimpleStats(i, data, stats, classCounts); m_Filtered.addElement(filtered); m_SimpleStats.addElement(stats); m_Distributions.addElement(classCounts); data = filtered[1]; // Data not covered } }
/** * Static utility function to count the data covered by the rules after the given index in the * given rules, and then remove them. It returns the data not covered by the successive rules. * * @param data the data to be processed * @param rules the ruleset * @param index the given index * @return the data after processing */ public static Instances rmCoveredBySuccessives(Instances data, FastVector rules, int index) { Instances rt = new Instances(data, 0); for (int i = 0; i < data.numInstances(); i++) { Instance datum = data.instance(i); boolean covered = false; for (int j = index + 1; j < rules.size(); j++) { Rule rule = (Rule) rules.elementAt(j); if (rule.covers(datum)) { covered = true; break; } } if (!covered) rt.add(datum); } return rt; }
/** * Returns description of the cross-validated classifier. * * @return description of the cross-validated classifier as a string */ public String toString() { if (m_InitOptions == null) return "CVParameterSelection: No model built yet."; String result = "Cross-validated Parameter selection.\n" + "Classifier: " + m_Classifier.getClass().getName() + "\n"; try { for (int i = 0; i < m_CVParams.size(); i++) { CVParameter cvParam = (CVParameter) m_CVParams.elementAt(i); result += "Cross-validation Parameter: '-" + cvParam.m_ParamChar + "'" + " ranged from " + cvParam.m_Lower + " to "; switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) { case 1: result += m_NumAttributes; break; case 2: result += m_TrainFoldSize; break; default: result += cvParam.m_Upper; break; } result += " with " + cvParam.m_Steps + " steps\n"; } } catch (Exception ex) { result += ex.getMessage(); } result += "Classifier Options: " + Utils.joinOptions(m_BestClassifierOptions) + "\n\n" + m_Classifier.toString(); return result; }
/** * Method that finds all association rules. * * @exception Exception if an attribute is numeric */ private void findRulesQuickly() throws Exception { FastVector[] rules; RuleGeneration currentItemSet; // Build rules for (int j = 0; j < m_Ls.size(); j++) { FastVector currentItemSets = (FastVector) m_Ls.elementAt(j); Enumeration enumItemSets = currentItemSets.elements(); while (enumItemSets.hasMoreElements()) { currentItemSet = new RuleGeneration((ItemSet) enumItemSets.nextElement()); m_best = currentItemSet.generateRules( m_numRules, m_midPoints, m_priors, m_expectation, m_instances, m_best, m_count); m_count = currentItemSet.m_count; if (!m_bestChanged && currentItemSet.m_change) m_bestChanged = true; // update minimum expected predictive accuracy to get into the n best if (m_best.size() > 0) m_expectation = ((RuleItem) m_best.first()).accuracy(); else m_expectation = 0; } } }