Exemplo n.º 1
0
  /**
   * Inserts an instance into the hash table
   *
   * @param inst instance to be inserted
   * @param instA to create the hash key from
   * @throws Exception if the instance can't be inserted
   */
  private void insertIntoTable(Instance inst, double[] instA) throws Exception {

    double[] tempClassDist2;
    double[] newDist;
    DecisionTableHashKey thekey;

    if (instA != null) {
      thekey = new DecisionTableHashKey(instA);
    } else {
      thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false);
    }

    // see if this one is already in the table
    tempClassDist2 = (double[]) m_entries.get(thekey);
    if (tempClassDist2 == null) {
      if (m_classIsNominal) {
        newDist = new double[m_theInstances.classAttribute().numValues()];

        // Leplace estimation
        for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
          newDist[i] = 1.0;
        }

        newDist[(int) inst.classValue()] = inst.weight();

        // add to the table
        m_entries.put(thekey, newDist);
      } else {
        newDist = new double[2];
        newDist[0] = inst.classValue() * inst.weight();
        newDist[1] = inst.weight();

        // add to the table
        m_entries.put(thekey, newDist);
      }
    } else {

      // update the distribution for this instance
      if (m_classIsNominal) {
        tempClassDist2[(int) inst.classValue()] += inst.weight();

        // update the table
        m_entries.put(thekey, tempClassDist2);
      } else {
        tempClassDist2[0] += (inst.classValue() * inst.weight());
        tempClassDist2[1] += inst.weight();

        // update the table
        m_entries.put(thekey, tempClassDist2);
      }
    }
  }
Exemplo n.º 2
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @throws Exception if distribution can't be computed
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    DecisionTableHashKey thekey;
    double[] tempDist;
    double[] normDist;

    m_disTransform.input(instance);
    m_disTransform.batchFinished();
    instance = m_disTransform.output();

    m_delTransform.input(instance);
    m_delTransform.batchFinished();
    instance = m_delTransform.output();

    thekey = new DecisionTableHashKey(instance, instance.numAttributes(), false);

    // if this one is not in the table
    if ((tempDist = (double[]) m_entries.get(thekey)) == null) {
      if (m_useIBk) {
        tempDist = m_ibk.distributionForInstance(instance);
      } else {
        if (!m_classIsNominal) {
          tempDist = new double[1];
          tempDist[0] = m_majority;
        } else {
          tempDist = m_classPriors.clone();
          /*tempDist = new double [m_theInstances.classAttribute().numValues()];
          tempDist[(int)m_majority] = 1.0; */
        }
      }
    } else {
      if (!m_classIsNominal) {
        normDist = new double[1];
        normDist[0] = (tempDist[0] / tempDist[1]);
        tempDist = normDist;
      } else {

        // normalise distribution
        normDist = new double[tempDist.length];
        System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
        Utils.normalize(normDist);
        tempDist = normDist;
      }
    }
    return tempDist;
  }
Exemplo n.º 3
0
  /**
   * Calculates the accuracy on a test fold for internal cross validation of feature sets
   *
   * @param fold set of instances to be "left out" and classified
   * @param fs currently selected feature set
   * @return the accuracy for the fold
   * @throws Exception if something goes wrong
   */
  double evaluateFoldCV(Instances fold, int[] fs) throws Exception {

    int i;
    int ruleCount = 0;
    int numFold = fold.numInstances();
    int numCl = m_theInstances.classAttribute().numValues();
    double[][] class_distribs = new double[numFold][numCl];
    double[] instA = new double[fs.length];
    double[] normDist;
    DecisionTableHashKey thekey;
    double acc = 0.0;
    int classI = m_theInstances.classIndex();
    Instance inst;

    if (m_classIsNominal) {
      normDist = new double[numCl];
    } else {
      normDist = new double[2];
    }

    // first *remove* instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);
      for (int j = 0; j < fs.length; j++) {
        if (fs[j] == classI) {
          instA[j] = Double.MAX_VALUE; // missing for the class
        } else if (inst.isMissing(fs[j])) {
          instA[j] = Double.MAX_VALUE;
        } else {
          instA[j] = inst.value(fs[j]);
        }
      }
      thekey = new DecisionTableHashKey(instA);
      if ((class_distribs[i] = (double[]) m_entries.get(thekey)) == null) {
        throw new Error("This should never happen!");
      } else {
        if (m_classIsNominal) {
          class_distribs[i][(int) inst.classValue()] -= inst.weight();
        } else {
          class_distribs[i][0] -= (inst.classValue() * inst.weight());
          class_distribs[i][1] -= inst.weight();
        }
        ruleCount++;
      }
      m_classPriorCounts[(int) inst.classValue()] -= inst.weight();
    }
    double[] classPriors = m_classPriorCounts.clone();
    Utils.normalize(classPriors);

    // now classify instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);
      System.arraycopy(class_distribs[i], 0, normDist, 0, normDist.length);
      if (m_classIsNominal) {
        boolean ok = false;
        for (int j = 0; j < normDist.length; j++) {
          if (Utils.gr(normDist[j], 1.0)) {
            ok = true;
            break;
          }
        }

        if (!ok) { // majority class
          normDist = classPriors.clone();
        }

        //	if (ok) {
        Utils.normalize(normDist);
        if (m_evaluationMeasure == EVAL_AUC) {
          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
        } else {
          m_evaluation.evaluateModelOnce(normDist, inst);
        }
        /*	} else {
          normDist[(int)m_majority] = 1.0;
          if (m_evaluationMeasure == EVAL_AUC) {
            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
          } else {
            m_evaluation.evaluateModelOnce(normDist, inst);
          }
        } */
      } else {
        if (Utils.eq(normDist[1], 0.0)) {
          double[] temp = new double[1];
          temp[0] = m_majority;
          m_evaluation.evaluateModelOnce(temp, inst);
        } else {
          double[] temp = new double[1];
          temp[0] = normDist[0] / normDist[1];
          m_evaluation.evaluateModelOnce(temp, inst);
        }
      }
    }

    // now re-insert instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);

      m_classPriorCounts[(int) inst.classValue()] += inst.weight();

      if (m_classIsNominal) {
        class_distribs[i][(int) inst.classValue()] += inst.weight();
      } else {
        class_distribs[i][0] += (inst.classValue() * inst.weight());
        class_distribs[i][1] += inst.weight();
      }
    }
    return acc;
  }
Exemplo n.º 4
0
  /**
   * Classifies an instance for internal leave one out cross validation of feature sets
   *
   * @param instance instance to be "left out" and classified
   * @param instA feature values of the selected features for the instance
   * @return the classification of the instance
   * @throws Exception if something goes wrong
   */
  double evaluateInstanceLeaveOneOut(Instance instance, double[] instA) throws Exception {

    DecisionTableHashKey thekey;
    double[] tempDist;
    double[] normDist;

    thekey = new DecisionTableHashKey(instA);
    if (m_classIsNominal) {

      // if this one is not in the table
      if ((tempDist = (double[]) m_entries.get(thekey)) == null) {
        throw new Error("This should never happen!");
      } else {
        normDist = new double[tempDist.length];
        System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
        normDist[(int) instance.classValue()] -= instance.weight();

        // update the table
        // first check to see if the class counts are all zero now
        boolean ok = false;
        for (int i = 0; i < normDist.length; i++) {
          if (Utils.gr(normDist[i], 1.0)) {
            ok = true;
            break;
          }
        }

        //	downdate the class prior counts
        m_classPriorCounts[(int) instance.classValue()] -= instance.weight();
        double[] classPriors = m_classPriorCounts.clone();
        Utils.normalize(classPriors);
        if (!ok) { // majority class
          normDist = classPriors;
        }

        m_classPriorCounts[(int) instance.classValue()] += instance.weight();

        // if (ok) {
        Utils.normalize(normDist);
        if (m_evaluationMeasure == EVAL_AUC) {
          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);
        } else {
          m_evaluation.evaluateModelOnce(normDist, instance);
        }
        return Utils.maxIndex(normDist);
        /*} else {
          normDist = new double [normDist.length];
          normDist[(int)m_majority] = 1.0;
          if (m_evaluationMeasure == EVAL_AUC) {
            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);
          } else {
            m_evaluation.evaluateModelOnce(normDist, instance);
          }
          return m_majority;
        } */
      }
      //      return Utils.maxIndex(tempDist);
    } else {

      // see if this one is already in the table
      if ((tempDist = (double[]) m_entries.get(thekey)) != null) {
        normDist = new double[tempDist.length];
        System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
        normDist[0] -= (instance.classValue() * instance.weight());
        normDist[1] -= instance.weight();
        if (Utils.eq(normDist[1], 0.0)) {
          double[] temp = new double[1];
          temp[0] = m_majority;
          m_evaluation.evaluateModelOnce(temp, instance);
          return m_majority;
        } else {
          double[] temp = new double[1];
          temp[0] = normDist[0] / normDist[1];
          m_evaluation.evaluateModelOnce(temp, instance);
          return temp[0];
        }
      } else {
        throw new Error("This should never happen!");
      }
    }

    // shouldn't get here
    // return 0.0;
  }
Exemplo n.º 5
0
  /**
   * Returns a description of the classifier.
   *
   * @return a description of the classifier as a string.
   */
  public String toString() {

    if (m_entries == null) {
      return "Decision Table: No model built yet.";
    } else {
      StringBuffer text = new StringBuffer();

      text.append(
          "Decision Table:"
              + "\n\nNumber of training instances: "
              + m_numInstances
              + "\nNumber of Rules : "
              + m_entries.size()
              + "\n");

      if (m_useIBk) {
        text.append("Non matches covered by IB1.\n");
      } else {
        text.append("Non matches covered by Majority class.\n");
      }

      text.append(m_search.toString());
      /*text.append("Best first search for feature set,\nterminated after "+
      m_maxStale+" non improving subsets.\n"); */

      text.append("Evaluation (for feature selection): CV ");
      if (m_CVFolds > 1) {
        text.append("(" + m_CVFolds + " fold) ");
      } else {
        text.append("(leave one out) ");
      }
      text.append("\nFeature set: " + printFeatures());

      if (m_displayRules) {

        // find out the max column width
        int maxColWidth = 0;
        for (int i = 0; i < m_dtInstances.numAttributes(); i++) {
          if (m_dtInstances.attribute(i).name().length() > maxColWidth) {
            maxColWidth = m_dtInstances.attribute(i).name().length();
          }

          if (m_classIsNominal || (i != m_dtInstances.classIndex())) {
            Enumeration e = m_dtInstances.attribute(i).enumerateValues();
            while (e.hasMoreElements()) {
              String ss = (String) e.nextElement();
              if (ss.length() > maxColWidth) {
                maxColWidth = ss.length();
              }
            }
          }
        }

        text.append("\n\nRules:\n");
        StringBuffer tm = new StringBuffer();
        for (int i = 0; i < m_dtInstances.numAttributes(); i++) {
          if (m_dtInstances.classIndex() != i) {
            int d = maxColWidth - m_dtInstances.attribute(i).name().length();
            tm.append(m_dtInstances.attribute(i).name());
            for (int j = 0; j < d + 1; j++) {
              tm.append(" ");
            }
          }
        }
        tm.append(m_dtInstances.attribute(m_dtInstances.classIndex()).name() + "  ");

        for (int i = 0; i < tm.length() + 10; i++) {
          text.append("=");
        }
        text.append("\n");
        text.append(tm);
        text.append("\n");
        for (int i = 0; i < tm.length() + 10; i++) {
          text.append("=");
        }
        text.append("\n");

        Enumeration e = m_entries.keys();
        while (e.hasMoreElements()) {
          DecisionTableHashKey tt = (DecisionTableHashKey) e.nextElement();
          text.append(tt.toString(m_dtInstances, maxColWidth));
          double[] ClassDist = (double[]) m_entries.get(tt);

          if (m_classIsNominal) {
            int m = Utils.maxIndex(ClassDist);
            try {
              text.append(m_dtInstances.classAttribute().value(m) + "\n");
            } catch (Exception ee) {
              System.out.println(ee.getMessage());
            }
          } else {
            text.append((ClassDist[0] / ClassDist[1]) + "\n");
          }
        }

        for (int i = 0; i < tm.length() + 10; i++) {
          text.append("=");
        }
        text.append("\n");
        text.append("\n");
      }
      return text.toString();
    }
  }