Beispiel #1
0
    @Override
    public PlanNode rewriteSemiJoin(
        SemiJoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      Expression sourceEffectivePredicate = EffectivePredicateExtractor.extract(node.getSource());

      List<Expression> sourceConjuncts = new ArrayList<>();
      List<Expression> filteringSourceConjuncts = new ArrayList<>();
      List<Expression> postJoinConjuncts = new ArrayList<>();

      // TODO: see if there are predicates that can be inferred from the semi join output

      // Push inherited and source predicates to filtering source via a contrived join predicate
      // (but needs to avoid touching NULL values in the filtering source)
      Expression joinPredicate =
          equalsExpression(node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol());
      EqualityInference joinInference =
          createEqualityInference(inheritedPredicate, sourceEffectivePredicate, joinPredicate);
      for (Expression conjunct :
          Iterables.concat(
              EqualityInference.nonInferrableConjuncts(inheritedPredicate),
              EqualityInference.nonInferrableConjuncts(sourceEffectivePredicate))) {
        Expression rewrittenConjunct =
            joinInference.rewriteExpression(conjunct, equalTo(node.getFilteringSourceJoinSymbol()));
        if (rewrittenConjunct != null && DeterminismEvaluator.isDeterministic(rewrittenConjunct)) {
          // Alter conjunct to include an OR filteringSourceJoinSymbol IS NULL disjunct
          Expression rewrittenConjunctOrNull =
              expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol()))
                  .apply(rewrittenConjunct);
          filteringSourceConjuncts.add(rewrittenConjunctOrNull);
        }
      }
      EqualityInference.EqualityPartition joinInferenceEqualityPartition =
          joinInference.generateEqualitiesPartitionedBy(
              equalTo(node.getFilteringSourceJoinSymbol()));
      filteringSourceConjuncts.addAll(
          ImmutableList.copyOf(
              transform(
                  joinInferenceEqualityPartition.getScopeEqualities(),
                  expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol())))));

      // Push inheritedPredicates down to the source if they don't involve the semi join output
      EqualityInference inheritedInference = createEqualityInference(inheritedPredicate);
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression rewrittenConjunct =
            inheritedInference.rewriteExpression(conjunct, in(node.getSource().getOutputSymbols()));
        // Since each source row is reflected exactly once in the output, ok to push
        // non-deterministic predicates down
        if (rewrittenConjunct != null) {
          sourceConjuncts.add(rewrittenConjunct);
        } else {
          postJoinConjuncts.add(conjunct);
        }
      }

      // Add the inherited equality predicates back in
      EqualityInference.EqualityPartition equalityPartition =
          inheritedInference.generateEqualitiesPartitionedBy(
              in(node.getSource().getOutputSymbols()));
      sourceConjuncts.addAll(equalityPartition.getScopeEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

      PlanNode rewrittenSource =
          planRewriter.rewrite(node.getSource(), combineConjuncts(sourceConjuncts));
      PlanNode rewrittenFilteringSource =
          planRewriter.rewrite(
              node.getFilteringSource(), combineConjuncts(filteringSourceConjuncts));

      PlanNode output = node;
      if (rewrittenSource != node.getSource()
          || rewrittenFilteringSource != node.getFilteringSource()) {
        output =
            new SemiJoinNode(
                node.getId(),
                rewrittenSource,
                rewrittenFilteringSource,
                node.getSourceJoinSymbol(),
                node.getFilteringSourceJoinSymbol(),
                node.getSemiJoinOutput());
      }
      if (!postJoinConjuncts.isEmpty()) {
        output =
            new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postJoinConjuncts));
      }
      return output;
    }
Beispiel #2
0
    @Override
    public PlanNode rewriteJoin(
        JoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      boolean isCrossJoin = (node.getType() == JoinNode.Type.CROSS);

      // See if we can rewrite outer joins in terms of a plain inner join
      node = tryNormalizeToInnerJoin(node, inheritedPredicate);

      Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft());
      Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight());
      Expression joinPredicate = extractJoinPredicate(node);

      Expression leftPredicate;
      Expression rightPredicate;
      Expression postJoinPredicate;
      Expression newJoinPredicate;

      switch (node.getType()) {
        case INNER:
          InnerJoinPushDownResult innerJoinPushDownResult =
              processInnerJoin(
                  inheritedPredicate,
                  leftEffectivePredicate,
                  rightEffectivePredicate,
                  joinPredicate,
                  node.getLeft().getOutputSymbols());
          leftPredicate = innerJoinPushDownResult.getLeftPredicate();
          rightPredicate = innerJoinPushDownResult.getRightPredicate();
          postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
          break;
        case LEFT:
          OuterJoinPushDownResult leftOuterJoinPushDownResult =
              processOuterJoin(
                  inheritedPredicate,
                  leftEffectivePredicate,
                  rightEffectivePredicate,
                  joinPredicate,
                  node.getLeft().getOutputSymbols());
          leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
          rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
          postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = joinPredicate; // Use the same as the original
          break;
        case RIGHT:
          OuterJoinPushDownResult rightOuterJoinPushDownResult =
              processOuterJoin(
                  inheritedPredicate,
                  rightEffectivePredicate,
                  leftEffectivePredicate,
                  joinPredicate,
                  node.getRight().getOutputSymbols());
          leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
          rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
          postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = joinPredicate; // Use the same as the original
          break;
        default:
          throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
      }

      PlanNode leftSource = planRewriter.rewrite(node.getLeft(), leftPredicate);
      PlanNode rightSource = planRewriter.rewrite(node.getRight(), rightPredicate);

      PlanNode output = node;
      if (leftSource != node.getLeft()
          || rightSource != node.getRight()
          || !newJoinPredicate.equals(joinPredicate)) {
        List<JoinNode.EquiJoinClause> criteria = node.getCriteria();

        // Rewrite criteria and add projections if there is a new join predicate

        if (!newJoinPredicate.equals(joinPredicate) || isCrossJoin) {
          // Create identity projections for all existing symbols
          ImmutableMap.Builder<Symbol, Expression> leftProjections = ImmutableMap.builder();
          leftProjections.putAll(
              IterableTransformer.<Symbol>on(node.getLeft().getOutputSymbols())
                  .toMap(symbolToQualifiedNameReference())
                  .map());
          ImmutableMap.Builder<Symbol, Expression> rightProjections = ImmutableMap.builder();
          rightProjections.putAll(
              IterableTransformer.<Symbol>on(node.getRight().getOutputSymbols())
                  .toMap(symbolToQualifiedNameReference())
                  .map());

          // HACK! we don't support cross joins right now, so put in a simple fake join predicate
          // instead if all of the join clauses got simplified out
          // TODO: remove this code when cross join support is added
          Iterable<Expression> simplifiedJoinConjuncts =
              transform(extractConjuncts(newJoinPredicate), simplifyExpressions());
          simplifiedJoinConjuncts =
              filter(
                  simplifiedJoinConjuncts,
                  not(Predicates.<Expression>equalTo(BooleanLiteral.TRUE_LITERAL)));
          if (Iterables.isEmpty(simplifiedJoinConjuncts)) {
            simplifiedJoinConjuncts =
                ImmutableList.<Expression>of(
                    new ComparisonExpression(
                        ComparisonExpression.Type.EQUAL,
                        new LongLiteral("0"),
                        new LongLiteral("0")));
          }

          // Create new projections for the new join clauses
          ImmutableList.Builder<JoinNode.EquiJoinClause> builder = ImmutableList.builder();
          for (Expression conjunct : simplifiedJoinConjuncts) {
            checkState(
                joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct),
                "Expected join predicate to be a valid join equality");

            ComparisonExpression equality = (ComparisonExpression) conjunct;

            boolean alignedComparison =
                Iterables.all(
                    DependencyExtractor.extractUnique(equality.getLeft()),
                    in(node.getLeft().getOutputSymbols()));
            Expression leftExpression =
                (alignedComparison) ? equality.getLeft() : equality.getRight();
            Expression rightExpression =
                (alignedComparison) ? equality.getRight() : equality.getLeft();

            Symbol leftSymbol =
                symbolAllocator.newSymbol(leftExpression, extractType(leftExpression));
            leftProjections.put(leftSymbol, leftExpression);
            Symbol rightSymbol =
                symbolAllocator.newSymbol(rightExpression, extractType(rightExpression));
            rightProjections.put(rightSymbol, rightExpression);

            builder.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
          }

          leftSource =
              new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
          rightSource =
              new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());
          criteria = builder.build();
        }
        output = new JoinNode(node.getId(), node.getType(), leftSource, rightSource, criteria);
      }
      if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
        output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
      }
      return output;
    }