예제 #1
0
  /**
   * @param inst
   * @return
   * @throws Exception
   */
  public double SVMOutput(Instance inst) throws Exception {

    double result = -m_b;
    // Is the machine linear?
    if (m_weights != null) {
      // Is weight vector stored in sparse format?
      for (int i = 0; i < inst.numValues(); i++) {
        if (inst.index(i) != m_classIndex) {
          result += m_weights[inst.index(i)] * inst.valueSparse(i);
        }
      }
    } else {
      for (int i = m_supportVectors.getNext(-1); i != -1; i = m_supportVectors.getNext(i)) {
        result += (m_alpha[i] - m_alphaStar[i]) * m_kernel.eval(-1, i, inst);
      }
    }
    return result;
  }
예제 #2
0
  /**
   * Calculates the distance between two instances
   *
   * @param test the first instance
   * @param train the second instance
   * @return the distance between the two given instances, between 0 and 1
   */
  protected double distance(Instance first, Instance second) {

    double distance = 0;
    int firstI, secondI;

    for (int p1 = 0, p2 = 0; p1 < first.numValues() || p2 < second.numValues(); ) {
      if (p1 >= first.numValues()) {
        firstI = m_instances.numAttributes();
      } else {
        firstI = first.index(p1);
      }
      if (p2 >= second.numValues()) {
        secondI = m_instances.numAttributes();
      } else {
        secondI = second.index(p2);
      }
      if (firstI == m_instances.classIndex()) {
        p1++;
        continue;
      }
      if (secondI == m_instances.classIndex()) {
        p2++;
        continue;
      }
      double diff;
      if (firstI == secondI) {
        diff = difference(firstI, first.valueSparse(p1), second.valueSparse(p2));
        p1++;
        p2++;
      } else if (firstI > secondI) {
        diff = difference(secondI, 0, second.valueSparse(p2));
        p2++;
      } else {
        diff = difference(firstI, first.valueSparse(p1), 0);
        p1++;
      }
      distance += diff * diff;
    }

    return Math.sqrt(distance / m_instances.numAttributes());
  }
예제 #3
0
  /**
   * Merges this instance with the given instance and returns the result. Dataset is set to null.
   *
   * @param inst the instance to be merged with this one
   * @return the merged instances
   */
  public Instance mergeInstance(Instance inst) {

    double[] values = new double[numValues() + inst.numValues()];
    int[] indices = new int[numValues() + inst.numValues()];

    int m = 0;
    for (int j = 0; j < numValues(); j++, m++) {
      values[m] = valueSparse(j);
      indices[m] = index(j);
    }
    for (int j = 0; j < inst.numValues(); j++, m++) {
      values[m] = inst.valueSparse(j);
      indices[m] = numAttributes() + inst.index(j);
    }

    return new SparseInstance(1.0, values, indices, numAttributes() + inst.numAttributes());
  }
  /**
   * Merges this instance with the given instance and returns the result. Dataset is set to null.
   *
   * @param inst the instance to be merged with this one
   * @return the merged instances
   */
  public Instance mergeInstance(Instance inst) {

    int[] indices = new int[numValues() + inst.numValues()];

    int m = 0;
    for (int j = 0; j < numValues(); j++) {
      indices[m++] = index(j);
    }
    for (int j = 0; j < inst.numValues(); j++) {
      if (inst.valueSparse(j) != 0) {
        indices[m++] = inst.index(j) + inst.numAttributes();
      }
    }

    if (m != indices.length) {
      // Need to truncate
      int[] newInd = new int[m];
      System.arraycopy(indices, 0, newInd, 0, m);
      indices = newInd;
    }
    return new BinarySparseInstance((float) 1.0, indices, numAttributes() + inst.numAttributes());
  }
예제 #5
0
  /**
   * Tests a certain range of attributes of the given data, whether it can be processed by the
   * handler, given its capabilities. Classifiers implementing the <code>
   * MultiInstanceCapabilitiesHandler</code> interface are checked automatically for their
   * multi-instance Capabilities (if no bags, then only the bag-structure, otherwise only the first
   * bag).
   *
   * @param data the data to test
   * @param fromIndex the range of attributes - start (incl.)
   * @param toIndex the range of attributes - end (incl.)
   * @return true if all the tests succeeded
   * @see MultiInstanceCapabilitiesHandler
   * @see #m_InstancesTest
   * @see #m_MissingValuesTest
   * @see #m_MissingClassValuesTest
   * @see #m_MinimumNumberInstancesTest
   */
  public boolean test(Instances data, int fromIndex, int toIndex) {
    int i;
    int n;
    int m;
    Attribute att;
    Instance inst;
    boolean testClass;
    Capabilities cap;
    boolean missing;
    Iterator iter;

    // shall we test the data?
    if (!m_InstancesTest) return true;

    // no Capabilities? -> warning
    if ((m_Capabilities.size() == 0)
        || ((m_Capabilities.size() == 1) && handles(Capability.NO_CLASS)))
      System.err.println(createMessage("No capabilities set!"));

    // any attributes?
    if (toIndex - fromIndex < 0) {
      m_FailReason = new WekaException(createMessage("No attributes!"));
      return false;
    }

    // do wee need to test the class attribute, i.e., is the class attribute
    // within the range of attributes?
    testClass =
        (data.classIndex() > -1)
            && (data.classIndex() >= fromIndex)
            && (data.classIndex() <= toIndex);

    // attributes
    for (i = fromIndex; i <= toIndex; i++) {
      att = data.attribute(i);

      // class is handled separately
      if (i == data.classIndex()) continue;

      // check attribute types
      if (!test(att)) return false;
    }

    // class
    if (!handles(Capability.NO_CLASS) && (data.classIndex() == -1)) {
      m_FailReason = new UnassignedClassException(createMessage("Class attribute not set!"));
      return false;
    }

    // special case: no class attribute can be handled
    if (handles(Capability.NO_CLASS) && (data.classIndex() > -1)) {
      cap = getClassCapabilities();
      cap.disable(Capability.NO_CLASS);
      iter = cap.capabilities();
      if (!iter.hasNext()) {
        m_FailReason = new WekaException(createMessage("Cannot handle any class attribute!"));
        return false;
      }
    }

    if (testClass && !handles(Capability.NO_CLASS)) {
      att = data.classAttribute();
      if (!test(att, true)) return false;

      // special handling of RELATIONAL class
      // TODO: store additional Capabilities for this case

      // missing class labels
      if (m_MissingClassValuesTest) {
        if (!handles(Capability.MISSING_CLASS_VALUES)) {
          for (i = 0; i < data.numInstances(); i++) {
            if (data.instance(i).classIsMissing()) {
              m_FailReason =
                  new WekaException(createMessage("Cannot handle missing class values!"));
              return false;
            }
          }
        } else {
          if (m_MinimumNumberInstancesTest) {
            int hasClass = 0;

            for (i = 0; i < data.numInstances(); i++) {
              if (!data.instance(i).classIsMissing()) hasClass++;
            }

            // not enough instances with class labels?
            if (hasClass < getMinimumNumberInstances()) {
              m_FailReason =
                  new WekaException(
                      createMessage(
                          "Not enough training instances with class labels (required: "
                              + getMinimumNumberInstances()
                              + ", provided: "
                              + hasClass
                              + ")!"));
              return false;
            }
          }
        }
      }
    }

    // missing values
    if (m_MissingValuesTest) {
      if (!handles(Capability.MISSING_VALUES)) {
        missing = false;
        for (i = 0; i < data.numInstances(); i++) {
          inst = data.instance(i);

          if (inst instanceof SparseInstance) {
            for (m = 0; m < inst.numValues(); m++) {
              n = inst.index(m);

              // out of scope?
              if (n < fromIndex) continue;
              if (n > toIndex) break;

              // skip class
              if (n == inst.classIndex()) continue;

              if (inst.isMissing(n)) {
                missing = true;
                break;
              }
            }
          } else {
            for (n = fromIndex; n <= toIndex; n++) {
              // skip class
              if (n == inst.classIndex()) continue;

              if (inst.isMissing(n)) {
                missing = true;
                break;
              }
            }
          }

          if (missing) {
            m_FailReason =
                new NoSupportForMissingValuesException(
                    createMessage("Cannot handle missing values!"));
            return false;
          }
        }
      }
    }

    // instances
    if (m_MinimumNumberInstancesTest) {
      if (data.numInstances() < getMinimumNumberInstances()) {
        m_FailReason =
            new WekaException(
                createMessage(
                    "Not enough training instances (required: "
                        + getMinimumNumberInstances()
                        + ", provided: "
                        + data.numInstances()
                        + ")!"));
        return false;
      }
    }

    // Multi-Instance? -> check structure (regardless of attribute range!)
    if (handles(Capability.ONLY_MULTIINSTANCE)) {
      // number of attributes?
      if (data.numAttributes() != 3) {
        m_FailReason =
            new WekaException(
                createMessage("Incorrect Multi-Instance format, must be 'bag-id, bag, class'!"));
        return false;
      }

      // type of attributes and position of class?
      if (!data.attribute(0).isNominal()
          || !data.attribute(1).isRelationValued()
          || (data.classIndex() != data.numAttributes() - 1)) {
        m_FailReason =
            new WekaException(
                createMessage(
                    "Incorrect Multi-Instance format, must be 'NOMINAL att, RELATIONAL att, CLASS att'!"));
        return false;
      }

      // check data immediately
      if (getOwner() instanceof MultiInstanceCapabilitiesHandler) {
        MultiInstanceCapabilitiesHandler handler = (MultiInstanceCapabilitiesHandler) getOwner();
        cap = handler.getMultiInstanceCapabilities();
        boolean result;
        if (data.numInstances() > 0) result = cap.test(data.attribute(1).relation(0));
        else result = cap.test(data.attribute(1).relation());

        if (!result) {
          m_FailReason = cap.m_FailReason;
          return false;
        }
      }
    }

    // passed all tests!
    return true;
  }
예제 #6
0
  /**
   * returns a Capabilities object specific for this data. The minimum number of instances is not
   * set, the check for multi-instance data is optional.
   *
   * @param data the data to base the capabilities on
   * @param multi if true then the structure is checked, too
   * @return a data-specific capabilities object
   * @throws Exception in case an error occurrs, e.g., an unknown attribute type
   */
  public static Capabilities forInstances(Instances data, boolean multi) throws Exception {
    Capabilities result;
    Capabilities multiInstance;
    int i;
    int n;
    int m;
    Instance inst;
    boolean missing;

    result = new Capabilities(null);

    // class
    if (data.classIndex() == -1) {
      result.enable(Capability.NO_CLASS);
    } else {
      switch (data.classAttribute().type()) {
        case Attribute.NOMINAL:
          if (data.classAttribute().numValues() == 1) result.enable(Capability.UNARY_CLASS);
          else if (data.classAttribute().numValues() == 2) result.enable(Capability.BINARY_CLASS);
          else result.enable(Capability.NOMINAL_CLASS);
          break;

        case Attribute.NUMERIC:
          result.enable(Capability.NUMERIC_CLASS);
          break;

        case Attribute.STRING:
          result.enable(Capability.STRING_CLASS);
          break;

        case Attribute.DATE:
          result.enable(Capability.DATE_CLASS);
          break;

        case Attribute.RELATIONAL:
          result.enable(Capability.RELATIONAL_CLASS);
          break;

        default:
          throw new UnsupportedAttributeTypeException(
              "Unknown class attribute type '" + data.classAttribute() + "'!");
      }

      // missing class values
      for (i = 0; i < data.numInstances(); i++) {
        if (data.instance(i).classIsMissing()) {
          result.enable(Capability.MISSING_CLASS_VALUES);
          break;
        }
      }
    }

    // attributes
    for (i = 0; i < data.numAttributes(); i++) {
      // skip class
      if (i == data.classIndex()) continue;

      switch (data.attribute(i).type()) {
        case Attribute.NOMINAL:
          result.enable(Capability.UNARY_ATTRIBUTES);
          if (data.attribute(i).numValues() == 2) result.enable(Capability.BINARY_ATTRIBUTES);
          else if (data.attribute(i).numValues() > 2) result.enable(Capability.NOMINAL_ATTRIBUTES);
          break;

        case Attribute.NUMERIC:
          result.enable(Capability.NUMERIC_ATTRIBUTES);
          break;

        case Attribute.DATE:
          result.enable(Capability.DATE_ATTRIBUTES);
          break;

        case Attribute.STRING:
          result.enable(Capability.STRING_ATTRIBUTES);
          break;

        case Attribute.RELATIONAL:
          result.enable(Capability.RELATIONAL_ATTRIBUTES);
          break;

        default:
          throw new UnsupportedAttributeTypeException(
              "Unknown attribute type '" + data.attribute(i).type() + "'!");
      }
    }

    // missing values
    missing = false;
    for (i = 0; i < data.numInstances(); i++) {
      inst = data.instance(i);

      if (inst instanceof SparseInstance) {
        for (m = 0; m < inst.numValues(); m++) {
          n = inst.index(m);

          // skip class
          if (n == inst.classIndex()) continue;

          if (inst.isMissing(n)) {
            missing = true;
            break;
          }
        }
      } else {
        for (n = 0; n < data.numAttributes(); n++) {
          // skip class
          if (n == inst.classIndex()) continue;

          if (inst.isMissing(n)) {
            missing = true;
            break;
          }
        }
      }

      if (missing) {
        result.enable(Capability.MISSING_VALUES);
        break;
      }
    }

    // multi-instance data?
    if (multi) {
      if ((data.numAttributes() == 3)
          && (data.attribute(0).isNominal()) // bag-id
          && (data.attribute(1).isRelationValued()) // bag
          && (data.classIndex() == data.numAttributes() - 1)) {
        multiInstance = new Capabilities(null);
        multiInstance.or(result.getClassCapabilities());
        multiInstance.enable(Capability.NOMINAL_ATTRIBUTES);
        multiInstance.enable(Capability.RELATIONAL_ATTRIBUTES);
        multiInstance.enable(Capability.ONLY_MULTIINSTANCE);
        result.assign(multiInstance);
      }
    }

    return result;
  }
  /**
   * Calculates the distance between two instances. Offers speed up (if the distance function class
   * in use supports it) in nearest neighbour search by taking into account the cutOff or maximum
   * distance. Depending on the distance function class, post processing of the distances by
   * postProcessDistances(double []) may be required if this function is used.
   *
   * @param first the first instance
   * @param second the second instance
   * @param cutOffValue If the distance being calculated becomes larger than cutOffValue then the
   *     rest of the calculation is discarded.
   * @param stats the performance stats object
   * @return the distance between the two given instances or Double.POSITIVE_INFINITY if the
   *     distance being calculated becomes larger than cutOffValue.
   */
  @Override
  public double distance(
      Instance first, Instance second, double cutOffValue, PerformanceStats stats) {
    double distance = 0;
    int firstI, secondI;
    int firstNumValues = first.numValues();
    int secondNumValues = second.numValues();
    int numAttributes = m_Data.numAttributes();
    int classIndex = m_Data.classIndex();

    validate();

    for (int p1 = 0, p2 = 0; p1 < firstNumValues || p2 < secondNumValues; ) {
      if (p1 >= firstNumValues) {
        firstI = numAttributes;
      } else {
        firstI = first.index(p1);
      }

      if (p2 >= secondNumValues) {
        secondI = numAttributes;
      } else {
        secondI = second.index(p2);
      }

      if (firstI == classIndex) {
        p1++;
        continue;
      }
      if ((firstI < numAttributes) && !m_ActiveIndices[firstI]) {
        p1++;
        continue;
      }

      if (secondI == classIndex) {
        p2++;
        continue;
      }
      if ((secondI < numAttributes) && !m_ActiveIndices[secondI]) {
        p2++;
        continue;
      }

      double diff;

      if (firstI == secondI) {
        diff = difference(firstI, first.valueSparse(p1), second.valueSparse(p2));
        p1++;
        p2++;
      } else if (firstI > secondI) {
        diff = difference(secondI, 0, second.valueSparse(p2));
        p2++;
      } else {
        diff = difference(firstI, first.valueSparse(p1), 0);
        p1++;
      }
      if (stats != null) {
        stats.incrCoordCount();
      }

      distance = updateDistance(distance, diff);
      if (distance > cutOffValue) {
        return Double.POSITIVE_INFINITY;
      }
    }

    return distance;
  }