예제 #1
0
 private double normalizeOutputValue(int outputToLearnIndex, double value) {
   double meanY =
       this.literalStatistics[outputToLearnIndex].getValue(1)
           / this.literalStatistics[outputToLearnIndex].getValue(0);
   double sdY =
       Utils.computeSD(
           this.literalStatistics[outputToLearnIndex].getValue(2),
           this.literalStatistics[outputToLearnIndex].getValue(1),
           this.literalStatistics[outputToLearnIndex].getValue(0));
   double normalizedY = 0.0;
   if (sdY > 0.0000001) {
     normalizedY = (value - meanY) / (sdY);
   }
   return normalizedY;
 }
예제 #2
0
  @Override
  public boolean tryToExpand(double splitConfidence, double tieThreshold) {
    boolean shouldSplit = false;
    // find the best split per attribute and rank the results
    AttributeExpansionSuggestion[] bestSplitSuggestions =
        this.getBestSplitSuggestions(splitCriterion);

    double sumMerit = 0;
    meritPerInput = new double[attributesMask.length];
    for (int i = 0; i < bestSplitSuggestions.length; i++) {
      double merit = bestSplitSuggestions[i].getMerit();
      if (merit > 0) {
        meritPerInput[bestSplitSuggestions[i].predicate.getAttributeIndex()] = merit;
        sumMerit += merit;
      }
    }

    // if merit==0 it means the split have not enough examples in the smallest branch
    if (sumMerit == 0)
      meritPerInput =
          null; // this indicates that no merit should be considered (e.g. for feature ranking)

    Arrays.sort(bestSplitSuggestions);

    // disable attributes that are not relevant
    int[] oldInputs = inputsToLearn.clone();
    inputsToLearn = inputSelector.getNextInputIndices(bestSplitSuggestions); //
    Arrays.sort(this.inputsToLearn);
    for (int i = 0; i < oldInputs.length; i++) {
      if (attributesMask[oldInputs[i]]) {
        if (Arrays.binarySearch(inputsToLearn, oldInputs[i]) < 0) {
          this.attributeObservers.set(oldInputs[i], null);
        }
      }
    }

    // If only one split was returned, use it
    if (bestSplitSuggestions.length < 2) {
      // shouldSplit = ((bestSplitSuggestions.length > 0) && (bestSplitSuggestions[0].merit > 0));
      bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
      shouldSplit = true;
    } // Otherwise, consider which of the splits proposed may be worth trying
    else {
      double hoeffdingBound =
          computeHoeffdingBound(
              splitCriterion.getRangeOfMerit(this.literalStatistics), splitConfidence, weightSeen);
      // debug("Hoeffding bound " + hoeffdingBound, 4);
      // Determine the top two ranked splitting suggestions
      bestSuggestion = bestSplitSuggestions[bestSplitSuggestions.length - 1];
      AttributeExpansionSuggestion secondBestSuggestion =
          bestSplitSuggestions[bestSplitSuggestions.length - 2];
      if ((((bestSuggestion.merit - secondBestSuggestion.merit)) > hoeffdingBound)
          || (hoeffdingBound < tieThreshold)) {
        // if ((((secondBestSuggestion.merit/bestSuggestion.merit) + hoeffdingBound) < 1) ||
        // (hoeffdingBound < tieThreshold)) {
        // debug("Expanded ", 5);
        shouldSplit = true;
        // System.out.println(bestSuggestion.merit);
      }
    }
    if (shouldSplit) {
      // check which branch is better and update bestSuggestion (in amrules the splits are binary )
      DoubleVector[][] resultingStatistics = bestSuggestion.getResultingNodeStatistics();
      // if not or higher is better, change predicate (negate condition)
      double[] branchMerits = splitCriterion.getBranchesSplitMerits(resultingStatistics);
      DoubleVector[] newLiteralStatistics;
      if (branchMerits[1] > branchMerits[0]) {
        bestSuggestion.getPredicate().negateCondition();
        newLiteralStatistics = getBranchStatistics(resultingStatistics, 1);
      } else {
        newLiteralStatistics = getBranchStatistics(resultingStatistics, 0);
      }
      //
      int[] newOutputs =
          outputSelector.getNextOutputIndices(
              newLiteralStatistics, literalStatistics, outputsToLearn);
      Arrays.sort(newOutputs); // Must be ordered for latter correspondence algorithm to work

      // set other branch (only used if default rule expands)
      otherBranchLearningLiteral = new LearningLiteralRegression();
      otherBranchLearningLiteral.instanceHeader = instanceHeader;
      otherBranchLearningLiteral.learner = (MultiLabelLearner) learner.copy();
      otherBranchLearningLiteral.instanceTransformer =
          (InstanceTransformer) this.instanceTransformer;

      // keep a rule learning to the complement set of newOutputs

      // Set expanding branch
      // if is AMRulesFunction and the  number of output attributes changes, start learning a new
      // predictor
      // should we do the same for input attributes (attributesMask)?. It would have impact in
      // RandomAMRules
      if (learner instanceof AMRulesFunction) { // Reset learning
        if (newOutputs.length != outputsToLearn.length) {
          // other outputs
          int[] otherOutputs = Utils.complementSet(outputsToLearn, newOutputs);
          int[] indices;
          if (otherOutputs.length > 0) {
            otherOutputsLearningLiteral = new LearningLiteralRegression(otherOutputs);
            MultiLabelLearner otherOutputsLearner = (MultiLabelLearner) learner.copy();
            indices = Utils.getIndexCorrespondence(outputsToLearn, otherOutputs);
            ((AMRulesFunction) otherOutputsLearner).selectOutputsToLearn(indices);
            ((AMRulesFunction) otherOutputsLearner).resetWithMemory();
            otherOutputsLearningLiteral.learner = otherOutputsLearner;
            otherOutputsLearningLiteral.instanceHeader = instanceHeader;
            otherOutputsLearningLiteral.instanceTransformer =
                new InstanceOutputAttributesSelector(instanceHeader, otherOutputs);
          }
          // expanded
          indices = Utils.getIndexCorrespondence(outputsToLearn, newOutputs);
          ((AMRulesFunction) learner).selectOutputsToLearn(indices);
        }

        ((AMRulesFunction) learner).resetWithMemory();
      }
      // just reset learning
      else {
        // other outputs //TODO JD: Test for general learner (other than AMRules functions
        if (newOutputs.length != outputsToLearn.length) {
          int[] otherOutputs = Utils.complementSet(outputsToLearn, newOutputs);
          if (otherOutputs.length > 0) {
            otherOutputsLearningLiteral = new LearningLiteralRegression();
            MultiLabelLearner otherOutputsLearner = (MultiLabelLearner) learner.copy();
            otherOutputsLearner.resetLearning();
            otherOutputsLearningLiteral.learner = otherOutputsLearner;
            otherOutputsLearningLiteral.instanceHeader = instanceHeader;
            otherOutputsLearningLiteral.instanceTransformer =
                new InstanceOutputAttributesSelector(instanceHeader, otherOutputs);
          }
        }
        // expanded
        learner.resetLearning();
      }
      expandedLearningLiteral = new LearningLiteralRegression(newOutputs);
      expandedLearningLiteral.learner = (MultiLabelLearner) this.learner.copy();
      expandedLearningLiteral.instanceHeader = instanceHeader;
      expandedLearningLiteral.instanceTransformer =
          new InstanceOutputAttributesSelector(instanceHeader, newOutputs);
    }
    return shouldSplit;
  }