Пример #1
0
  /**
   * Compute the minimal data description length of the ruleset if the rule in the given position is
   * NOT deleted.<br>
   * The min_data_DL_if_n_deleted = data_DL_if_n_deleted - potential
   *
   * @param index the index of the rule in question
   * @param expFPRate expected FP/(FP+FN), used in dataDL calculation
   * @param checkErr whether check if error rate >= 0.5
   * @return the minDataDL
   */
  public double minDataDLIfExists(int index, double expFPRate, boolean checkErr) {
    //	System.out.println("!!!Enter with: ");
    double[] rulesetStat = new double[6]; // Stats of ruleset if rule exists
    for (int j = 0; j < m_SimpleStats.size(); j++) {
      // Covered stats are cumulative
      rulesetStat[0] += ((double[]) m_SimpleStats.elementAt(j))[0];
      rulesetStat[2] += ((double[]) m_SimpleStats.elementAt(j))[2];
      rulesetStat[4] += ((double[]) m_SimpleStats.elementAt(j))[4];
      if (j == m_SimpleStats.size() - 1) { // Last rule
        rulesetStat[1] = ((double[]) m_SimpleStats.elementAt(j))[1];
        rulesetStat[3] = ((double[]) m_SimpleStats.elementAt(j))[3];
        rulesetStat[5] = ((double[]) m_SimpleStats.elementAt(j))[5];
      }
    }

    // Potential
    double potential = 0;
    for (int k = index + 1; k < m_SimpleStats.size(); k++) {
      double[] ruleStat = (double[]) getSimpleStats(k);
      double ifDeleted = potential(k, expFPRate, rulesetStat, ruleStat, checkErr);
      if (!Double.isNaN(ifDeleted)) potential += ifDeleted;
    }

    // Data DL of the ruleset without the rule
    // Note that ruleset stats has already been updated to reflect deletion
    // if any potential
    double dataDLWith =
        dataDL(expFPRate, rulesetStat[0], rulesetStat[1], rulesetStat[4], rulesetStat[5]);
    // System.out.println("!!!with: "+dataDLWith + " |potential: "+
    //		   potential);
    return (dataDLWith - potential);
  }
  /**
   * Adds the prediction intervals as additional attributes at the end. Since classifiers can
   * returns varying number of intervals per instance, the dataset is filled with missing values for
   * non-existing intervals.
   */
  protected void addPredictionIntervals() {
    int maxNum;
    int num;
    int i;
    int n;
    FastVector preds;
    FastVector atts;
    Instances data;
    Instance inst;
    Instance newInst;
    double[] values;
    double[][] predInt;

    // determine the maximum number of intervals
    maxNum = 0;
    preds = m_Evaluation.predictions();
    for (i = 0; i < preds.size(); i++) {
      num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length;
      if (num > maxNum) maxNum = num;
    }

    // create new header
    atts = new FastVector();
    for (i = 0; i < m_PlotInstances.numAttributes(); i++)
      atts.addElement(m_PlotInstances.attribute(i));
    for (i = 0; i < maxNum; i++) {
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-lowerBoundary"));
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-upperBoundary"));
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-width"));
    }
    data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances());
    data.setClassIndex(m_PlotInstances.classIndex());

    // update data
    for (i = 0; i < m_PlotInstances.numInstances(); i++) {
      inst = m_PlotInstances.instance(i);
      // copy old values
      values = new double[data.numAttributes()];
      System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes());
      // add interval data
      predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals();
      for (n = 0; n < maxNum; n++) {
        if (n < predInt.length) {
          values[m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0];
          values[m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1];
          values[m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0];
        } else {
          values[m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue();
          values[m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue();
          values[m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue();
        }
      }
      // create new Instance
      newInst = new DenseInstance(inst.weight(), values);
      data.add(newInst);
    }

    m_PlotInstances = data;
  }
Пример #3
0
  /**
   * Get the list of labels for nominal attribute creation.
   *
   * @return the list of labels for nominal attribute creation
   */
  public String getNominalLabels() {

    String labelList = "";
    for (int i = 0; i < m_Labels.size(); i++) {
      if (i == 0) {
        labelList = (String) m_Labels.elementAt(i);
      } else {
        labelList += "," + (String) m_Labels.elementAt(i);
      }
    }
    return labelList;
  }
Пример #4
0
  /**
   * Find all the instances in the dataset covered/not covered by the rule in given index, and the
   * correponding simple statistics and predicted class distributions are stored in the given double
   * array, which can be obtained by getSimpleStats() and getDistributions().<br>
   *
   * @param index the given index, assuming correct
   * @param insts the dataset to be covered by the rule
   * @param stats the given double array to hold stats, side-effected
   * @param dist the given array to hold class distributions, side-effected if null, the
   *     distribution is not necessary
   * @return the instances covered and not covered by the rule
   */
  private Instances[] computeSimpleStats(
      int index, Instances insts, double[] stats, double[] dist) {
    Rule rule = (Rule) m_Ruleset.elementAt(index);

    Instances[] data = new Instances[2];
    data[0] = new Instances(insts, insts.numInstances());
    data[1] = new Instances(insts, insts.numInstances());

    for (int i = 0; i < insts.numInstances(); i++) {
      Instance datum = insts.instance(i);
      double weight = datum.weight();
      if (rule.covers(datum)) {
        data[0].add(datum); // Covered by this rule
        stats[0] += weight; // Coverage
        if ((int) datum.classValue() == (int) rule.getConsequent())
          stats[2] += weight; // True positives
        else stats[4] += weight; // False positives
        if (dist != null) dist[(int) datum.classValue()] += weight;
      } else {
        data[1].add(datum); // Not covered by this rule
        stats[1] += weight;
        if ((int) datum.classValue() != (int) rule.getConsequent())
          stats[3] += weight; // True negatives
        else stats[5] += weight; // False negatives
      }
    }

    return data;
  }
Пример #5
0
  /**
   * Get the class distribution predicted by the rule in given position
   *
   * @param index the position index of the rule
   * @return the class distributions
   */
  public double[] getDistributions(int index) {

    if ((m_Distributions != null) && (index < m_Distributions.size()))
      return (double[]) m_Distributions.elementAt(index);

    return null;
  }
Пример #6
0
  /**
   * Get the data after filtering the given rule
   *
   * @param index the index of the rule
   * @return the data covered and uncovered by the rule
   */
  public Instances[] getFiltered(int index) {

    if ((m_Filtered != null) && (index < m_Filtered.size()))
      return (Instances[]) m_Filtered.elementAt(index);

    return null;
  }
  /**
   * Create the options array to pass to the classifier. The parameter values and positions are
   * taken from m_ClassifierOptions and m_CVParams.
   *
   * @return the options array
   */
  protected String[] createOptions() {

    String[] options = new String[m_ClassifierOptions.length + 2 * m_CVParams.size()];
    int start = 0, end = options.length;

    // Add the cross-validation parameters and their values
    for (int i = 0; i < m_CVParams.size(); i++) {
      CVParameter cvParam = (CVParameter) m_CVParams.elementAt(i);
      double paramValue = cvParam.m_ParamValue;
      if (cvParam.m_RoundParam) {
        //	paramValue = (double)((int) (paramValue + 0.5));
        paramValue = Math.rint(paramValue);
      }
      if (cvParam.m_AddAtEnd) {
        options[--end] = "" + Utils.doubleToString(paramValue, 4);
        options[--end] = "-" + cvParam.m_ParamChar;
      } else {
        options[start++] = "-" + cvParam.m_ParamChar;
        options[start++] = "" + Utils.doubleToString(paramValue, 4);
      }
    }
    // Add the static parameters
    System.arraycopy(m_ClassifierOptions, 0, options, start, m_ClassifierOptions.length);

    return options;
  }
  /**
   * Gets the scheme paramter with the given index.
   *
   * @param index the index for the parameter
   * @return the scheme parameter
   */
  public String getCVParameter(int index) {

    if (m_CVParams.size() <= index) {
      return "";
    }
    return ((CVParameter) m_CVParams.elementAt(index)).toString();
  }
  /**
   * Handles the various button clicking type activities.
   *
   * @param e a value of type 'ActionEvent'
   */
  public void actionPerformed(ActionEvent e) {

    if (e.getSource() == m_ConfigureBut) {
      selectProperty();
    } else if (e.getSource() == m_StatusBox) {
      // notify any listeners
      for (int i = 0; i < m_Listeners.size(); i++) {
        ActionListener temp = ((ActionListener) m_Listeners.elementAt(i));
        temp.actionPerformed(
            new ActionEvent(this, ActionEvent.ACTION_PERFORMED, "Editor status change"));
      }

      // Toggles whether the custom property is used
      if (m_StatusBox.getSelectedIndex() == 0) {
        m_Exp.setUsePropertyIterator(false);
        m_ConfigureBut.setEnabled(false);
        m_ArrayEditor.setEnabled(false);
        m_ArrayEditor.setValue(null);
        validate();
      } else {
        if (m_Exp.getPropertyArray() == null) {
          selectProperty();
        }
        if (m_Exp.getPropertyArray() == null) {
          m_StatusBox.setSelectedIndex(0);
        } else {
          m_Exp.setUsePropertyIterator(true);
          m_ConfigureBut.setEnabled(true);
          m_ArrayEditor.setEnabled(true);
        }
        validate();
      }
    }
  }
  /**
   * Finds the best parameter combination. (recursive for each parameter being optimised).
   *
   * @param depth the index of the parameter to be optimised at this level
   * @param trainData the data the search is based on
   * @param random a random number generator
   * @throws Exception if an error occurs
   */
  protected void findParamsByCrossValidation(int depth, Instances trainData, Random random)
      throws Exception {

    if (depth < m_CVParams.size()) {
      CVParameter cvParam = (CVParameter) m_CVParams.elementAt(depth);

      double upper;
      switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
        case 1:
          upper = m_NumAttributes;
          break;
        case 2:
          upper = m_TrainFoldSize;
          break;
        default:
          upper = cvParam.m_Upper;
          break;
      }
      double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
      for (cvParam.m_ParamValue = cvParam.m_Lower;
          cvParam.m_ParamValue <= upper;
          cvParam.m_ParamValue += increment) {
        findParamsByCrossValidation(depth + 1, trainData, random);
      }
    } else {

      Evaluation evaluation = new Evaluation(trainData);

      // Set the classifier options
      String[] options = createOptions();
      if (m_Debug) {
        System.err.print("Setting options for " + m_Classifier.getClass().getName() + ":");
        for (int i = 0; i < options.length; i++) {
          System.err.print(" " + options[i]);
        }
        System.err.println("");
      }
      ((OptionHandler) m_Classifier).setOptions(options);
      for (int j = 0; j < m_NumFolds; j++) {

        // We want to randomize the data the same way for every
        // learning scheme.
        Instances train = trainData.trainCV(m_NumFolds, j, new Random(1));
        Instances test = trainData.testCV(m_NumFolds, j);
        m_Classifier.buildClassifier(train);
        evaluation.setPriors(train);
        evaluation.evaluateModel(m_Classifier, test);
      }
      double error = evaluation.errorRate();
      if (m_Debug) {
        System.err.println("Cross-validated error rate: " + Utils.doubleToString(error, 6, 4));
      }
      if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {

        m_BestPerformance = error;
        m_BestClassifierOptions = createOptions();
      }
    }
  }
Пример #11
0
  /**
   * Calculates the performance stats for the default class and return results as a set of
   * Instances. The structure of these Instances is as follows:
   *
   * <p>
   *
   * <ul>
   *   <li><b>True Positives </b>
   *   <li><b>False Negatives</b>
   *   <li><b>False Positives</b>
   *   <li><b>True Negatives</b>
   *   <li><b>False Positive Rate</b>
   *   <li><b>True Positive Rate</b>
   *   <li><b>Precision</b>
   *   <li><b>Recall</b>
   *   <li><b>Fallout</b>
   *   <li><b>Threshold</b> contains the probability threshold that gives rise to the previous
   *       performance values.
   * </ul>
   *
   * <p>For the definitions of these measures, see TwoClassStats
   *
   * <p>
   *
   * @see TwoClassStats
   * @param predictions the predictions to base the curve on
   * @return datapoints as a set of instances, null if no predictions have been made.
   */
  public Instances getCurve(FastVector predictions) {

    if (predictions.size() == 0) {
      return null;
    }
    return getCurve(
        predictions, ((NominalPrediction) predictions.elementAt(0)).distribution().length - 1);
  }
Пример #12
0
  /**
   * Try to reduce the DL of the ruleset by testing removing the rules one by one in reverse order
   * and update all the stats
   *
   * @param expFPRate expected FP/(FP+FN), used in dataDL calculation
   * @param checkErr whether check if error rate >= 0.5
   */
  public void reduceDL(double expFPRate, boolean checkErr) {

    boolean needUpdate = false;
    double[] rulesetStat = new double[6];
    for (int j = 0; j < m_SimpleStats.size(); j++) {
      // Covered stats are cumulative
      rulesetStat[0] += ((double[]) m_SimpleStats.elementAt(j))[0];
      rulesetStat[2] += ((double[]) m_SimpleStats.elementAt(j))[2];
      rulesetStat[4] += ((double[]) m_SimpleStats.elementAt(j))[4];
      if (j == m_SimpleStats.size() - 1) { // Last rule
        rulesetStat[1] = ((double[]) m_SimpleStats.elementAt(j))[1];
        rulesetStat[3] = ((double[]) m_SimpleStats.elementAt(j))[3];
        rulesetStat[5] = ((double[]) m_SimpleStats.elementAt(j))[5];
      }
    }

    // Potential
    for (int k = m_SimpleStats.size() - 1; k >= 0; k--) {

      double[] ruleStat = (double[]) m_SimpleStats.elementAt(k);

      // rulesetStat updated
      double ifDeleted = potential(k, expFPRate, rulesetStat, ruleStat, checkErr);
      if (!Double.isNaN(ifDeleted)) {
        /*System.err.println("!!!deleted ("+k+"): save "+ifDeleted
          +" | "+rulesetStat[0]
          +" | "+rulesetStat[1]
          +" | "+rulesetStat[4]
          +" | "+rulesetStat[5]);
        */

        if (k == (m_SimpleStats.size() - 1)) removeLast();
        else {
          m_Ruleset.removeElementAt(k);
          needUpdate = true;
        }
      }
    }

    if (needUpdate) {
      m_Filtered = null;
      m_SimpleStats = null;
      countData();
    }
  }
  /**
   * Adds the statistics encapsulated in the supplied Evaluation object into this one. Does not
   * perform any checks for compatibility between the supplied Evaluation object and this one.
   *
   * @param evaluation the evaluation object to aggregate
   */
  public void aggregate(Evaluation evaluation) {
    m_Incorrect += evaluation.incorrect();
    m_Correct += evaluation.correct();
    m_Unclassified += evaluation.unclassified();
    m_MissingClass += evaluation.m_MissingClass;
    m_WithClass += evaluation.m_WithClass;

    if (evaluation.m_ConfusionMatrix != null) {
      double[][] newMatrix = evaluation.confusionMatrix();
      if (newMatrix != null) {
        for (int i = 0; i < m_ConfusionMatrix.length; i++) {
          for (int j = 0; j < m_ConfusionMatrix[i].length; j++) {
            m_ConfusionMatrix[i][j] += newMatrix[i][j];
          }
        }
      }
    }
    double[] newClassPriors = evaluation.m_ClassPriors;
    if (newClassPriors != null) {
      for (int i = 0; i < this.m_ClassPriors.length; i++) {
        m_ClassPriors[i] = newClassPriors[i];
      }
    }
    m_ClassPriorsSum = evaluation.m_ClassPriorsSum;
    m_TotalCost += evaluation.totalCost();
    m_SumErr += evaluation.m_SumErr;
    m_SumAbsErr += evaluation.m_SumAbsErr;
    m_SumSqrErr += evaluation.m_SumSqrErr;
    m_SumClass += evaluation.m_SumClass;
    m_SumSqrClass += evaluation.m_SumSqrClass;
    m_SumPredicted += evaluation.m_SumPredicted;
    m_SumSqrPredicted += evaluation.m_SumSqrPredicted;
    m_SumClassPredicted += evaluation.m_SumClassPredicted;
    m_SumPriorAbsErr += evaluation.m_SumPriorAbsErr;
    m_SumPriorSqrErr += evaluation.m_SumPriorSqrErr;
    m_SumKBInfo += evaluation.m_SumKBInfo;
    double[] newMarginCounts = evaluation.m_MarginCounts;
    if (newMarginCounts != null) {
      for (int i = 0; i < m_MarginCounts.length; i++) {
        m_MarginCounts[i] += newMarginCounts[i];
      }
    }
    m_SumPriorEntropy += evaluation.m_SumPriorEntropy;
    m_SumSchemeEntropy += evaluation.m_SumSchemeEntropy;
    m_TotalSizeOfRegions += evaluation.m_TotalSizeOfRegions;
    m_TotalCoverage += evaluation.m_TotalCoverage;

    FastVector predsToAdd = evaluation.m_Predictions;
    if (predsToAdd != null) {
      if (m_Predictions == null) {
        m_Predictions = new FastVector();
      }
      for (int i = 0; i < predsToAdd.size(); i++) {
        m_Predictions.addElement(predsToAdd.elementAt(i));
      }
    }
  }
Пример #14
0
  /**
   * @param predictions the predictions to use
   * @param classIndex the class index
   * @return the probabilities
   */
  private double[] getProbabilities(FastVector predictions, int classIndex) {

    // sort by predicted probability of the desired class.
    double[] probs = new double[predictions.size()];
    for (int i = 0; i < probs.length; i++) {
      NominalPrediction pred = (NominalPrediction) predictions.elementAt(i);
      probs[i] = pred.distribution()[classIndex];
    }
    return probs;
  }
  /** Scales numeric class predictions into shape sizes for plotting in the visualize panel. */
  protected void scaleNumericPredictions() {
    double maxErr;
    double minErr;
    double err;
    int i;
    Double errd;
    double temp;

    maxErr = Double.NEGATIVE_INFINITY;
    minErr = Double.POSITIVE_INFINITY;

    // find min/max errors
    for (i = 0; i < m_PlotSizes.size(); i++) {
      errd = (Double) m_PlotSizes.elementAt(i);
      if (errd != null) {
        err = Math.abs(errd.doubleValue());
        if (err < minErr) minErr = err;
        if (err > maxErr) maxErr = err;
      }
    }

    // scale errors
    for (i = 0; i < m_PlotSizes.size(); i++) {
      errd = (Double) m_PlotSizes.elementAt(i);
      if (errd != null) {
        err = Math.abs(errd.doubleValue());
        if (maxErr - minErr > 0) {
          temp =
              (((err - minErr) / (maxErr - minErr))
                  * (m_MaximumPlotSizeNumeric - m_MinimumPlotSizeNumeric + 1));
          m_PlotSizes.setElementAt(new Integer((int) temp) + m_MinimumPlotSizeNumeric, i);
        } else {
          m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i);
        }
      } else {
        m_PlotSizes.setElementAt(new Integer(m_MinimumPlotSizeNumeric), i);
      }
    }
  }
Пример #16
0
  /**
   * The description length of the theory for a given rule. Computed as:<br>
   * 0.5* [||k||+ S(t, k, k/t)]<br>
   * where k is the number of antecedents of the rule; t is the total possible antecedents that
   * could appear in a rule; ||K|| is the universal prior for k , log2*(k) and S(t,k,p) =
   * -k*log2(p)-(n-k)log2(1-p) is the subset encoding length.
   *
   * <p>Details see Quilan: "MDL and categorical theories (Continued)",ML95
   *
   * @param index the index of the given rule (assuming correct)
   * @return the theory DL, weighted if weight != 1.0
   */
  public double theoryDL(int index) {

    double k = ((Rule) m_Ruleset.elementAt(index)).size();

    if (k == 0) return 0.0;

    double tdl = Utils.log2(k);
    if (k > 1) // Approximation
    tdl += 2.0 * Utils.log2(tdl); // of log2 star
    tdl += subsetDL(m_Total, k, k / m_Total);
    // System.out.println("!!!theory: "+MDL_THEORY_WEIGHT * REDUNDANCY_FACTOR * tdl);
    return MDL_THEORY_WEIGHT * REDUNDANCY_FACTOR * tdl;
  }
Пример #17
0
  /**
   * @param args
   * @throws Exception
   */
  void analyseResult(DBScan algo, Instances data) throws Exception {
    FileWriter fw = new FileWriter(new File("topictrackresult"));
    BufferedWriter bw = new BufferedWriter(fw);
    StringBuilder sb = new StringBuilder();
    FastVector resultset = new FastVector();
    for (int i = 0; i < algo.numberOfClusters(); i++) {
      ArrayList<String> oneCluster = new ArrayList<String>();
      resultset.addElement(oneCluster);
    }

    for (int i = 0; i < algo.database.size(); i++) {
      DataObject dataObject = algo.database.getDataObject(Integer.toString(i));
      int IDraw = Integer.parseInt(dataObject.getKey());
      int ID = (int) data.instance(IDraw).value(0);
      String title = findFileNameWithID(ID);
      if (DataObject.NOISE != dataObject.getClusterLabel()) {
        int label = dataObject.getClusterLabel();
        ArrayList<String> filenamelib = (ArrayList<String>) resultset.elementAt(label);
        filenamelib.add(title);
      }
    }

    sb.append("The generated Clusters are:\n");
    for (int i = 0; i < algo.numberOfClusters(); i++) {
      sb.append("\r\n cluster" + i + "\n");
      ArrayList<String> filenamelib = (ArrayList<String>) resultset.elementAt(i);
      Iterator iter = filenamelib.iterator();
      while (iter.hasNext()) {
        String title = (String) iter.next();
        sb.append(title + "\n");
      }
    }

    bw.write(sb.toString());
    bw.close();
    fw.close();
  }
Пример #18
0
  /**
   * Static utility function to count the data covered by the rules after the given index in the
   * given rules, and then remove them. It returns the data not covered by the successive rules.
   *
   * @param data the data to be processed
   * @param rules the ruleset
   * @param index the given index
   * @return the data after processing
   */
  public static Instances rmCoveredBySuccessives(Instances data, FastVector rules, int index) {
    Instances rt = new Instances(data, 0);

    for (int i = 0; i < data.numInstances(); i++) {
      Instance datum = data.instance(i);
      boolean covered = false;

      for (int j = index + 1; j < rules.size(); j++) {
        Rule rule = (Rule) rules.elementAt(j);
        if (rule.covers(datum)) {
          covered = true;
          break;
        }
      }

      if (!covered) rt.add(datum);
    }
    return rt;
  }
  /**
   * Generates the classifier.
   *
   * @param instances set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances instances) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    Instances trainData = new Instances(instances);
    trainData.deleteWithMissingClass();

    if (!(m_Classifier instanceof OptionHandler)) {
      throw new IllegalArgumentException("Base classifier should be OptionHandler.");
    }
    m_InitOptions = ((OptionHandler) m_Classifier).getOptions();
    m_BestPerformance = -99;
    m_NumAttributes = trainData.numAttributes();
    Random random = new Random(m_Seed);
    trainData.randomize(random);
    m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();

    // Check whether there are any parameters to optimize
    if (m_CVParams.size() == 0) {
      m_Classifier.buildClassifier(trainData);
      m_BestClassifierOptions = m_InitOptions;
      return;
    }

    if (trainData.classAttribute().isNominal()) {
      trainData.stratify(m_NumFolds);
    }
    m_BestClassifierOptions = null;

    // Set up m_ClassifierOptions -- take getOptions() and remove
    // those being optimised.
    m_ClassifierOptions = ((OptionHandler) m_Classifier).getOptions();
    for (int i = 0; i < m_CVParams.size(); i++) {
      Utils.getOption(((CVParameter) m_CVParams.elementAt(i)).m_ParamChar, m_ClassifierOptions);
    }
    findParamsByCrossValidation(0, trainData, random);

    String[] options = (String[]) m_BestClassifierOptions.clone();
    ((OptionHandler) m_Classifier).setOptions(options);
    m_Classifier.buildClassifier(trainData);
  }
Пример #20
0
  /**
   * Parses a given list of options. Valid options are:
   *
   * <p>-D <br>
   * Turn on debugging output.
   *
   * <p>-S seed <br>
   * Random number seed (default 1).
   *
   * <p>-B classifierstring <br>
   * Classifierstring should contain the full class name of a scheme included for selection followed
   * by options to the classifier (required, option should be used once for each classifier).
   *
   * <p>-X num_folds <br>
   * Use cross validation error as the basis for classifier selection. (default 0, is to use error
   * on the training data instead)
   *
   * <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  public void setOptions(String[] options) throws Exception {

    setDebug(Utils.getFlag('D', options));

    String numFoldsString = Utils.getOption('X', options);
    if (numFoldsString.length() != 0) {
      setNumFolds(Integer.parseInt(numFoldsString));
    } else {
      setNumFolds(0);
    }

    String randomString = Utils.getOption('S', options);
    if (randomString.length() != 0) {
      setSeed(Integer.parseInt(randomString));
    } else {
      setSeed(1);
    }

    // Iterate through the schemes
    FastVector classifiers = new FastVector();
    while (true) {
      String classifierString = Utils.getOption('B', options);
      if (classifierString.length() == 0) {
        break;
      }
      String[] classifierSpec = Utils.splitOptions(classifierString);
      if (classifierSpec.length == 0) {
        throw new Exception("Invalid classifier specification string");
      }
      String classifierName = classifierSpec[0];
      classifierSpec[0] = "";
      classifiers.addElement(Classifier.forName(classifierName, classifierSpec));
    }
    if (classifiers.size() <= 1) {
      throw new Exception("At least two classifiers must be specified" + " with the -B option.");
    } else {
      Classifier[] classifiersArray = new Classifier[classifiers.size()];
      for (int i = 0; i < classifiersArray.length; i++) {
        classifiersArray[i] = (Classifier) classifiers.elementAt(i);
      }
      setClassifiers(classifiersArray);
    }
  }
  /**
   * Returns description of the cross-validated classifier.
   *
   * @return description of the cross-validated classifier as a string
   */
  public String toString() {

    if (m_InitOptions == null) return "CVParameterSelection: No model built yet.";

    String result =
        "Cross-validated Parameter selection.\n"
            + "Classifier: "
            + m_Classifier.getClass().getName()
            + "\n";
    try {
      for (int i = 0; i < m_CVParams.size(); i++) {
        CVParameter cvParam = (CVParameter) m_CVParams.elementAt(i);
        result +=
            "Cross-validation Parameter: '-"
                + cvParam.m_ParamChar
                + "'"
                + " ranged from "
                + cvParam.m_Lower
                + " to ";
        switch ((int) (cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
          case 1:
            result += m_NumAttributes;
            break;
          case 2:
            result += m_TrainFoldSize;
            break;
          default:
            result += cvParam.m_Upper;
            break;
        }
        result += " with " + cvParam.m_Steps + " steps\n";
      }
    } catch (Exception ex) {
      result += ex.getMessage();
    }
    result +=
        "Classifier Options: "
            + Utils.joinOptions(m_BestClassifierOptions)
            + "\n\n"
            + m_Classifier.toString();
    return result;
  }
Пример #22
0
  /**
   * Method that finds all association rules.
   *
   * @exception Exception if an attribute is numeric
   */
  private void findRulesQuickly() throws Exception {

    FastVector[] rules;
    RuleGeneration currentItemSet;

    // Build rules
    for (int j = 0; j < m_Ls.size(); j++) {
      FastVector currentItemSets = (FastVector) m_Ls.elementAt(j);
      Enumeration enumItemSets = currentItemSets.elements();
      while (enumItemSets.hasMoreElements()) {
        currentItemSet = new RuleGeneration((ItemSet) enumItemSets.nextElement());
        m_best =
            currentItemSet.generateRules(
                m_numRules, m_midPoints, m_priors, m_expectation, m_instances, m_best, m_count);

        m_count = currentItemSet.m_count;
        if (!m_bestChanged && currentItemSet.m_change) m_bestChanged = true;
        // update minimum expected predictive accuracy to get into the n best
        if (m_best.size() > 0) m_expectation = ((RuleItem) m_best.first()).accuracy();
        else m_expectation = 0;
      }
    }
  }
  public void testRangeNone() throws Exception {

    int cind = 0;
    ((ThresholdSelector) m_Classifier)
        .setDesignatedClass(
            new SelectedTag(ThresholdSelector.OPTIMIZE_0, ThresholdSelector.TAGS_OPTIMIZE));
    ((ThresholdSelector) m_Classifier)
        .setRangeCorrection(
            new SelectedTag(ThresholdSelector.RANGE_NONE, ThresholdSelector.TAGS_RANGE));
    FastVector result = null;
    m_Instances.setClassIndex(1);
    result = useClassifier();
    assertTrue(result.size() != 0);
    double minp = 0;
    double maxp = 0;
    for (int i = 0; i < result.size(); i++) {
      NominalPrediction p = (NominalPrediction) result.elementAt(i);
      double prob = p.distribution()[cind];
      if ((i == 0) || (prob < minp)) minp = prob;
      if ((i == 0) || (prob > maxp)) maxp = prob;
    }
    assertTrue("Upper limit shouldn't increase", maxp <= 1.0);
    assertTrue("Lower limit shouldn'd decrease", minp >= 0.25);
  }
Пример #24
0
  /**
   * Calculates the performance stats for the desired class and return results as a set of
   * Instances.
   *
   * @param predictions the predictions to base the curve on
   * @param classIndex index of the class of interest.
   * @return datapoints as a set of instances.
   */
  public Instances getCurve(FastVector predictions, int classIndex) {

    if ((predictions.size() == 0)
        || (((NominalPrediction) predictions.elementAt(0)).distribution().length <= classIndex)) {
      System.out.println(
          "Foooobared "
              + predictions.size()
              + " "
              + ((NominalPrediction) predictions.elementAt(0)).distribution().length
              + " "
              + classIndex);
      return null;
    }

    double totPos = 0, totNeg = 0;
    double[] probs = getProbabilities(predictions, classIndex);

    // Get distribution of positive/negatives
    for (int i = 0; i < probs.length; i++) {
      NominalPrediction pred = (NominalPrediction) predictions.elementAt(i);
      if (pred.actual() == Prediction.MISSING_VALUE) {
        System.err.println(getClass().getName() + " Skipping prediction with missing class value");
        continue;
      }
      if (pred.weight() < 0) {
        System.err.println(getClass().getName() + " Skipping prediction with negative weight");
        continue;
      }
      if (pred.actual() == classIndex) {
        totPos += pred.weight();
      } else {
        totNeg += pred.weight();
      }
    }

    Instances insts = makeHeader();
    int[] sorted = Utils.sort(probs);
    TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);
    double threshold = 0;
    double cumulativePos = 0;
    double cumulativeNeg = 0;

    for (int i = 0; i < sorted.length; i++) {

      if ((i == 0) || (probs[sorted[i]] > threshold)) {
        tc.setTruePositive(tc.getTruePositive() - cumulativePos);
        tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
        tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
        tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
        threshold = probs[sorted[i]];
        insts.add(makeInstance(tc, threshold));
        cumulativePos = 0;
        cumulativeNeg = 0;
        if (i == sorted.length - 1) {
          break;
        }
      }

      NominalPrediction pred = (NominalPrediction) predictions.elementAt(sorted[i]);

      if (pred.actual() == Prediction.MISSING_VALUE) {
        System.err.println(getClass().getName() + " Skipping prediction with missing class value");
        continue;
      }
      if (pred.weight() < 0) {
        System.err.println(getClass().getName() + " Skipping prediction with negative weight");
        continue;
      }
      if (pred.actual() == classIndex) {
        cumulativePos += pred.weight();
      } else {
        cumulativeNeg += pred.weight();
      }

      /*
      System.out.println(tc + " " + probs[sorted[i]]
                         + " " + (pred.actual() == classIndex));
      */
      /*if ((i != (sorted.length - 1)) &&
               ((i == 0) ||
               (probs[sorted[i]] != probs[sorted[i - 1]]))) {
             insts.add(makeInstance(tc, probs[sorted[i]]));
      }*/
    }

    // make sure a zero point gets into the curve
    if (tc.getFalseNegative() != totPos || tc.getTrueNegative() != totNeg) {
      tc = new TwoClassStats(0, 0, totNeg, totPos);
      threshold = probs[sorted[sorted.length - 1]] + 10e-6;
      insts.add(makeInstance(tc, threshold));
    }

    return insts;
  }
Пример #25
0
  /**
   * Get the simple stats of one rule, including 6 parameters: 0: coverage; 1:uncoverage; 2: true
   * positive; 3: true negatives; 4: false positives; 5: false negatives
   *
   * @param index the index of the rule
   * @return the stats
   */
  public double[] getSimpleStats(int index) {
    if ((m_SimpleStats != null) && (index < m_SimpleStats.size()))
      return (double[]) m_SimpleStats.elementAt(index);

    return null;
  }
Пример #26
0
  /**
   * Compute the minimal data description length of the ruleset if the rule in the given position is
   * deleted.<br>
   * The min_data_DL_if_deleted = data_DL_if_deleted - potential
   *
   * @param index the index of the rule in question
   * @param expFPRate expected FP/(FP+FN), used in dataDL calculation
   * @param checkErr whether check if error rate >= 0.5
   * @return the minDataDL
   */
  public double minDataDLIfDeleted(int index, double expFPRate, boolean checkErr) {
    // System.out.println("!!!Enter without: ");
    double[] rulesetStat = new double[6]; // Stats of ruleset if deleted
    int more = m_Ruleset.size() - 1 - index; // How many rules after?
    FastVector indexPlus = new FastVector(more); // Their stats

    // 0...(index-1) are OK
    for (int j = 0; j < index; j++) {
      // Covered stats are cumulative
      rulesetStat[0] += ((double[]) m_SimpleStats.elementAt(j))[0];
      rulesetStat[2] += ((double[]) m_SimpleStats.elementAt(j))[2];
      rulesetStat[4] += ((double[]) m_SimpleStats.elementAt(j))[4];
    }

    // Recount data from index+1
    Instances data = (index == 0) ? m_Data : ((Instances[]) m_Filtered.elementAt(index - 1))[1];
    // System.out.println("!!!without: " + data.sumOfWeights());

    for (int j = (index + 1); j < m_Ruleset.size(); j++) {
      double[] stats = new double[6];
      Instances[] split = computeSimpleStats(j, data, stats, null);
      indexPlus.addElement(stats);
      rulesetStat[0] += stats[0];
      rulesetStat[2] += stats[2];
      rulesetStat[4] += stats[4];
      data = split[1];
    }
    // Uncovered stats are those of the last rule
    if (more > 0) {
      rulesetStat[1] = ((double[]) indexPlus.lastElement())[1];
      rulesetStat[3] = ((double[]) indexPlus.lastElement())[3];
      rulesetStat[5] = ((double[]) indexPlus.lastElement())[5];
    } else if (index > 0) {
      rulesetStat[1] = ((double[]) m_SimpleStats.elementAt(index - 1))[1];
      rulesetStat[3] = ((double[]) m_SimpleStats.elementAt(index - 1))[3];
      rulesetStat[5] = ((double[]) m_SimpleStats.elementAt(index - 1))[5];
    } else { // Null coverage
      rulesetStat[1] =
          ((double[]) m_SimpleStats.elementAt(0))[0] + ((double[]) m_SimpleStats.elementAt(0))[1];
      rulesetStat[3] =
          ((double[]) m_SimpleStats.elementAt(0))[3] + ((double[]) m_SimpleStats.elementAt(0))[4];
      rulesetStat[5] =
          ((double[]) m_SimpleStats.elementAt(0))[2] + ((double[]) m_SimpleStats.elementAt(0))[5];
    }

    // Potential
    double potential = 0;
    for (int k = index + 1; k < m_Ruleset.size(); k++) {
      double[] ruleStat = (double[]) indexPlus.elementAt(k - index - 1);
      double ifDeleted = potential(k, expFPRate, rulesetStat, ruleStat, checkErr);
      if (!Double.isNaN(ifDeleted)) potential += ifDeleted;
    }

    // Data DL of the ruleset without the rule
    // Note that ruleset stats has already been updated to reflect
    // deletion if any potential
    double dataDLWithout =
        dataDL(expFPRate, rulesetStat[0], rulesetStat[1], rulesetStat[4], rulesetStat[5]);
    // System.out.println("!!!without: "+dataDLWithout + " |potential: "+
    //		   potential);
    // Why subtract potential again?  To reflect change of theory DL??
    return (dataDLWithout - potential);
  }
Пример #27
0
  /**
   * Makes a database query to convert a table into a set of instances
   *
   * @param query the query to convert to instances
   * @return the instances contained in the result of the query, NULL if the SQL query doesn't
   *     return a ResultSet, e.g., DELETE/INSERT/UPDATE
   * @throws Exception if an error occurs
   */
  public Instances retrieveInstances(String query) throws Exception {

    if (m_Debug) System.err.println("Executing query: " + query);
    connectToDatabase();
    if (execute(query) == false) {
      if (m_PreparedStatement.getUpdateCount() == -1) {
        throw new Exception("Query didn't produce results");
      } else {
        if (m_Debug) System.err.println(m_PreparedStatement.getUpdateCount() + " rows affected.");
        close();
        return null;
      }
    }
    ResultSet rs = getResultSet();
    if (m_Debug) System.err.println("Getting metadata...");
    ResultSetMetaData md = rs.getMetaData();
    if (m_Debug) System.err.println("Completed getting metadata...");

    // Determine structure of the instances
    int numAttributes = md.getColumnCount();
    int[] attributeTypes = new int[numAttributes];
    Hashtable[] nominalIndexes = new Hashtable[numAttributes];
    FastVector[] nominalStrings = new FastVector[numAttributes];
    for (int i = 1; i <= numAttributes; i++) {
      /* switch (md.getColumnType(i)) {
      case Types.CHAR:
      case Types.VARCHAR:
      case Types.LONGVARCHAR:
      case Types.BINARY:
      case Types.VARBINARY:
      case Types.LONGVARBINARY:*/

      switch (translateDBColumnType(md.getColumnTypeName(i))) {
        case STRING:
          // System.err.println("String --> nominal");
          attributeTypes[i - 1] = Attribute.NOMINAL;
          nominalIndexes[i - 1] = new Hashtable();
          nominalStrings[i - 1] = new FastVector();
          break;
        case TEXT:
          // System.err.println("Text --> string");
          attributeTypes[i - 1] = Attribute.STRING;
          nominalIndexes[i - 1] = new Hashtable();
          nominalStrings[i - 1] = new FastVector();
          break;
        case BOOL:
          // System.err.println("boolean --> nominal");
          attributeTypes[i - 1] = Attribute.NOMINAL;
          nominalIndexes[i - 1] = new Hashtable();
          nominalIndexes[i - 1].put("false", new Double(0));
          nominalIndexes[i - 1].put("true", new Double(1));
          nominalStrings[i - 1] = new FastVector();
          nominalStrings[i - 1].addElement("false");
          nominalStrings[i - 1].addElement("true");
          break;
        case DOUBLE:
          // System.err.println("BigDecimal --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case BYTE:
          // System.err.println("byte --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case SHORT:
          // System.err.println("short --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case INTEGER:
          // System.err.println("int --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case LONG:
          // System.err.println("long --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case FLOAT:
          // System.err.println("float --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case DATE:
          attributeTypes[i - 1] = Attribute.DATE;
          break;
        case TIME:
          attributeTypes[i - 1] = Attribute.DATE;
          break;
        default:
          // System.err.println("Unknown column type");
          attributeTypes[i - 1] = Attribute.STRING;
      }
    }

    // For sqlite
    // cache column names because the last while(rs.next()) { iteration for
    // the tuples below will close the md object:
    Vector<String> columnNames = new Vector<String>();
    for (int i = 0; i < numAttributes; i++) {
      columnNames.add(md.getColumnName(i + 1));
    }

    // Step through the tuples
    if (m_Debug) System.err.println("Creating instances...");
    FastVector instances = new FastVector();
    int rowCount = 0;
    while (rs.next()) {
      if (rowCount % 100 == 0) {
        if (m_Debug) {
          System.err.print("read " + rowCount + " instances \r");
          System.err.flush();
        }
      }
      double[] vals = new double[numAttributes];
      for (int i = 1; i <= numAttributes; i++) {
        /*switch (md.getColumnType(i)) {
        case Types.CHAR:
        case Types.VARCHAR:
        case Types.LONGVARCHAR:
        case Types.BINARY:
        case Types.VARBINARY:
        case Types.LONGVARBINARY:*/
        switch (translateDBColumnType(md.getColumnTypeName(i))) {
          case STRING:
            String str = rs.getString(i);

            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              Double index = (Double) nominalIndexes[i - 1].get(str);
              if (index == null) {
                index = new Double(nominalStrings[i - 1].size());
                nominalIndexes[i - 1].put(str, index);
                nominalStrings[i - 1].addElement(str);
              }
              vals[i - 1] = index.doubleValue();
            }
            break;
          case TEXT:
            String txt = rs.getString(i);

            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              Double index = (Double) nominalIndexes[i - 1].get(txt);
              if (index == null) {
                index = new Double(nominalStrings[i - 1].size());
                nominalIndexes[i - 1].put(txt, index);
                nominalStrings[i - 1].addElement(txt);
              }
              vals[i - 1] = index.doubleValue();
            }
            break;
          case BOOL:
            boolean boo = rs.getBoolean(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              vals[i - 1] = (boo ? 1.0 : 0.0);
            }
            break;
          case DOUBLE:
            //	  BigDecimal bd = rs.getBigDecimal(i, 4);
            double dd = rs.getDouble(i);
            // Use the column precision instead of 4?
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              //	    newInst.setValue(i - 1, bd.doubleValue());
              vals[i - 1] = dd;
            }
            break;
          case BYTE:
            byte by = rs.getByte(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              vals[i - 1] = (double) by;
            }
            break;
          case SHORT:
            short sh = rs.getShort(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              vals[i - 1] = (double) sh;
            }
            break;
          case INTEGER:
            int in = rs.getInt(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              vals[i - 1] = (double) in;
            }
            break;
          case LONG:
            long lo = rs.getLong(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              vals[i - 1] = (double) lo;
            }
            break;
          case FLOAT:
            float fl = rs.getFloat(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              vals[i - 1] = (double) fl;
            }
            break;
          case DATE:
            Date date = rs.getDate(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              // TODO: Do a value check here.
              vals[i - 1] = (double) date.getTime();
            }
            break;
          case TIME:
            Time time = rs.getTime(i);
            if (rs.wasNull()) {
              vals[i - 1] = Instance.missingValue();
            } else {
              // TODO: Do a value check here.
              vals[i - 1] = (double) time.getTime();
            }
            break;
          default:
            vals[i - 1] = Instance.missingValue();
        }
      }
      Instance newInst;
      if (m_CreateSparseData) {
        newInst = new SparseInstance(1.0, vals);
      } else {
        newInst = new Instance(1.0, vals);
      }
      instances.addElement(newInst);
      rowCount++;
    }
    // disconnectFromDatabase();  (perhaps other queries might be made)

    // Create the header and add the instances to the dataset
    if (m_Debug) System.err.println("Creating header...");
    FastVector attribInfo = new FastVector();
    for (int i = 0; i < numAttributes; i++) {
      /* Fix for databases that uppercase column names */
      // String attribName = attributeCaseFix(md.getColumnName(i + 1));
      String attribName = attributeCaseFix(columnNames.get(i));
      switch (attributeTypes[i]) {
        case Attribute.NOMINAL:
          attribInfo.addElement(new Attribute(attribName, nominalStrings[i]));
          break;
        case Attribute.NUMERIC:
          attribInfo.addElement(new Attribute(attribName));
          break;
        case Attribute.STRING:
          Attribute att = new Attribute(attribName, (FastVector) null);
          attribInfo.addElement(att);
          for (int n = 0; n < nominalStrings[i].size(); n++) {
            att.addStringValue((String) nominalStrings[i].elementAt(n));
          }
          break;
        case Attribute.DATE:
          attribInfo.addElement(new Attribute(attribName, (String) null));
          break;
        default:
          throw new Exception("Unknown attribute type");
      }
    }
    Instances result = new Instances("QueryResult", attribInfo, instances.size());
    for (int i = 0; i < instances.size(); i++) {
      result.add((Instance) instances.elementAt(i));
    }
    close(rs);

    return result;
  }
Пример #28
0
 /**
  * Returns the next element.
  *
  * @return the next element to be enumerated
  */
 public final Object nextElement() {
   Object result = m_Vector.elementAt(m_Counter);
   m_Counter++;
   if (m_Counter == m_SpecialElement) m_Counter++;
   return result;
 }