예제 #1
0
 @Override
 public PlanNode rewriteMarkDistinct(
     MarkDistinctNode node,
     Expression inheritedPredicate,
     PlanRewriter<Expression> planRewriter) {
   checkState(
       !DependencyExtractor.extractUnique(inheritedPredicate).contains(node.getMarkerSymbol()),
       "predicate depends on marker symbol");
   return planRewriter.defaultRewrite(node, inheritedPredicate);
 }
예제 #2
0
    @Override
    public PlanNode visitFilter(FilterNode node, RewriteContext<Set<Symbol>> context) {
      Set<Symbol> expectedInputs =
          ImmutableSet.<Symbol>builder()
              .addAll(DependencyExtractor.extractUnique(node.getPredicate()))
              .addAll(context.get())
              .build();

      PlanNode source = context.rewrite(node.getSource(), expectedInputs);

      return new FilterNode(node.getId(), source, node.getPredicate());
    }
예제 #3
0
    @Override
    public PlanNode visitWindow(WindowNode node, RewriteContext<Set<Symbol>> context) {
      ImmutableSet.Builder<Symbol> expectedInputs =
          ImmutableSet.<Symbol>builder()
              .addAll(context.get())
              .addAll(node.getPartitionBy())
              .addAll(node.getOrderBy());

      if (node.getFrame().getStartValue().isPresent()) {
        expectedInputs.add(node.getFrame().getStartValue().get());
      }
      if (node.getFrame().getEndValue().isPresent()) {
        expectedInputs.add(node.getFrame().getEndValue().get());
      }

      if (node.getHashSymbol().isPresent()) {
        expectedInputs.add(node.getHashSymbol().get());
      }

      ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder();
      for (Map.Entry<Symbol, FunctionCall> entry : node.getWindowFunctions().entrySet()) {
        Symbol symbol = entry.getKey();

        if (context.get().contains(symbol)) {
          FunctionCall call = entry.getValue();
          expectedInputs.addAll(DependencyExtractor.extractUnique(call));

          functionCalls.put(symbol, call);
          functions.put(symbol, node.getSignatures().get(symbol));
        }
      }

      PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

      return new WindowNode(
          node.getId(),
          source,
          node.getPartitionBy(),
          node.getOrderBy(),
          node.getOrderings(),
          node.getFrame(),
          functionCalls.build(),
          functions.build(),
          node.getHashSymbol(),
          node.getPrePartitionedInputs(),
          node.getPreSortedOrderPrefix());
    }
예제 #4
0
 // TODO: temporary addition to infer result type from expression. fix this with the new planner
 // refactoring (martint)
 private Type extractType(Expression expression) {
   ExpressionAnalyzer expressionAnalyzer =
       new ExpressionAnalyzer(new Analysis(), session, metadata, experimentalSyntaxEnabled);
   List<Field> fields =
       IterableTransformer.<Symbol>on(DependencyExtractor.extractUnique(expression))
           .transform(
               new Function<Symbol, Field>() {
                 @Override
                 public Field apply(Symbol symbol) {
                   return Field.newUnqualified(
                       symbol.getName(), symbolAllocator.getTypes().get(symbol));
                 }
               })
           .list();
   return expressionAnalyzer.analyze(
       expression, new TupleDescriptor(fields), new AnalysisContext());
 }
예제 #5
0
    @Override
    public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) {
      ImmutableSet.Builder<Symbol> expectedInputs =
          ImmutableSet.<Symbol>builder().addAll(node.getGroupBy());
      if (node.getHashSymbol().isPresent()) {
        expectedInputs.add(node.getHashSymbol().get());
      }

      ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder();
      ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder();
      for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) {
        Symbol symbol = entry.getKey();

        if (context.get().contains(symbol)) {
          FunctionCall call = entry.getValue();
          expectedInputs.addAll(DependencyExtractor.extractUnique(call));
          if (node.getMasks().containsKey(symbol)) {
            expectedInputs.add(node.getMasks().get(symbol));
            masks.put(symbol, node.getMasks().get(symbol));
          }

          functionCalls.put(symbol, call);
          functions.put(symbol, node.getFunctions().get(symbol));
        }
      }
      if (node.getSampleWeight().isPresent()) {
        expectedInputs.add(node.getSampleWeight().get());
      }

      PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

      return new AggregationNode(
          node.getId(),
          source,
          node.getGroupBy(),
          functionCalls.build(),
          functions.build(),
          masks.build(),
          node.getStep(),
          node.getSampleWeight(),
          node.getConfidence(),
          node.getHashSymbol());
    }
예제 #6
0
    @Override
    public PlanNode visitProject(ProjectNode node, RewriteContext<Set<Symbol>> context) {
      ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.builder();

      ImmutableMap.Builder<Symbol, Expression> builder = ImmutableMap.builder();
      for (int i = 0; i < node.getOutputSymbols().size(); i++) {
        Symbol output = node.getOutputSymbols().get(i);
        Expression expression = node.getAssignments().get(output);

        if (context.get().contains(output)) {
          expectedInputs.addAll(DependencyExtractor.extractUnique(expression));
          builder.put(output, expression);
        }
      }

      PlanNode source = context.rewrite(node.getSource(), expectedInputs.build());

      return new ProjectNode(node.getId(), source, builder.build());
    }
예제 #7
0
  public static ExpressionAnalysis analyzeExpressionsWithSymbols(
      Session session,
      Metadata metadata,
      SqlParser sqlParser,
      Map<Symbol, Type> types,
      Iterable<? extends Expression> expressions) {
    List<Field> fields =
        DependencyExtractor.extractUnique(expressions)
            .stream()
            .map(
                symbol -> {
                  Type type = types.get(symbol);
                  checkArgument(type != null, "No type for symbol %s", symbol);
                  return Field.newUnqualified(symbol.getName(), type);
                })
            .collect(toImmutableList());

    return analyzeExpressions(
        session, metadata, sqlParser, new TupleDescriptor(fields), expressions);
  }
예제 #8
0
    private InnerJoinPushDownResult processInnerJoin(
        Expression inheritedPredicate,
        Expression leftEffectivePredicate,
        Expression rightEffectivePredicate,
        Expression joinPredicate,
        Collection<Symbol> leftSymbols) {
      checkArgument(
          Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)),
          "leftEffectivePredicate must only contain symbols from leftSymbols");
      checkArgument(
          Iterables.all(
              DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))),
          "rightEffectivePredicate must not contain symbols from leftSymbols");

      ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder();

      // Strip out non-deterministic conjuncts
      joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic())));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(deterministic())));
      joinPredicate = stripNonDeterministicConjuncts(joinPredicate);

      leftEffectivePredicate = stripNonDeterministicConjuncts(leftEffectivePredicate);
      rightEffectivePredicate = stripNonDeterministicConjuncts(rightEffectivePredicate);

      // Generate equality inferences
      EqualityInference allInference =
          createEqualityInference(
              inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate);
      EqualityInference allInferenceWithoutLeftInferred =
          createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate);
      EqualityInference allInferenceWithoutRightInferred =
          createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate);

      // Sort through conjuncts in inheritedPredicate that were not used for inference
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression leftRewrittenConjunct =
            allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (leftRewrittenConjunct != null) {
          leftPushDownConjuncts.add(leftRewrittenConjunct);
        }

        Expression rightRewrittenConjunct =
            allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rightRewrittenConjunct != null) {
          rightPushDownConjuncts.add(rightRewrittenConjunct);
        }

        // Drop predicate after join only if unable to push down to either side
        if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) {
          joinConjuncts.add(conjunct);
        }
      }

      // See if we can push the right effective predicate to the left side
      for (Expression conjunct :
          EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) {
        Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (rewritten != null) {
          leftPushDownConjuncts.add(rewritten);
        }
      }

      // See if we can push the left effective predicate to the right side
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) {
        Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rewritten != null) {
          rightPushDownConjuncts.add(rewritten);
        }
      }

      // See if we can push any parts of the join predicates to either side
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) {
        Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (leftRewritten != null) {
          leftPushDownConjuncts.add(leftRewritten);
        }

        Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rightRewritten != null) {
          rightPushDownConjuncts.add(rightRewritten);
        }

        if (leftRewritten == null && rightRewritten == null) {
          joinConjuncts.add(conjunct);
        }
      }

      // Add equalities from the inference back in
      leftPushDownConjuncts.addAll(
          allInferenceWithoutLeftInferred
              .generateEqualitiesPartitionedBy(in(leftSymbols))
              .getScopeEqualities());
      rightPushDownConjuncts.addAll(
          allInferenceWithoutRightInferred
              .generateEqualitiesPartitionedBy(not(in(leftSymbols)))
              .getScopeEqualities());
      joinConjuncts.addAll(
          allInference
              .generateEqualitiesPartitionedBy(in(leftSymbols))
              .getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as
      // part of the join predicate

      // Since we only currently support equality in join conjuncts, factor out the non-equality
      // conjuncts to a post-join filter
      List<Expression> joinConjunctsList = joinConjuncts.build();
      List<Expression> postJoinConjuncts =
          ImmutableList.copyOf(filter(joinConjunctsList, not(joinEqualityExpression(leftSymbols))));
      joinConjunctsList =
          ImmutableList.copyOf(filter(joinConjunctsList, joinEqualityExpression(leftSymbols)));

      return new InnerJoinPushDownResult(
          combineConjuncts(leftPushDownConjuncts.build()),
          combineConjuncts(rightPushDownConjuncts.build()),
          combineConjuncts(joinConjunctsList),
          combineConjuncts(postJoinConjuncts));
    }
예제 #9
0
    private OuterJoinPushDownResult processOuterJoin(
        Expression inheritedPredicate,
        Expression outerEffectivePredicate,
        Expression innerEffectivePredicate,
        Expression joinPredicate,
        Collection<Symbol> outerSymbols) {
      checkArgument(
          Iterables.all(
              DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)),
          "outerEffectivePredicate must only contain symbols from outerSymbols");
      checkArgument(
          Iterables.all(
              DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))),
          "innerEffectivePredicate must not contain symbols from outerSymbols");

      ImmutableList.Builder<Expression> outerPushdownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> innerPushdownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> postJoinConjuncts = ImmutableList.builder();

      // Strip out non-deterministic conjuncts
      postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic())));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      outerEffectivePredicate = stripNonDeterministicConjuncts(outerEffectivePredicate);
      innerEffectivePredicate = stripNonDeterministicConjuncts(innerEffectivePredicate);
      joinPredicate = stripNonDeterministicConjuncts(joinPredicate);

      // Generate equality inferences
      EqualityInference inheritedInference = createEqualityInference(inheritedPredicate);
      EqualityInference outerInference =
          createEqualityInference(inheritedPredicate, outerEffectivePredicate);

      EqualityInference.EqualityPartition equalityPartition =
          inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols));
      Expression outerOnlyInheritedEqualities =
          combineConjuncts(equalityPartition.getScopeEqualities());
      EqualityInference potentialNullSymbolInference =
          createEqualityInference(
              outerOnlyInheritedEqualities,
              outerEffectivePredicate,
              innerEffectivePredicate,
              joinPredicate);
      EqualityInference potentialNullSymbolInferenceWithoutInnerInferred =
          createEqualityInference(
              outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate);

      // Sort through conjuncts in inheritedPredicate that were not used for inference
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerSymbols));
        if (outerRewritten != null) {
          outerPushdownConjuncts.add(outerRewritten);

          // A conjunct can only be pushed down into an inner side if it can be rewritten in terms
          // of the outer side
          Expression innerRewritten =
              potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerSymbols)));
          if (innerRewritten != null) {
            innerPushdownConjuncts.add(innerRewritten);
          }
        } else {
          postJoinConjuncts.add(conjunct);
        }
      }

      // See if we can push down any outer or join predicates to the inner side
      for (Expression conjunct :
          EqualityInference.nonInferrableConjuncts(and(outerEffectivePredicate, joinPredicate))) {
        Expression rewritten =
            potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols)));
        if (rewritten != null) {
          innerPushdownConjuncts.add(rewritten);
        }
      }

      // TODO: consider adding join predicate optimizations to outer joins

      // Add the equalities from the inferences back in
      outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
      innerPushdownConjuncts.addAll(
          potentialNullSymbolInferenceWithoutInnerInferred
              .generateEqualitiesPartitionedBy(not(in(outerSymbols)))
              .getScopeEqualities());

      return new OuterJoinPushDownResult(
          combineConjuncts(outerPushdownConjuncts.build()),
          combineConjuncts(innerPushdownConjuncts.build()),
          combineConjuncts(postJoinConjuncts.build()));
    }
예제 #10
0
    @Override
    public PlanNode rewriteJoin(
        JoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      boolean isCrossJoin = (node.getType() == JoinNode.Type.CROSS);

      // See if we can rewrite outer joins in terms of a plain inner join
      node = tryNormalizeToInnerJoin(node, inheritedPredicate);

      Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft());
      Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight());
      Expression joinPredicate = extractJoinPredicate(node);

      Expression leftPredicate;
      Expression rightPredicate;
      Expression postJoinPredicate;
      Expression newJoinPredicate;

      switch (node.getType()) {
        case INNER:
          InnerJoinPushDownResult innerJoinPushDownResult =
              processInnerJoin(
                  inheritedPredicate,
                  leftEffectivePredicate,
                  rightEffectivePredicate,
                  joinPredicate,
                  node.getLeft().getOutputSymbols());
          leftPredicate = innerJoinPushDownResult.getLeftPredicate();
          rightPredicate = innerJoinPushDownResult.getRightPredicate();
          postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
          break;
        case LEFT:
          OuterJoinPushDownResult leftOuterJoinPushDownResult =
              processOuterJoin(
                  inheritedPredicate,
                  leftEffectivePredicate,
                  rightEffectivePredicate,
                  joinPredicate,
                  node.getLeft().getOutputSymbols());
          leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
          rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
          postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = joinPredicate; // Use the same as the original
          break;
        case RIGHT:
          OuterJoinPushDownResult rightOuterJoinPushDownResult =
              processOuterJoin(
                  inheritedPredicate,
                  rightEffectivePredicate,
                  leftEffectivePredicate,
                  joinPredicate,
                  node.getRight().getOutputSymbols());
          leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
          rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
          postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = joinPredicate; // Use the same as the original
          break;
        default:
          throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
      }

      PlanNode leftSource = planRewriter.rewrite(node.getLeft(), leftPredicate);
      PlanNode rightSource = planRewriter.rewrite(node.getRight(), rightPredicate);

      PlanNode output = node;
      if (leftSource != node.getLeft()
          || rightSource != node.getRight()
          || !newJoinPredicate.equals(joinPredicate)) {
        List<JoinNode.EquiJoinClause> criteria = node.getCriteria();

        // Rewrite criteria and add projections if there is a new join predicate

        if (!newJoinPredicate.equals(joinPredicate) || isCrossJoin) {
          // Create identity projections for all existing symbols
          ImmutableMap.Builder<Symbol, Expression> leftProjections = ImmutableMap.builder();
          leftProjections.putAll(
              IterableTransformer.<Symbol>on(node.getLeft().getOutputSymbols())
                  .toMap(symbolToQualifiedNameReference())
                  .map());
          ImmutableMap.Builder<Symbol, Expression> rightProjections = ImmutableMap.builder();
          rightProjections.putAll(
              IterableTransformer.<Symbol>on(node.getRight().getOutputSymbols())
                  .toMap(symbolToQualifiedNameReference())
                  .map());

          // HACK! we don't support cross joins right now, so put in a simple fake join predicate
          // instead if all of the join clauses got simplified out
          // TODO: remove this code when cross join support is added
          Iterable<Expression> simplifiedJoinConjuncts =
              transform(extractConjuncts(newJoinPredicate), simplifyExpressions());
          simplifiedJoinConjuncts =
              filter(
                  simplifiedJoinConjuncts,
                  not(Predicates.<Expression>equalTo(BooleanLiteral.TRUE_LITERAL)));
          if (Iterables.isEmpty(simplifiedJoinConjuncts)) {
            simplifiedJoinConjuncts =
                ImmutableList.<Expression>of(
                    new ComparisonExpression(
                        ComparisonExpression.Type.EQUAL,
                        new LongLiteral("0"),
                        new LongLiteral("0")));
          }

          // Create new projections for the new join clauses
          ImmutableList.Builder<JoinNode.EquiJoinClause> builder = ImmutableList.builder();
          for (Expression conjunct : simplifiedJoinConjuncts) {
            checkState(
                joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct),
                "Expected join predicate to be a valid join equality");

            ComparisonExpression equality = (ComparisonExpression) conjunct;

            boolean alignedComparison =
                Iterables.all(
                    DependencyExtractor.extractUnique(equality.getLeft()),
                    in(node.getLeft().getOutputSymbols()));
            Expression leftExpression =
                (alignedComparison) ? equality.getLeft() : equality.getRight();
            Expression rightExpression =
                (alignedComparison) ? equality.getRight() : equality.getLeft();

            Symbol leftSymbol =
                symbolAllocator.newSymbol(leftExpression, extractType(leftExpression));
            leftProjections.put(leftSymbol, leftExpression);
            Symbol rightSymbol =
                symbolAllocator.newSymbol(rightExpression, extractType(rightExpression));
            rightProjections.put(rightSymbol, rightExpression);

            builder.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
          }

          leftSource =
              new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
          rightSource =
              new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());
          criteria = builder.build();
        }
        output = new JoinNode(node.getId(), node.getType(), leftSource, rightSource, criteria);
      }
      if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
        output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
      }
      return output;
    }