@Override
  public int compare(E left, E right) {
    FieldName leftKey = left.getKey();
    FieldName rightKey = right.getKey();

    boolean caseSensitive = isCaseSensitive();
    if (caseSensitive) {
      return (leftKey.getValue()).compareTo(rightKey.getValue());
    } else {
      return (leftKey.getValue()).compareToIgnoreCase(rightKey.getValue());
    }
  }
Exemplo n.º 2
0
  /**
   * Evaluates the {@link Output} element.
   *
   * @param predictions A map of {@link Evaluator#getTargetFields() target field} values.
   * @return A map of {@link Evaluator#getTargetFields() target field} values together with {@link
   *     Evaluator#getOutputFields() output field} values.
   */
  @SuppressWarnings(value = {"fallthrough"})
  public static Map<FieldName, ?> evaluate(
      Map<FieldName, ?> predictions, ModelEvaluationContext context) {
    ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

    Model model = modelEvaluator.getModel();

    Output output = model.getOutput();
    if (output == null) {
      return predictions;
    }

    Map<FieldName, Object> result = new LinkedHashMap<>(predictions);

    List<OutputField> outputFields = output.getOutputFields();

    outputFields:
    for (OutputField outputField : outputFields) {
      FieldName targetFieldName = outputField.getTargetField();

      Object targetValue = null;

      ResultFeature resultFeature = outputField.getResultFeature();

      String segmentId = outputField.getSegmentId();

      SegmentResult segmentPredictions = null;

      // Load the target value of the specified segment
      if (segmentId != null) {

        if (!(model instanceof MiningModel)) {
          throw new InvalidFeatureException(outputField);
        }

        MiningModelEvaluationContext miningModelContext = (MiningModelEvaluationContext) context;

        segmentPredictions = miningModelContext.getResult(segmentId);

        // "If there is no Segment matching segmentId or if the predicate of the matching Segment
        // evaluated to false, then the result delivered by this OutputField is missing"
        if (segmentPredictions == null) {
          continue outputFields;
        } // End if

        if (targetFieldName != null) {

          if (!segmentPredictions.containsKey(targetFieldName)) {
            throw new MissingValueException(targetFieldName, outputField);
          }

          targetValue = segmentPredictions.get(targetFieldName);
        } else {
          targetValue = segmentPredictions.getTargetValue();
        }
      } else

      // Load the target value
      {
        switch (resultFeature) {
          case ENTITY_ID:
            {
              // "Result feature entityId returns the id of the winning segment"
              if (model instanceof MiningModel) {
                targetValue = TypeUtil.cast(HasEntityId.class, predictions);

                break;
              }
            }
            // Falls through
          default:
            {
              if (targetFieldName == null) {
                targetFieldName = modelEvaluator.getTargetFieldName();
              } // End if

              if (!predictions.containsKey(targetFieldName)) {
                throw new MissingValueException(targetFieldName, outputField);
              }

              targetValue = predictions.get(targetFieldName);
            }
            break;
        }
      }

      // "If the target value is missing, then the result delivered by this OutputField is missing"
      if (targetValue == null) {
        continue outputFields;
      }

      Object value;

      // Perform the requested computation on the target value
      switch (resultFeature) {
        case PREDICTED_VALUE:
          {
            value = getPredictedValue(targetValue);
          }
          break;
        case PREDICTED_DISPLAY_VALUE:
          {
            DataField dataField = modelEvaluator.getDataField(targetFieldName);
            if (dataField == null) {
              throw new MissingFieldException(targetFieldName, outputField);
            }

            Target target = modelEvaluator.getTarget(targetFieldName);

            value = getPredictedDisplayValue(targetValue, dataField, target);
          }
          break;
        case TRANSFORMED_VALUE:
        case DECISION:
          {
            if (segmentId != null) {
              String name = outputField.getValue();
              if (name == null) {
                throw new InvalidFeatureException(outputField);
              }

              Expression expression = outputField.getExpression();
              if (expression != null) {
                throw new InvalidFeatureException(outputField);
              }

              value = segmentPredictions.get(FieldName.create(name));

              break;
            }

            Expression expression = outputField.getExpression();
            if (expression == null) {
              throw new InvalidFeatureException(outputField);
            }

            value = FieldValueUtil.getValue(ExpressionUtil.evaluate(expression, context));
          }
          break;
        case PROBABILITY:
          {
            value = getProbability(targetValue, outputField);
          }
          break;
        case RESIDUAL:
          {
            FieldValue expectedTargetValue = context.evaluate(targetFieldName);
            if (expectedTargetValue == null) {
              throw new MissingValueException(targetFieldName, outputField);
            }

            DataField dataField = modelEvaluator.getDataField(targetFieldName);

            OpType opType = dataField.getOpType();
            switch (opType) {
              case CONTINUOUS:
                value = getContinuousResidual(targetValue, expectedTargetValue);
                break;
              case CATEGORICAL:
                value = getCategoricalResidual(targetValue, expectedTargetValue);
                break;
              default:
                throw new UnsupportedFeatureException(dataField, opType);
            }
          }
          break;
        case CLUSTER_ID:
          {
            value = getClusterId(targetValue);
          }
          break;
        case ENTITY_ID:
          {
            if (targetValue instanceof HasRuleValues) {
              value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE_ID);

              break;
            }

            value = getEntityId(targetValue, outputField);
          }
          break;
        case AFFINITY:
          {
            value = getAffinity(targetValue, outputField);
          }
          break;
        case CLUSTER_AFFINITY:
        case ENTITY_AFFINITY:
          {
            String entityId = outputField.getValue();

            // Select the specified entity instead of the winning entity
            if (entityId != null) {
              value = getAffinity(targetValue, outputField);

              break;
            }

            value = getEntityAffinity(targetValue);
          }
          break;
        case REASON_CODE:
          {
            value = getReasonCode(targetValue, outputField);
          }
          break;
        case RULE_VALUE:
          {
            value = getRuleValue(targetValue, outputField);
          }
          break;
        case ANTECEDENT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.ANTECEDENT);
          }
          break;
        case CONSEQUENT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.CONSEQUENT);
          }
          break;
        case RULE:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE);
          }
          break;
        case RULE_ID:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE_ID);
          }
          break;
        case CONFIDENCE:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.CONFIDENCE);
          }
          break;
        case SUPPORT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.SUPPORT);
          }
          break;
        case LIFT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.LIFT);
          }
          break;
        case LEVERAGE:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.LEVERAGE);
          }
          break;
        case WARNING:
          {
            value = context.getWarnings();
          }
          break;
        default:
          throw new UnsupportedFeatureException(outputField, resultFeature);
      }

      FieldValue outputValue = FieldValueUtil.create(outputField, value);

      // The result of one output field becomes available to other output fields
      context.declare(outputField.getName(), outputValue);

      result.put(outputField.getName(), FieldValueUtil.getValue(outputValue));
    }

    return result;
  }