Beispiel #1
0
 /**
  * Whether the instance is covered by this antecedent
  *
  * @param inst the instance in question
  * @return the boolean value indicating whether the instance is covered by this antecedent
  */
 public double covers(Instance inst) {
   double isCover = 0;
   if (!inst.isMissing(att)) {
     if ((int) inst.value(att) == (int) value) isCover = 1;
   }
   return isCover;
 }
  public double classifyInstance(Instance inst) throws Exception {

    if (m_attribute == null) {
      return m_intercept;
    } else {
      if (inst.isMissing(m_attribute.index())) {
        throw new Exception("UnivariateLinearRegression: No missing values!");
      }
      return m_intercept + m_slope * inst.value(m_attribute.index());
    }
  }
  /**
   * Determines the output format based on the input format and returns this. In case the output
   * format cannot be returned immediately, i.e., immediateOutputFormat() returns false, then this
   * method will be called from batchFinished().
   *
   * @param inputFormat the input format to base the output format on
   * @return the output format
   * @throws Exception in case the determination goes wrong
   * @see #hasImmediateOutputFormat()
   * @see #batchFinished()
   */
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {

    Instances data;
    Instances result;
    FastVector atts;
    FastVector values;
    HashSet hash;
    int i;
    int n;
    boolean isDate;
    Instance inst;
    Vector sorted;

    m_Cols.setUpper(inputFormat.numAttributes() - 1);
    data = new Instances(inputFormat);
    atts = new FastVector();
    for (i = 0; i < data.numAttributes(); i++) {
      if (!m_Cols.isInRange(i) || !data.attribute(i).isNumeric()) {
        atts.addElement(data.attribute(i));
        continue;
      }

      // date attribute?
      isDate = (data.attribute(i).type() == Attribute.DATE);

      // determine all available attribtues in dataset
      hash = new HashSet();
      for (n = 0; n < data.numInstances(); n++) {
        inst = data.instance(n);
        if (inst.isMissing(i)) continue;

        if (isDate) hash.add(inst.stringValue(i));
        else hash.add(new Double(inst.value(i)));
      }

      // sort values
      sorted = new Vector();
      for (Object o : hash) sorted.add(o);
      Collections.sort(sorted);

      // create attribute from sorted values
      values = new FastVector();
      for (Object o : sorted) {
        if (isDate) values.addElement(o.toString());
        else values.addElement(Utils.doubleToString(((Double) o).doubleValue(), MAX_DECIMALS));
      }
      atts.addElement(new Attribute(data.attribute(i).name(), values));
    }

    result = new Instances(inputFormat.relationName(), atts, 0);
    result.setClassIndex(inputFormat.classIndex());

    return result;
  }
 /**
  * Used to initialize the ranges. For this the values of the first instance is used to save time.
  * Sets low and high to the values of the first instance and width to zero.
  *
  * @param instance the new instance
  * @param numAtt number of attributes in the model
  * @param ranges low, high and width values for all attributes
  */
 public void updateRangesFirst(Instance instance, int numAtt, double[][] ranges) {
   for (int j = 0; j < numAtt; j++) {
     if (!instance.isMissing(j)) {
       ranges[j][R_MIN] = instance.value(j);
       ranges[j][R_MAX] = instance.value(j);
       ranges[j][R_WIDTH] = 0.0;
     } else { // if value was missing
       ranges[j][R_MIN] = Double.POSITIVE_INFINITY;
       ranges[j][R_MAX] = -Double.POSITIVE_INFINITY;
       ranges[j][R_WIDTH] = Double.POSITIVE_INFINITY;
     }
   }
 }
Beispiel #5
0
  /**
   * Gets the subset of instances that apply to a particluar branch of the split. If the branch
   * index is -1, the subset will consist of those instances that don't apply to any branch.
   *
   * @param branch the index of the branch
   * @param sourceInstances the instances from which to find the subset
   * @return the set of instances that apply
   */
  public ReferenceInstances instancesDownBranch(int branch, Instances instances) {

    ReferenceInstances filteredInstances = new ReferenceInstances(instances, 1);
    if (branch == -1) {
      for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
        Instance inst = (Instance) e.nextElement();
        if (inst.isMissing(attIndex)) filteredInstances.addReference(inst);
      }
    } else if (branch == 0) {
      for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
        Instance inst = (Instance) e.nextElement();
        if (!inst.isMissing(attIndex) && inst.value(attIndex) < splitPoint)
          filteredInstances.addReference(inst);
      }
    } else {
      for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
        Instance inst = (Instance) e.nextElement();
        if (!inst.isMissing(attIndex) && inst.value(attIndex) >= splitPoint)
          filteredInstances.addReference(inst);
      }
    }
    return filteredInstances;
  }
  /**
   * Test if an instance is within the given ranges.
   *
   * @param instance the instance
   * @param ranges the ranges the instance is tested to be in
   * @return true if instance is within the ranges
   */
  public boolean inRanges(Instance instance, double[][] ranges) {
    boolean isIn = true;

    // updateRangesFirst must have been called on ranges
    for (int j = 0; isIn && (j < ranges.length); j++) {
      if (!instance.isMissing(j)) {
        double value = instance.value(j);
        isIn = value <= ranges[j][R_MAX];
        if (isIn) {
          isIn = value >= ranges[j][R_MIN];
        }
      }
    }

    return isIn;
  }
  /**
   * Updates the ranges given a new instance.
   *
   * @param instance the new instance
   * @param ranges low, high and width values for all attributes
   * @return the updated ranges
   */
  public double[][] updateRanges(Instance instance, double[][] ranges) {
    // updateRangesFirst must have been called on ranges
    for (int j = 0; j < ranges.length; j++) {
      double value = instance.value(j);
      if (!instance.isMissing(j)) {
        if (value < ranges[j][R_MIN]) {
          ranges[j][R_MIN] = value;
          ranges[j][R_WIDTH] = ranges[j][R_MAX] - ranges[j][R_MIN];
        } else {
          if (instance.value(j) > ranges[j][R_MAX]) {
            ranges[j][R_MAX] = value;
            ranges[j][R_WIDTH] = ranges[j][R_MAX] - ranges[j][R_MIN];
          }
        }
      }
    }

    return ranges;
  }
  /**
   * Processes the given data (may change the provided dataset) and returns the modified version.
   * This method is called in batchFinished().
   *
   * @param instances the data to process
   * @return the modified data
   * @throws Exception in case the processing goes wrong
   * @see #batchFinished()
   */
  protected Instances process(Instances instances) throws Exception {
    Instances result;
    int i;
    int n;
    double[] values;
    String value;
    Instance inst;
    Instance newInst;

    // we need the complete input data!
    if (!isFirstBatchDone()) setOutputFormat(determineOutputFormat(getInputFormat()));

    result = new Instances(getOutputFormat());

    for (i = 0; i < instances.numInstances(); i++) {
      inst = instances.instance(i);
      values = inst.toDoubleArray();

      for (n = 0; n < values.length; n++) {
        if (!m_Cols.isInRange(n) || !instances.attribute(n).isNumeric() || inst.isMissing(n))
          continue;

        // get index of value
        if (instances.attribute(n).type() == Attribute.DATE) value = inst.stringValue(n);
        else value = Utils.doubleToString(inst.value(n), MAX_DECIMALS);

        values[n] = result.attribute(n).indexOfValue(value);
      }

      // generate new instance
      if (inst instanceof SparseInstance) newInst = new SparseInstance(inst.weight(), values);
      else newInst = new DenseInstance(inst.weight(), values);

      // copy possible string, relational values
      newInst.setDataset(getOutputFormat());
      copyValues(newInst, false, inst.dataset(), getOutputFormat());

      result.add(newInst);
    }

    return result;
  }
 /**
  * Updates the minimum and maximum and width values for all the attributes based on a new
  * instance.
  *
  * @param instance the new instance
  * @param numAtt number of attributes in the model
  * @param ranges low, high and width values for all attributes
  */
 public void updateRanges(Instance instance, int numAtt, double[][] ranges) {
   // updateRangesFirst must have been called on ranges
   for (int j = 0; j < numAtt; j++) {
     double value = instance.value(j);
     if (!instance.isMissing(j)) {
       if (value < ranges[j][R_MIN]) {
         ranges[j][R_MIN] = value;
         ranges[j][R_WIDTH] = ranges[j][R_MAX] - ranges[j][R_MIN];
         if (value > ranges[j][R_MAX]) { // if this is the first value that is
           ranges[j][R_MAX] = value; // not missing. The,0
           ranges[j][R_WIDTH] = ranges[j][R_MAX] - ranges[j][R_MIN];
         }
       } else {
         if (value > ranges[j][R_MAX]) {
           ranges[j][R_MAX] = value;
           ranges[j][R_WIDTH] = ranges[j][R_MAX] - ranges[j][R_MIN];
         }
       }
     }
   }
 }
Beispiel #10
0
  /**
   * Implements the splitData function. This procedure is to split the data into bags according to
   * the nominal attribute value The infoGain for each bag is also calculated.
   *
   * @param data the data to be split
   * @param defAcRt the default accuracy rate for data
   * @param cl the class label to be predicted
   * @return the array of data after split
   */
  public Instances[] splitData(Instances data, double defAcRt, double cl) {
    int bag = att.numValues();
    Instances[] splitData = new Instances[bag];

    for (int x = 0; x < bag; x++) {
      splitData[x] = new Instances(data, data.numInstances());
      accurate[x] = 0;
      coverage[x] = 0;
    }

    for (int x = 0; x < data.numInstances(); x++) {
      Instance inst = data.instance(x);
      if (!inst.isMissing(att)) {
        int v = (int) inst.value(att);
        splitData[v].add(inst);
        coverage[v] += inst.weight();
        if ((int) inst.classValue() == (int) cl) accurate[v] += inst.weight();
      }
    }

    for (int x = 0; x < bag; x++) {
      double t = coverage[x] + 1.0;
      double p = accurate[x] + 1.0;
      double infoGain =
          // Utils.eq(defAcRt, 1.0) ?
          // accurate[x]/(double)numConds :
          accurate[x] * (Utils.log2(p / t) - Utils.log2(defAcRt));

      if (infoGain > maxInfoGain) {
        maxInfoGain = infoGain;
        cover = coverage[x];
        accu = accurate[x];
        accuRate = p / t;
        value = (double) x;
      }
    }

    return splitData;
  }
Beispiel #11
0
  /**
   * Tests a certain range of attributes of the given data, whether it can be processed by the
   * handler, given its capabilities. Classifiers implementing the <code>
   * MultiInstanceCapabilitiesHandler</code> interface are checked automatically for their
   * multi-instance Capabilities (if no bags, then only the bag-structure, otherwise only the first
   * bag).
   *
   * @param data the data to test
   * @param fromIndex the range of attributes - start (incl.)
   * @param toIndex the range of attributes - end (incl.)
   * @return true if all the tests succeeded
   * @see MultiInstanceCapabilitiesHandler
   * @see #m_InstancesTest
   * @see #m_MissingValuesTest
   * @see #m_MissingClassValuesTest
   * @see #m_MinimumNumberInstancesTest
   */
  public boolean test(Instances data, int fromIndex, int toIndex) {
    int i;
    int n;
    int m;
    Attribute att;
    Instance inst;
    boolean testClass;
    Capabilities cap;
    boolean missing;
    Iterator iter;

    // shall we test the data?
    if (!m_InstancesTest) return true;

    // no Capabilities? -> warning
    if ((m_Capabilities.size() == 0)
        || ((m_Capabilities.size() == 1) && handles(Capability.NO_CLASS)))
      System.err.println(createMessage("No capabilities set!"));

    // any attributes?
    if (toIndex - fromIndex < 0) {
      m_FailReason = new WekaException(createMessage("No attributes!"));
      return false;
    }

    // do wee need to test the class attribute, i.e., is the class attribute
    // within the range of attributes?
    testClass =
        (data.classIndex() > -1)
            && (data.classIndex() >= fromIndex)
            && (data.classIndex() <= toIndex);

    // attributes
    for (i = fromIndex; i <= toIndex; i++) {
      att = data.attribute(i);

      // class is handled separately
      if (i == data.classIndex()) continue;

      // check attribute types
      if (!test(att)) return false;
    }

    // class
    if (!handles(Capability.NO_CLASS) && (data.classIndex() == -1)) {
      m_FailReason = new UnassignedClassException(createMessage("Class attribute not set!"));
      return false;
    }

    // special case: no class attribute can be handled
    if (handles(Capability.NO_CLASS) && (data.classIndex() > -1)) {
      cap = getClassCapabilities();
      cap.disable(Capability.NO_CLASS);
      iter = cap.capabilities();
      if (!iter.hasNext()) {
        m_FailReason = new WekaException(createMessage("Cannot handle any class attribute!"));
        return false;
      }
    }

    if (testClass && !handles(Capability.NO_CLASS)) {
      att = data.classAttribute();
      if (!test(att, true)) return false;

      // special handling of RELATIONAL class
      // TODO: store additional Capabilities for this case

      // missing class labels
      if (m_MissingClassValuesTest) {
        if (!handles(Capability.MISSING_CLASS_VALUES)) {
          for (i = 0; i < data.numInstances(); i++) {
            if (data.instance(i).classIsMissing()) {
              m_FailReason =
                  new WekaException(createMessage("Cannot handle missing class values!"));
              return false;
            }
          }
        } else {
          if (m_MinimumNumberInstancesTest) {
            int hasClass = 0;

            for (i = 0; i < data.numInstances(); i++) {
              if (!data.instance(i).classIsMissing()) hasClass++;
            }

            // not enough instances with class labels?
            if (hasClass < getMinimumNumberInstances()) {
              m_FailReason =
                  new WekaException(
                      createMessage(
                          "Not enough training instances with class labels (required: "
                              + getMinimumNumberInstances()
                              + ", provided: "
                              + hasClass
                              + ")!"));
              return false;
            }
          }
        }
      }
    }

    // missing values
    if (m_MissingValuesTest) {
      if (!handles(Capability.MISSING_VALUES)) {
        missing = false;
        for (i = 0; i < data.numInstances(); i++) {
          inst = data.instance(i);

          if (inst instanceof SparseInstance) {
            for (m = 0; m < inst.numValues(); m++) {
              n = inst.index(m);

              // out of scope?
              if (n < fromIndex) continue;
              if (n > toIndex) break;

              // skip class
              if (n == inst.classIndex()) continue;

              if (inst.isMissing(n)) {
                missing = true;
                break;
              }
            }
          } else {
            for (n = fromIndex; n <= toIndex; n++) {
              // skip class
              if (n == inst.classIndex()) continue;

              if (inst.isMissing(n)) {
                missing = true;
                break;
              }
            }
          }

          if (missing) {
            m_FailReason =
                new NoSupportForMissingValuesException(
                    createMessage("Cannot handle missing values!"));
            return false;
          }
        }
      }
    }

    // instances
    if (m_MinimumNumberInstancesTest) {
      if (data.numInstances() < getMinimumNumberInstances()) {
        m_FailReason =
            new WekaException(
                createMessage(
                    "Not enough training instances (required: "
                        + getMinimumNumberInstances()
                        + ", provided: "
                        + data.numInstances()
                        + ")!"));
        return false;
      }
    }

    // Multi-Instance? -> check structure (regardless of attribute range!)
    if (handles(Capability.ONLY_MULTIINSTANCE)) {
      // number of attributes?
      if (data.numAttributes() != 3) {
        m_FailReason =
            new WekaException(
                createMessage("Incorrect Multi-Instance format, must be 'bag-id, bag, class'!"));
        return false;
      }

      // type of attributes and position of class?
      if (!data.attribute(0).isNominal()
          || !data.attribute(1).isRelationValued()
          || (data.classIndex() != data.numAttributes() - 1)) {
        m_FailReason =
            new WekaException(
                createMessage(
                    "Incorrect Multi-Instance format, must be 'NOMINAL att, RELATIONAL att, CLASS att'!"));
        return false;
      }

      // check data immediately
      if (getOwner() instanceof MultiInstanceCapabilitiesHandler) {
        MultiInstanceCapabilitiesHandler handler = (MultiInstanceCapabilitiesHandler) getOwner();
        cap = handler.getMultiInstanceCapabilities();
        boolean result;
        if (data.numInstances() > 0) result = cap.test(data.attribute(1).relation(0));
        else result = cap.test(data.attribute(1).relation());

        if (!result) {
          m_FailReason = cap.m_FailReason;
          return false;
        }
      }
    }

    // passed all tests!
    return true;
  }
Beispiel #12
0
  /**
   * returns a Capabilities object specific for this data. The minimum number of instances is not
   * set, the check for multi-instance data is optional.
   *
   * @param data the data to base the capabilities on
   * @param multi if true then the structure is checked, too
   * @return a data-specific capabilities object
   * @throws Exception in case an error occurrs, e.g., an unknown attribute type
   */
  public static Capabilities forInstances(Instances data, boolean multi) throws Exception {
    Capabilities result;
    Capabilities multiInstance;
    int i;
    int n;
    int m;
    Instance inst;
    boolean missing;

    result = new Capabilities(null);

    // class
    if (data.classIndex() == -1) {
      result.enable(Capability.NO_CLASS);
    } else {
      switch (data.classAttribute().type()) {
        case Attribute.NOMINAL:
          if (data.classAttribute().numValues() == 1) result.enable(Capability.UNARY_CLASS);
          else if (data.classAttribute().numValues() == 2) result.enable(Capability.BINARY_CLASS);
          else result.enable(Capability.NOMINAL_CLASS);
          break;

        case Attribute.NUMERIC:
          result.enable(Capability.NUMERIC_CLASS);
          break;

        case Attribute.STRING:
          result.enable(Capability.STRING_CLASS);
          break;

        case Attribute.DATE:
          result.enable(Capability.DATE_CLASS);
          break;

        case Attribute.RELATIONAL:
          result.enable(Capability.RELATIONAL_CLASS);
          break;

        default:
          throw new UnsupportedAttributeTypeException(
              "Unknown class attribute type '" + data.classAttribute() + "'!");
      }

      // missing class values
      for (i = 0; i < data.numInstances(); i++) {
        if (data.instance(i).classIsMissing()) {
          result.enable(Capability.MISSING_CLASS_VALUES);
          break;
        }
      }
    }

    // attributes
    for (i = 0; i < data.numAttributes(); i++) {
      // skip class
      if (i == data.classIndex()) continue;

      switch (data.attribute(i).type()) {
        case Attribute.NOMINAL:
          result.enable(Capability.UNARY_ATTRIBUTES);
          if (data.attribute(i).numValues() == 2) result.enable(Capability.BINARY_ATTRIBUTES);
          else if (data.attribute(i).numValues() > 2) result.enable(Capability.NOMINAL_ATTRIBUTES);
          break;

        case Attribute.NUMERIC:
          result.enable(Capability.NUMERIC_ATTRIBUTES);
          break;

        case Attribute.DATE:
          result.enable(Capability.DATE_ATTRIBUTES);
          break;

        case Attribute.STRING:
          result.enable(Capability.STRING_ATTRIBUTES);
          break;

        case Attribute.RELATIONAL:
          result.enable(Capability.RELATIONAL_ATTRIBUTES);
          break;

        default:
          throw new UnsupportedAttributeTypeException(
              "Unknown attribute type '" + data.attribute(i).type() + "'!");
      }
    }

    // missing values
    missing = false;
    for (i = 0; i < data.numInstances(); i++) {
      inst = data.instance(i);

      if (inst instanceof SparseInstance) {
        for (m = 0; m < inst.numValues(); m++) {
          n = inst.index(m);

          // skip class
          if (n == inst.classIndex()) continue;

          if (inst.isMissing(n)) {
            missing = true;
            break;
          }
        }
      } else {
        for (n = 0; n < data.numAttributes(); n++) {
          // skip class
          if (n == inst.classIndex()) continue;

          if (inst.isMissing(n)) {
            missing = true;
            break;
          }
        }
      }

      if (missing) {
        result.enable(Capability.MISSING_VALUES);
        break;
      }
    }

    // multi-instance data?
    if (multi) {
      if ((data.numAttributes() == 3)
          && (data.attribute(0).isNominal()) // bag-id
          && (data.attribute(1).isRelationValued()) // bag
          && (data.classIndex() == data.numAttributes() - 1)) {
        multiInstance = new Capabilities(null);
        multiInstance.or(result.getClassCapabilities());
        multiInstance.enable(Capability.NOMINAL_ATTRIBUTES);
        multiInstance.enable(Capability.RELATIONAL_ATTRIBUTES);
        multiInstance.enable(Capability.ONLY_MULTIINSTANCE);
        result.assign(multiInstance);
      }
    }

    return result;
  }
  /**
   * evaluates an individual attribute by measuring the gain ratio of the class given the attribute.
   *
   * @param attribute the index of the attribute to be evaluated
   * @return the gain ratio
   * @throws Exception if the attribute could not be evaluated
   */
  public double evaluateAttribute(int attribute) throws Exception {
    int i, j, ii, jj;
    int ni, nj;
    double sum = 0.0;
    ni = m_trainInstances.attribute(attribute).numValues() + 1;
    nj = m_numClasses + 1;
    double[] sumi, sumj;
    Instance inst;
    double temp = 0.0;
    sumi = new double[ni];
    sumj = new double[nj];
    double[][] counts = new double[ni][nj];
    sumi = new double[ni];
    sumj = new double[nj];

    for (i = 0; i < ni; i++) {
      sumi[i] = 0.0;

      for (j = 0; j < nj; j++) {
        sumj[j] = 0.0;
        counts[i][j] = 0.0;
      }
    }

    // Fill the contingency table
    for (i = 0; i < m_numInstances; i++) {
      inst = m_trainInstances.instance(i);

      if (inst.isMissing(attribute)) {
        ii = ni - 1;
      } else {
        ii = (int) inst.value(attribute);
      }

      if (inst.isMissing(m_classIndex)) {
        jj = nj - 1;
      } else {
        jj = (int) inst.value(m_classIndex);
      }

      counts[ii][jj]++;
    }

    // get the row totals
    for (i = 0; i < ni; i++) {
      sumi[i] = 0.0;

      for (j = 0; j < nj; j++) {
        sumi[i] += counts[i][j];
        sum += counts[i][j];
      }
    }

    // get the column totals
    for (j = 0; j < nj; j++) {
      sumj[j] = 0.0;

      for (i = 0; i < ni; i++) {
        sumj[j] += counts[i][j];
      }
    }

    // distribute missing counts
    if (m_missing_merge && (sumi[ni - 1] < m_numInstances) && (sumj[nj - 1] < m_numInstances)) {
      double[] i_copy = new double[sumi.length];
      double[] j_copy = new double[sumj.length];
      double[][] counts_copy = new double[sumi.length][sumj.length];

      for (i = 0; i < ni; i++) {
        System.arraycopy(counts[i], 0, counts_copy[i], 0, sumj.length);
      }

      System.arraycopy(sumi, 0, i_copy, 0, sumi.length);
      System.arraycopy(sumj, 0, j_copy, 0, sumj.length);
      double total_missing = (sumi[ni - 1] + sumj[nj - 1] - counts[ni - 1][nj - 1]);

      // do the missing i's
      if (sumi[ni - 1] > 0.0) {
        for (j = 0; j < nj - 1; j++) {
          if (counts[ni - 1][j] > 0.0) {
            for (i = 0; i < ni - 1; i++) {
              temp = ((i_copy[i] / (sum - i_copy[ni - 1])) * counts[ni - 1][j]);
              counts[i][j] += temp;
              sumi[i] += temp;
            }

            counts[ni - 1][j] = 0.0;
          }
        }
      }

      sumi[ni - 1] = 0.0;

      // do the missing j's
      if (sumj[nj - 1] > 0.0) {
        for (i = 0; i < ni - 1; i++) {
          if (counts[i][nj - 1] > 0.0) {
            for (j = 0; j < nj - 1; j++) {
              temp = ((j_copy[j] / (sum - j_copy[nj - 1])) * counts[i][nj - 1]);
              counts[i][j] += temp;
              sumj[j] += temp;
            }

            counts[i][nj - 1] = 0.0;
          }
        }
      }

      sumj[nj - 1] = 0.0;

      // do the both missing
      if (counts[ni - 1][nj - 1] > 0.0 && total_missing != sum) {
        for (i = 0; i < ni - 1; i++) {
          for (j = 0; j < nj - 1; j++) {
            temp = (counts_copy[i][j] / (sum - total_missing)) * counts_copy[ni - 1][nj - 1];
            counts[i][j] += temp;
            sumi[i] += temp;
            sumj[j] += temp;
          }
        }

        counts[ni - 1][nj - 1] = 0.0;
      }
    }

    return ContingencyTables.gainRatio(counts);
  }
Beispiel #14
0
  /**
   * Gets the index of the branch that an instance applies to. Returns -1 if no branches apply.
   *
   * @param i the instance
   * @return the branch index
   */
  public int branchInstanceGoesDown(Instance inst) {

    if (inst.isMissing(attIndex)) return -1;
    else if (inst.value(attIndex) < splitPoint) return 0;
    else return 1;
  }
Beispiel #15
0
  /**
   * Saves an instances incrementally. Structure has to be set by using the setStructure() method or
   * setInstances() method.
   *
   * @param inst the instance to save
   * @throws IOException throws IOEXception if an instance cannot be saved incrementally.
   */
  public void writeIncremental(Instance inst) throws IOException {

    int writeMode = getWriteMode();
    Instances structure = getInstances();
    PrintWriter outW = null;

    if (structure != null) {
      if (structure.classIndex() == -1) {
        structure.setClassIndex(structure.numAttributes() - 1);
        System.err.println("No class specified. Last attribute is used as class attribute.");
      }
      if (structure.attribute(structure.classIndex()).isNumeric())
        throw new IOException("To save in C4.5 format the class attribute cannot be numeric.");
    }
    if (getRetrieval() == BATCH || getRetrieval() == NONE)
      throw new IOException("Batch and incremental saving cannot be mixed.");
    if (retrieveFile() == null || getWriter() == null) {
      throw new IOException(
          "C4.5 format requires two files. Therefore no output to standard out can be generated.\nPlease specifiy output files using the -o option.");
    }

    outW = new PrintWriter(getWriter());

    if (writeMode == WAIT) {
      if (structure == null) {
        setWriteMode(CANCEL);
        if (inst != null)
          System.err.println("Structure(Header Information) has to be set in advance");
      } else setWriteMode(STRUCTURE_READY);
      writeMode = getWriteMode();
    }
    if (writeMode == CANCEL) {
      if (outW != null) outW.close();
      cancel();
    }
    if (writeMode == STRUCTURE_READY) {
      setWriteMode(WRITE);
      // write header: here names file
      for (int i = 0; i < structure.attribute(structure.classIndex()).numValues(); i++) {
        outW.write(structure.attribute(structure.classIndex()).value(i));
        if (i < structure.attribute(structure.classIndex()).numValues() - 1) {
          outW.write(",");
        } else {
          outW.write(".\n");
        }
      }
      for (int i = 0; i < structure.numAttributes(); i++) {
        if (i != structure.classIndex()) {
          outW.write(structure.attribute(i).name() + ": ");
          if (structure.attribute(i).isNumeric() || structure.attribute(i).isDate()) {
            outW.write("continuous.\n");
          } else {
            Attribute temp = structure.attribute(i);
            for (int j = 0; j < temp.numValues(); j++) {
              outW.write(temp.value(j));
              if (j < temp.numValues() - 1) {
                outW.write(",");
              } else {
                outW.write(".\n");
              }
            }
          }
        }
      }
      outW.flush();
      outW.close();

      writeMode = getWriteMode();

      String out = retrieveFile().getAbsolutePath();
      setFileExtension(".data");
      out = out.substring(0, out.lastIndexOf('.')) + getFileExtension();
      File namesFile = new File(out);
      try {
        setFile(namesFile);
      } catch (Exception ex) {
        throw new IOException("Cannot create data file, only names file created.");
      }
      if (retrieveFile() == null || getWriter() == null) {
        throw new IOException("Cannot create data file, only names file created.");
      }
      outW = new PrintWriter(getWriter());
    }
    if (writeMode == WRITE) {
      if (structure == null) throw new IOException("No instances information available.");
      if (inst != null) {
        // write instance: here data file
        for (int j = 0; j < inst.numAttributes(); j++) {
          if (j != structure.classIndex()) {
            if (inst.isMissing(j)) {
              outW.write("?,");
            } else if (structure.attribute(j).isNominal() || structure.attribute(j).isString()) {
              outW.write(structure.attribute(j).value((int) inst.value(j)) + ",");
            } else {
              outW.write("" + inst.value(j) + ",");
            }
          }
        }
        // write the class value
        if (inst.isMissing(structure.classIndex())) {
          outW.write("?");
        } else {
          outW.write(
              structure
                  .attribute(structure.classIndex())
                  .value((int) inst.value(structure.classIndex())));
        }
        outW.write("\n");
        // flushes every 100 instances
        m_incrementalCounter++;
        if (m_incrementalCounter > 100) {
          m_incrementalCounter = 0;
          outW.flush();
        }
      } else {
        // close
        if (outW != null) {
          outW.flush();
          outW.close();
        }
        setFileExtension(".names");
        m_incrementalCounter = 0;
        resetStructure();
        outW = null;
        resetWriter();
      }
    }
  }
Beispiel #16
0
  /**
   * Evaluates a feature subset by cross validation
   *
   * @param feature_set the subset to be evaluated
   * @param num_atts the number of attributes in the subset
   * @return the estimated accuracy
   * @throws Exception if subset can't be evaluated
   */
  protected double estimatePerformance(BitSet feature_set, int num_atts) throws Exception {

    m_evaluation = new Evaluation(m_theInstances);
    int i;
    int[] fs = new int[num_atts];

    double[] instA = new double[num_atts];
    int classI = m_theInstances.classIndex();

    int index = 0;
    for (i = 0; i < m_numAttributes; i++) {
      if (feature_set.get(i)) {
        fs[index++] = i;
      }
    }

    // create new hash table
    m_entries = new Hashtable((int) (m_theInstances.numInstances() * 1.5));

    // insert instances into the hash table
    for (i = 0; i < m_numInstances; i++) {

      Instance inst = m_theInstances.instance(i);
      for (int j = 0; j < fs.length; j++) {
        if (fs[j] == classI) {
          instA[j] = Double.MAX_VALUE; // missing for the class
        } else if (inst.isMissing(fs[j])) {
          instA[j] = Double.MAX_VALUE;
        } else {
          instA[j] = inst.value(fs[j]);
        }
      }
      insertIntoTable(inst, instA);
    }

    if (m_CVFolds == 1) {

      // calculate leave one out error
      for (i = 0; i < m_numInstances; i++) {
        Instance inst = m_theInstances.instance(i);
        for (int j = 0; j < fs.length; j++) {
          if (fs[j] == classI) {
            instA[j] = Double.MAX_VALUE; // missing for the class
          } else if (inst.isMissing(fs[j])) {
            instA[j] = Double.MAX_VALUE;
          } else {
            instA[j] = inst.value(fs[j]);
          }
        }
        evaluateInstanceLeaveOneOut(inst, instA);
      }
    } else {
      m_theInstances.randomize(m_rr);
      m_theInstances.stratify(m_CVFolds);

      // calculate 10 fold cross validation error
      for (i = 0; i < m_CVFolds; i++) {
        Instances insts = m_theInstances.testCV(m_CVFolds, i);
        evaluateFoldCV(insts, fs);
      }
    }

    switch (m_evaluationMeasure) {
      case EVAL_DEFAULT:
        if (m_classIsNominal) {
          return m_evaluation.pctCorrect();
        }
        return -m_evaluation.rootMeanSquaredError();
      case EVAL_ACCURACY:
        return m_evaluation.pctCorrect();
      case EVAL_RMSE:
        return -m_evaluation.rootMeanSquaredError();
      case EVAL_MAE:
        return -m_evaluation.meanAbsoluteError();
      case EVAL_AUC:
        double[] classPriors = m_evaluation.getClassPriors();
        Utils.normalize(classPriors);
        double weightedAUC = 0;
        for (i = 0; i < m_theInstances.classAttribute().numValues(); i++) {
          double tempAUC = m_evaluation.areaUnderROC(i);
          if (!Utils.isMissingValue(tempAUC)) {
            weightedAUC += (classPriors[i] * tempAUC);
          } else {
            System.err.println("Undefined AUC!!");
          }
        }
        return weightedAUC;
    }
    // shouldn't get here
    return 0.0;
  }
Beispiel #17
0
  /**
   * Calculates the accuracy on a test fold for internal cross validation of feature sets
   *
   * @param fold set of instances to be "left out" and classified
   * @param fs currently selected feature set
   * @return the accuracy for the fold
   * @throws Exception if something goes wrong
   */
  double evaluateFoldCV(Instances fold, int[] fs) throws Exception {

    int i;
    int ruleCount = 0;
    int numFold = fold.numInstances();
    int numCl = m_theInstances.classAttribute().numValues();
    double[][] class_distribs = new double[numFold][numCl];
    double[] instA = new double[fs.length];
    double[] normDist;
    DecisionTableHashKey thekey;
    double acc = 0.0;
    int classI = m_theInstances.classIndex();
    Instance inst;

    if (m_classIsNominal) {
      normDist = new double[numCl];
    } else {
      normDist = new double[2];
    }

    // first *remove* instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);
      for (int j = 0; j < fs.length; j++) {
        if (fs[j] == classI) {
          instA[j] = Double.MAX_VALUE; // missing for the class
        } else if (inst.isMissing(fs[j])) {
          instA[j] = Double.MAX_VALUE;
        } else {
          instA[j] = inst.value(fs[j]);
        }
      }
      thekey = new DecisionTableHashKey(instA);
      if ((class_distribs[i] = (double[]) m_entries.get(thekey)) == null) {
        throw new Error("This should never happen!");
      } else {
        if (m_classIsNominal) {
          class_distribs[i][(int) inst.classValue()] -= inst.weight();
        } else {
          class_distribs[i][0] -= (inst.classValue() * inst.weight());
          class_distribs[i][1] -= inst.weight();
        }
        ruleCount++;
      }
      m_classPriorCounts[(int) inst.classValue()] -= inst.weight();
    }
    double[] classPriors = m_classPriorCounts.clone();
    Utils.normalize(classPriors);

    // now classify instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);
      System.arraycopy(class_distribs[i], 0, normDist, 0, normDist.length);
      if (m_classIsNominal) {
        boolean ok = false;
        for (int j = 0; j < normDist.length; j++) {
          if (Utils.gr(normDist[j], 1.0)) {
            ok = true;
            break;
          }
        }

        if (!ok) { // majority class
          normDist = classPriors.clone();
        }

        //	if (ok) {
        Utils.normalize(normDist);
        if (m_evaluationMeasure == EVAL_AUC) {
          m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
        } else {
          m_evaluation.evaluateModelOnce(normDist, inst);
        }
        /*	} else {
          normDist[(int)m_majority] = 1.0;
          if (m_evaluationMeasure == EVAL_AUC) {
            m_evaluation.evaluateModelOnceAndRecordPrediction(normDist, inst);
          } else {
            m_evaluation.evaluateModelOnce(normDist, inst);
          }
        } */
      } else {
        if (Utils.eq(normDist[1], 0.0)) {
          double[] temp = new double[1];
          temp[0] = m_majority;
          m_evaluation.evaluateModelOnce(temp, inst);
        } else {
          double[] temp = new double[1];
          temp[0] = normDist[0] / normDist[1];
          m_evaluation.evaluateModelOnce(temp, inst);
        }
      }
    }

    // now re-insert instances
    for (i = 0; i < numFold; i++) {
      inst = fold.instance(i);

      m_classPriorCounts[(int) inst.classValue()] += inst.weight();

      if (m_classIsNominal) {
        class_distribs[i][(int) inst.classValue()] += inst.weight();
      } else {
        class_distribs[i][0] += (inst.classValue() * inst.weight());
        class_distribs[i][1] += inst.weight();
      }
    }
    return acc;
  }
Beispiel #18
0
  /**
   * Writes a Batch of instances
   *
   * @throws IOException throws IOException if saving in batch mode is not possible
   */
  public void writeBatch() throws IOException {

    Instances instances = getInstances();

    if (instances == null) throw new IOException("No instances to save");
    if (instances.classIndex() == -1) {
      instances.setClassIndex(instances.numAttributes() - 1);
      System.err.println("No class specified. Last attribute is used as class attribute.");
    }
    if (instances.attribute(instances.classIndex()).isNumeric())
      throw new IOException("To save in C4.5 format the class attribute cannot be numeric.");
    if (getRetrieval() == INCREMENTAL)
      throw new IOException("Batch and incremental saving cannot be mixed.");

    setRetrieval(BATCH);
    if (retrieveFile() == null || getWriter() == null) {
      throw new IOException(
          "C4.5 format requires two files. Therefore no output to standard out can be generated.\nPlease specifiy output files using the -o option.");
    }
    setWriteMode(WRITE);
    // print names file
    setFileExtension(".names");
    PrintWriter outW = new PrintWriter(getWriter());
    for (int i = 0; i < instances.attribute(instances.classIndex()).numValues(); i++) {
      outW.write(instances.attribute(instances.classIndex()).value(i));
      if (i < instances.attribute(instances.classIndex()).numValues() - 1) {
        outW.write(",");
      } else {
        outW.write(".\n");
      }
    }
    for (int i = 0; i < instances.numAttributes(); i++) {
      if (i != instances.classIndex()) {
        outW.write(instances.attribute(i).name() + ": ");
        if (instances.attribute(i).isNumeric() || instances.attribute(i).isDate()) {
          outW.write("continuous.\n");
        } else {
          Attribute temp = instances.attribute(i);
          for (int j = 0; j < temp.numValues(); j++) {
            outW.write(temp.value(j));
            if (j < temp.numValues() - 1) {
              outW.write(",");
            } else {
              outW.write(".\n");
            }
          }
        }
      }
    }
    outW.flush();
    outW.close();

    // print data file
    String out = retrieveFile().getAbsolutePath();
    setFileExtension(".data");
    out = out.substring(0, out.lastIndexOf('.')) + getFileExtension();
    File namesFile = new File(out);
    try {
      setFile(namesFile);
    } catch (Exception ex) {
      throw new IOException(
          "Cannot create data file, only names file created (Reason: " + ex.toString() + ").");
    }
    if (retrieveFile() == null || getWriter() == null) {
      throw new IOException("Cannot create data file, only names file created.");
    }
    outW = new PrintWriter(getWriter());
    // print data file
    for (int i = 0; i < instances.numInstances(); i++) {
      Instance temp = instances.instance(i);
      for (int j = 0; j < temp.numAttributes(); j++) {
        if (j != instances.classIndex()) {
          if (temp.isMissing(j)) {
            outW.write("?,");
          } else if (instances.attribute(j).isNominal() || instances.attribute(j).isString()) {
            outW.write(instances.attribute(j).value((int) temp.value(j)) + ",");
          } else {
            outW.write("" + temp.value(j) + ",");
          }
        }
      }
      // write the class value
      if (temp.isMissing(instances.classIndex())) {
        outW.write("?");
      } else {
        outW.write(
            instances
                .attribute(instances.classIndex())
                .value((int) temp.value(instances.classIndex())));
      }
      outW.write("\n");
    }
    outW.flush();
    outW.close();
    setFileExtension(".names");
    setWriteMode(WAIT);
    outW = null;
    resetWriter();
    setWriteMode(CANCEL);
  }
  public void buildClassifier(Instances insts) throws Exception {

    // Compute mean of target value
    double yMean = insts.meanOrMode(insts.classIndex());

    // Choose best attribute
    double minMsq = Double.MAX_VALUE;
    m_attribute = null;
    int chosen = -1;
    double chosenSlope = Double.NaN;
    double chosenIntercept = Double.NaN;
    for (int i = 0; i < insts.numAttributes(); i++) {
      if (i != insts.classIndex()) {
        if (!insts.attribute(i).isNumeric()) {
          throw new Exception("UnivariateLinearRegression: Only numeric attributes!");
        }
        m_attribute = insts.attribute(i);

        // Compute slope and intercept
        double xMean = insts.meanOrMode(i);
        double sumWeightedXDiffSquared = 0;
        double sumWeightedYDiffSquared = 0;
        m_slope = 0;
        for (int j = 0; j < insts.numInstances(); j++) {
          Instance inst = insts.instance(j);
          if (!inst.isMissing(i) && !inst.classIsMissing()) {
            double xDiff = inst.value(i) - xMean;
            double yDiff = inst.classValue() - yMean;
            double weightedXDiff = inst.weight() * xDiff;
            double weightedYDiff = inst.weight() * yDiff;
            m_slope += weightedXDiff * yDiff;
            sumWeightedXDiffSquared += weightedXDiff * xDiff;
            sumWeightedYDiffSquared += weightedYDiff * yDiff;
          }
        }

        // Skip attribute if not useful
        if (sumWeightedXDiffSquared == 0) {
          continue;
        }
        double numerator = m_slope;
        m_slope /= sumWeightedXDiffSquared;
        m_intercept = yMean - m_slope * xMean;

        // Compute sum of squared errors
        double msq = sumWeightedYDiffSquared - m_slope * numerator;

        // Check whether this is the best attribute
        if (msq < minMsq) {
          minMsq = msq;
          chosen = i;
          chosenSlope = m_slope;
          chosenIntercept = m_intercept;
        }
      }
    }

    // Set parameters
    if (chosen == -1) {

      System.err.println("----- no useful attribute found");
      m_attribute = null;
      m_slope = 0;
      m_intercept = yMean;
    } else {
      m_attribute = insts.attribute(chosen);
      m_slope = chosenSlope;
      m_intercept = chosenIntercept;
    }
  }