示例#1
0
  /**
   * Inserts an instance into the hash table
   *
   * @param inst instance to be inserted
   * @param instA to create the hash key from
   * @throws Exception if the instance can't be inserted
   */
  private void insertIntoTable(Instance inst, double[] instA) throws Exception {

    double[] tempClassDist2;
    double[] newDist;
    DecisionTableHashKey thekey;

    if (instA != null) {
      thekey = new DecisionTableHashKey(instA);
    } else {
      thekey = new DecisionTableHashKey(inst, inst.numAttributes(), false);
    }

    // see if this one is already in the table
    tempClassDist2 = (double[]) m_entries.get(thekey);
    if (tempClassDist2 == null) {
      if (m_classIsNominal) {
        newDist = new double[m_theInstances.classAttribute().numValues()];

        // Leplace estimation
        for (int i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
          newDist[i] = 1.0;
        }

        newDist[(int) inst.classValue()] = inst.weight();

        // add to the table
        m_entries.put(thekey, newDist);
      } else {
        newDist = new double[2];
        newDist[0] = inst.classValue() * inst.weight();
        newDist[1] = inst.weight();

        // add to the table
        m_entries.put(thekey, newDist);
      }
    } else {

      // update the distribution for this instance
      if (m_classIsNominal) {
        tempClassDist2[(int) inst.classValue()] += inst.weight();

        // update the table
        m_entries.put(thekey, tempClassDist2);
      } else {
        tempClassDist2[0] += (inst.classValue() * inst.weight());
        tempClassDist2[1] += inst.weight();

        // update the table
        m_entries.put(thekey, tempClassDist2);
      }
    }
  }
示例#2
0
  /**
   * Calculates the accuracy on a test fold for internal cross validation of feature sets
   *
   * @param fold set of instances to be "left out" and classified
   * @param fs currently selected feature set
   * @return the accuracy for the fold
   * @throws Exception if something goes wrong
   */
  double evaluateFoldCV(Instances fold, int[] fs) throws Exception {

    int i;
    int ruleCount = 0;
    int numFold = fold.numInstances();
    int numCl = m_theInstances.classAttribute().numValues();
    double[][] class_distribs = new double[numFold][numCl];
    double[] instA = new double[fs.length];
    double[] normDist;
    DecisionTableHashKey thekey;
    double acc = 0.0;
    int classI = m_theInstances.classIndex();
    Instance inst;

    if (m_classIsNominal) {
      normDist = new double[numCl];
    } else {
      normDist = new double[2];
    }

    // first *remove* instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);
      for (int j = 0; j < fs.length; j++) {
        if (fs[j] == classI) {
          instA[j] = Double.MAX_VALUE; // missing for the class
        } else if (inst.isMissing(fs[j])) {
          instA[j] = Double.MAX_VALUE;
        } else {
          instA[j] = inst.value(fs[j]);
        }
      }
      thekey = new DecisionTableHashKey(instA);
      if ((class_distribs[i] = (double[]) m_entries.get(thekey)) == null) {
        throw new Error("This should never happen!");
      } else {
        if (m_classIsNominal) {
          class_distribs[i][(int) inst.classValue()] -= inst.weight();
        } else {
          class_distribs[i][0] -= (inst.classValue() * inst.weight());
          class_distribs[i][1] -= inst.weight();
        }
        ruleCount++;
      }
      m_classPriorCounts[(int) inst.classValue()] -= inst.weight();
    }
    double[] classPriors = m_classPriorCounts.clone();
    Utils.normalize(classPriors);

    // now classify instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);
      System.arraycopy(class_distribs[i], 0, normDist, 0, normDist.length);
      if (m_classIsNominal) {
        boolean ok = false;
        for (int j = 0; j < normDist.length; j++) {
          if (Utils.gr(normDist[j], 1.0)) {
            ok = true;
            break;
          }
        }

        if (!ok) { // majority class
          normDist = classPriors.clone();
        }

        //	if (ok) {
        Utils.normalize(normDist);
        if (m_evaluationMeasure == EVAL_AUC) {
          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
        } else {
          m_evaluation.evaluateModelOnce(normDist, inst);
        }
        /*	} else {
          normDist[(int)m_majority] = 1.0;
          if (m_evaluationMeasure == EVAL_AUC) {
            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
          } else {
            m_evaluation.evaluateModelOnce(normDist, inst);
          }
        } */
      } else {
        if (Utils.eq(normDist[1], 0.0)) {
          double[] temp = new double[1];
          temp[0] = m_majority;
          m_evaluation.evaluateModelOnce(temp, inst);
        } else {
          double[] temp = new double[1];
          temp[0] = normDist[0] / normDist[1];
          m_evaluation.evaluateModelOnce(temp, inst);
        }
      }
    }

    // now re-insert instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);

      m_classPriorCounts[(int) inst.classValue()] += inst.weight();

      if (m_classIsNominal) {
        class_distribs[i][(int) inst.classValue()] += inst.weight();
      } else {
        class_distribs[i][0] += (inst.classValue() * inst.weight());
        class_distribs[i][1] += inst.weight();
      }
    }
    return acc;
  }
示例#3
0
  /**
   * Classifies an instance for internal leave one out cross validation of feature sets
   *
   * @param instance instance to be "left out" and classified
   * @param instA feature values of the selected features for the instance
   * @return the classification of the instance
   * @throws Exception if something goes wrong
   */
  double evaluateInstanceLeaveOneOut(Instance instance, double[] instA) throws Exception {

    DecisionTableHashKey thekey;
    double[] tempDist;
    double[] normDist;

    thekey = new DecisionTableHashKey(instA);
    if (m_classIsNominal) {

      // if this one is not in the table
      if ((tempDist = (double[]) m_entries.get(thekey)) == null) {
        throw new Error("This should never happen!");
      } else {
        normDist = new double[tempDist.length];
        System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
        normDist[(int) instance.classValue()] -= instance.weight();

        // update the table
        // first check to see if the class counts are all zero now
        boolean ok = false;
        for (int i = 0; i < normDist.length; i++) {
          if (Utils.gr(normDist[i], 1.0)) {
            ok = true;
            break;
          }
        }

        //	downdate the class prior counts
        m_classPriorCounts[(int) instance.classValue()] -= instance.weight();
        double[] classPriors = m_classPriorCounts.clone();
        Utils.normalize(classPriors);
        if (!ok) { // majority class
          normDist = classPriors;
        }

        m_classPriorCounts[(int) instance.classValue()] += instance.weight();

        // if (ok) {
        Utils.normalize(normDist);
        if (m_evaluationMeasure == EVAL_AUC) {
          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);
        } else {
          m_evaluation.evaluateModelOnce(normDist, instance);
        }
        return Utils.maxIndex(normDist);
        /*} else {
          normDist = new double [normDist.length];
          normDist[(int)m_majority] = 1.0;
          if (m_evaluationMeasure == EVAL_AUC) {
            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, instance);
          } else {
            m_evaluation.evaluateModelOnce(normDist, instance);
          }
          return m_majority;
        } */
      }
      //      return Utils.maxIndex(tempDist);
    } else {

      // see if this one is already in the table
      if ((tempDist = (double[]) m_entries.get(thekey)) != null) {
        normDist = new double[tempDist.length];
        System.arraycopy(tempDist, 0, normDist, 0, tempDist.length);
        normDist[0] -= (instance.classValue() * instance.weight());
        normDist[1] -= instance.weight();
        if (Utils.eq(normDist[1], 0.0)) {
          double[] temp = new double[1];
          temp[0] = m_majority;
          m_evaluation.evaluateModelOnce(temp, instance);
          return m_majority;
        } else {
          double[] temp = new double[1];
          temp[0] = normDist[0] / normDist[1];
          m_evaluation.evaluateModelOnce(temp, instance);
          return temp[0];
        }
      } else {
        throw new Error("This should never happen!");
      }
    }

    // shouldn't get here
    // return 0.0;
  }
示例#4
0
  /**
   * Generates the classifier.
   *
   * @param data set of instances serving as training data
   * @throws Exception if the classifier has not been generated successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    m_theInstances = new Instances(data);
    m_theInstances.deleteWithMissingClass();

    m_rr = new Random(1);

    if (m_theInstances.classAttribute().isNominal()) { // 	 Set up class priors
      m_classPriorCounts = new double[data.classAttribute().numValues()];
      Arrays.fill(m_classPriorCounts, 1.0);
      for (int i = 0; i < data.numInstances(); i++) {
        Instance curr = data.instance(i);
        m_classPriorCounts[(int) curr.classValue()] += curr.weight();
      }
      m_classPriors = m_classPriorCounts.clone();
      Utils.normalize(m_classPriors);
    }

    setUpEvaluator();

    if (m_theInstances.classAttribute().isNumeric()) {
      m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
      m_classIsNominal = false;

      // use binned discretisation if the class is numeric
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setBins(10);
      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform).setInvertSelection(true);

      // Discretize all attributes EXCEPT the class
      String rangeList = "";
      rangeList += (m_theInstances.classIndex() + 1);
      // System.out.println("The class col: "+m_theInstances.classIndex());

      ((weka.filters.unsupervised.attribute.Discretize) m_disTransform)
          .setAttributeIndices(rangeList);
    } else {
      m_disTransform = new weka.filters.supervised.attribute.Discretize();
      ((weka.filters.supervised.attribute.Discretize) m_disTransform).setUseBetterEncoding(true);
      m_classIsNominal = true;
    }

    m_disTransform.setInputFormat(m_theInstances);
    m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);

    m_numAttributes = m_theInstances.numAttributes();
    m_numInstances = m_theInstances.numInstances();
    m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());

    // Perform the search
    int[] selected = m_search.search(m_evaluator, m_theInstances);

    m_decisionFeatures = new int[selected.length + 1];
    System.arraycopy(selected, 0, m_decisionFeatures, 0, selected.length);
    m_decisionFeatures[m_decisionFeatures.length - 1] = m_theInstances.classIndex();

    // reduce instances to selected features
    m_delTransform = new Remove();
    m_delTransform.setInvertSelection(true);

    // set features to keep
    m_delTransform.setAttributeIndicesArray(m_decisionFeatures);
    m_delTransform.setInputFormat(m_theInstances);
    m_dtInstances = Filter.useFilter(m_theInstances, m_delTransform);

    // reset the number of attributes
    m_numAttributes = m_dtInstances.numAttributes();

    // create hash table
    m_entries = new Hashtable((int) (m_dtInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (int i = 0; i < m_numInstances; i++) {
      Instance inst = m_dtInstances.instance(i);
      insertIntoTable(inst, null);
    }

    // Replace the global table majority with nearest neighbour?
    if (m_useIBk) {
      m_ibk = new IBk();
      m_ibk.buildClassifier(m_theInstances);
    }

    // Save memory
    if (m_saveMemory) {
      m_theInstances = new Instances(m_theInstances, 0);
      m_dtInstances = new Instances(m_dtInstances, 0);
    }
    m_evaluation = null;
  }