Ejemplo n.º 1
0
  /**
   * Normalize the weights for the next iteration.
   *
   * @param training the training instances
   * @throws Exception if something goes wrong
   */
  protected void normalizeWeights(Instances training, double oldSumOfWeights) throws Exception {

    // Renormalize weights
    double newSumOfWeights = training.sumOfWeights();
    for (Instance instance : training) {
      instance.setWeight(instance.weight() * oldSumOfWeights / newSumOfWeights);
    }
  }
Ejemplo n.º 2
0
  @Override
  public void trainOnInstanceImpl(Instance inst) {
    double lambda_d = 1.0;
    for (int i = 0; i < this.ensemble.length; i++) {
      double k =
          this.pureBoostOption.isSet()
              ? lambda_d
              : MiscUtils.poisson(lambda_d, this.classifierRandom);
      if (k > 0.0) {
        Instance weightedInst = (Instance) inst.copy();
        weightedInst.setWeight(inst.weight() * k);
        this.ensemble[i].trainOnInstance(weightedInst);
      }

      if (this.ensemble[i].correctlyClassifies(inst)) {
        this.scms[i] += lambda_d;
        lambda_d *= this.trainingWeightSeenByModel / (2 * this.scms[i]);
      } else {
        this.swms[i] += lambda_d;
        lambda_d *= this.trainingWeightSeenByModel / (2 * this.swms[i]);
      }
    }
  }
Ejemplo n.º 3
0
  /**
   * Sets the weights for the next iteration.
   *
   * @param training the training instances
   * @throws Exception if something goes wrong
   */
  protected void setWeights(Instances training, int iteration) throws Exception {

    for (Instance instance : training) {
      double reweight = 1;
      double prob = 1, shrinkage = m_Shrinkage;

      if (iteration == -1) {
        prob = m_ZeroR.distributionForInstance(instance)[0];
        shrinkage = 1.0;
      } else {
        prob = m_Classifiers[iteration].distributionForInstance(instance)[0];

        // Make sure that probabilities are never 0 or 1 using ad-hoc smoothing
        prob = (m_SumOfWeights * prob + 1) / (m_SumOfWeights + 2);
      }

      if (instance.classValue() == 1) {
        reweight = shrinkage * 0.5 * (Math.log(prob) - Math.log(1 - prob));
      } else {
        reweight = shrinkage * 0.5 * (Math.log(1 - prob) - Math.log(prob));
      }
      instance.setWeight(instance.weight() * Math.exp(reweight));
    }
  }
  /**
   * Accepts and processes a classifier encapsulated in an incremental classifier event
   *
   * @param ce an <code>IncrementalClassifierEvent</code> value
   */
  @Override
  public void acceptClassifier(final IncrementalClassifierEvent ce) {
    try {
      if (ce.getStatus() == IncrementalClassifierEvent.NEW_BATCH) {
        m_throughput = new StreamThroughput(statusMessagePrefix());
        m_throughput.setSamplePeriod(m_statusFrequency);

        // m_eval = new Evaluation(ce.getCurrentInstance().dataset());
        m_eval = new Evaluation(ce.getStructure());
        m_eval.useNoPriors();

        m_dataLegend = new Vector();
        m_reset = true;
        m_dataPoint = new double[0];
        Instances inst = ce.getStructure();
        System.err.println("NEW BATCH");
        m_instanceCount = 0;

        if (m_windowSize > 0) {
          m_window = new LinkedList<Instance>();
          m_windowEval = new Evaluation(ce.getStructure());
          m_windowEval.useNoPriors();
          m_windowedPreds = new LinkedList<double[]>();

          if (m_logger != null) {
            m_logger.logMessage(
                statusMessagePrefix()
                    + "[IncrementalClassifierEvaluator] Chart output using windowed "
                    + "evaluation over "
                    + m_windowSize
                    + " instances");
          }
        }

        /*
         * if (m_logger != null) { m_logger.statusMessage(statusMessagePrefix()
         * + "IncrementalClassifierEvaluator: started processing...");
         * m_logger.logMessage(statusMessagePrefix() +
         * " [IncrementalClassifierEvaluator]" + statusMessagePrefix() +
         * " started processing..."); }
         */
      } else {
        Instance inst = ce.getCurrentInstance();
        if (inst != null) {
          m_throughput.updateStart();
          m_instanceCount++;
          // if (inst.attribute(inst.classIndex()).isNominal()) {
          double[] dist = ce.getClassifier().distributionForInstance(inst);
          double pred = 0;
          if (!inst.isMissing(inst.classIndex())) {
            if (m_outputInfoRetrievalStats) {
              // store predictions so AUC etc can be output.
              m_eval.evaluateModelOnceAndRecordPrediction(dist, inst);
            } else {
              m_eval.evaluateModelOnce(dist, inst);
            }

            if (m_windowSize > 0) {

              m_windowEval.evaluateModelOnce(dist, inst);
              m_window.addFirst(inst);
              m_windowedPreds.addFirst(dist);

              if (m_instanceCount > m_windowSize) {
                // "forget" the oldest prediction
                Instance oldest = m_window.removeLast();

                double[] oldDist = m_windowedPreds.removeLast();
                oldest.setWeight(-oldest.weight());
                m_windowEval.evaluateModelOnce(oldDist, oldest);
                oldest.setWeight(-oldest.weight());
              }
            }
          } else {
            pred = ce.getClassifier().classifyInstance(inst);
          }
          if (inst.classIndex() >= 0) {
            // need to check that the class is not missing
            if (inst.attribute(inst.classIndex()).isNominal()) {
              if (!inst.isMissing(inst.classIndex())) {
                if (m_dataPoint.length < 2) {
                  m_dataPoint = new double[3];
                  m_dataLegend.addElement("Accuracy");
                  m_dataLegend.addElement("RMSE (prob)");
                  m_dataLegend.addElement("Kappa");
                }
                // int classV = (int) inst.value(inst.classIndex());

                if (m_windowSize > 0) {
                  m_dataPoint[1] = m_windowEval.rootMeanSquaredError();
                  m_dataPoint[2] = m_windowEval.kappa();
                } else {
                  m_dataPoint[1] = m_eval.rootMeanSquaredError();
                  m_dataPoint[2] = m_eval.kappa();
                }
                // int maxO = Utils.maxIndex(dist);
                // if (maxO == classV) {
                // dist[classV] = -1;
                // maxO = Utils.maxIndex(dist);
                // }
                // m_dataPoint[1] -= dist[maxO];
              } else {
                if (m_dataPoint.length < 1) {
                  m_dataPoint = new double[1];
                  m_dataLegend.addElement("Confidence");
                }
              }
              double primaryMeasure = 0;
              if (!inst.isMissing(inst.classIndex())) {
                if (m_windowSize > 0) {
                  primaryMeasure = 1.0 - m_windowEval.errorRate();
                } else {
                  primaryMeasure = 1.0 - m_eval.errorRate();
                }
              } else {
                // record confidence as the primary measure
                // (another possibility would be entropy of
                // the distribution, or perhaps average
                // confidence)
                primaryMeasure = dist[Utils.maxIndex(dist)];
              }
              // double [] dataPoint = new double[1];
              m_dataPoint[0] = primaryMeasure;
              // double min = 0; double max = 100;
              /*
               * ChartEvent e = new
               * ChartEvent(IncrementalClassifierEvaluator.this, m_dataLegend,
               * min, max, dataPoint);
               */

              m_ce.setLegendText(m_dataLegend);
              m_ce.setMin(0);
              m_ce.setMax(1);
              m_ce.setDataPoint(m_dataPoint);
              m_ce.setReset(m_reset);
              m_reset = false;
            } else {
              // numeric class
              if (m_dataPoint.length < 1) {
                m_dataPoint = new double[1];
                if (inst.isMissing(inst.classIndex())) {
                  m_dataLegend.addElement("Prediction");
                } else {
                  m_dataLegend.addElement("RMSE");
                }
              }
              if (!inst.isMissing(inst.classIndex())) {
                double update;
                if (!inst.isMissing(inst.classIndex())) {
                  if (m_windowSize > 0) {
                    update = m_windowEval.rootMeanSquaredError();
                  } else {
                    update = m_eval.rootMeanSquaredError();
                  }
                } else {
                  update = pred;
                }
                m_dataPoint[0] = update;
                if (update > m_max) {
                  m_max = update;
                }
                if (update < m_min) {
                  m_min = update;
                }
              }

              m_ce.setLegendText(m_dataLegend);
              m_ce.setMin((inst.isMissing(inst.classIndex()) ? m_min : 0));
              m_ce.setMax(m_max);
              m_ce.setDataPoint(m_dataPoint);
              m_ce.setReset(m_reset);
              m_reset = false;
            }
            notifyChartListeners(m_ce);
          }
          m_throughput.updateEnd(m_logger);
        }

        if (ce.getStatus() == IncrementalClassifierEvent.BATCH_FINISHED || inst == null) {
          if (m_logger != null) {
            m_logger.logMessage(
                "[IncrementalClassifierEvaluator]"
                    + statusMessagePrefix()
                    + " Finished processing.");
          }
          m_throughput.finished(m_logger);

          // save memory if using windowed evaluation for charting
          m_windowEval = null;
          m_window = null;
          m_windowedPreds = null;

          if (m_textListeners.size() > 0) {
            String textTitle = ce.getClassifier().getClass().getName();
            textTitle = textTitle.substring(textTitle.lastIndexOf('.') + 1, textTitle.length());
            String results =
                "=== Performance information ===\n\n"
                    + "Scheme:   "
                    + textTitle
                    + "\n"
                    + "Relation: "
                    + m_eval.getHeader().relationName()
                    + "\n\n"
                    + m_eval.toSummaryString();
            if (m_eval.getHeader().classIndex() >= 0
                && m_eval.getHeader().classAttribute().isNominal()
                && (m_outputInfoRetrievalStats)) {
              results += "\n" + m_eval.toClassDetailsString();
            }

            if (m_eval.getHeader().classIndex() >= 0
                && m_eval.getHeader().classAttribute().isNominal()) {
              results += "\n" + m_eval.toMatrixString();
            }
            textTitle = "Results: " + textTitle;
            TextEvent te = new TextEvent(this, results, textTitle);
            notifyTextListeners(te);
          }
        }
      }
    } catch (Exception ex) {
      if (m_logger != null) {
        m_logger.logMessage(
            "[IncrementalClassifierEvaluator]"
                + statusMessagePrefix()
                + " Error processing prediction "
                + ex.getMessage());
        m_logger.statusMessage(
            statusMessagePrefix() + "ERROR: problem processing prediction (see log for details)");
      }
      ex.printStackTrace();
      stop();
    }
  }
Ejemplo n.º 5
0
  /**
   * Calculates the class membership probabilities for the given test instance.
   *
   * @param instance the instance to be classified
   * @return preedicted class probability distribution
   * @throws Exception if distribution can't be computed successfully
   */
  public double[] distributionForInstance(Instance instance) throws Exception {

    // default model?
    if (m_ZeroR != null) {
      return m_ZeroR.distributionForInstance(instance);
    }

    if (m_Train.numInstances() == 0) {
      throw new Exception("No training instances!");
    }

    m_NNSearch.addInstanceInfo(instance);

    int k = m_Train.numInstances();
    if ((!m_UseAllK && (m_kNN < k)) /*&&
       !(m_WeightKernel==INVERSE ||
         m_WeightKernel==GAUSS)*/) {
      k = m_kNN;
    }

    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, k);
    double distances[] = m_NNSearch.getDistances();

    if (m_Debug) {
      System.out.println("Test Instance: " + instance);
      System.out.println(
          "For "
              + k
              + " kept "
              + neighbours.numInstances()
              + " out of "
              + m_Train.numInstances()
              + " instances.");
    }

    // IF LinearNN has skipped so much that <k neighbours are remaining.
    if (k > distances.length) k = distances.length;

    if (m_Debug) {
      System.out.println("Instance Distances");
      for (int i = 0; i < distances.length; i++) {
        System.out.println("" + distances[i]);
      }
    }

    // Determine the bandwidth
    double bandwidth = distances[k - 1];

    // Check for bandwidth zero
    if (bandwidth <= 0) {
      // if the kth distance is zero than give all instances the same weight
      for (int i = 0; i < distances.length; i++) distances[i] = 1;
    } else {
      // Rescale the distances by the bandwidth
      for (int i = 0; i < distances.length; i++) distances[i] = distances[i] / bandwidth;
    }

    // Pass the distances through a weighting kernel
    for (int i = 0; i < distances.length; i++) {
      switch (m_WeightKernel) {
        case LINEAR:
          distances[i] = 1.0001 - distances[i];
          break;
        case EPANECHNIKOV:
          distances[i] = 3 / 4D * (1.0001 - distances[i] * distances[i]);
          break;
        case TRICUBE:
          distances[i] = Math.pow((1.0001 - Math.pow(distances[i], 3)), 3);
          break;
        case CONSTANT:
          // System.err.println("using constant kernel");
          distances[i] = 1;
          break;
        case INVERSE:
          distances[i] = 1.0 / (1.0 + distances[i]);
          break;
        case GAUSS:
          distances[i] = Math.exp(-distances[i] * distances[i]);
          break;
      }
    }

    if (m_Debug) {
      System.out.println("Instance Weights");
      for (int i = 0; i < distances.length; i++) {
        System.out.println("" + distances[i]);
      }
    }

    // Set the weights on the training data
    double sumOfWeights = 0, newSumOfWeights = 0;
    for (int i = 0; i < distances.length; i++) {
      double weight = distances[i];
      Instance inst = (Instance) neighbours.instance(i);
      sumOfWeights += inst.weight();
      newSumOfWeights += inst.weight() * weight;
      inst.setWeight(inst.weight() * weight);
      // weightedTrain.add(newInst);
    }

    // Rescale weights
    for (int i = 0; i < neighbours.numInstances(); i++) {
      Instance inst = neighbours.instance(i);
      inst.setWeight(inst.weight() * sumOfWeights / newSumOfWeights);
    }

    // Create a weighted classifier
    m_Classifier.buildClassifier(neighbours);

    if (m_Debug) {
      System.out.println("Classifying test instance: " + instance);
      System.out.println("Built base classifier:\n" + m_Classifier.toString());
    }

    // Return the classifier's predictions
    return m_Classifier.distributionForInstance(instance);
  }