@Override public Map<Symbol, Symbol> visitAggregation(AggregationNode node, Set<Symbol> lookupSymbols) { Set<Symbol> groupByLookupSymbols = lookupSymbols .stream() .filter(node.getGroupingKeys()::contains) .collect(toImmutableSet()); checkState( !groupByLookupSymbols.isEmpty(), "No lookup symbols were able to pass through the aggregation group by"); return node.getSource().accept(this, groupByLookupSymbols); }
@Override public Void visitAggregation(AggregationNode node, Void context) { StringBuilder builder = new StringBuilder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue())); } printNode( node, format("Aggregate[%s]", node.getStep()), builder.toString(), NODE_COLORS.get(NodeType.AGGREGATE)); return node.getSource().accept(this, context); }
@Override public ActualProperties visitAggregation( AggregationNode node, List<ActualProperties> inputProperties) { ActualProperties properties = Iterables.getOnlyElement(inputProperties); ActualProperties translated = properties.translate( symbol -> node.getGroupingKeys().contains(symbol) ? Optional.of(symbol) : Optional.<Symbol>empty()); return ActualProperties.builderFrom(translated) .local(LocalProperties.grouped(node.getGroupingKeys())) .build(); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Context> context) { // Lookup symbols can only be passed through if they are part of the group by columns Set<Symbol> groupByLookupSymbols = context .get() .getLookupSymbols() .stream() .filter(node.getGroupingKeys()::contains) .collect(toImmutableSet()); if (groupByLookupSymbols.isEmpty()) { return node; } return context.defaultRewrite( node, new Context(groupByLookupSymbols, context.get().getSuccess())); }
@Override public SubPlanBuilder visitAggregation(AggregationNode node, Void context) { SubPlanBuilder current = node.getSource().accept(this, context); if (!current.isDistributed()) { // add the aggregation node as the root of the current fragment current.setRoot( new AggregationNode( node.getId(), current.getRoot(), node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), SINGLE, node.getSampleWeight(), node.getConfidence())); return current; } Map<Symbol, FunctionCall> aggregations = node.getAggregations(); Map<Symbol, Signature> functions = node.getFunctions(); Map<Symbol, Symbol> masks = node.getMasks(); List<Symbol> groupBy = node.getGroupBy(); boolean decomposable = true; for (Signature function : functions.values()) { if (!metadata.getFunction(function).getAggregationFunction().isDecomposable()) { decomposable = false; break; } } // else, we need to "close" the current fragment and create an unpartitioned fragment for the // final aggregation if (decomposable) { return addDistributedAggregation( current, aggregations, functions, masks, groupBy, node.getSampleWeight(), node.getConfidence()); } return addSingleNodeAggregation( current, aggregations, functions, masks, groupBy, node.getSampleWeight(), node.getConfidence()); }
private SubPlanBuilder addDistributedAggregation( SubPlanBuilder plan, Map<Symbol, FunctionCall> aggregations, Map<Symbol, Signature> functions, Map<Symbol, Symbol> masks, List<Symbol> groupBy, Optional<Symbol> sampleWeight, double confidence) { Map<Symbol, FunctionCall> finalCalls = new HashMap<>(); Map<Symbol, FunctionCall> intermediateCalls = new HashMap<>(); Map<Symbol, Signature> intermediateFunctions = new HashMap<>(); Map<Symbol, Symbol> intermediateMask = new HashMap<>(); for (Map.Entry<Symbol, FunctionCall> entry : aggregations.entrySet()) { Signature signature = functions.get(entry.getKey()); FunctionInfo function = metadata.getFunction(signature); Symbol intermediateSymbol = allocator.newSymbol(function.getName().getSuffix(), function.getIntermediateType()); intermediateCalls.put(intermediateSymbol, entry.getValue()); intermediateFunctions.put(intermediateSymbol, signature); if (masks.containsKey(entry.getKey())) { intermediateMask.put(intermediateSymbol, masks.get(entry.getKey())); } // rewrite final aggregation in terms of intermediate function finalCalls.put( entry.getKey(), new FunctionCall( function.getName(), ImmutableList.<Expression>of( new QualifiedNameReference(intermediateSymbol.toQualifiedName())))); } // create partial aggregation plan AggregationNode partialAggregation = new AggregationNode( idAllocator.getNextId(), plan.getRoot(), groupBy, intermediateCalls, intermediateFunctions, intermediateMask, PARTIAL, sampleWeight, confidence); plan.setRoot( new SinkNode( idAllocator.getNextId(), partialAggregation, partialAggregation.getOutputSymbols())); // create final aggregation plan ExchangeNode source = new ExchangeNode( idAllocator.getNextId(), plan.getId(), plan.getRoot().getOutputSymbols()); AggregationNode finalAggregation = new AggregationNode( idAllocator.getNextId(), source, groupBy, finalCalls, functions, ImmutableMap.<Symbol, Symbol>of(), FINAL, Optional.<Symbol>absent(), confidence); if (groupBy.isEmpty()) { plan = createSingleNodePlan(finalAggregation).addChild(plan.build()); } else { plan.setHashOutputPartitioning(groupBy); plan = createFixedDistributionPlan(finalAggregation).addChild(plan.build()); } return plan; }
@Override public Expression visitAggregation(AggregationNode node, Void context) { Expression underlyingPredicate = node.getSource().accept(this, context); return pullExpressionThroughSymbols(underlyingPredicate, node.getGroupBy()); }
@Override public PlanNode rewriteAggregation( AggregationNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { EqualityInference equalityInference = createEqualityInference(inheritedPredicate); List<Expression> pushdownConjuncts = new ArrayList<>(); List<Expression> postAggregationConjuncts = new ArrayList<>(); // Strip out non-deterministic conjuncts postAggregationConjuncts.addAll( ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(deterministic())))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getGroupBy())); if (rewrittenConjunct != null) { pushdownConjuncts.add(rewrittenConjunct); } else { postAggregationConjuncts.add(conjunct); } } // Add the equality predicates back in EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupBy())); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = planRewriter.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource()) { output = new AggregationNode( node.getId(), rewrittenSource, node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence()); } if (!postAggregationConjuncts.isEmpty()) { output = new FilterNode( idAllocator.getNextId(), output, combineConjuncts(postAggregationConjuncts)); } return output; }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder().addAll(node.getGroupBy()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { FunctionCall call = entry.getValue(); expectedInputs.addAll(DependencyExtractor.extractUnique(call)); if (node.getMasks().containsKey(symbol)) { expectedInputs.add(node.getMasks().get(symbol)); masks.put(symbol, node.getMasks().get(symbol)); } functionCalls.put(symbol, call); functions.put(symbol, node.getFunctions().get(symbol)); } } if (node.getSampleWeight().isPresent()) { expectedInputs.add(node.getSampleWeight().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode( node.getId(), source, node.getGroupBy(), functionCalls.build(), functions.build(), masks.build(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }