public static Double getAffinity(Object object, OutputField outputField) { HasAffinity hasAffinity = TypeUtil.cast(HasAffinity.class, object); int rank = outputField.getRank(); if (rank <= 0) { throw new InvalidFeatureException(outputField); } if (rank > 1) { HasAffinityRanking hasAffinityRanking = TypeUtil.cast(HasAffinityRanking.class, object); OutputField.RankOrder rankOrder = outputField.getRankOrder(); switch (rankOrder) { case DESCENDING: break; default: throw new UnsupportedFeatureException(outputField, rankOrder); } return getElement(hasAffinityRanking.getAffinityRanking(), rank); } String value = getCategoryValue(object, outputField); return hasAffinity.getAffinity(value); }
public static Double getCategoricalResidual(Object object, FieldValue expectedObject) { HasProbability hasProbability = TypeUtil.cast(HasProbability.class, object); String value = TypeUtil.format(getPredictedValue(object)); String expectedValue = TypeUtil.format(FieldValueUtil.getValue(expectedObject)); boolean equals = TypeUtil.equals(DataType.STRING, value, expectedValue); return Double.valueOf((equals ? 1d : 0d) - hasProbability.getProbability(value)); }
public static TargetValue getTargetValue(Target target, Object value) { DataType dataType = TypeUtil.getDataType(value); List<TargetValue> targetValues = target.getTargetValues(); for (TargetValue targetValue : targetValues) { if (TypeUtil.equals( dataType, value, TypeUtil.parseOrCast(dataType, targetValue.getValue()))) { return targetValue; } } return null; }
private static Double getProbability(Object object, OutputField outputField) { HasProbability hasProbability = TypeUtil.cast(HasProbability.class, object); String value = getCategoryValue(object, outputField); return hasProbability.getProbability(value); }
@Override public FieldValue evaluate(List<FieldValue> values) { if (values.size() != 2) { throw new EvaluationException(); } FieldValue left = values.get(0); FieldValue right = values.get(1); // "If one of the input fields of a simple arithmetic function is a missing value, the result // evaluates to missing value" if (left == null || right == null) { return null; } DataType dataType = TypeUtil.getResultDataType(left.getDataType(), right.getDataType()); Number result; try { result = evaluate(left.asNumber(), right.asNumber()); } catch (ArithmeticException ae) { throw new InvalidResultException(null); } return FieldValueUtil.create(cast(dataType, result)); }
@Override public FieldValue evaluate(List<FieldValue> values) { StorelessUnivariateStatistic statistic = createStatistic(); DataType dataType = null; for (FieldValue value : values) { // "Missing values in the input to an aggregate function are simply ignored" if (value == null) { continue; } statistic.increment((value.asNumber()).doubleValue()); if (dataType != null) { dataType = TypeUtil.getResultDataType(dataType, value.getDataType()); } else { dataType = value.getDataType(); } } if (statistic.getN() == 0) { throw new MissingResultException(null); } Object result = cast(getResultType(dataType), statistic.getResult()); return FieldValueUtil.create(result); }
public static String getReasonCode(Object object, OutputField outputField) { HasReasonCodeRanking hasReasonCodeRanking = TypeUtil.cast(HasReasonCodeRanking.class, object); int rank = outputField.getRank(); if (rank <= 0) { throw new InvalidFeatureException(outputField); } return getElement(hasReasonCodeRanking.getReasonCodeRanking(), rank); }
public static Object getRuleValue(Object object, OutputField outputField) { HasRuleValues hasRuleValues = TypeUtil.cast(HasRuleValues.class, object); List<AssociationRule> associationRules = getRuleValues(hasRuleValues, outputField); String isMultiValued = outputField.getIsMultiValued(); // Return a single result if (("0").equals(isMultiValued)) { int rank = outputField.getRank(); if (rank <= 0) { throw new InvalidFeatureException(outputField); } AssociationRule associationRule = getElement(associationRules, rank); if (associationRule != null) { return getRuleFeature(hasRuleValues, associationRule, outputField); } return null; } else // Return multiple results if (("1").equals(isMultiValued)) { int size; int rank = outputField.getRank(); if (rank < 0) { throw new InvalidFeatureException(outputField); } else // "A zero value indicates that all output values are to be returned" if (rank == 0) { size = associationRules.size(); } else // "A positive value indicates the number of output values to be returned" { size = Math.min(rank, associationRules.size()); } associationRules = associationRules.subList(0, size); List<Object> result = new ArrayList<>(associationRules.size()); for (AssociationRule associationRule : associationRules) { result.add(getRuleFeature(hasRuleValues, associationRule, outputField)); } return result; } else { throw new InvalidFeatureException(outputField); } }
private static String getCategoryValue(Object object, OutputField outputField) { String value = outputField.getValue(); // "If the value attribute is not specified, then the predicted categorical value should be // returned as a result" if (value == null) { return TypeUtil.format(getPredictedValue(object)); } return value; }
static Object evaluateRegressionInternal( FieldName name, Object value, ModelEvaluationContext context) { ModelEvaluator<?> evaluator = context.getModelEvaluator(); if (Objects.equals(Evaluator.DEFAULT_TARGET, name)) { DataField dataField = evaluator.getDataField(); if (value != null) { value = TypeUtil.cast(dataField.getDataType(), value); } } else { Target target = evaluator.getTarget(name); if (target != null) { if (value == null) { value = getDefaultValue(target); } // End if if (value != null) { value = processValue(target, (Double) value); } } DataField dataField = evaluator.getDataField(name); if (dataField == null) { throw new MissingFieldException(name); } // End if if (value != null) { value = TypeUtil.cast(dataField.getDataType(), value); } MiningField miningField = evaluator.getMiningField(name); context.declare( name, FieldValueUtil.createTargetValue(dataField, miningField, target, value)); } return value; }
private static Object getPredictedDisplayValue( Object object, DataField dataField, Target target) { if (object instanceof HasDisplayValue) { HasDisplayValue hasDisplayValue = TypeUtil.cast(HasDisplayValue.class, object); return hasDisplayValue.getDisplayValue(); } object = getPredictedValue(object); if (target != null) { TargetValue targetValue = TargetUtil.getTargetValue(target, object); if (targetValue != null) { String displayValue = targetValue.getDisplayValue(); if (displayValue != null) { return displayValue; } } } OpType opType = dataField.getOpType(); switch (opType) { case CONTINUOUS: break; case CATEGORICAL: case ORDINAL: { Value value = FieldValueUtil.getValidValue(dataField, object); if (value != null) { String displayValue = value.getDisplayValue(); if (displayValue != null) { return displayValue; } } } break; default: throw new UnsupportedFeatureException(dataField, opType); } // "If the display value is not specified explicitly, then the raw predicted value is used by // default" return object; }
private static String getEntityId(Object object, OutputField outputField) { HasEntityId hasEntityId = TypeUtil.cast(HasEntityId.class, object); int rank = outputField.getRank(); if (rank <= 0) { throw new InvalidFeatureException(outputField); } if (rank > 1) { HasEntityIdRanking hasEntityIdRanking = TypeUtil.cast(HasEntityIdRanking.class, object); OutputField.RankOrder rankOrder = outputField.getRankOrder(); switch (rankOrder) { case DESCENDING: break; default: throw new UnsupportedFeatureException(outputField, rankOrder); } return getElement(hasEntityIdRanking.getEntityIdRanking(), rank); } return hasEntityId.getEntityId(); }
private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) { Trail trail = new Trail(); Node node = evaluateTree(trail, context); if (node == null) { return TargetUtil.evaluateRegressionDefault(context); } Double score = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, node.getScore()); TargetField targetField = getTargetField(); NodeScore nodeScore = createNodeScore(node, TargetUtil.evaluateRegressionInternal(targetField, score, context)); return Collections.singletonMap(targetField.getName(), nodeScore); }
private static Double evaluate(Collection<?> values, int quantile) { List<Double> doubleValues = new ArrayList<>(); for (Object value : values) { Double doubleValue = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, value); doubleValues.add(doubleValue); } double[] data = Doubles.toArray(doubleValues); // The data must be (at least partially) ordered Arrays.sort(data); Percentile percentile = new Percentile(); percentile.setData(data); return percentile.evaluate(quantile); }
public static Object getRuleValue( Object object, OutputField outputField, OutputField.RuleFeature ruleFeature) { HasRuleValues hasRuleValues = TypeUtil.cast(HasRuleValues.class, object); List<AssociationRule> associationRules = getRuleValues(hasRuleValues, outputField); String isMultiValued = outputField.getIsMultiValued(); if (!("0").equals(isMultiValued)) { throw new UnsupportedFeatureException(outputField); } int rank = outputField.getRank(); if (rank <= 0) { throw new InvalidFeatureException(outputField); } AssociationRule associationRule = getElement(associationRules, rank); if (associationRule != null) { return getRuleFeature(hasRuleValues, associationRule, outputField, ruleFeature); } return null; }
public static Double getEntityAffinity(Object object) { HasEntityAffinity hasEntityAffinity = TypeUtil.cast(HasEntityAffinity.class, object); return hasEntityAffinity.getEntityAffinity(); }
/** * Evaluates the {@link Output} element. * * @param predictions A map of {@link Evaluator#getTargetFields() target field} values. * @return A map of {@link Evaluator#getTargetFields() target field} values together with {@link * Evaluator#getOutputFields() output field} values. */ @SuppressWarnings(value = {"fallthrough"}) public static Map<FieldName, ?> evaluate( Map<FieldName, ?> predictions, ModelEvaluationContext context) { ModelEvaluator<?> modelEvaluator = context.getModelEvaluator(); Model model = modelEvaluator.getModel(); Output output = model.getOutput(); if (output == null) { return predictions; } Map<FieldName, Object> result = new LinkedHashMap<>(predictions); List<OutputField> outputFields = output.getOutputFields(); outputFields: for (OutputField outputField : outputFields) { FieldName targetFieldName = outputField.getTargetField(); Object targetValue = null; ResultFeature resultFeature = outputField.getResultFeature(); String segmentId = outputField.getSegmentId(); SegmentResult segmentPredictions = null; // Load the target value of the specified segment if (segmentId != null) { if (!(model instanceof MiningModel)) { throw new InvalidFeatureException(outputField); } MiningModelEvaluationContext miningModelContext = (MiningModelEvaluationContext) context; segmentPredictions = miningModelContext.getResult(segmentId); // "If there is no Segment matching segmentId or if the predicate of the matching Segment // evaluated to false, then the result delivered by this OutputField is missing" if (segmentPredictions == null) { continue outputFields; } // End if if (targetFieldName != null) { if (!segmentPredictions.containsKey(targetFieldName)) { throw new MissingValueException(targetFieldName, outputField); } targetValue = segmentPredictions.get(targetFieldName); } else { targetValue = segmentPredictions.getTargetValue(); } } else // Load the target value { switch (resultFeature) { case ENTITY_ID: { // "Result feature entityId returns the id of the winning segment" if (model instanceof MiningModel) { targetValue = TypeUtil.cast(HasEntityId.class, predictions); break; } } // Falls through default: { if (targetFieldName == null) { targetFieldName = modelEvaluator.getTargetFieldName(); } // End if if (!predictions.containsKey(targetFieldName)) { throw new MissingValueException(targetFieldName, outputField); } targetValue = predictions.get(targetFieldName); } break; } } // "If the target value is missing, then the result delivered by this OutputField is missing" if (targetValue == null) { continue outputFields; } Object value; // Perform the requested computation on the target value switch (resultFeature) { case PREDICTED_VALUE: { value = getPredictedValue(targetValue); } break; case PREDICTED_DISPLAY_VALUE: { DataField dataField = modelEvaluator.getDataField(targetFieldName); if (dataField == null) { throw new MissingFieldException(targetFieldName, outputField); } Target target = modelEvaluator.getTarget(targetFieldName); value = getPredictedDisplayValue(targetValue, dataField, target); } break; case TRANSFORMED_VALUE: case DECISION: { if (segmentId != null) { String name = outputField.getValue(); if (name == null) { throw new InvalidFeatureException(outputField); } Expression expression = outputField.getExpression(); if (expression != null) { throw new InvalidFeatureException(outputField); } value = segmentPredictions.get(FieldName.create(name)); break; } Expression expression = outputField.getExpression(); if (expression == null) { throw new InvalidFeatureException(outputField); } value = FieldValueUtil.getValue(ExpressionUtil.evaluate(expression, context)); } break; case PROBABILITY: { value = getProbability(targetValue, outputField); } break; case RESIDUAL: { FieldValue expectedTargetValue = context.evaluate(targetFieldName); if (expectedTargetValue == null) { throw new MissingValueException(targetFieldName, outputField); } DataField dataField = modelEvaluator.getDataField(targetFieldName); OpType opType = dataField.getOpType(); switch (opType) { case CONTINUOUS: value = getContinuousResidual(targetValue, expectedTargetValue); break; case CATEGORICAL: value = getCategoricalResidual(targetValue, expectedTargetValue); break; default: throw new UnsupportedFeatureException(dataField, opType); } } break; case CLUSTER_ID: { value = getClusterId(targetValue); } break; case ENTITY_ID: { if (targetValue instanceof HasRuleValues) { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE_ID); break; } value = getEntityId(targetValue, outputField); } break; case AFFINITY: { value = getAffinity(targetValue, outputField); } break; case CLUSTER_AFFINITY: case ENTITY_AFFINITY: { String entityId = outputField.getValue(); // Select the specified entity instead of the winning entity if (entityId != null) { value = getAffinity(targetValue, outputField); break; } value = getEntityAffinity(targetValue); } break; case REASON_CODE: { value = getReasonCode(targetValue, outputField); } break; case RULE_VALUE: { value = getRuleValue(targetValue, outputField); } break; case ANTECEDENT: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.ANTECEDENT); } break; case CONSEQUENT: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.CONSEQUENT); } break; case RULE: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE); } break; case RULE_ID: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.RULE_ID); } break; case CONFIDENCE: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.CONFIDENCE); } break; case SUPPORT: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.SUPPORT); } break; case LIFT: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.LIFT); } break; case LEVERAGE: { value = getRuleValue(targetValue, outputField, OutputField.RuleFeature.LEVERAGE); } break; case WARNING: { value = context.getWarnings(); } break; default: throw new UnsupportedFeatureException(outputField, resultFeature); } FieldValue outputValue = FieldValueUtil.create(outputField, value); // The result of one output field becomes available to other output fields context.declare(outputField.getName(), outputValue); result.put(outputField.getName(), FieldValueUtil.getValue(outputValue)); } return result; }
private static String getClusterId(Object object) { HasEntityId hasEntityId = TypeUtil.cast(HasEntityId.class, object); return hasEntityId.getEntityId(); }