@Override
  public AttributeWeights calculateWeights(ExampleSet exampleSet) throws OperatorException {
    Attributes attributes = exampleSet.getAttributes();
    Attribute labelAttribute = attributes.getLabel();
    boolean useSquaredCorrelation = getParameterAsBoolean(PARAMETER_SQUARED_CORRELATION);

    AttributeWeights weights = new AttributeWeights(exampleSet);
    getProgress().setTotal(attributes.size());
    int progressCounter = 0;
    int exampleSetSize = exampleSet.size();
    int exampleCounter = 0;
    for (Attribute attribute : attributes) {
      double correlation =
          MathFunctions.correlation(exampleSet, labelAttribute, attribute, useSquaredCorrelation);
      weights.setWeight(attribute.getName(), Math.abs(correlation));
      progressCounter++;
      exampleCounter += exampleSetSize;
      if (exampleCounter > PROGRESS_UPDATE_STEPS) {
        exampleCounter = 0;
        getProgress().setCompleted(progressCounter);
      }
    }

    return weights;
  }
Exemplo n.º 2
0
  @Override
  public AttributeWeights getWeightsOfComponent(int component) throws OperatorException {
    if (component < 1) {
      component = 1;
    }
    if (component > attributeNames.length) {
      logWarning("Creating weights of component " + attributeNames.length + "!");
      component = attributeNames.length;
    }
    AttributeWeights weights = new AttributeWeights();

    double[] eigenvector = eigenVectors.get(component - 1).getEigenvector();
    for (int i = 0; i < attributeNames.length; i++) {
      weights.setWeight(attributeNames[i], eigenvector[i]);
    }

    return weights;
  }
Exemplo n.º 3
0
  @Override
  public AttributeWeights getWeightsOfComponent(int component) throws OperatorException {
    if (component < 1) {
      component = 1;
    }
    if (component > attributeNames.length) {
      logWarning("Creating weights of component " + attributeNames.length + "!");
      component = attributeNames.length;
    }
    AttributeWeights weights = new AttributeWeights();

    double[] singularVector = vMatrix.getArray()[component];
    for (int i = 0; i < attributeNames.length; i++) {
      weights.setWeight(attributeNames[i], singularVector[i]);
    }

    return weights;
  }
 private void load() {
   File file = SwingTools.chooseFile(null, null, true, "wgt", "attribute weight file");
   try {
     AttributeWeights fileWeights = AttributeWeights.load(file);
     attributeTableModel.mergeWeights(fileWeights);
   } catch (IOException e) {
     SwingTools.showSimpleErrorMessage("cannot_load_attr_weights_from_file", e, file.getName());
   }
   update();
 }
  @Override
  public DataTable getDataTable(Object renderable, IOContainer ioContainer, boolean isRendering) {
    AttributeWeights weights = (AttributeWeights) renderable;

    if (!isRendering) {
      // use parameters only during rendering
      AttributeWeights clonedWeights = (AttributeWeights) weights.clone();
      try {
        Pattern pattern = Pattern.compile(getParameterAsString(PARAMETER_ATTRIBUTE_SELECTION));
        for (String attributeName : weights.getAttributeNames()) {
          if (!pattern.matcher(attributeName).matches()) {
            clonedWeights.removeAttributeWeight(attributeName);
          }
        }
      } catch (UndefinedParameterError e) {
      }

      return clonedWeights.createDataTable();
    } else return weights.createDataTable();
  }
  @Override
  public void doWork() throws OperatorException {
    ExampleSet exampleSetOriginal = exampleSetInput.getData(ExampleSet.class);
    ExampleSet exampleSet = (ExampleSet) exampleSetOriginal.clone();
    int numberOfAttributes = exampleSet.getAttributes().size();
    Attributes attributes = exampleSet.getAttributes();

    int maxNumberOfAttributes =
        Math.min(getParameterAsInt(PARAMETER_MAX_ATTRIBUTES), numberOfAttributes - 1);
    int maxNumberOfFails = getParameterAsInt(PARAMETER_ALLOWED_CONSECUTIVE_FAILS);
    int behavior = getParameterAsInt(PARAMETER_STOPPING_BEHAVIOR);

    boolean useRelativeIncrease =
        (behavior == WITH_DECREASE_EXCEEDS)
            ? getParameterAsBoolean(PARAMETER_USE_RELATIVE_DECREASE)
            : false;
    double maximalDecrease = 0;
    if (useRelativeIncrease)
      maximalDecrease =
          useRelativeIncrease
              ? getParameterAsDouble(PARAMETER_MAX_RELATIVE_DECREASE)
              : getParameterAsDouble(PARAMETER_MAX_ABSOLUT_DECREASE);
    double alpha =
        (behavior == WITH_DECREASE_SIGNIFICANT) ? getParameterAsDouble(PARAMETER_ALPHA) : 0d;

    // remembering attributes and removing all from example set
    Attribute[] attributeArray = new Attribute[numberOfAttributes];
    int i = 0;
    Iterator<Attribute> iterator = attributes.iterator();
    while (iterator.hasNext()) {
      Attribute attribute = iterator.next();
      attributeArray[i] = attribute;
      i++;
    }

    boolean[] selected = new boolean[numberOfAttributes];
    Arrays.fill(selected, true);

    boolean earlyAbort = false;
    List<Integer> speculativeList = new ArrayList<Integer>(maxNumberOfFails);
    int numberOfFails = maxNumberOfFails;
    currentNumberOfFeatures = numberOfAttributes;
    currentAttributes = attributes;
    PerformanceVector lastPerformance = getPerformance(exampleSet);
    PerformanceVector bestPerformanceEver = lastPerformance;
    for (i = 0; i < maxNumberOfAttributes && !earlyAbort; i++) {
      // setting values for logging
      currentNumberOfFeatures = numberOfAttributes - i - 1;

      // performing a round
      int bestIndex = 0;
      PerformanceVector currentBestPerformance = null;
      for (int current = 0; current < numberOfAttributes; current++) {
        if (selected[current]) {
          // switching off
          attributes.remove(attributeArray[current]);
          currentAttributes = attributes;

          // evaluate performance
          PerformanceVector performance = getPerformance(exampleSet);
          if (currentBestPerformance == null || performance.compareTo(currentBestPerformance) > 0) {
            bestIndex = current;
            currentBestPerformance = performance;
          }

          // switching on
          attributes.addRegular(attributeArray[current]);
          currentAttributes = null; // removing reference
        }
      }
      double currentFitness = currentBestPerformance.getMainCriterion().getFitness();
      if (i != 0) {
        double lastFitness = lastPerformance.getMainCriterion().getFitness();
        // switch stopping behavior
        switch (behavior) {
          case WITH_DECREASE:
            if (lastFitness >= currentFitness) earlyAbort = true;
            break;
          case WITH_DECREASE_EXCEEDS:
            if (useRelativeIncrease) {
              // relative increase testing
              if (currentFitness < lastFitness - Math.abs(lastFitness * maximalDecrease))
                earlyAbort = true;
            } else {
              // absolute increase testing
              if (currentFitness < lastFitness - maximalDecrease) earlyAbort = true;
            }
            break;
          case WITH_DECREASE_SIGNIFICANT:
            AnovaCalculator calculator = new AnovaCalculator();
            calculator.setAlpha(alpha);

            PerformanceCriterion pc = currentBestPerformance.getMainCriterion();
            calculator.addGroup(pc.getAverageCount(), pc.getAverage(), pc.getVariance());
            pc = lastPerformance.getMainCriterion();
            calculator.addGroup(pc.getAverageCount(), pc.getAverage(), pc.getVariance());

            SignificanceTestResult result;
            try {
              result = calculator.performSignificanceTest();
            } catch (SignificanceCalculationException e) {
              throw new UserError(this, 920, e.getMessage());
            }
            if (lastFitness > currentFitness && result.getProbability() < alpha) earlyAbort = true;
        }
      }
      if (earlyAbort) {
        // check if there are some free tries left
        if (numberOfFails == 0) {
          break;
        }
        numberOfFails--;
        speculativeList.add(bestIndex);
        earlyAbort = false;

        // needs performance increase compared to better performance of current and last!
        if (currentBestPerformance.compareTo(lastPerformance) > 0)
          lastPerformance = currentBestPerformance;
      } else {
        // resetting maximal number of fails.
        numberOfFails = maxNumberOfFails;
        speculativeList.clear();
        lastPerformance = currentBestPerformance;
        bestPerformanceEver = currentBestPerformance;
      }

      // switching best index off
      attributes.remove(attributeArray[bestIndex]);
      selected[bestIndex] = false;
    }
    // add predictively removed attributes: speculative execution did not yield  good result
    for (Integer removeIndex : speculativeList) {
      selected[removeIndex] = true;
      attributes.addRegular(attributeArray[removeIndex]);
    }

    AttributeWeights weights = new AttributeWeights();
    i = 0;
    for (Attribute attribute : attributeArray) {
      if (selected[i]) weights.setWeight(attribute.getName(), 1d);
      else weights.setWeight(attribute.getName(), 0d);
      i++;
    }

    exampleSetOutput.deliver(exampleSet);
    performanceOutput.deliver(bestPerformanceEver);
    weightsOutput.deliver(weights);
  }
 @Override
 public DataTable getDataTable(Object renderable, IOContainer ioContainer) {
   AttributeWeights weights = (AttributeWeights) renderable;
   return weights.createDataTable();
 }