示例#1
0
    @Override
    public PlanNode rewriteAggregation(
        AggregationNode node,
        Expression inheritedPredicate,
        PlanRewriter<Expression> planRewriter) {
      EqualityInference equalityInference = createEqualityInference(inheritedPredicate);

      List<Expression> pushdownConjuncts = new ArrayList<>();
      List<Expression> postAggregationConjuncts = new ArrayList<>();

      // Strip out non-deterministic conjuncts
      postAggregationConjuncts.addAll(
          ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(deterministic()))));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      // Sort non-equality predicates by those that can be pushed down and those that cannot
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression rewrittenConjunct =
            equalityInference.rewriteExpression(conjunct, in(node.getGroupBy()));
        if (rewrittenConjunct != null) {
          pushdownConjuncts.add(rewrittenConjunct);
        } else {
          postAggregationConjuncts.add(conjunct);
        }
      }

      // Add the equality predicates back in
      EqualityInference.EqualityPartition equalityPartition =
          equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupBy()));
      pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
      postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

      PlanNode rewrittenSource =
          planRewriter.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts));

      PlanNode output = node;
      if (rewrittenSource != node.getSource()) {
        output =
            new AggregationNode(
                node.getId(),
                rewrittenSource,
                node.getGroupBy(),
                node.getAggregations(),
                node.getFunctions(),
                node.getMasks(),
                node.getStep(),
                node.getSampleWeight(),
                node.getConfidence());
      }
      if (!postAggregationConjuncts.isEmpty()) {
        output =
            new FilterNode(
                idAllocator.getNextId(), output, combineConjuncts(postAggregationConjuncts));
      }
      return output;
    }
 @Override
 public Void visitAggregation(AggregationNode node, Void context) {
   StringBuilder builder = new StringBuilder();
   for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
     builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue()));
   }
   printNode(
       node,
       format("Aggregate[%s]", node.getStep()),
       builder.toString(),
       NODE_COLORS.get(NodeType.AGGREGATE));
   return node.getSource().accept(this, context);
 }
    @Override
    public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) {
      ImmutableSet.Builder<Symbol> expectedInputs =
          ImmutableSet.<Symbol>builder().addAll(node.getGroupBy());
      if (node.getHashSymbol().isPresent()) {
        expectedInputs.add(node.getHashSymbol().get());
      }

      ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
      for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
        Symbol symbol = entry.getKey();

        if (context.get().contains(symbol)) {
          FunctionCall call = entry.getValue();
          expectedInputs.addAll(DependencyExtractor.extractUnique(call));
          if (node.getMasks().containsKey(symbol)) {
            expectedInputs.add(node.getMasks().get(symbol));
            masks.put(symbol, node.getMasks().get(symbol));
          }

          functionCalls.put(symbol, call);
          functions.put(symbol, node.getFunctions().get(symbol));
        }
      }
      if (node.getSampleWeight().isPresent()) {
        expectedInputs.add(node.getSampleWeight().get());
      }

      PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

      return new AggregationNode(
          node.getId(),
          source,
          node.getGroupBy(),
          functionCalls.build(),
          functions.build(),
          masks.build(),
          node.getStep(),
          node.getSampleWeight(),
          node.getConfidence(),
          node.getHashSymbol());
    }