예제 #1
0
  /**
   * Classifies the given test instance. The instance has to belong to a dataset when it's being
   * classified. Note that a classifier MUST implement either this or distributionForInstance().
   *
   * @param instance the instance to be classified
   * @return the predicted most likely class for the instance or Utils.missingValue() if no
   *     prediction is made
   * @exception Exception if an error occurred during the prediction
   */
  @Override
  public double classifyInstance(Instance instance) throws Exception {

    double[] dist = distributionForInstance(instance);
    if (dist == null) {
      throw new Exception("Null distribution predicted");
    }
    switch (instance.classAttribute().type()) {
      case Attribute.NOMINAL:
        double max = 0;
        int maxIndex = 0;

        for (int i = 0; i < dist.length; i++) {
          if (dist[i] > max) {
            maxIndex = i;
            max = dist[i];
          }
        }
        if (max > 0) {
          return maxIndex;
        } else {
          return Utils.missingValue();
        }
      case Attribute.NUMERIC:
      case Attribute.DATE:
        return dist[0];
      default:
        return Utils.missingValue();
    }
  }
예제 #2
0
  /**
   * Convert a single instance over. The converted instance is added to the end of the output queue.
   *
   * @param instance the instance to convert
   */
  protected void convertInstance(Instance instance) {

    int index = 0;
    double[] vals = new double[outputFormatPeek().numAttributes()];
    // Copy and convert the values
    for (int i = 0; i < getInputFormat().numAttributes(); i++) {
      if (m_DiscretizeCols.isInRange(i) && getInputFormat().attribute(i).isNumeric()) {
        int j;
        double currentVal = instance.value(i);
        if (m_CutPoints[i] == null) {
          if (instance.isMissing(i)) {
            vals[index] = Utils.missingValue();
          } else {
            vals[index] = 0;
          }
          index++;
        } else {
          if (!m_MakeBinary) {
            if (instance.isMissing(i)) {
              vals[index] = Utils.missingValue();
            } else {
              for (j = 0; j < m_CutPoints[i].length; j++) {
                if (currentVal <= m_CutPoints[i][j]) {
                  break;
                }
              }
              vals[index] = j;
            }
            index++;
          } else {
            for (j = 0; j < m_CutPoints[i].length; j++) {
              if (instance.isMissing(i)) {
                vals[index] = Utils.missingValue();
              } else if (currentVal <= m_CutPoints[i][j]) {
                vals[index] = 0;
              } else {
                vals[index] = 1;
              }
              index++;
            }
          }
        }
      } else {
        vals[index] = instance.value(i);
        index++;
      }
    }

    Instance inst = null;
    if (instance instanceof SparseInstance) {
      inst = new SparseInstance(instance.weight(), vals);
    } else {
      inst = new DenseInstance(instance.weight(), vals);
    }
    inst.setDataset(getOutputFormat());
    copyValues(inst, false, instance.dataset(), getOutputFormat());
    inst.setDataset(getOutputFormat());
    push(inst);
  }
  /**
   * Adds the prediction intervals as additional attributes at the end. Since classifiers can
   * returns varying number of intervals per instance, the dataset is filled with missing values for
   * non-existing intervals.
   */
  protected void addPredictionIntervals() {
    int maxNum;
    int num;
    int i;
    int n;
    FastVector preds;
    FastVector atts;
    Instances data;
    Instance inst;
    Instance newInst;
    double[] values;
    double[][] predInt;

    // determine the maximum number of intervals
    maxNum = 0;
    preds = m_Evaluation.predictions();
    for (i = 0; i < preds.size(); i++) {
      num = ((NumericPrediction) preds.elementAt(i)).predictionIntervals().length;
      if (num > maxNum) maxNum = num;
    }

    // create new header
    atts = new FastVector();
    for (i = 0; i < m_PlotInstances.numAttributes(); i++)
      atts.addElement(m_PlotInstances.attribute(i));
    for (i = 0; i < maxNum; i++) {
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-lowerBoundary"));
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-upperBoundary"));
      atts.addElement(new Attribute("predictionInterval_" + (i + 1) + "-width"));
    }
    data = new Instances(m_PlotInstances.relationName(), atts, m_PlotInstances.numInstances());
    data.setClassIndex(m_PlotInstances.classIndex());

    // update data
    for (i = 0; i < m_PlotInstances.numInstances(); i++) {
      inst = m_PlotInstances.instance(i);
      // copy old values
      values = new double[data.numAttributes()];
      System.arraycopy(inst.toDoubleArray(), 0, values, 0, inst.numAttributes());
      // add interval data
      predInt = ((NumericPrediction) preds.elementAt(i)).predictionIntervals();
      for (n = 0; n < maxNum; n++) {
        if (n < predInt.length) {
          values[m_PlotInstances.numAttributes() + n * 3 + 0] = predInt[n][0];
          values[m_PlotInstances.numAttributes() + n * 3 + 1] = predInt[n][1];
          values[m_PlotInstances.numAttributes() + n * 3 + 2] = predInt[n][1] - predInt[n][0];
        } else {
          values[m_PlotInstances.numAttributes() + n * 3 + 0] = Utils.missingValue();
          values[m_PlotInstances.numAttributes() + n * 3 + 1] = Utils.missingValue();
          values[m_PlotInstances.numAttributes() + n * 3 + 2] = Utils.missingValue();
        }
      }
      // create new Instance
      newInst = new DenseInstance(inst.weight(), values);
      data.add(newInst);
    }

    m_PlotInstances = data;
  }
예제 #4
0
  /**
   * Constructs an instance suitable for passing to the model for scoring
   *
   * @param incoming the incoming instance
   * @return an instance with values mapped to be consistent with what the model is expecting
   */
  protected Instance mapIncomingFieldsToModelFields(Instance incoming) {
    Instances modelHeader = m_model.getHeader();
    double[] vals = new double[modelHeader.numAttributes()];

    for (int i = 0; i < modelHeader.numAttributes(); i++) {

      if (m_attributeMap[i] < 0) {
        // missing or type mismatch
        vals[i] = Utils.missingValue();
        continue;
      }

      Attribute modelAtt = modelHeader.attribute(i);
      Attribute incomingAtt = incoming.dataset().attribute(m_attributeMap[i]);

      if (incoming.isMissing(incomingAtt.index())) {
        vals[i] = Utils.missingValue();
        continue;
      }

      if (modelAtt.isNumeric()) {
        vals[i] = incoming.value(m_attributeMap[i]);
      } else if (modelAtt.isNominal()) {
        String incomingVal = incoming.stringValue(m_attributeMap[i]);
        int modelIndex = modelAtt.indexOfValue(incomingVal);

        if (modelIndex < 0) {
          vals[i] = Utils.missingValue();
        } else {
          vals[i] = modelIndex;
        }
      } else if (modelAtt.isString()) {
        vals[i] = 0;
        modelAtt.setStringValue(incoming.stringValue(m_attributeMap[i]));
      }
    }

    if (modelHeader.classIndex() >= 0) {
      // set class to missing value
      vals[modelHeader.classIndex()] = Utils.missingValue();
    }

    Instance newInst = null;
    if (incoming instanceof SparseInstance) {
      newInst = new SparseInstance(incoming.weight(), vals);
    } else {
      newInst = new DenseInstance(incoming.weight(), vals);
    }

    newInst.setDataset(modelHeader);
    return newInst;
  }
  /**
   * Convert an input instance
   *
   * @param current the input instance to convert
   * @return a transformed instance
   * @throws Exception if a problem occurs
   */
  protected Instance convertInstance(Instance current) throws Exception {
    double[] vals = new double[getOutputFormat().numAttributes()];
    int index = 0;
    for (int j = 0; j < current.numAttributes(); j++) {
      if (j != current.classIndex()) {
        if (m_unchanged != null && m_unchanged.attribute(current.attribute(j).name()) != null) {
          vals[index++] = current.value(j);
        } else {
          Estimator[] estForAtt = m_estimatorLookup.get(current.attribute(j).name());
          for (int k = 0; k < current.classAttribute().numValues(); k++) {
            if (current.isMissing(j)) {
              vals[index++] = Utils.missingValue();
            } else {
              double e = estForAtt[k].getProbability(current.value(j));
              vals[index++] = e;
            }
          }
        }
      }
    }

    vals[vals.length - 1] = current.classValue();
    DenseInstance instNew = new DenseInstance(current.weight(), vals);

    return instNew;
  }
예제 #6
0
  /**
   * generates an instance out of the given data
   *
   * @param tc the statistics
   * @param prob the probability
   * @return the generated instance
   */
  private Instance makeInstance(TwoClassStats tc, double prob) {

    int count = 0;
    double[] vals = new double[13];
    vals[count++] = tc.getTruePositive();
    vals[count++] = tc.getFalseNegative();
    vals[count++] = tc.getFalsePositive();
    vals[count++] = tc.getTrueNegative();
    vals[count++] = tc.getFalsePositiveRate();
    vals[count++] = tc.getTruePositiveRate();
    vals[count++] = tc.getPrecision();
    vals[count++] = tc.getRecall();
    vals[count++] = tc.getFallout();
    vals[count++] = tc.getFMeasure();
    double ss =
        (tc.getTruePositive() + tc.getFalsePositive())
            / (tc.getTruePositive()
                + tc.getFalsePositive()
                + tc.getTrueNegative()
                + tc.getFalseNegative());
    vals[count++] = ss;
    double expectedByChance = (ss * (tc.getTruePositive() + tc.getFalseNegative()));
    if (expectedByChance < 1) {
      vals[count++] = Utils.missingValue();
    } else {
      vals[count++] = tc.getTruePositive() / expectedByChance;
    }
    vals[count++] = prob;
    return new DenseInstance(1.0, vals);
  }
예제 #7
0
  /**
   * Calculates the area under the precision-recall curve (AUPRC).
   *
   * @param tcurve a previously extracted threshold curve Instances.
   * @return the PRC area, or Double.NaN if you don't pass in a ThresholdCurve generated Instances.
   */
  public static double getPRCArea(Instances tcurve) {
    final int n = tcurve.numInstances();
    if (!RELATION_NAME.equals(tcurve.relationName()) || (n == 0)) {
      return Double.NaN;
    }

    final int pInd = tcurve.attribute(PRECISION_NAME).index();
    final int rInd = tcurve.attribute(RECALL_NAME).index();
    final double[] pVals = tcurve.attributeToDoubleArray(pInd);
    final double[] rVals = tcurve.attributeToDoubleArray(rInd);

    double area = 0;
    double xlast = rVals[n - 1];

    // start from the first real p/r pair (not the artificial zero point)
    for (int i = n - 2; i >= 0; i--) {
      double recallDelta = rVals[i] - xlast;
      area += (pVals[i] * recallDelta);

      xlast = rVals[i];
    }

    if (area == 0) {
      return Utils.missingValue();
    }
    return area;
  }
  /**
   * Input an instance for filtering. Ordinarily the instance is processed and made available for
   * output immediately. Some filters require all instances be read before producing output.
   *
   * @param instance the input instance
   * @return true if the filtered instance may now be collected with output().
   * @throws IllegalStateException if no input structure has been defined.
   */
  @Override
  public boolean input(Instance instance) {

    if (getInputFormat() == null) {
      throw new IllegalStateException("No input instance format defined");
    }
    if (m_NewBatch) {
      resetQueue();
      m_NewBatch = false;
    }

    if (getOutputFormat().numAttributes() == 0) {
      return false;
    }

    if (m_selectedAttributes.length == 0) {
      push(instance);
    } else {
      double vals[] = new double[getOutputFormat().numAttributes()];
      for (int i = 0; i < instance.numAttributes(); i++) {
        double currentV = instance.value(i);

        if (!m_selectedCols.isInRange(i)) {
          vals[i] = currentV;
        } else {
          if (currentV == Utils.missingValue()) {
            vals[i] = currentV;
          } else {
            String currentS = instance.attribute(i).value((int) currentV);
            String replace =
                m_ignoreCase ? m_renameMap.get(currentS.toLowerCase()) : m_renameMap.get(currentS);
            if (replace == null) {
              vals[i] = currentV;
            } else {
              vals[i] = getOutputFormat().attribute(i).indexOfValue(replace);
            }
          }
        }
      }

      Instance inst = null;
      if (instance instanceof SparseInstance) {
        inst = new SparseInstance(instance.weight(), vals);
      } else {
        inst = new DenseInstance(instance.weight(), vals);
      }
      inst.setDataset(getOutputFormat());
      copyValues(inst, false, instance.dataset(), getOutputFormat());
      inst.setDataset(getOutputFormat());
      push(inst);
    }

    return true;
  }
예제 #9
0
파일: Id3.java 프로젝트: alishakiba/jDenetX
  /**
   * Method for building an Id3 tree.
   *
   * @param data the training data
   * @exception Exception if decision tree can't be built successfully
   */
  private void makeTree(Instances data) throws Exception {

    // Check if no instances have reached this node.
    if (data.numInstances() == 0) {
      m_Attribute = null;
      m_ClassValue = Utils.missingValue();
      m_Distribution = new double[data.numClasses()];
      return;
    }

    // Compute attribute with maximum information gain.
    double[] infoGains = new double[data.numAttributes()];
    Enumeration attEnum = data.enumerateAttributes();
    while (attEnum.hasMoreElements()) {
      Attribute att = (Attribute) attEnum.nextElement();
      infoGains[att.index()] = computeInfoGain(data, att);
    }
    m_Attribute = data.attribute(Utils.maxIndex(infoGains));

    // Make leaf if information gain is zero.
    // Otherwise create successors.
    if (Utils.eq(infoGains[m_Attribute.index()], 0)) {
      m_Attribute = null;
      m_Distribution = new double[data.numClasses()];
      Enumeration instEnum = data.enumerateInstances();
      while (instEnum.hasMoreElements()) {
        Instance inst = (Instance) instEnum.nextElement();
        m_Distribution[(int) inst.classValue()]++;
      }
      Utils.normalize(m_Distribution);
      m_ClassValue = Utils.maxIndex(m_Distribution);
      m_ClassAttribute = data.classAttribute();
    } else {
      Instances[] splitData = splitData(data, m_Attribute);
      m_Successors = new Id3[m_Attribute.numValues()];
      for (int j = 0; j < m_Attribute.numValues(); j++) {
        m_Successors[j] = new Id3();
        m_Successors[j].makeTree(splitData[j]);
      }
    }
  }
예제 #10
0
  /**
   * Return the full data set. If the structure hasn't yet been determined by a call to getStructure
   * then method should do so before processing the rest of the data set.
   *
   * @return the structure of the data set as an empty set of Instances
   * @exception IOException if there is no source or parsing fails
   */
  @Override
  public Instances getDataSet() throws IOException {
    if ((m_sourceFile == null) && (m_sourceReader == null)) {
      throw new IOException("No source has been specified");
    }

    if (m_structure == null) {
      getStructure();
    }

    if (m_st == null) {
      m_st = new StreamTokenizer(m_sourceReader);
      initTokenizer(m_st);
    }

    m_st.ordinaryChar(m_FieldSeparator.charAt(0));

    m_cumulativeStructure = new ArrayList<Hashtable<Object, Integer>>(m_structure.numAttributes());
    for (int i = 0; i < m_structure.numAttributes(); i++) {
      m_cumulativeStructure.add(new Hashtable<Object, Integer>());
    }

    m_cumulativeInstances = new ArrayList<ArrayList<Object>>();
    ArrayList<Object> current;
    while ((current = getInstance(m_st)) != null) {
      m_cumulativeInstances.add(current);
    }

    ArrayList<Attribute> atts = new ArrayList<Attribute>(m_structure.numAttributes());
    for (int i = 0; i < m_structure.numAttributes(); i++) {
      String attname = m_structure.attribute(i).name();
      Hashtable<Object, Integer> tempHash = m_cumulativeStructure.get(i);
      if (tempHash.size() == 0) {
        atts.add(new Attribute(attname));
      } else {
        if (m_StringAttributes.isInRange(i)) {
          atts.add(new Attribute(attname, (ArrayList<String>) null));
        } else {
          ArrayList<String> values = new ArrayList<String>(tempHash.size());
          // add dummy objects in order to make the ArrayList's size == capacity
          for (int z = 0; z < tempHash.size(); z++) {
            values.add("dummy");
          }
          Enumeration e = tempHash.keys();
          while (e.hasMoreElements()) {
            Object ob = e.nextElement();
            //	  if (ob instanceof Double) {
            int index = ((Integer) tempHash.get(ob)).intValue();
            String s = ob.toString();
            if (s.startsWith("'") || s.startsWith("\"")) s = s.substring(1, s.length() - 1);
            values.set(index, new String(s));
            //	  }
          }
          atts.add(new Attribute(attname, values));
        }
      }
    }

    // make the instances
    String relationName;
    if (m_sourceFile != null)
      relationName = (m_sourceFile.getName()).replaceAll("\\.[cC][sS][vV]$", "");
    else relationName = "stream";
    Instances dataSet = new Instances(relationName, atts, m_cumulativeInstances.size());

    for (int i = 0; i < m_cumulativeInstances.size(); i++) {
      current = m_cumulativeInstances.get(i);
      double[] vals = new double[dataSet.numAttributes()];
      for (int j = 0; j < current.size(); j++) {
        Object cval = current.get(j);
        if (cval instanceof String) {
          if (((String) cval).compareTo(m_MissingValue) == 0) {
            vals[j] = Utils.missingValue();
          } else {
            if (dataSet.attribute(j).isString()) {
              vals[j] = dataSet.attribute(j).addStringValue((String) cval);
            } else if (dataSet.attribute(j).isNominal()) {
              // find correct index
              Hashtable<Object, Integer> lookup = m_cumulativeStructure.get(j);
              int index = ((Integer) lookup.get(cval)).intValue();
              vals[j] = index;
            } else {
              throw new IllegalStateException(
                  "Wrong attribute type at position " + (i + 1) + "!!!");
            }
          }
        } else if (dataSet.attribute(j).isNominal()) {
          // find correct index
          Hashtable<Object, Integer> lookup = m_cumulativeStructure.get(j);
          int index = ((Integer) lookup.get(cval)).intValue();
          vals[j] = index;
        } else if (dataSet.attribute(j).isString()) {
          vals[j] = dataSet.attribute(j).addStringValue("" + cval);
        } else {
          vals[j] = ((Double) cval).doubleValue();
        }
      }
      dataSet.add(new DenseInstance(1.0, vals));
    }
    m_structure = new Instances(dataSet, 0);
    setRetrieval(BATCH);
    m_cumulativeStructure = null; // conserve memory

    // close the stream
    m_sourceReader.close();

    return dataSet;
  }
예제 #11
0
  protected Instance makeInstance() throws IOException {

    if (m_current == null) {
      return null;
    }

    double[] vals = new double[m_structure.numAttributes()];
    for (int i = 0; i < m_structure.numAttributes(); i++) {
      Object val = m_current.get(i);
      if (val.toString().equals("?")) {
        vals[i] = Utils.missingValue();
      } else if (m_structure.attribute(i).isString()) {
        vals[i] = 0;
        m_structure.attribute(i).setStringValue(Utils.unquote(val.toString()));
      } else if (m_structure.attribute(i).isDate()) {
        String format = m_structure.attribute(i).getDateFormat();
        SimpleDateFormat sdf = new SimpleDateFormat(format);
        String dateVal = Utils.unquote(val.toString());
        try {
          vals[i] = sdf.parse(dateVal).getTime();
        } catch (ParseException e) {
          throw new IOException(
              "Unable to parse date value "
                  + dateVal
                  + " using date format "
                  + format
                  + " for date attribute "
                  + m_structure.attribute(i)
                  + " (line: "
                  + m_rowCount
                  + ")");
        }
      } else if (m_structure.attribute(i).isNumeric()) {
        try {
          Double v = Double.parseDouble(val.toString());
          vals[i] = v.doubleValue();
        } catch (NumberFormatException ex) {
          throw new IOException(
              "Was expecting a number for attribute "
                  + m_structure.attribute(i).name()
                  + " but read "
                  + val.toString()
                  + " instead. (line: "
                  + m_rowCount
                  + ")");
        }
      } else {
        // nominal
        double index = m_structure.attribute(i).indexOfValue(Utils.unquote(val.toString()));
        if (index < 0) {
          throw new IOException(
              "Read unknown nominal value "
                  + val.toString()
                  + "for attribute "
                  + m_structure.attribute(i).name()
                  + " (line: "
                  + m_rowCount
                  + "). Try increasing the size of the memory buffer"
                  + " (-B option) or explicitly specify legal nominal values with "
                  + "the -L option.");
        }
        vals[i] = index;
      }
    }

    DenseInstance inst = new DenseInstance(1.0, vals);
    inst.setDataset(m_structure);

    return inst;
  }
예제 #12
0
  /**
   * Convert a single instance over. The converted instance is added to the end of the output queue.
   *
   * @param instance the instance to convert
   * @throws Exception if instance cannot be converted
   */
  private void convertInstance(Instance instance) throws Exception {

    Instance inst = null;
    HashMap symbols = new HashMap(5);
    if (instance instanceof SparseInstance) {
      double[] newVals = new double[instance.numAttributes()];
      int[] newIndices = new int[instance.numAttributes()];
      double[] vals = instance.toDoubleArray();
      int ind = 0;
      double value;
      for (int j = 0; j < instance.numAttributes(); j++) {
        if (m_SelectCols.isInRange(j)) {
          if (instance.attribute(j).isNumeric()
              && (!Utils.isMissingValue(vals[j]))
              && (getInputFormat().classIndex() != j)) {
            symbols.put("A", new Double(vals[j]));
            symbols.put("MAX", new Double(m_attStats[j].numericStats.max));
            symbols.put("MIN", new Double(m_attStats[j].numericStats.min));
            symbols.put("MEAN", new Double(m_attStats[j].numericStats.mean));
            symbols.put("SD", new Double(m_attStats[j].numericStats.stdDev));
            symbols.put("COUNT", new Double(m_attStats[j].numericStats.count));
            symbols.put("SUM", new Double(m_attStats[j].numericStats.sum));
            symbols.put("SUMSQUARED", new Double(m_attStats[j].numericStats.sumSq));
            value = eval(symbols);
            if (Double.isNaN(value) || Double.isInfinite(value)) {
              System.err.println("WARNING:Error in evaluating the expression: missing value set");
              value = Utils.missingValue();
            }
            if (value != 0.0) {
              newVals[ind] = value;
              newIndices[ind] = j;
              ind++;
            }
          }
        } else {
          value = vals[j];
          if (value != 0.0) {
            newVals[ind] = value;
            newIndices[ind] = j;
            ind++;
          }
        }
      }
      double[] tempVals = new double[ind];
      int[] tempInd = new int[ind];
      System.arraycopy(newVals, 0, tempVals, 0, ind);
      System.arraycopy(newIndices, 0, tempInd, 0, ind);
      inst = new SparseInstance(instance.weight(), tempVals, tempInd, instance.numAttributes());
    } else {
      double[] vals = instance.toDoubleArray();
      for (int j = 0; j < getInputFormat().numAttributes(); j++) {
        if (m_SelectCols.isInRange(j)) {
          if (instance.attribute(j).isNumeric()
              && (!Utils.isMissingValue(vals[j]))
              && (getInputFormat().classIndex() != j)) {
            symbols.put("A", new Double(vals[j]));
            symbols.put("MAX", new Double(m_attStats[j].numericStats.max));
            symbols.put("MIN", new Double(m_attStats[j].numericStats.min));
            symbols.put("MEAN", new Double(m_attStats[j].numericStats.mean));
            symbols.put("SD", new Double(m_attStats[j].numericStats.stdDev));
            symbols.put("COUNT", new Double(m_attStats[j].numericStats.count));
            symbols.put("SUM", new Double(m_attStats[j].numericStats.sum));
            symbols.put("SUMSQUARED", new Double(m_attStats[j].numericStats.sumSq));
            vals[j] = eval(symbols);
            if (Double.isNaN(vals[j]) || Double.isInfinite(vals[j])) {
              System.err.println("WARNING:Error in Evaluation the Expression: missing value set");
              vals[j] = Utils.missingValue();
            }
          }
        }
      }
      inst = new DenseInstance(instance.weight(), vals);
    }
    inst.setDataset(instance.dataset());
    push(inst);
  }
예제 #13
0
  @Override
  protected void notifyJobOutputListeners() {
    weka.classifiers.Classifier finalClassifier =
        ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassifier();
    Instances modelHeader =
        ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getTrainingHeader();
    String classAtt =
        ((weka.distributed.spark.WekaClassifierSparkJob) m_runningJob).getClassAttribute();
    try {
      weka.distributed.spark.WekaClassifierSparkJob.setClassIndex(classAtt, modelHeader, true);
    } catch (Exception ex) {
      if (m_log != null) {
        m_log.logMessage(statusMessagePrefix() + ex.getMessage());
      }
      ex.printStackTrace();
    }

    if (finalClassifier == null) {
      if (m_log != null) {
        m_log.logMessage(statusMessagePrefix() + "No classifier produced!");
      }
    }

    if (modelHeader == null) {
      if (m_log != null) {
        m_log.logMessage(statusMessagePrefix() + "No training header available for the model!");
      }
    }

    if (finalClassifier != null) {
      if (m_textListeners.size() > 0) {
        String textual = finalClassifier.toString();

        String title = "Spark: ";
        String classifierSpec = finalClassifier.getClass().getName();
        if (finalClassifier instanceof OptionHandler) {
          classifierSpec += " " + Utils.joinOptions(((OptionHandler) finalClassifier).getOptions());
        }
        title += classifierSpec;
        TextEvent te = new TextEvent(this, textual, title);
        for (TextListener t : m_textListeners) {
          t.acceptText(te);
        }
      }

      if (modelHeader != null) {
        // have to add a single bogus instance to the header to trick
        // the SerializedModelSaver into saving it (since it ignores
        // structure only DataSetEvents) :-)
        double[] vals = new double[modelHeader.numAttributes()];
        for (int i = 0; i < vals.length; i++) {
          vals[i] = Utils.missingValue();
        }
        Instance tempI = new DenseInstance(1.0, vals);
        modelHeader.add(tempI);
        DataSetEvent dse = new DataSetEvent(this, modelHeader);
        BatchClassifierEvent be = new BatchClassifierEvent(this, finalClassifier, dse, dse, 1, 1);
        for (BatchClassifierListener b : m_classifierListeners) {
          b.acceptClassifier(be);
        }
      }
    }
  }
예제 #14
0
  /**
   * Gets the results for the supplied train and test datasets. Now performs a deep copy of the
   * classifier before it is built and evaluated (just in case the classifier is not initialized
   * properly in buildClassifier()).
   *
   * @param train the training Instances.
   * @param test the testing Instances.
   * @return the results stored in an array. The objects stored in the array may be Strings,
   *     Doubles, or null (for the missing value).
   * @throws Exception if a problem occurs while getting the results
   */
  public Object[] getResult(Instances train, Instances test) throws Exception {

    if (train.classAttribute().type() != Attribute.NUMERIC) {
      throw new Exception("Class attribute is not numeric!");
    }
    if (m_Template == null) {
      throw new Exception("No classifier has been specified");
    }
    ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean();
    boolean canMeasureCPUTime = thMonitor.isThreadCpuTimeSupported();
    if (canMeasureCPUTime && !thMonitor.isThreadCpuTimeEnabled())
      thMonitor.setThreadCpuTimeEnabled(true);

    int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length : 0;
    Object[] result = new Object[RESULT_SIZE + addm + m_numPluginStatistics];
    long thID = Thread.currentThread().getId();
    long CPUStartTime = -1,
        trainCPUTimeElapsed = -1,
        testCPUTimeElapsed = -1,
        trainTimeStart,
        trainTimeElapsed,
        testTimeStart,
        testTimeElapsed;
    Evaluation eval = new Evaluation(train);
    m_Classifier = AbstractClassifier.makeCopy(m_Template);

    trainTimeStart = System.currentTimeMillis();
    if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID);
    m_Classifier.buildClassifier(train);
    if (canMeasureCPUTime) trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
    testTimeStart = System.currentTimeMillis();
    if (canMeasureCPUTime) CPUStartTime = thMonitor.getThreadUserTime(thID);
    eval.evaluateModel(m_Classifier, test);
    if (canMeasureCPUTime) testCPUTimeElapsed = thMonitor.getThreadUserTime(thID) - CPUStartTime;
    testTimeElapsed = System.currentTimeMillis() - testTimeStart;
    thMonitor = null;

    m_result = eval.toSummaryString();
    // The results stored are all per instance -- can be multiplied by the
    // number of instances to get absolute numbers
    int current = 0;
    result[current++] = new Double(train.numInstances());
    result[current++] = new Double(eval.numInstances());

    result[current++] = new Double(eval.meanAbsoluteError());
    result[current++] = new Double(eval.rootMeanSquaredError());
    result[current++] = new Double(eval.relativeAbsoluteError());
    result[current++] = new Double(eval.rootRelativeSquaredError());
    result[current++] = new Double(eval.correlationCoefficient());

    result[current++] = new Double(eval.SFPriorEntropy());
    result[current++] = new Double(eval.SFSchemeEntropy());
    result[current++] = new Double(eval.SFEntropyGain());
    result[current++] = new Double(eval.SFMeanPriorEntropy());
    result[current++] = new Double(eval.SFMeanSchemeEntropy());
    result[current++] = new Double(eval.SFMeanEntropyGain());

    // Timing stats
    result[current++] = new Double(trainTimeElapsed / 1000.0);
    result[current++] = new Double(testTimeElapsed / 1000.0);
    if (canMeasureCPUTime) {
      result[current++] = new Double((trainCPUTimeElapsed / 1000000.0) / 1000.0);
      result[current++] = new Double((testCPUTimeElapsed / 1000000.0) / 1000.0);
    } else {
      result[current++] = new Double(Utils.missingValue());
      result[current++] = new Double(Utils.missingValue());
    }

    // sizes
    if (m_NoSizeDetermination) {
      result[current++] = -1.0;
      result[current++] = -1.0;
      result[current++] = -1.0;
    } else {
      ByteArrayOutputStream bastream = new ByteArrayOutputStream();
      ObjectOutputStream oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(m_Classifier);
      result[current++] = new Double(bastream.size());
      bastream = new ByteArrayOutputStream();
      oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(train);
      result[current++] = new Double(bastream.size());
      bastream = new ByteArrayOutputStream();
      oostream = new ObjectOutputStream(bastream);
      oostream.writeObject(test);
      result[current++] = new Double(bastream.size());
    }

    // Prediction interval statistics
    result[current++] = new Double(eval.coverageOfTestCasesByPredictedRegions());
    result[current++] = new Double(eval.sizeOfPredictedRegions());

    if (m_Classifier instanceof Summarizable) {
      result[current++] = ((Summarizable) m_Classifier).toSummaryString();
    } else {
      result[current++] = null;
    }

    for (int i = 0; i < addm; i++) {
      if (m_doesProduce[i]) {
        try {
          double dv =
              ((AdditionalMeasureProducer) m_Classifier).getMeasure(m_AdditionalMeasures[i]);
          if (!Utils.isMissingValue(dv)) {
            Double value = new Double(dv);
            result[current++] = value;
          } else {
            result[current++] = null;
          }
        } catch (Exception ex) {
          System.err.println(ex);
        }
      } else {
        result[current++] = null;
      }
    }

    // get the actual metrics from the evaluation object
    List<AbstractEvaluationMetric> metrics = eval.getPluginMetrics();
    if (metrics != null) {
      for (AbstractEvaluationMetric m : metrics) {
        if (m.appliesToNumericClass()) {
          List<String> statNames = m.getStatisticNames();
          for (String s : statNames) {
            result[current++] = new Double(m.getStatistic(s));
          }
        }
      }
    }

    if (current != RESULT_SIZE + addm + m_numPluginStatistics) {
      throw new Error("Results didn't fit RESULT_SIZE");
    }
    return result;
  }
예제 #15
0
 /**
  * Classifies a given instance.
  *
  * @param instance the instance to be classified
  * @return index of the predicted class
  * @throws Exception if an error occurred during the prediction
  */
 public double classifyInstance(Instance instance) throws Exception {
   if (m_GroovyObject != null) return m_GroovyObject.classifyInstance(instance);
   else return Utils.missingValue();
 }
예제 #16
0
  public static Instances retrieveInstances(InstanceQueryAdapter adapter, ResultSet rs)
      throws Exception {
    if (adapter.getDebug()) System.err.println("Getting metadata...");
    ResultSetMetaData md = rs.getMetaData();
    if (adapter.getDebug()) System.err.println("Completed getting metadata...");

    // Determine structure of the instances
    int numAttributes = md.getColumnCount();
    int[] attributeTypes = new int[numAttributes];
    Hashtable[] nominalIndexes = new Hashtable[numAttributes];
    FastVector[] nominalStrings = new FastVector[numAttributes];
    for (int i = 1; i <= numAttributes; i++) {
      /* switch (md.getColumnType(i)) {
      case Types.CHAR:
      case Types.VARCHAR:
      case Types.LONGVARCHAR:
      case Types.BINARY:
      case Types.VARBINARY:
      case Types.LONGVARBINARY:*/

      switch (adapter.translateDBColumnType(md.getColumnTypeName(i))) {
        case STRING:
          // System.err.println("String --> nominal");
          attributeTypes[i - 1] = Attribute.NOMINAL;
          nominalIndexes[i - 1] = new Hashtable();
          nominalStrings[i - 1] = new FastVector();
          break;
        case TEXT:
          // System.err.println("Text --> string");
          attributeTypes[i - 1] = Attribute.STRING;
          nominalIndexes[i - 1] = new Hashtable();
          nominalStrings[i - 1] = new FastVector();
          break;
        case BOOL:
          // System.err.println("boolean --> nominal");
          attributeTypes[i - 1] = Attribute.NOMINAL;
          nominalIndexes[i - 1] = new Hashtable();
          nominalIndexes[i - 1].put("false", new Double(0));
          nominalIndexes[i - 1].put("true", new Double(1));
          nominalStrings[i - 1] = new FastVector();
          nominalStrings[i - 1].addElement("false");
          nominalStrings[i - 1].addElement("true");
          break;
        case DOUBLE:
          // System.err.println("BigDecimal --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case BYTE:
          // System.err.println("byte --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case SHORT:
          // System.err.println("short --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case INTEGER:
          // System.err.println("int --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case LONG:
          // System.err.println("long --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case FLOAT:
          // System.err.println("float --> numeric");
          attributeTypes[i - 1] = Attribute.NUMERIC;
          break;
        case DATE:
          attributeTypes[i - 1] = Attribute.DATE;
          break;
        case TIME:
          attributeTypes[i - 1] = Attribute.DATE;
          break;
        default:
          // System.err.println("Unknown column type");
          attributeTypes[i - 1] = Attribute.STRING;
      }
    }

    // For sqlite
    // cache column names because the last while(rs.next()) { iteration for
    // the tuples below will close the md object:
    Vector<String> columnNames = new Vector<String>();
    for (int i = 0; i < numAttributes; i++) {
      columnNames.add(md.getColumnLabel(i + 1));
    }

    // Step through the tuples
    if (adapter.getDebug()) System.err.println("Creating instances...");
    FastVector instances = new FastVector();
    int rowCount = 0;
    while (rs.next()) {
      if (rowCount % 100 == 0) {
        if (adapter.getDebug()) {
          System.err.print("read " + rowCount + " instances \r");
          System.err.flush();
        }
      }
      double[] vals = new double[numAttributes];
      for (int i = 1; i <= numAttributes; i++) {
        /*switch (md.getColumnType(i)) {
        case Types.CHAR:
        case Types.VARCHAR:
        case Types.LONGVARCHAR:
        case Types.BINARY:
        case Types.VARBINARY:
        case Types.LONGVARBINARY:*/
        switch (adapter.translateDBColumnType(md.getColumnTypeName(i))) {
          case STRING:
            String str = rs.getString(i);

            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              Double index = (Double) nominalIndexes[i - 1].get(str);
              if (index == null) {
                index = new Double(nominalStrings[i - 1].size());
                nominalIndexes[i - 1].put(str, index);
                nominalStrings[i - 1].addElement(str);
              }
              vals[i - 1] = index.doubleValue();
            }
            break;
          case TEXT:
            String txt = rs.getString(i);

            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              Double index = (Double) nominalIndexes[i - 1].get(txt);
              if (index == null) {
                index = new Double(nominalStrings[i - 1].size());
                nominalIndexes[i - 1].put(txt, index);
                nominalStrings[i - 1].addElement(txt);
              }
              vals[i - 1] = index.doubleValue();
            }
            break;
          case BOOL:
            boolean boo = rs.getBoolean(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              vals[i - 1] = (boo ? 1.0 : 0.0);
            }
            break;
          case DOUBLE:
            //      BigDecimal bd = rs.getBigDecimal(i, 4);
            double dd = rs.getDouble(i);
            // Use the column precision instead of 4?
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              //      newInst.setValue(i - 1, bd.doubleValue());
              vals[i - 1] = dd;
            }
            break;
          case BYTE:
            byte by = rs.getByte(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              vals[i - 1] = (double) by;
            }
            break;
          case SHORT:
            short sh = rs.getShort(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              vals[i - 1] = (double) sh;
            }
            break;
          case INTEGER:
            int in = rs.getInt(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              vals[i - 1] = (double) in;
            }
            break;
          case LONG:
            long lo = rs.getLong(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              vals[i - 1] = (double) lo;
            }
            break;
          case FLOAT:
            float fl = rs.getFloat(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              vals[i - 1] = (double) fl;
            }
            break;
          case DATE:
            Date date = rs.getDate(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              // TODO: Do a value check here.
              vals[i - 1] = (double) date.getTime();
            }
            break;
          case TIME:
            Time time = rs.getTime(i);
            if (rs.wasNull()) {
              vals[i - 1] = Utils.missingValue();
            } else {
              // TODO: Do a value check here.
              vals[i - 1] = (double) time.getTime();
            }
            break;
          default:
            vals[i - 1] = Utils.missingValue();
        }
      }
      Instance newInst;
      if (adapter.getSparseData()) {
        newInst = new SparseInstance(1.0, vals);
      } else {
        newInst = new DenseInstance(1.0, vals);
      }
      instances.addElement(newInst);
      rowCount++;
    }
    // disconnectFromDatabase();  (perhaps other queries might be made)

    // Create the header and add the instances to the dataset
    if (adapter.getDebug()) System.err.println("Creating header...");
    FastVector attribInfo = new FastVector();
    for (int i = 0; i < numAttributes; i++) {
      /* Fix for databases that uppercase column names */
      // String attribName = attributeCaseFix(md.getColumnName(i + 1));
      String attribName = adapter.attributeCaseFix(columnNames.get(i));
      switch (attributeTypes[i]) {
        case Attribute.NOMINAL:
          attribInfo.addElement(new Attribute(attribName, nominalStrings[i]));
          break;
        case Attribute.NUMERIC:
          attribInfo.addElement(new Attribute(attribName));
          break;
        case Attribute.STRING:
          Attribute att = new Attribute(attribName, (FastVector) null);
          attribInfo.addElement(att);
          for (int n = 0; n < nominalStrings[i].size(); n++) {
            att.addStringValue((String) nominalStrings[i].elementAt(n));
          }
          break;
        case Attribute.DATE:
          attribInfo.addElement(new Attribute(attribName, (String) null));
          break;
        default:
          throw new Exception("Unknown attribute type");
      }
    }
    Instances result = new Instances("QueryResult", attribInfo, instances.size());
    for (int i = 0; i < instances.size(); i++) {
      result.add((Instance) instances.elementAt(i));
    }

    return result;
  }