示例#1
0
 @Override
 public Map<Symbol, Symbol> visitAggregation(AggregationNode node, Set<Symbol> lookupSymbols) {
   Set<Symbol> groupByLookupSymbols =
       lookupSymbols
           .stream()
           .filter(node.getGroupingKeys()::contains)
           .collect(toImmutableSet());
   checkState(
       !groupByLookupSymbols.isEmpty(),
       "No lookup symbols were able to pass through the aggregation group by");
   return node.getSource().accept(this, groupByLookupSymbols);
 }
 @Override
 public Void visitAggregation(AggregationNode node, Void context) {
   StringBuilder builder = new StringBuilder();
   for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
     builder.append(format("%s := %s\\n", entry.getKey(), entry.getValue()));
   }
   printNode(
       node,
       format("Aggregate[%s]", node.getStep()),
       builder.toString(),
       NODE_COLORS.get(NodeType.AGGREGATE));
   return node.getSource().accept(this, context);
 }
    @Override
    public ActualProperties visitAggregation(
        AggregationNode node, List<ActualProperties> inputProperties) {
      ActualProperties properties = Iterables.getOnlyElement(inputProperties);

      ActualProperties translated =
          properties.translate(
              symbol ->
                  node.getGroupingKeys().contains(symbol)
                      ? Optional.of(symbol)
                      : Optional.<Symbol>empty());

      return ActualProperties.builderFrom(translated)
          .local(LocalProperties.grouped(node.getGroupingKeys()))
          .build();
    }
示例#4
0
    @Override
    public PlanNode visitAggregation(AggregationNode node, RewriteContext<Context> context) {
      // Lookup symbols can only be passed through if they are part of the group by columns
      Set<Symbol> groupByLookupSymbols =
          context
              .get()
              .getLookupSymbols()
              .stream()
              .filter(node.getGroupingKeys()::contains)
              .collect(toImmutableSet());

      if (groupByLookupSymbols.isEmpty()) {
        return node;
      }

      return context.defaultRewrite(
          node, new Context(groupByLookupSymbols, context.get().getSuccess()));
    }
    @Override
    public SubPlanBuilder visitAggregation(AggregationNode node, Void context) {
      SubPlanBuilder current = node.getSource().accept(this, context);

      if (!current.isDistributed()) {
        // add the aggregation node as the root of the current fragment
        current.setRoot(
            new AggregationNode(
                node.getId(),
                current.getRoot(),
                node.getGroupBy(),
                node.getAggregations(),
                node.getFunctions(),
                node.getMasks(),
                SINGLE,
                node.getSampleWeight(),
                node.getConfidence()));
        return current;
      }

      Map<Symbol, FunctionCall> aggregations = node.getAggregations();
      Map<Symbol, Signature> functions = node.getFunctions();
      Map<Symbol, Symbol> masks = node.getMasks();
      List<Symbol> groupBy = node.getGroupBy();

      boolean decomposable = true;
      for (Signature function : functions.values()) {
        if (!metadata.getFunction(function).getAggregationFunction().isDecomposable()) {
          decomposable = false;
          break;
        }
      }

      // else, we need to "close" the current fragment and create an unpartitioned fragment for the
      // final aggregation
      if (decomposable) {
        return addDistributedAggregation(
            current,
            aggregations,
            functions,
            masks,
            groupBy,
            node.getSampleWeight(),
            node.getConfidence());
      }
      return addSingleNodeAggregation(
          current,
          aggregations,
          functions,
          masks,
          groupBy,
          node.getSampleWeight(),
          node.getConfidence());
    }
    private SubPlanBuilder addDistributedAggregation(
        SubPlanBuilder plan,
        Map<Symbol, FunctionCall> aggregations,
        Map<Symbol, Signature> functions,
        Map<Symbol, Symbol> masks,
        List<Symbol> groupBy,
        Optional<Symbol> sampleWeight,
        double confidence) {
      Map<Symbol, FunctionCall> finalCalls = new HashMap<>();
      Map<Symbol, FunctionCall> intermediateCalls = new HashMap<>();
      Map<Symbol, Signature> intermediateFunctions = new HashMap<>();
      Map<Symbol, Symbol> intermediateMask = new HashMap<>();
      for (Map.Entry<Symbol, FunctionCall> entry : aggregations.entrySet()) {
        Signature signature = functions.get(entry.getKey());
        FunctionInfo function = metadata.getFunction(signature);

        Symbol intermediateSymbol =
            allocator.newSymbol(function.getName().getSuffix(), function.getIntermediateType());
        intermediateCalls.put(intermediateSymbol, entry.getValue());
        intermediateFunctions.put(intermediateSymbol, signature);
        if (masks.containsKey(entry.getKey())) {
          intermediateMask.put(intermediateSymbol, masks.get(entry.getKey()));
        }

        // rewrite final aggregation in terms of intermediate function
        finalCalls.put(
            entry.getKey(),
            new FunctionCall(
                function.getName(),
                ImmutableList.<Expression>of(
                    new QualifiedNameReference(intermediateSymbol.toQualifiedName()))));
      }

      // create partial aggregation plan
      AggregationNode partialAggregation =
          new AggregationNode(
              idAllocator.getNextId(),
              plan.getRoot(),
              groupBy,
              intermediateCalls,
              intermediateFunctions,
              intermediateMask,
              PARTIAL,
              sampleWeight,
              confidence);
      plan.setRoot(
          new SinkNode(
              idAllocator.getNextId(), partialAggregation, partialAggregation.getOutputSymbols()));

      // create final aggregation plan
      ExchangeNode source =
          new ExchangeNode(
              idAllocator.getNextId(), plan.getId(), plan.getRoot().getOutputSymbols());
      AggregationNode finalAggregation =
          new AggregationNode(
              idAllocator.getNextId(),
              source,
              groupBy,
              finalCalls,
              functions,
              ImmutableMap.<Symbol, Symbol>of(),
              FINAL,
              Optional.<Symbol>absent(),
              confidence);

      if (groupBy.isEmpty()) {
        plan = createSingleNodePlan(finalAggregation).addChild(plan.build());
      } else {
        plan.setHashOutputPartitioning(groupBy);
        plan = createFixedDistributionPlan(finalAggregation).addChild(plan.build());
      }
      return plan;
    }
  @Override
  public Expression visitAggregation(AggregationNode node, Void context) {
    Expression underlyingPredicate = node.getSource().accept(this, context);

    return pullExpressionThroughSymbols(underlyingPredicate, node.getGroupBy());
  }
示例#8
0
    @Override
    public PlanNode rewriteAggregation(
        AggregationNode node,
        Expression inheritedPredicate,
        PlanRewriter<Expression> planRewriter) {
      EqualityInference equalityInference = createEqualityInference(inheritedPredicate);

      List<Expression> pushdownConjuncts = new ArrayList<>();
      List<Expression> postAggregationConjuncts = new ArrayList<>();

      // Strip out non-deterministic conjuncts
      postAggregationConjuncts.addAll(
          ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(deterministic()))));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      // Sort non-equality predicates by those that can be pushed down and those that cannot
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression rewrittenConjunct =
            equalityInference.rewriteExpression(conjunct, in(node.getGroupBy()));
        if (rewrittenConjunct != null) {
          pushdownConjuncts.add(rewrittenConjunct);
        } else {
          postAggregationConjuncts.add(conjunct);
        }
      }

      // Add the equality predicates back in
      EqualityInference.EqualityPartition equalityPartition =
          equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupBy()));
      pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
      postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

      PlanNode rewrittenSource =
          planRewriter.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts));

      PlanNode output = node;
      if (rewrittenSource != node.getSource()) {
        output =
            new AggregationNode(
                node.getId(),
                rewrittenSource,
                node.getGroupBy(),
                node.getAggregations(),
                node.getFunctions(),
                node.getMasks(),
                node.getStep(),
                node.getSampleWeight(),
                node.getConfidence());
      }
      if (!postAggregationConjuncts.isEmpty()) {
        output =
            new FilterNode(
                idAllocator.getNextId(), output, combineConjuncts(postAggregationConjuncts));
      }
      return output;
    }
    @Override
    public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) {
      ImmutableSet.Builder<Symbol> expectedInputs =
          ImmutableSet.<Symbol>builder().addAll(node.getGroupBy());
      if (node.getHashSymbol().isPresent()) {
        expectedInputs.add(node.getHashSymbol().get());
      }

      ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
      for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
        Symbol symbol = entry.getKey();

        if (context.get().contains(symbol)) {
          FunctionCall call = entry.getValue();
          expectedInputs.addAll(DependencyExtractor.extractUnique(call));
          if (node.getMasks().containsKey(symbol)) {
            expectedInputs.add(node.getMasks().get(symbol));
            masks.put(symbol, node.getMasks().get(symbol));
          }

          functionCalls.put(symbol, call);
          functions.put(symbol, node.getFunctions().get(symbol));
        }
      }
      if (node.getSampleWeight().isPresent()) {
        expectedInputs.add(node.getSampleWeight().get());
      }

      PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

      return new AggregationNode(
          node.getId(),
          source,
          node.getGroupBy(),
          functionCalls.build(),
          functions.build(),
          masks.build(),
          node.getStep(),
          node.getSampleWeight(),
          node.getConfidence(),
          node.getHashSymbol());
    }