예제 #1
0
  public static FieldValue evaluate(
      DefineFunction defineFunction, List<FieldValue> values, EvaluationContext context) {
    List<ParameterField> parameterFields = defineFunction.getParameterFields();

    if (parameterFields.size() < 1) {
      throw new InvalidFeatureException(defineFunction);
    } // End if

    if (parameterFields.size() != values.size()) {
      throw new EvaluationException();
    }

    FunctionEvaluationContext functionContext = new FunctionEvaluationContext(context);

    for (int i = 0; i < parameterFields.size(); i++) {
      ParameterField parameterField = parameterFields.get(i);

      FieldValue value = FieldValueUtil.refine(parameterField, values.get(i));

      functionContext.declare(parameterField.getName(), value);
    }

    Expression expression = defineFunction.getExpression();
    if (expression == null) {
      throw new InvalidFeatureException(defineFunction);
    }

    FieldValue result = ExpressionUtil.evaluate(expression, functionContext);

    return FieldValueUtil.refine(defineFunction.getDataType(), defineFunction.getOptype(), result);
  }
예제 #2
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      StorelessUnivariateStatistic statistic = createStatistic();

      DataType dataType = null;

      for (FieldValue value : values) {

        // "Missing values in the input to an aggregate function are simply ignored"
        if (value == null) {
          continue;
        }

        statistic.increment((value.asNumber()).doubleValue());

        if (dataType != null) {
          dataType = TypeUtil.getResultDataType(dataType, value.getDataType());
        } else {
          dataType = value.getDataType();
        }
      }

      if (statistic.getN() == 0) {
        throw new MissingResultException(null);
      }

      Object result = cast(getResultType(dataType), statistic.getResult());

      return FieldValueUtil.create(result);
    }
예제 #3
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {

      if (values.size() != 2) {
        throw new EvaluationException();
      }

      FieldValue left = values.get(0);
      FieldValue right = values.get(1);

      // "If one of the input fields of a simple arithmetic function is a missing value, the result
      // evaluates to missing value"
      if (left == null || right == null) {
        return null;
      }

      DataType dataType = TypeUtil.getResultDataType(left.getDataType(), right.getDataType());

      Number result;

      try {
        result = evaluate(left.asNumber(), right.asNumber());
      } catch (ArithmeticException ae) {
        throw new InvalidResultException(null);
      }

      return FieldValueUtil.create(cast(dataType, result));
    }
예제 #4
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      checkVariableArguments(values, 2);

      Boolean result = evaluate(values.get(0), values.subList(1, values.size()));

      return FieldValueUtil.create(result);
    }
예제 #5
0
  public static Double getCategoricalResidual(Object object, FieldValue expectedObject) {
    HasProbability hasProbability = TypeUtil.cast(HasProbability.class, object);

    String value = TypeUtil.format(getPredictedValue(object));
    String expectedValue = TypeUtil.format(FieldValueUtil.getValue(expectedObject));

    boolean equals = TypeUtil.equals(DataType.STRING, value, expectedValue);

    return Double.valueOf((equals ? 1d : 0d) - hasProbability.getProbability(value));
  }
예제 #6
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      checkArguments(values, 1);

      FieldValue value = values.get(0);

      String result = evaluate(value.asString());

      return FieldValueUtil.create(result);
    }
예제 #7
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      checkArguments(values, 1, true);

      FieldValue value = values.get(0);

      Boolean result = evaluate(value);

      return FieldValueUtil.create(result);
    }
예제 #8
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      checkArguments(values, 1);

      FieldValue value = values.get(0);

      Number result = cast(getResultType(value.getDataType()), evaluate(value.asNumber()));

      return FieldValueUtil.create(result);
    }
예제 #9
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      checkArguments(values, 2);

      FieldValue left = values.get(0);
      FieldValue right = values.get(1);

      Boolean result = evaluate((left).compareToValue(right));

      return FieldValueUtil.create(result);
    }
예제 #10
0
    @Override
    public FieldValue evaluate(List<FieldValue> values) {
      checkVariableArguments(values, 2);

      Boolean result = (values.get(0)).asBoolean();

      for (int i = 1; i < values.size(); i++) {
        result = evaluate(result, (values.get(i)).asBoolean());
      }

      return FieldValueUtil.create(result);
    }
예제 #11
0
  private static Object getPredictedDisplayValue(
      Object object, DataField dataField, Target target) {

    if (object instanceof HasDisplayValue) {
      HasDisplayValue hasDisplayValue = TypeUtil.cast(HasDisplayValue.class, object);

      return hasDisplayValue.getDisplayValue();
    }

    object = getPredictedValue(object);

    if (target != null) {
      TargetValue targetValue = TargetUtil.getTargetValue(target, object);

      if (targetValue != null) {
        String displayValue = targetValue.getDisplayValue();

        if (displayValue != null) {
          return displayValue;
        }
      }
    }

    OpType opType = dataField.getOpType();
    switch (opType) {
      case CONTINUOUS:
        break;
      case CATEGORICAL:
      case ORDINAL:
        {
          Value value = FieldValueUtil.getValidValue(dataField, object);

          if (value != null) {
            String displayValue = value.getDisplayValue();

            if (displayValue != null) {
              return displayValue;
            }
          }
        }
        break;
      default:
        throw new UnsupportedFeatureException(dataField, opType);
    }

    // "If the display value is not specified explicitly, then the raw predicted value is used by
    // default"
    return object;
  }
예제 #12
0
  static Object evaluateRegressionInternal(
      FieldName name, Object value, ModelEvaluationContext context) {
    ModelEvaluator<?> evaluator = context.getModelEvaluator();

    if (Objects.equals(Evaluator.DEFAULT_TARGET, name)) {
      DataField dataField = evaluator.getDataField();

      if (value != null) {
        value = TypeUtil.cast(dataField.getDataType(), value);
      }
    } else {
      Target target = evaluator.getTarget(name);
      if (target != null) {

        if (value == null) {
          value = getDefaultValue(target);
        } // End if

        if (value != null) {
          value = processValue(target, (Double) value);
        }
      }

      DataField dataField = evaluator.getDataField(name);
      if (dataField == null) {
        throw new MissingFieldException(name);
      } // End if

      if (value != null) {
        value = TypeUtil.cast(dataField.getDataType(), value);
      }

      MiningField miningField = evaluator.getMiningField(name);

      context.declare(
          name, FieldValueUtil.createTargetValue(dataField, miningField, target, value));
    }

    return value;
  }
예제 #13
0
  static Classification evaluateClassificationInternal(
      FieldName name, Classification value, ModelEvaluationContext context) {
    ModelEvaluator<?> evaluator = context.getModelEvaluator();

    if (Objects.equals(Evaluator.DEFAULT_TARGET, name)) {
      DataField dataField = evaluator.getDataField();

      if (value != null) {
        value.computeResult(dataField.getDataType());
      }
    } else {
      Target target = evaluator.getTarget(name);
      if (target != null) {

        if (value == null) {
          value = getPriorProbabilities(target);
        }
      }

      DataField dataField = evaluator.getDataField(name);
      if (dataField == null) {
        throw new MissingFieldException(name);
      } // End if

      if (value != null) {
        value.computeResult(dataField.getDataType());
      }

      MiningField miningField = evaluator.getMiningField(name);

      context.declare(
          name,
          FieldValueUtil.createTargetValue(
              dataField, miningField, target, value != null ? value.getResult() : null));
    }

    return value;
  }
예제 #14
0
  /**
   * Evaluates the {@link Output} element.
   *
   * @param predictions A map of {@link Evaluator#getTargetFields() target field} values.
   * @return A map of {@link Evaluator#getTargetFields() target field} values together with {@link
   *     Evaluator#getOutputFields() output field} values.
   */
  @SuppressWarnings(value = {"fallthrough"})
  public static Map<FieldName, ?> evaluate(
      Map<FieldName, ?> predictions, ModelEvaluationContext context) {
    ModelEvaluator<?> modelEvaluator = context.getModelEvaluator();

    Model model = modelEvaluator.getModel();

    Output output = model.getOutput();
    if (output == null) {
      return predictions;
    }

    Map<FieldName, Object> result = new LinkedHashMap<>(predictions);

    List<OutputField> outputFields = output.getOutputFields();

    outputFields:
    for (OutputField outputField : outputFields) {
      FieldName targetFieldName = outputField.getTargetField();

      Object targetValue = null;

      ResultFeature resultFeature = outputField.getResultFeature();

      String segmentId = outputField.getSegmentId();

      SegmentResult segmentPredictions = null;

      // Load the target value of the specified segment
      if (segmentId != null) {

        if (!(model instanceof MiningModel)) {
          throw new InvalidFeatureException(outputField);
        }

        MiningModelEvaluationContext miningModelContext = (MiningModelEvaluationContext) context;

        segmentPredictions = miningModelContext.getResult(segmentId);

        // "If there is no Segment matching segmentId or if the predicate of the matching Segment
        // evaluated to false, then the result delivered by this OutputField is missing"
        if (segmentPredictions == null) {
          continue outputFields;
        } // End if

        if (targetFieldName != null) {

          if (!segmentPredictions.containsKey(targetFieldName)) {
            throw new MissingValueException(targetFieldName, outputField);
          }

          targetValue = segmentPredictions.get(targetFieldName);
        } else {
          targetValue = segmentPredictions.getTargetValue();
        }
      } else

      // Load the target value
      {
        switch (resultFeature) {
          case ENTITY_ID:
            {
              // "Result feature entityId returns the id of the winning segment"
              if (model instanceof MiningModel) {
                targetValue = TypeUtil.cast(HasEntityId.class, predictions);

                break;
              }
            }
            // Falls through
          default:
            {
              if (targetFieldName == null) {
                targetFieldName = modelEvaluator.getTargetFieldName();
              } // End if

              if (!predictions.containsKey(targetFieldName)) {
                throw new MissingValueException(targetFieldName, outputField);
              }

              targetValue = predictions.get(targetFieldName);
            }
            break;
        }
      }

      // "If the target value is missing, then the result delivered by this OutputField is missing"
      if (targetValue == null) {
        continue outputFields;
      }

      Object value;

      // Perform the requested computation on the target value
      switch (resultFeature) {
        case PREDICTED_VALUE:
          {
            value = getPredictedValue(targetValue);
          }
          break;
        case PREDICTED_DISPLAY_VALUE:
          {
            DataField dataField = modelEvaluator.getDataField(targetFieldName);
            if (dataField == null) {
              throw new MissingFieldException(targetFieldName, outputField);
            }

            Target target = modelEvaluator.getTarget(targetFieldName);

            value = getPredictedDisplayValue(targetValue, dataField, target);
          }
          break;
        case TRANSFORMED_VALUE:
        case DECISION:
          {
            if (segmentId != null) {
              String name = outputField.getValue();
              if (name == null) {
                throw new InvalidFeatureException(outputField);
              }

              Expression expression = outputField.getExpression();
              if (expression != null) {
                throw new InvalidFeatureException(outputField);
              }

              value = segmentPredictions.get(FieldName.create(name));

              break;
            }

            Expression expression = outputField.getExpression();
            if (expression == null) {
              throw new InvalidFeatureException(outputField);
            }

            value = FieldValueUtil.getValue(ExpressionUtil.evaluate(expression, context));
          }
          break;
        case PROBABILITY:
          {
            value = getProbability(targetValue, outputField);
          }
          break;
        case RESIDUAL:
          {
            FieldValue expectedTargetValue = context.evaluate(targetFieldName);
            if (expectedTargetValue == null) {
              throw new MissingValueException(targetFieldName, outputField);
            }

            DataField dataField = modelEvaluator.getDataField(targetFieldName);

            OpType opType = dataField.getOpType();
            switch (opType) {
              case CONTINUOUS:
                value = getContinuousResidual(targetValue, expectedTargetValue);
                break;
              case CATEGORICAL:
                value = getCategoricalResidual(targetValue, expectedTargetValue);
                break;
              default:
                throw new UnsupportedFeatureException(dataField, opType);
            }
          }
          break;
        case CLUSTER_ID:
          {
            value = getClusterId(targetValue);
          }
          break;
        case ENTITY_ID:
          {
            if (targetValue instanceof HasRuleValues) {
              value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE_ID);

              break;
            }

            value = getEntityId(targetValue, outputField);
          }
          break;
        case AFFINITY:
          {
            value = getAffinity(targetValue, outputField);
          }
          break;
        case CLUSTER_AFFINITY:
        case ENTITY_AFFINITY:
          {
            String entityId = outputField.getValue();

            // Select the specified entity instead of the winning entity
            if (entityId != null) {
              value = getAffinity(targetValue, outputField);

              break;
            }

            value = getEntityAffinity(targetValue);
          }
          break;
        case REASON_CODE:
          {
            value = getReasonCode(targetValue, outputField);
          }
          break;
        case RULE_VALUE:
          {
            value = getRuleValue(targetValue, outputField);
          }
          break;
        case ANTECEDENT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.ANTECEDENT);
          }
          break;
        case CONSEQUENT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.CONSEQUENT);
          }
          break;
        case RULE:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE);
          }
          break;
        case RULE_ID:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE_ID);
          }
          break;
        case CONFIDENCE:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.CONFIDENCE);
          }
          break;
        case SUPPORT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.SUPPORT);
          }
          break;
        case LIFT:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.LIFT);
          }
          break;
        case LEVERAGE:
          {
            value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.LEVERAGE);
          }
          break;
        case WARNING:
          {
            value = context.getWarnings();
          }
          break;
        default:
          throw new UnsupportedFeatureException(outputField, resultFeature);
      }

      FieldValue outputValue = FieldValueUtil.create(outputField, value);

      // The result of one output field becomes available to other output fields
      context.declare(outputField.getName(), outputValue);

      result.put(outputField.getName(), FieldValueUtil.getValue(outputValue));
    }

    return result;
  }
예제 #15
0
  private static Double getContinuousResidual(Object object, FieldValue expectedObject) {
    Number value = (Number) getPredictedValue(object);
    Number expectedValue = (Number) FieldValueUtil.getValue(expectedObject);

    return Double.valueOf(expectedValue.doubleValue() - value.doubleValue());
  }