Ejemplo n.º 1
0
  private AttributeExpansionSuggestion[] getBestSplitSuggestions(
      MultiLabelSplitCriterion criterion) {
    List<AttributeExpansionSuggestion> bestSuggestions =
        new LinkedList<AttributeExpansionSuggestion>();
    for (int i = 0; i < this.inputsToLearn.length; i++) {
      if (attributesMask[
          inputsToLearn[i]]) { // Should always be true (check trainOnInstance(). Remove?
        AttributeStatisticsObserver obs = this.attributeObservers.get(inputsToLearn[i]);
        if (obs != null) {
          AttributeExpansionSuggestion bestSuggestion =
              obs.getBestEvaluatedSplitSuggestion(criterion, literalStatistics, inputsToLearn[i]);

          if (bestSuggestion == null) {
            // ALL attributes must have a best suggestion. Adding dummy suggestion with minimal
            // merit.
            bestSuggestion =
                new AttributeExpansionSuggestion(
                    new NumericRulePredicate(inputsToLearn[i], 0, true), null, -Double.MAX_VALUE);
          }
          bestSuggestions.add(bestSuggestion);
        }
      }
    }
    return bestSuggestions.toArray(new AttributeExpansionSuggestion[bestSuggestions.size()]);
  }
Ejemplo n.º 2
0
  @Override
  public void trainOnInstance(MultiLabelInstance instance) {
    int numInputs = 0;
    if (attributesMask == null) numInputs = initializeAttibutesMask(instance);

    // learn for all output attributes if not specified at construction time
    int numOutputs = instance.numberOutputTargets();

    if (!hasStarted) {
      if (this.learner.isRandomizable()) this.learner.setRandomSeed(this.randomGenerator.nextInt());
      if (outputsToLearn == null) {
        outputsToLearn = new int[numOutputs];
        for (int i = 0; i < numOutputs; i++) {
          outputsToLearn[i] = i;
        }
      }
      if (inputsToLearn == null) {
        inputsToLearn = new int[numInputs];
        int ct = 0;
        for (int i = 0; i < instance.numInputAttributes(); i++) { // TODO JD: check with mask?
          if (attributesMask[i]) {
            inputsToLearn[ct] = i;
            ct++;
          }
        }
      }

      literalStatistics = new DoubleVector[outputsToLearn.length];
      varianceShift = new double[outputsToLearn.length];
      for (int i = 0; i < outputsToLearn.length; i++) {
        literalStatistics[i] = new DoubleVector(new double[5]);
        varianceShift[i] = instance.valueOutputAttribute(outputsToLearn[i]);
      }
      instanceHeader = (InstancesHeader) instance.dataset();
      hasStarted = true;
    }
    double weight = instance.weight();
    DoubleVector[] exampleStatistics = new DoubleVector[outputsToLearn.length];
    for (int i = 0; i < outputsToLearn.length; i++) {
      double target = instance.valueOutputAttribute(outputsToLearn[i]);
      double sum = weight * target;
      double squaredSum = weight * target * target;
      double sumShifted = weight * target - varianceShift[i];
      double squaredSumShifted = weight * (target - varianceShift[i]) * (target - varianceShift[i]);
      exampleStatistics[i] =
          new DoubleVector(new double[] {weight, sum, squaredSum, sumShifted, squaredSumShifted});
      literalStatistics[i].addValues(exampleStatistics[i].getArrayRef());
    }

    if (this.attributeObservers == null)
      this.attributeObservers = new AutoExpandVector<AttributeStatisticsObserver>();
    for (int i = 0; i < inputsToLearn.length; i++) {
      if (attributesMask[inputsToLearn[i]]) { // this is checked above. Remove?
        AttributeStatisticsObserver obs = this.attributeObservers.get(inputsToLearn[i]);
        if (obs == null) {
          if (instance.attribute(inputsToLearn[i]).isNumeric()) {
            obs = ((NumericStatisticsObserver) numericStatisticsObserver.copy());
          } else if (instance
              .attribute(inputsToLearn[i])
              .isNominal()) { // just to make sure its nominal (in the future there may be ordinal?
            obs = ((NominalStatisticsObserver) nominalStatisticsObserver.copy());
          }
          this.attributeObservers.set(inputsToLearn[i], obs);
        }
        obs.observeAttribute(instance.valueInputAttribute(inputsToLearn[i]), exampleStatistics);
      }
    }

    // Transform instance for learning
    Instance transformedInstance = instanceTransformer.sourceInstanceToTarget(instance);
    Prediction prediction = null;
    Prediction targetPrediction = learner.getPredictionForInstance(transformedInstance);
    if (targetPrediction != null)
      prediction = instanceTransformer.targetPredictionToSource(targetPrediction);

    if (prediction != null) errorMeasurer.addPrediction(prediction, instance);

    learner.trainOnInstance(transformedInstance);

    weightSeen += instance.weight();
  }