Esempio n. 1
0
  /**
   * Calculates the area under the ROC curve as the Wilcoxon-Mann-Whitney statistic.
   *
   * @param tcurve a previously extracted threshold curve Instances.
   * @return the ROC area, or Double.NaN if you don't pass in a ThresholdCurve generated Instances.
   */
  public static double getROCArea(Instances tcurve) {

    final int n = tcurve.numInstances();
    if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) {
      return Double.NaN;
    }
    final int tpInd = tcurve.attribute(TRUE_POS_NAME).index();
    final int fpInd = tcurve.attribute(FALSE_POS_NAME).index();
    final double[] tpVals = tcurve.attributeToDoubleArray(tpInd);
    final double[] fpVals = tcurve.attributeToDoubleArray(fpInd);

    double area = 0.0, cumNeg = 0.0;
    final double totalPos = tpVals[0];
    final double totalNeg = fpVals[0];
    for (int i = 0; i < n; i++) {
      double cip, cin;
      if (i < n - 1) {
        cip = tpVals[i] - tpVals[i + 1];
        cin = fpVals[i] - fpVals[i + 1];
      } else {
        cip = tpVals[n - 1];
        cin = fpVals[n - 1];
      }
      area += cip * (cumNeg + (0.5 * cin));
      cumNeg += cin;
    }
    area /= (totalNeg * totalPos);

    return area;
  }
Esempio n. 2
0
  /**
   * Calculates the area under the precision-recall curve (AUPRC).
   *
   * @param tcurve a previously extracted threshold curve Instances.
   * @return the PRC area, or Double.NaN if you don't pass in a ThresholdCurve generated Instances.
   */
  public static double getPRCArea(Instances tcurve) {
    final int n = tcurve.numInstances();
    if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) {
      return Double.NaN;
    }

    final int pInd = tcurve.attribute(PRECISION_NAME).index();
    final int rInd = tcurve.attribute(RECALL_NAME).index();
    final double[] pVals = tcurve.attributeToDoubleArray(pInd);
    final double[] rVals = tcurve.attributeToDoubleArray(rInd);

    double area = 0;
    double xlast = rVals[n - 1];

    // start from the first real p/r pair (not the artificial zero point)
    for (int i = n - 2; i >= 0; i--) {
      double recallDelta = rVals[i] - xlast;
      area += (pVals[i] * recallDelta);

      xlast = rVals[i];
    }

    if (area == 0) {
      return Utils.missingValue();
    }
    return area;
  }
  /**
   * Builds the model of the base learner.
   *
   * @param data the training data
   * @throws Exception if the classifier could not be built successfully
   */
  public void buildClassifier(Instances data) throws Exception {

    // can classifier handle the data?
    getCapabilities().testWithFail(data);

    // remove instances with missing class
    data = new Instances(data);
    data.deleteWithMissingClass();

    if (m_Classifier == null) {
      throw new Exception("No base classifier has been set!");
    }
    if (m_MatrixSource == MATRIX_ON_DEMAND) {
      String costName = data.relationName() + CostMatrix.FILE_EXTENSION;
      File costFile = new File(getOnDemandDirectory(), costName);
      if (!costFile.exists()) {
        throw new Exception("On-demand cost file doesn't exist: " + costFile);
      }
      setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(costFile))));
    } else if (m_CostMatrix == null) {
      // try loading an old format cost file
      m_CostMatrix = new CostMatrix(data.numClasses());
      m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile)));
    }

    if (!m_MinimizeExpectedCost) {
      Random random = null;
      if (!(m_Classifier instanceof WeightedInstancesHandler)) {
        random = new Random(m_Seed);
      }
      data = m_CostMatrix.applyCostMatrix(data, random);
    }
    m_Classifier.buildClassifier(data);
  }
Esempio n. 4
0
  /**
   * Determines the output format based on the input format and returns this. In case the output
   * format cannot be returned immediately, i.e., immediateOutputFormat() returns false, then this
   * method will be called from batchFinished().
   *
   * @param inputFormat the input format to base the output format on
   * @return the output format
   * @throws Exception in case the determination goes wrong
   * @see #hasImmediateOutputFormat()
   * @see #batchFinished()
   */
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {

    Instances data;
    Instances result;
    FastVector atts;
    FastVector values;
    HashSet hash;
    int i;
    int n;
    boolean isDate;
    Instance inst;
    Vector sorted;

    m_Cols.setUpper(inputFormat.numAttributes() - 1);
    data = new Instances(inputFormat);
    atts = new FastVector();
    for (i = 0; i < data.numAttributes(); i++) {
      if (!m_Cols.isInRange(i) || !data.attribute(i).isNumeric()) {
        atts.addElement(data.attribute(i));
        continue;
      }

      // date attribute?
      isDate = (data.attribute(i).type() == Attribute.DATE);

      // determine all available attribtues in dataset
      hash = new HashSet();
      for (n = 0; n < data.numInstances(); n++) {
        inst = data.instance(n);
        if (inst.isMissing(i)) continue;

        if (isDate) hash.add(inst.stringValue(i));
        else hash.add(new Double(inst.value(i)));
      }

      // sort values
      sorted = new Vector();
      for (Object o : hash) sorted.add(o);
      Collections.sort(sorted);

      // create attribute from sorted values
      values = new FastVector();
      for (Object o : sorted) {
        if (isDate) values.addElement(o.toString());
        else values.addElement(Utils.doubleToString(((Double) o).doubleValue(), MAX_DECIMALS));
      }
      atts.addElement(new Attribute(data.attribute(i).name(), values));
    }

    result = new Instances(inputFormat.relationName(), atts, 0);
    result.setClassIndex(inputFormat.classIndex());

    return result;
  }
Esempio n. 5
0
  /**
   * Determines the output format based on the input format and returns this. In case the output
   * format cannot be returned immediately, i.e., hasImmediateOutputFormat() returns false, then
   * this method will called from batchFinished() after the call of preprocess(Instances), in which,
   * e.g., statistics for the actual processing step can be gathered.
   *
   * @param inputFormat the input format to base the output format on
   * @return the output format
   * @throws Exception in case the determination goes wrong
   */
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
    Instances result;
    FastVector atts;
    int i;
    int numAtts;
    Vector<Integer> indices;
    Vector<Integer> subset;
    Random rand;
    int index;

    // determine the number of attributes
    numAtts = inputFormat.numAttributes();
    if (inputFormat.classIndex() > -1) numAtts--;

    if (m_NumAttributes < 1) {
      numAtts = (int) Math.round((double) numAtts * m_NumAttributes);
    } else {
      if (m_NumAttributes < numAtts) numAtts = (int) m_NumAttributes;
    }
    if (getDebug()) System.out.println("# of atts: " + numAtts);

    // determine random indices
    indices = new Vector<Integer>();
    for (i = 0; i < inputFormat.numAttributes(); i++) {
      if (i == inputFormat.classIndex()) continue;
      indices.add(i);
    }

    subset = new Vector<Integer>();
    rand = new Random(m_Seed);
    for (i = 0; i < numAtts; i++) {
      index = rand.nextInt(indices.size());
      subset.add(indices.get(index));
      indices.remove(index);
    }
    Collections.sort(subset);
    if (inputFormat.classIndex() > -1) subset.add(inputFormat.classIndex());
    if (getDebug()) System.out.println("indices: " + subset);

    // generate output format
    atts = new FastVector();
    m_Indices = new int[subset.size()];
    for (i = 0; i < subset.size(); i++) {
      atts.addElement(inputFormat.attribute(subset.get(i)));
      m_Indices[i] = subset.get(i);
    }
    result = new Instances(inputFormat.relationName(), atts, 0);
    if (inputFormat.classIndex() > -1) result.setClassIndex(result.numAttributes() - 1);

    return result;
  }
Esempio n. 6
0
  /**
   * Gets the index of the instance with the closest threshold value to the desired target
   *
   * @param tcurve a set of instances that have been generated by this class
   * @param threshold the target threshold
   * @return the index of the instance that has threshold closest to the target, or -1 if this could
   *     not be found (i.e. no data, or bad threshold target)
   */
  public static int getThresholdInstance(Instances tcurve, double threshold) {

    if (!RELATION_NAME.equals(tcurve.relationName())
        || (tcurve.numInstances() == 0)
        || (threshold < 0)
        || (threshold > 1.0)) {
      return -1;
    }
    if (tcurve.numInstances() == 1) {
      return 0;
    }
    double[] tvals = tcurve.attributeToDoubleArray(tcurve.numAttributes() - 1);
    int[] sorted = Utils.sort(tvals);
    return binarySearch(sorted, tvals, threshold);
  }
Esempio n. 7
0
  /**
   * Determines the output format based on the input format and returns this.
   *
   * @param inputFormat the input format to base the output format on
   * @return the output format
   * @throws Exception in case the determination goes wrong
   */
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
    Instances result;
    Attribute att;
    Attribute attSorted;
    FastVector atts;
    FastVector values;
    Vector<String> sorted;
    int i;
    int n;

    m_AttributeIndices.setUpper(inputFormat.numAttributes() - 1);

    // determine sorted indices
    atts = new FastVector();
    m_NewOrder = new int[inputFormat.numAttributes()][];
    for (i = 0; i < inputFormat.numAttributes(); i++) {
      att = inputFormat.attribute(i);
      if (!att.isNominal() || !m_AttributeIndices.isInRange(i)) {
        m_NewOrder[i] = new int[0];
        atts.addElement(inputFormat.attribute(i).copy());
        continue;
      }

      // sort labels
      sorted = new Vector<String>();
      for (n = 0; n < att.numValues(); n++) sorted.add(att.value(n));
      Collections.sort(sorted, m_Comparator);

      // determine new indices
      m_NewOrder[i] = new int[att.numValues()];
      values = new FastVector();
      for (n = 0; n < att.numValues(); n++) {
        m_NewOrder[i][n] = sorted.indexOf(att.value(n));
        values.addElement(sorted.get(n));
      }
      attSorted = new Attribute(att.name(), values);
      attSorted.setWeight(att.weight());
      atts.addElement(attSorted);
    }

    // generate new header
    result = new Instances(inputFormat.relationName(), atts, 0);
    result.setClassIndex(inputFormat.classIndex());

    return result;
  }
Esempio n. 8
0
  /**
   * Calculates the n point precision result, which is the precision averaged over n evenly spaced
   * (w.r.t recall) samples of the curve.
   *
   * @param tcurve a previously extracted threshold curve Instances.
   * @param n the number of points to average over.
   * @return the n-point precision.
   */
  public static double getNPointPrecision(Instances tcurve, int n) {

    if (!RELATION_NAME.equals(tcurve.relationName()) || (tcurve.numInstances() == 0)) {
      return Double.NaN;
    }
    int recallInd = tcurve.attribute(RECALL_NAME).index();
    int precisInd = tcurve.attribute(PRECISION_NAME).index();
    double[] recallVals = tcurve.attributeToDoubleArray(recallInd);
    int[] sorted = Utils.sort(recallVals);
    double isize = 1.0 / (n - 1);
    double psum = 0;
    for (int i = 0; i < n; i++) {
      int pos = binarySearch(sorted, recallVals, i * isize);
      double recall = recallVals[sorted[pos]];
      double precis = tcurve.instance(sorted[pos]).value(precisInd);
      /*
      System.err.println("Point " + (i + 1) + ": i=" + pos
                         + " r=" + (i * isize)
                         + " p'=" + precis
                         + " r'=" + recall);
      */
      // interpolate figures for non-endpoints
      while ((pos != 0) && (pos < sorted.length - 1)) {
        pos++;
        double recall2 = recallVals[sorted[pos]];
        if (recall2 != recall) {
          double precis2 = tcurve.instance(sorted[pos]).value(precisInd);
          double slope = (precis2 - precis) / (recall2 - recall);
          double offset = precis - recall * slope;
          precis = isize * i * slope + offset;
          /*
          System.err.println("Point2 " + (i + 1) + ": i=" + pos
                             + " r=" + (i * isize)
                             + " p'=" + precis2
                             + " r'=" + recall2
                             + " p''=" + precis);
          */
          break;
        }
      }
      psum += precis;
    }
    return psum / n;
  }
Esempio n. 9
0
 public boolean batchFinished() throws Exception {
   Instances input = getInputFormat();
   String relation = input.relationName();
   Instances output = new Instances(relation);
   int numAttributes = input.numAttributes();
   int numInstances = input.numInstances();
   for (int i = 0; i < numAttributes; i++) {
     FastVector vector = new FastVector();
     for (int j = 0; j < numInstances; j++) {
       double value = input.instance(j).value(i);
       String string = String.valueOf(value);
       if (vector.indexOf(string) == -1) vector.addElement(string);
     }
     Attribute attribute = new Attribute(input.attribute(i).name(), vector);
     output.appendAttribute(attribute);
   }
   setOutputFormat(output);
   for (int i = 0; i < numInstances; i++) push(input.instance(i));
   return super.batchFinished();
 }
Esempio n. 10
0
  /**
   * Sets the format of output instances. The derived class should use this method once it has
   * determined the outputformat. The output queue is cleared.
   *
   * @param outputFormat the new output format
   */
  protected void setOutputFormat(Instances outputFormat) {

    if (outputFormat != null) {
      m_OutputFormat = outputFormat.stringFreeStructure();
      initOutputLocators(m_OutputFormat, null);

      // Rename the relation
      String relationName = outputFormat.relationName() + "-" + this.getClass().getName();
      if (this instanceof OptionHandler) {
        String[] options = ((OptionHandler) this).getOptions();
        for (int i = 0; i < options.length; i++) {
          relationName += options[i].trim();
        }
      }
      m_OutputFormat.setRelationName(relationName);
    } else {
      m_OutputFormat = null;
    }
    m_OutputQueue = new Queue();
  }
Esempio n. 11
0
  @Override
  protected void manipulateAttributes(Instances data) throws TaskException {

    String[] attribs = this.getParameterVal(ATTRIBS).split("\\s+");

    for (int i = 0; i < attribs.length; ++i) {
      if (data.attribute(attribs[i]) == null) {
        Logger.getInstance()
            .message(
                "Attribute " + attribs[i] + " not found in data set " + data.relationName(),
                Logger.V_WARNING);
        continue;
      }
      Enumeration<Instance> insts = data.enumerateInstances();
      int attrIndex = data.attribute(attribs[i]).index();

      while (insts.hasMoreElements()) {
        Instance inst = insts.nextElement();
        inst.setMissing(attrIndex);
      }
    }
  }
Esempio n. 12
0
  public void fillWekaInstances(weka.core.Instances winsts) {
    // set name
    setName(winsts.relationName());
    // set attributes
    List onto_attrs = new ArrayList();
    for (int i = 0; i < winsts.numAttributes(); i++) {
      Attribute a = new Attribute();
      a.fillWekaAttribute(winsts.attribute(i));
      onto_attrs.add(a);
    }
    setAttributes(onto_attrs);

    // set instances
    List onto_insts = new ArrayList();
    for (int i = 0; i < winsts.numInstances(); i++) {
      Instance inst = new Instance();
      weka.core.Instance winst = winsts.instance(i);

      List instvalues = new ArrayList();
      List instmis = new ArrayList();
      for (int j = 0; j < winst.numValues(); j++) {
        if (winst.isMissing(j)) {
          instvalues.add(new Double(0.0));
          instmis.add(new Boolean(true));
        } else {
          instvalues.add(new Double(winst.value(j)));
          instmis.add(new Boolean(false));
        }
      }

      inst.setValues(instvalues);
      inst.setMissing(instmis);
      onto_insts.add(inst);
    }
    setInstances(onto_insts);
    setClass_index(winsts.classIndex());
  }
Esempio n. 13
0
  /**
   * Determines the output format based on the input format and returns this. In case the output
   * format cannot be returned immediately, i.e., hasImmediateOutputFormat() returns false, then
   * this method will called from batchFinished() after the call of preprocess(Instances), in which,
   * e.g., statistics for the actual processing step can be gathered.
   *
   * @param inputFormat the input format to base the output format on
   * @return the output format
   * @throws Exception in case the determination goes wrong
   */
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {
    Instances result;
    Attribute att;
    ArrayList<Attribute> atts;
    int i;

    m_AttributeIndices.setUpper(inputFormat.numAttributes() - 1);

    // generate new header
    atts = new ArrayList<Attribute>();
    for (i = 0; i < inputFormat.numAttributes(); i++) {
      att = inputFormat.attribute(i);
      if (m_AttributeIndices.isInRange(i)) {
        if (m_ReplaceAll) atts.add(att.copy(att.name().replaceAll(m_Find, m_Replace)));
        else atts.add(att.copy(att.name().replaceFirst(m_Find, m_Replace)));
      } else {
        atts.add((Attribute) att.copy());
      }
    }
    result = new Instances(inputFormat.relationName(), atts, 0);
    result.setClassIndex(inputFormat.classIndex());

    return result;
  }
  /**
   * Sets the format of the input instances.
   *
   * @param instanceInfo an Instances object containing the input instance structure (any instances
   *     contained in the object are ignored - only the structure is required).
   * @return true if the outputFormat may be collected immediately
   * @throws Exception if the format couldn't be set successfully
   */
  @Override
  public boolean setInputFormat(Instances instanceInfo) throws Exception {

    super.setInputFormat(instanceInfo);

    int classIndex = instanceInfo.classIndex();

    // setup the map
    if (m_renameVals != null && m_renameVals.length() > 0) {
      String[] vals = m_renameVals.split(",");

      for (String val : vals) {
        String[] parts = val.split(":");
        if (parts.length != 2) {
          throw new WekaException("Invalid replacement string: " + val);
        }

        if (parts[0].length() == 0 || parts[1].length() == 0) {
          throw new WekaException("Invalid replacement string: " + val);
        }

        m_renameMap.put(
            m_ignoreCase ? parts[0].toLowerCase().trim() : parts[0].trim(), parts[1].trim());
      }
    }

    // try selected atts as a numeric range first
    Range tempRange = new Range();
    tempRange.setInvert(m_invert);
    if (m_selectedColsString == null) {
      m_selectedColsString = "";
    }

    try {
      tempRange.setRanges(m_selectedColsString);
      tempRange.setUpper(instanceInfo.numAttributes() - 1);
      m_selectedAttributes = tempRange.getSelection();
      m_selectedCols = tempRange;
    } catch (Exception r) {
      // OK, now try as named attributes
      StringBuffer indexes = new StringBuffer();
      String[] attNames = m_selectedColsString.split(",");
      boolean first = true;
      for (String n : attNames) {
        n = n.trim();
        Attribute found = instanceInfo.attribute(n);
        if (found == null) {
          throw new WekaException(
              "Unable to find attribute '" + n + "' in the incoming instances'");
        }
        if (first) {
          indexes.append("" + (found.index() + 1));
          first = false;
        } else {
          indexes.append("," + (found.index() + 1));
        }
      }

      tempRange = new Range();
      tempRange.setRanges(indexes.toString());
      tempRange.setUpper(instanceInfo.numAttributes() - 1);
      m_selectedAttributes = tempRange.getSelection();
      m_selectedCols = tempRange;
    }

    ArrayList<Attribute> attributes = new ArrayList<Attribute>();
    for (int i = 0; i < instanceInfo.numAttributes(); i++) {
      if (m_selectedCols.isInRange(i)) {
        if (instanceInfo.attribute(i).isNominal()) {
          List<String> valsForAtt = new ArrayList<String>();
          for (int j = 0; j < instanceInfo.attribute(i).numValues(); j++) {
            String origV = instanceInfo.attribute(i).value(j);

            String replace =
                m_ignoreCase ? m_renameMap.get(origV.toLowerCase()) : m_renameMap.get(origV);
            if (replace != null && !valsForAtt.contains(replace)) {
              valsForAtt.add(replace);
            } else {
              valsForAtt.add(origV);
            }
          }
          Attribute newAtt = new Attribute(instanceInfo.attribute(i).name(), valsForAtt);
          attributes.add(newAtt);
        } else {
          // ignore any selected attributes that are not nominal
          Attribute att = (Attribute) instanceInfo.attribute(i).copy();
          attributes.add(att);
        }
      } else {
        Attribute att = (Attribute) instanceInfo.attribute(i).copy();
        attributes.add(att);
      }
    }

    Instances outputFormat = new Instances(instanceInfo.relationName(), attributes, 0);
    outputFormat.setClassIndex(classIndex);
    setOutputFormat(outputFormat);

    return true;
  }
  @Override
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {

    if (m_excludeNominalAttributes && m_excludeNumericAttributes) {
      throw new Exception(
          "No transformation will be done if both nominal and "
              + "numeric attributes are excluded!");
    }

    if (m_remove == null) {
      List<Integer> attsToExclude = new ArrayList<Integer>();
      if (m_excludeNumericAttributes) {
        for (int i = 0; i < inputFormat.numAttributes(); i++) {
          if (inputFormat.attribute(i).isNumeric() && i != inputFormat.classIndex()) {
            attsToExclude.add(i);
          }
        }
      }

      if (m_excludeNominalAttributes || m_nominalConversionThreshold > 1) {
        for (int i = 0; i < inputFormat.numAttributes(); i++) {
          if (inputFormat.attribute(i).isNominal() && i != inputFormat.classIndex()) {
            if (m_excludeNominalAttributes
                || inputFormat.attribute(i).numValues() < m_nominalConversionThreshold) {
              attsToExclude.add(i);
            }
          }
        }
      }

      if (attsToExclude.size() > 0) {
        int[] r = new int[attsToExclude.size()];
        for (int i = 0; i < attsToExclude.size(); i++) {
          r[i] = attsToExclude.get(i);
        }
        m_remove = new Remove();
        m_remove.setAttributeIndicesArray(r);
        m_remove.setInputFormat(inputFormat);

        Remove forRetaining = new Remove();
        forRetaining.setAttributeIndicesArray(r);
        forRetaining.setInvertSelection(true);
        forRetaining.setInputFormat(inputFormat);
        m_unchanged = Filter.useFilter(inputFormat, forRetaining);
      }
    }

    ArrayList<Attribute> atts = new ArrayList<Attribute>();
    for (int i = 0; i < inputFormat.numAttributes(); i++) {
      if (i != inputFormat.classIndex()) {
        if (m_unchanged != null && m_unchanged.attribute(inputFormat.attribute(i).name()) != null) {
          atts.add((Attribute) m_unchanged.attribute(inputFormat.attribute(i).name()).copy());
          continue;
        }

        for (int j = 0; j < inputFormat.classAttribute().numValues(); j++) {
          String name =
              "pr_" + inputFormat.attribute(i).name() + "|" + inputFormat.classAttribute().value(j);
          atts.add(new Attribute(name));
        }
      }
    }

    atts.add((Attribute) inputFormat.classAttribute().copy());
    Instances data = new Instances(inputFormat.relationName(), atts, 0);
    data.setClassIndex(data.numAttributes() - 1);

    return data;
  }
Esempio n. 16
0
  /**
   * generates source code from the filter
   *
   * @param filter the filter to output as source
   * @param className the name of the generated class
   * @param input the input data the header is generated for
   * @param output the output data the header is generated for
   * @return the generated source code
   * @throws Exception if source code cannot be generated
   */
  public static String wekaStaticWrapper(
      Sourcable filter, String className, Instances input, Instances output) throws Exception {

    StringBuffer result;
    int i;
    int n;

    result = new StringBuffer();

    result.append("// Generated with Weka " + Version.VERSION + "\n");
    result.append("//\n");
    result.append("// This code is public domain and comes with no warranty.\n");
    result.append("//\n");
    result.append("// Timestamp: " + new Date() + "\n");
    result.append("// Relation: " + input.relationName() + "\n");
    result.append("\n");

    result.append("package weka.filters;\n");
    result.append("\n");
    result.append("import weka.core.Attribute;\n");
    result.append("import weka.core.Capabilities;\n");
    result.append("import weka.core.Capabilities.Capability;\n");
    result.append("import weka.core.FastVector;\n");
    result.append("import weka.core.Instance;\n");
    result.append("import weka.core.Instances;\n");
    result.append("import weka.filters.Filter;\n");
    result.append("\n");
    result.append("public class WekaWrapper\n");
    result.append("  extends Filter {\n");

    // globalInfo
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Returns only the toString() method.\n");
    result.append("   *\n");
    result.append("   * @return a string describing the filter\n");
    result.append("   */\n");
    result.append("  public String globalInfo() {\n");
    result.append("    return toString();\n");
    result.append("  }\n");

    // getCapabilities
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Returns the capabilities of this filter.\n");
    result.append("   *\n");
    result.append("   * @return the capabilities\n");
    result.append("   */\n");
    result.append("  public Capabilities getCapabilities() {\n");
    result.append(((Filter) filter).getCapabilities().toSource("result", 4));
    result.append("    return result;\n");
    result.append("  }\n");

    // objectsToInstance
    result.append("\n");
    result.append("  /**\n");
    result.append("   * turns array of Objects into an Instance object\n");
    result.append("   *\n");
    result.append("   * @param obj	the Object array to turn into an Instance\n");
    result.append("   * @param format	the data format to use\n");
    result.append("   * @return		the generated Instance object\n");
    result.append("   */\n");
    result.append("  protected Instance objectsToInstance(Object[] obj, Instances format) {\n");
    result.append("    Instance		result;\n");
    result.append("    double[]		values;\n");
    result.append("    int		i;\n");
    result.append("\n");
    result.append("    values = new double[obj.length];\n");
    result.append("\n");
    result.append("    for (i = 0 ; i < obj.length; i++) {\n");
    result.append("      if (obj[i] == null)\n");
    result.append("        values[i] = Instance.missingValue();\n");
    result.append("      else if (format.attribute(i).isNumeric())\n");
    result.append("        values[i] = (Double) obj[i];\n");
    result.append("      else if (format.attribute(i).isNominal())\n");
    result.append("        values[i] = format.attribute(i).indexOfValue((String) obj[i]);\n");
    result.append("    }\n");
    result.append("\n");
    result.append("    // create new instance\n");
    result.append("    result = new Instance(1.0, values);\n");
    result.append("    result.setDataset(format);\n");
    result.append("\n");
    result.append("    return result;\n");
    result.append("  }\n");

    // instanceToObjects
    result.append("\n");
    result.append("  /**\n");
    result.append("   * turns the Instance object into an array of Objects\n");
    result.append("   *\n");
    result.append("   * @param inst	the instance to turn into an array\n");
    result.append("   * @return		the Object array representing the instance\n");
    result.append("   */\n");
    result.append("  protected Object[] instanceToObjects(Instance inst) {\n");
    result.append("    Object[]	result;\n");
    result.append("    int		i;\n");
    result.append("\n");
    result.append("    result = new Object[inst.numAttributes()];\n");
    result.append("\n");
    result.append("    for (i = 0 ; i < inst.numAttributes(); i++) {\n");
    result.append("      if (inst.isMissing(i))\n");
    result.append("  	result[i] = null;\n");
    result.append("      else if (inst.attribute(i).isNumeric())\n");
    result.append("  	result[i] = inst.value(i);\n");
    result.append("      else\n");
    result.append("  	result[i] = inst.stringValue(i);\n");
    result.append("    }\n");
    result.append("\n");
    result.append("    return result;\n");
    result.append("  }\n");

    // instancesToObjects
    result.append("\n");
    result.append("  /**\n");
    result.append("   * turns the Instances object into an array of Objects\n");
    result.append("   *\n");
    result.append("   * @param data	the instances to turn into an array\n");
    result.append("   * @return		the Object array representing the instances\n");
    result.append("   */\n");
    result.append("  protected Object[][] instancesToObjects(Instances data) {\n");
    result.append("    Object[][]	result;\n");
    result.append("    int		i;\n");
    result.append("\n");
    result.append("    result = new Object[data.numInstances()][];\n");
    result.append("\n");
    result.append("    for (i = 0; i < data.numInstances(); i++)\n");
    result.append("      result[i] = instanceToObjects(data.instance(i));\n");
    result.append("\n");
    result.append("    return result;\n");
    result.append("  }\n");

    // setInputFormat
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Only tests the input data.\n");
    result.append("   *\n");
    result.append("   * @param instanceInfo the format of the data to convert\n");
    result.append("   * @return always true, to indicate that the output format can \n");
    result.append("   *         be collected immediately.\n");
    result.append("   */\n");
    result.append("  public boolean setInputFormat(Instances instanceInfo) throws Exception {\n");
    result.append("    super.setInputFormat(instanceInfo);\n");
    result.append("    \n");
    result.append("    // generate output format\n");
    result.append("    FastVector atts = new FastVector();\n");
    result.append("    FastVector attValues;\n");
    for (i = 0; i < output.numAttributes(); i++) {
      result.append("    // " + output.attribute(i).name() + "\n");
      if (output.attribute(i).isNumeric()) {
        result.append(
            "    atts.addElement(new Attribute(\"" + output.attribute(i).name() + "\"));\n");
      } else if (output.attribute(i).isNominal()) {
        result.append("    attValues = new FastVector();\n");
        for (n = 0; n < output.attribute(i).numValues(); n++) {
          result.append("    attValues.addElement(\"" + output.attribute(i).value(n) + "\");\n");
        }
        result.append(
            "    atts.addElement(new Attribute(\""
                + output.attribute(i).name()
                + "\", attValues));\n");
      } else {
        throw new UnsupportedAttributeTypeException(
            "Attribute type '"
                + output.attribute(i).type()
                + "' (position "
                + (i + 1)
                + ") is not supported!");
      }
    }
    result.append("    \n");
    result.append(
        "    Instances format = new Instances(\"" + output.relationName() + "\", atts, 0);\n");
    result.append("    format.setClassIndex(" + output.classIndex() + ");\n");
    result.append("    setOutputFormat(format);\n");
    result.append("    \n");
    result.append("    return true;\n");
    result.append("  }\n");

    // input
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Directly filters the instance.\n");
    result.append("   *\n");
    result.append("   * @param instance the instance to convert\n");
    result.append("   * @return always true, to indicate that the output can \n");
    result.append("   *         be collected immediately.\n");
    result.append("   */\n");
    result.append("  public boolean input(Instance instance) throws Exception {\n");
    result.append(
        "    Object[] filtered = " + className + ".filter(instanceToObjects(instance));\n");
    result.append("    push(objectsToInstance(filtered, getOutputFormat()));\n");
    result.append("    return true;\n");
    result.append("  }\n");

    // batchFinished
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Performs a batch filtering of the buffered data, if any available.\n");
    result.append("   *\n");
    result.append("   * @return true if instances were filtered otherwise false\n");
    result.append("   */\n");
    result.append("  public boolean batchFinished() throws Exception {\n");
    result.append("    if (getInputFormat() == null)\n");
    result.append("      throw new NullPointerException(\"No input instance format defined\");;\n");
    result.append("\n");
    result.append("    Instances inst = getInputFormat();\n");
    result.append("    if (inst.numInstances() > 0) {\n");
    result.append(
        "      Object[][] filtered = " + className + ".filter(instancesToObjects(inst));\n");
    result.append("      for (int i = 0; i < filtered.length; i++) {\n");
    result.append("        push(objectsToInstance(filtered[i], getOutputFormat()));\n");
    result.append("      }\n");
    result.append("    }\n");
    result.append("\n");
    result.append("    flushInput();\n");
    result.append("    m_NewBatch = true;\n");
    result.append("    m_FirstBatchDone = true;\n");
    result.append("\n");
    result.append("    return (inst.numInstances() > 0);\n");
    result.append("  }\n");

    // toString
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Returns only the classnames and what filter it is based on.\n");
    result.append("   *\n");
    result.append("   * @return a short description\n");
    result.append("   */\n");
    result.append("  public String toString() {\n");
    result.append(
        "    return \"Auto-generated filter wrapper, based on "
            + filter.getClass().getName()
            + " (generated with Weka "
            + Version.VERSION
            + ").\\n"
            + "\" + this.getClass().getName() + \"/"
            + className
            + "\";\n");
    result.append("  }\n");

    // main
    result.append("\n");
    result.append("  /**\n");
    result.append("   * Runs the filter from commandline.\n");
    result.append("   *\n");
    result.append("   * @param args the commandline arguments\n");
    result.append("   */\n");
    result.append("  public static void main(String args[]) {\n");
    result.append("    runFilter(new WekaWrapper(), args);\n");
    result.append("  }\n");
    result.append("}\n");

    // actual filter code
    result.append("\n");
    result.append(filter.toSource(className, input));

    return result.toString();
  }
  /**
   * Takes an evaluation object from a task and aggregates it with the overall one.
   *
   * @param eval the evaluation object to aggregate
   * @param classifier the classifier used by the task
   * @param testData the testData from the task
   * @param plotInstances the ClassifierErrorsPlotInstances object from the task
   * @param setNum the set number processed by the task
   * @param maxSetNum the maximum number of sets in this batch
   */
  protected synchronized void aggregateEvalTask(
      Evaluation eval,
      Classifier classifier,
      Instances testData,
      ClassifierErrorsPlotInstances plotInstances,
      int setNum,
      int maxSetNum) {

    m_eval.aggregate(eval);

    if (m_aggregatedPlotInstances == null) {
      m_aggregatedPlotInstances = new Instances(plotInstances.getPlotInstances());
      m_aggregatedPlotShapes = plotInstances.getPlotShapes();
      m_aggregatedPlotSizes = plotInstances.getPlotSizes();
    } else {
      Instances temp = plotInstances.getPlotInstances();
      for (int i = 0; i < temp.numInstances(); i++) {
        m_aggregatedPlotInstances.add(temp.get(i));
        m_aggregatedPlotShapes.addElement(plotInstances.getPlotShapes().get(i));
        m_aggregatedPlotSizes.addElement(plotInstances.getPlotSizes().get(i));
      }
    }
    m_setsComplete++;

    //  if (ce.getSetNumber() == ce.getMaxSetNumber()) {
    if (m_setsComplete == maxSetNum) {
      try {
        String textTitle = classifier.getClass().getName();
        String textOptions = "";
        if (classifier instanceof OptionHandler) {
          textOptions = Utils.joinOptions(((OptionHandler) classifier).getOptions());
        }
        textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length());
        String resultT =
            "=== Evaluation result ===\n\n"
                + "Scheme: "
                + textTitle
                + "\n"
                + ((textOptions.length() > 0) ? "Options: " + textOptions + "\n" : "")
                + "Relation: "
                + testData.relationName()
                + "\n\n"
                + m_eval.toSummaryString();

        if (testData.classAttribute().isNominal()) {
          resultT += "\n" + m_eval.toClassDetailsString() + "\n" + m_eval.toMatrixString();
        }

        TextEvent te = new TextEvent(ClassifierPerformanceEvaluator.this, resultT, textTitle);
        notifyTextListeners(te);

        // set up visualizable errors
        if (m_visualizableErrorListeners.size() > 0) {
          PlotData2D errorD = new PlotData2D(m_aggregatedPlotInstances);
          errorD.setShapeSize(m_aggregatedPlotSizes);
          errorD.setShapeType(m_aggregatedPlotShapes);
          errorD.setPlotName(textTitle + " " + textOptions);

          /*          PlotData2D errorD = m_PlotInstances.getPlotData(
          textTitle + " " + textOptions); */
          VisualizableErrorEvent vel =
              new VisualizableErrorEvent(ClassifierPerformanceEvaluator.this, errorD);
          notifyVisualizableErrorListeners(vel);
          m_PlotInstances.cleanUp();
        }

        if (testData.classAttribute().isNominal() && m_thresholdListeners.size() > 0) {
          ThresholdCurve tc = new ThresholdCurve();
          Instances result = tc.getCurve(m_eval.predictions(), 0);
          result.setRelationName(testData.relationName());
          PlotData2D pd = new PlotData2D(result);
          String htmlTitle = "<html><font size=-2>" + textTitle;
          String newOptions = "";
          if (classifier instanceof OptionHandler) {
            String[] options = ((OptionHandler) classifier).getOptions();
            if (options.length > 0) {
              for (int ii = 0; ii < options.length; ii++) {
                if (options[ii].length() == 0) {
                  continue;
                }
                if (options[ii].charAt(0) == '-'
                    && !(options[ii].charAt(1) >= '0' && options[ii].charAt(1) <= '9')) {
                  newOptions += "<br>";
                }
                newOptions += options[ii];
              }
            }
          }

          htmlTitle +=
              " "
                  + newOptions
                  + "<br>"
                  + " (class: "
                  + testData.classAttribute().value(0)
                  + ")"
                  + "</font></html>";
          pd.setPlotName(textTitle + " (class: " + testData.classAttribute().value(0) + ")");
          pd.setPlotNameHTML(htmlTitle);
          boolean[] connectPoints = new boolean[result.numInstances()];
          for (int jj = 1; jj < connectPoints.length; jj++) {
            connectPoints[jj] = true;
          }

          pd.setConnectPoints(connectPoints);

          ThresholdDataEvent rde =
              new ThresholdDataEvent(
                  ClassifierPerformanceEvaluator.this, pd, testData.classAttribute());
          notifyThresholdListeners(rde);
        }
        if (m_logger != null) {
          m_logger.statusMessage(statusMessagePrefix() + "Finished.");
        }

      } catch (Exception ex) {
        if (m_logger != null) {
          m_logger.logMessage(
              "[ClassifierPerformanceEvaluator] "
                  + statusMessagePrefix()
                  + " problem constructing evaluation results. "
                  + ex.getMessage());
        }
        ex.printStackTrace();
      } finally {
        m_visual.setStatic();
        // save memory
        m_PlotInstances = null;
        m_setsComplete = 0;
        m_tasks = null;
        m_aggregatedPlotInstances = null;
      }
    }
  }
Esempio n. 18
0
  /**
   * Determines the output format based on the input format and returns this. In case the output
   * format cannot be returned immediately, i.e., hasImmediateOutputFormat() returns false, then
   * this method will called from batchFinished() after the call of preprocess(Instances), in which,
   * e.g., statistics for the actual processing step can be gathered.
   *
   * @param inputFormat the input format to base the output format on
   * @return the output format
   * @throws Exception in case the determination goes wrong
   * @see #hasImmediateOutputFormat()
   * @see #batchFinished()
   */
  protected Instances determineOutputFormat(Instances inputFormat) throws Exception {

    FastVector atts;
    FastVector values;
    Instances result;
    int i;

    // attributes must be numeric
    m_Attributes.setUpper(inputFormat.numAttributes() - 1);
    m_AttributeIndices = m_Attributes.getSelection();
    for (i = 0; i < m_AttributeIndices.length; i++) {
      // ignore class
      if (m_AttributeIndices[i] == inputFormat.classIndex()) {
        m_AttributeIndices[i] = NON_NUMERIC;
        continue;
      }
      // not numeric -> ignore it
      if (!inputFormat.attribute(m_AttributeIndices[i]).isNumeric())
        m_AttributeIndices[i] = NON_NUMERIC;
    }

    // get old attributes
    atts = new FastVector();
    for (i = 0; i < inputFormat.numAttributes(); i++) atts.addElement(inputFormat.attribute(i));

    if (!getDetectionPerAttribute()) {
      m_OutlierAttributePosition = new int[1];
      m_OutlierAttributePosition[0] = atts.size();

      // add 2 new attributes
      values = new FastVector();
      values.addElement("no");
      values.addElement("yes");
      atts.addElement(new Attribute("Outlier", values));

      values = new FastVector();
      values.addElement("no");
      values.addElement("yes");
      atts.addElement(new Attribute("ExtremeValue", values));
    } else {
      m_OutlierAttributePosition = new int[m_AttributeIndices.length];

      for (i = 0; i < m_AttributeIndices.length; i++) {
        if (m_AttributeIndices[i] == NON_NUMERIC) continue;

        m_OutlierAttributePosition[i] = atts.size();

        // add new attributes
        values = new FastVector();
        values.addElement("no");
        values.addElement("yes");
        atts.addElement(
            new Attribute(
                inputFormat.attribute(m_AttributeIndices[i]).name() + "_Outlier", values));

        values = new FastVector();
        values.addElement("no");
        values.addElement("yes");
        atts.addElement(
            new Attribute(
                inputFormat.attribute(m_AttributeIndices[i]).name() + "_ExtremeValue", values));

        if (getOutputOffsetMultiplier())
          atts.addElement(
              new Attribute(inputFormat.attribute(m_AttributeIndices[i]).name() + "_Offset"));
      }
    }

    // generate header
    result = new Instances(inputFormat.relationName(), atts, 0);
    result.setClassIndex(inputFormat.classIndex());

    return result;
  }
Esempio n. 19
0
  /**
   * pads the data to conform to the necessary number of attributes
   *
   * @param data the data to pad
   * @return the padded data
   */
  protected Instances pad(Instances data) {
    Instances result;
    int i;
    int n;
    String prefix;
    int numAtts;
    boolean isLast;
    int index;
    Vector<Integer> padded;
    int[] indices;
    FastVector atts;

    // determine number of padding attributes
    switch (m_Padding) {
      case PADDING_ZERO:
        if (data.classIndex() > -1)
          numAtts = (nextPowerOf2(data.numAttributes() - 1) + 1) - data.numAttributes();
        else numAtts = nextPowerOf2(data.numAttributes()) - data.numAttributes();
        break;

      default:
        throw new IllegalStateException(
            "Padding " + new SelectedTag(m_Algorithm, TAGS_PADDING) + " not implemented!");
    }

    result = new Instances(data);
    prefix = getAlgorithm().getSelectedTag().getReadable();

    // any padding necessary?
    if (numAtts > 0) {
      // add padding attributes
      isLast = (data.classIndex() == data.numAttributes() - 1);
      padded = new Vector<Integer>();
      for (i = 0; i < numAtts; i++) {
        if (isLast) index = result.numAttributes() - 1;
        else index = result.numAttributes();

        result.insertAttributeAt(new Attribute(prefix + "_padding_" + (i + 1)), index);

        // record index
        padded.add(new Integer(index));
      }

      // get padded indices
      indices = new int[padded.size()];
      for (i = 0; i < padded.size(); i++) indices[i] = padded.get(i);

      // determine number of padding attributes
      switch (m_Padding) {
        case PADDING_ZERO:
          for (i = 0; i < result.numInstances(); i++) {
            for (n = 0; n < indices.length; n++) result.instance(i).setValue(indices[n], 0);
          }
          break;
      }
    }

    // rename all attributes apart from class
    data = result;
    atts = new FastVector();
    n = 0;
    for (i = 0; i < data.numAttributes(); i++) {
      n++;
      if (i == data.classIndex()) atts.addElement((Attribute) data.attribute(i).copy());
      else atts.addElement(new Attribute(prefix + "_" + n));
    }

    // create new dataset
    result = new Instances(data.relationName(), atts, data.numInstances());
    result.setClassIndex(data.classIndex());
    for (i = 0; i < data.numInstances(); i++)
      result.add(new DenseInstance(1.0, data.instance(i).toDoubleArray()));

    return result;
  }