private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
    Trail trail = new Trail();

    Node node = evaluateTree(trail, context);
    if (node == null) {
      return TargetUtil.evaluateRegressionDefault(context);
    }

    Double score = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, node.getScore());

    TargetField targetField = getTargetField();

    NodeScore nodeScore =
        createNodeScore(node, TargetUtil.evaluateRegressionInternal(targetField, score, context));

    return Collections.singletonMap(targetField.getName(), nodeScore);
  }
  private Map<FieldName, ? extends Classification> evaluateClassification(
      ModelEvaluationContext context) {
    TreeModel treeModel = getModel();

    Trail trail = new Trail();

    Node node = evaluateTree(trail, context);
    if (node == null) {
      return TargetUtil.evaluateClassificationDefault(context);
    }

    double missingValuePenalty = 1d;

    int missingLevels = trail.getMissingLevels();
    if (missingLevels > 0) {
      missingValuePenalty = Math.pow(treeModel.getMissingValuePenalty(), missingLevels);
    }

    NodeScoreDistribution result = createNodeScoreDistribution(node, missingValuePenalty);

    return TargetUtil.evaluateClassification(result, context);
  }
Esempio n. 3
0
  private static Object getPredictedDisplayValue(
      Object object, DataField dataField, Target target) {

    if (object instanceof HasDisplayValue) {
      HasDisplayValue hasDisplayValue = TypeUtil.cast(HasDisplayValue.class, object);

      return hasDisplayValue.getDisplayValue();
    }

    object = getPredictedValue(object);

    if (target != null) {
      TargetValue targetValue = TargetUtil.getTargetValue(target, object);

      if (targetValue != null) {
        String displayValue = targetValue.getDisplayValue();

        if (displayValue != null) {
          return displayValue;
        }
      }
    }

    OpType opType = dataField.getOpType();
    switch (opType) {
      case CONTINUOUS:
        break;
      case CATEGORICAL:
      case ORDINAL:
        {
          Value value = FieldValueUtil.getValidValue(dataField, object);

          if (value != null) {
            String displayValue = value.getDisplayValue();

            if (displayValue != null) {
              return displayValue;
            }
          }
        }
        break;
      default:
        throw new UnsupportedFeatureException(dataField, opType);
    }

    // "If the display value is not specified explicitly, then the raw predicted value is used by
    // default"
    return object;
  }
  private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) {
    SupportVectorMachineModel supportVectorMachineModel = getModel();

    List<SupportVectorMachine> supportVectorMachines =
        supportVectorMachineModel.getSupportVectorMachines();
    if (supportVectorMachines.size() != 1) {
      throw new InvalidFeatureException(supportVectorMachineModel);
    }

    SupportVectorMachine supportVectorMachine = supportVectorMachines.get(0);

    double[] input = createInput(context);

    Double result = evaluateSupportVectorMachine(supportVectorMachine, input);

    return TargetUtil.evaluateRegression(result, context);
  }
  private Map<FieldName, ? extends Classification> evaluateClassification(
      ModelEvaluationContext context) {
    SupportVectorMachineModel supportVectorMachineModel = getModel();

    List<SupportVectorMachine> supportVectorMachines =
        supportVectorMachineModel.getSupportVectorMachines();
    if (supportVectorMachines.size() < 1) {
      throw new InvalidFeatureException(supportVectorMachineModel);
    }

    String alternateBinaryTargetCategory =
        supportVectorMachineModel.getAlternateBinaryTargetCategory();

    Classification result;

    SvmClassificationMethodType svmClassificationMethod = getClassificationMethod();
    switch (svmClassificationMethod) {
      case ONE_AGAINST_ALL:
        result = new Classification(Classification.Type.DISTANCE);
        break;
      case ONE_AGAINST_ONE:
        result = new VoteDistribution();
        break;
      default:
        throw new UnsupportedFeatureException(supportVectorMachineModel, svmClassificationMethod);
    }

    double[] input = createInput(context);

    for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
      String targetCategory = supportVectorMachine.getTargetCategory();
      String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();

      Double value = evaluateSupportVectorMachine(supportVectorMachine, input);

      switch (svmClassificationMethod) {
        case ONE_AGAINST_ALL:
          {
            if (targetCategory == null || alternateTargetCategory != null) {
              throw new InvalidFeatureException(supportVectorMachine);
            }

            result.put(targetCategory, value);
          }
          break;
        case ONE_AGAINST_ONE:
          if (alternateBinaryTargetCategory != null) {

            if (targetCategory == null || alternateTargetCategory != null) {
              throw new InvalidFeatureException(supportVectorMachine);
            }

            String label;

            long roundedValue = Math.round(value);

            // "A rounded value of 1 corresponds to the targetCategory attribute of the
            // SupportVectorMachine element"
            if (roundedValue == 1) {
              label = targetCategory;
            } else

            // "A rounded value of 0 corresponds to the alternateBinaryTargetCategory attribute of
            // the SupportVectorMachineModel element"
            if (roundedValue == 0) {
              label = alternateBinaryTargetCategory;
            } else

            // "The numeric prediction must be between 0 and 1"
            {
              throw new EvaluationException("Invalid numeric prediction " + value);
            }

            Double vote = result.get(label);
            if (vote == null) {
              vote = 0d;
            }

            result.put(label, (vote + 1d));
          } else {
            if (targetCategory == null || alternateTargetCategory == null) {
              throw new InvalidFeatureException(supportVectorMachine);
            }

            Double threshold = supportVectorMachine.getThreshold();
            if (threshold == null) {
              threshold = supportVectorMachineModel.getThreshold();
            }

            String label;

            // "If the numeric prediction is smaller than the threshold, it corresponds to the
            // targetCategory attribute"
            if ((value).compareTo(threshold) < 0) {
              label = targetCategory;
            } else {
              label = alternateTargetCategory;
            }

            Double vote = result.get(label);
            if (vote == null) {
              vote = 0d;
            }

            result.put(label, (vote + 1d));
          }
          break;
        default:
          break;
      }
    }

    return TargetUtil.evaluateClassification(result, context);
  }
Esempio n. 6
0
  private Map<FieldName, ? extends ClassificationMap<?>> evaluateRuleSet(
      ModelManagerEvaluationContext context) {
    RuleSetModel ruleSetModel = getModel();

    RuleSet ruleSet = ruleSetModel.getRuleSet();

    List<RuleSelectionMethod> ruleSelectionMethods = ruleSet.getRuleSelectionMethods();

    RuleSelectionMethod ruleSelectionMethod;

    // "If more than one method is included, the first method is used as the default method for
    // scoring"
    if (ruleSelectionMethods.size() > 0) {
      ruleSelectionMethod = ruleSelectionMethods.get(0);
    } else {
      throw new InvalidFeatureException(ruleSet);
    }

    // Both the ordering of keys and values is significant
    ListMultimap<String, SimpleRule> firedRules = LinkedListMultimap.create();

    List<Rule> rules = ruleSet.getRules();
    for (Rule rule : rules) {
      collectFiredRules(firedRules, rule, context);
    }

    RuleClassificationMap result = new RuleClassificationMap();

    RuleSelectionMethod.Criterion criterion = ruleSelectionMethod.getCriterion();

    Set<String> keys = firedRules.keySet();
    for (String key : keys) {
      List<SimpleRule> keyRules = firedRules.get(key);

      switch (criterion) {
        case FIRST_HIT:
          {
            SimpleRule winner = keyRules.get(0);

            // The first value of the first key
            if (result.getEntity() == null) {
              result.setEntity(winner);
            }

            result.put(key, winner.getConfidence());
          }
          break;
        case WEIGHTED_SUM:
          {
            SimpleRule winner = null;

            double totalWeight = 0;

            for (SimpleRule keyRule : keyRules) {

              if (winner == null || (winner.getWeight() < keyRule.getWeight())) {
                winner = keyRule;
              }

              totalWeight += keyRule.getWeight();
            }

            result.put(winner, key, totalWeight / firedRules.size());
          }
          break;
        case WEIGHTED_MAX:
          {
            SimpleRule winner = null;

            for (SimpleRule keyRule : keyRules) {

              if (winner == null || (winner.getWeight() < keyRule.getWeight())) {
                winner = keyRule;
              }
            }

            result.put(winner, key, winner.getConfidence());
          }
          break;
        default:
          throw new UnsupportedFeatureException(ruleSelectionMethod, criterion);
      }
    }

    return TargetUtil.evaluateClassification(result, context);
  }