private Trail handleDefaultChild(Trail trail, Node node, EvaluationContext context) { // "The defaultChild missing value strategy requires the presence of the defaultChild attribute // in every non-leaf Node" String defaultChild = node.getDefaultChild(); if (defaultChild == null) { throw new InvalidFeatureException(node); } trail.addMissingLevel(); List<Node> children = node.getNodes(); for (int i = 0, max = children.size(); i < max; i++) { Node child = children.get(i); String id = child.getId(); if (id != null && (id).equals(defaultChild)) { // The predicate of the referenced Node is not evaluated return handleTrue(trail, child, context); } } // "Only Nodes which are immediate children of the respective Node can be referenced" throw new InvalidFeatureException(node); }
private Trail handleTrue(Trail trail, Node node, EvaluationContext context) { // A "true" leaf node if (!node.hasNodes()) { return trail.selectNode(node); } trail.push(node); List<Node> children = node.getNodes(); for (int i = 0, max = children.size(); i < max; i++) { Node child = children.get(i); Boolean status = evaluateNode(trail, child, context); if (status == null) { Trail destination = handleMissingValue(trail, node, child, context); if (destination != null) { return destination; } } else if (status.booleanValue()) { return handleTrue(trail, child, context); } } // A "true" non-leaf node return handleNoTrueChild(trail); }
private Boolean evaluateNode(Trail trail, Node node, EvaluationContext context) { EmbeddedModel embeddedModel = node.getEmbeddedModel(); if (embeddedModel != null) { throw new UnsupportedFeatureException(embeddedModel); } Predicate predicate = node.getPredicate(); if (predicate == null) { throw new InvalidFeatureException(node); } // End if // A compound predicate whose boolean operator is "surrogate" represents a special case if (predicate instanceof CompoundPredicate) { CompoundPredicate compoundPredicate = (CompoundPredicate) predicate; PredicateUtil.CompoundPredicateResult result = PredicateUtil.evaluateCompoundPredicateInternal(compoundPredicate, context); if (result.isAlternative()) { trail.addMissingLevel(); } return result.getResult(); } else { return PredicateUtil.evaluate(predicate, context); } }
private ImmutableBiMap.Builder<String, Node> collectNodes( Node node, AtomicInteger index, ImmutableBiMap.Builder<String, Node> builder) { builder = EntityUtil.put(node, index, builder); if (!node.hasNodes()) { return builder; } List<Node> children = node.getNodes(); for (Node child : children) { builder = collectNodes(child, index, builder); } return builder; }
private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext context) { Trail trail = new Trail(); Node node = evaluateTree(trail, context); if (node == null) { return TargetUtil.evaluateRegressionDefault(context); } Double score = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, node.getScore()); TargetField targetField = getTargetField(); NodeScore nodeScore = createNodeScore(node, TargetUtil.evaluateRegressionInternal(targetField, score, context)); return Collections.singletonMap(targetField.getName(), nodeScore); }
private NodeScoreDistribution createNodeScoreDistribution(Node node, double missingValuePenalty) { BiMap<String, Node> entityRegistry = getEntityRegistry(); NodeScoreDistribution result = new NodeScoreDistribution(entityRegistry, node); if (!node.hasScoreDistributions()) { return result; } List<ScoreDistribution> scoreDistributions = node.getScoreDistributions(); double sum = 0; for (int i = 0, max = scoreDistributions.size(); i < max; i++) { ScoreDistribution scoreDistribution = scoreDistributions.get(i); Double recordCount = scoreDistribution.getRecordCount(); if (recordCount == null) { throw new InvalidFeatureException(scoreDistribution); } sum += recordCount; } // End for for (int i = 0, max = scoreDistributions.size(); i < max; i++) { ScoreDistribution scoreDistribution = scoreDistributions.get(i); Double probability = scoreDistribution.getProbability(); if (probability == null) { Double recordCount = scoreDistribution.getRecordCount(); probability = (recordCount / sum); } result.put(scoreDistribution.getValue(), probability); Double confidence = scoreDistribution.getConfidence(); if (confidence != null) { result.putConfidence(scoreDistribution.getValue(), confidence * missingValuePenalty); } } return result; }
private Trail handleNoTrueChild(Trail trail) { TreeModel treeModel = getModel(); TreeModel.NoTrueChildStrategy noTrueChildStrategy = treeModel.getNoTrueChildStrategy(); switch (noTrueChildStrategy) { case RETURN_NULL_PREDICTION: return trail.selectNull(); case RETURN_LAST_PREDICTION: Node lastPrediction = trail.getLastPrediction(); // "Return the parent Node only if it specifies a score attribute" if (lastPrediction.hasScore()) { return trail.selectLastPrediction(); } return trail.selectNull(); default: throw new UnsupportedFeatureException(treeModel, noTrueChildStrategy); } }
private Node evaluateTree(Trail trail, EvaluationContext context) { TreeModel treeModel = getModel(); Node root = treeModel.getNode(); Boolean status = evaluateNode(trail, root, context); if (status != null && status.booleanValue()) { trail = handleTrue(trail, root, context); Node node = trail.getResult(); // "It is not possible that the scoring process ends in a Node which does not have a score // attribute" if (node != null && !node.hasScore()) { throw new InvalidFeatureException(node); } return node; } return null; }