@Override public PlanNode rewriteMarkDistinct( MarkDistinctNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { checkState( !DependencyExtractor.extractUnique(inheritedPredicate).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); return planRewriter.defaultRewrite(node, inheritedPredicate); }
@Override public PlanNode visitFilter(FilterNode node, RewriteContext<Set<Symbol>> context) { Set<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(DependencyExtractor.extractUnique(node.getPredicate())) .addAll(context.get()) .build(); PlanNode source = context.rewrite(node.getSource(), expectedInputs); return new FilterNode(node.getId(), source, node.getPredicate()); }
@Override public PlanNode visitWindow(WindowNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(context.get()) .addAll(node.getPartitionBy()) .addAll(node.getOrderBy()); if (node.getFrame().getStartValue().isPresent()) { expectedInputs.add(node.getFrame().getStartValue().get()); } if (node.getFrame().getEndValue().isPresent()) { expectedInputs.add(node.getFrame().getEndValue().get()); } if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getWindowFunctions().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { FunctionCall call = entry.getValue(); expectedInputs.addAll(DependencyExtractor.extractUnique(call)); functionCalls.put(symbol, call); functions.put(symbol, node.getSignatures().get(symbol)); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new WindowNode( node.getId(), source, node.getPartitionBy(), node.getOrderBy(), node.getOrderings(), node.getFrame(), functionCalls.build(), functions.build(), node.getHashSymbol(), node.getPrePartitionedInputs(), node.getPreSortedOrderPrefix()); }
// TODO: temporary addition to infer result type from expression. fix this with the new planner // refactoring (martint) private Type extractType(Expression expression) { ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(new Analysis(), session, metadata, experimentalSyntaxEnabled); List<Field> fields = IterableTransformer.<Symbol>on(DependencyExtractor.extractUnique(expression)) .transform( new Function<Symbol, Field>() { @Override public Field apply(Symbol symbol) { return Field.newUnqualified( symbol.getName(), symbolAllocator.getTypes().get(symbol)); } }) .list(); return expressionAnalyzer.analyze( expression, new TupleDescriptor(fields), new AnalysisContext()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder().addAll(node.getGroupBy()); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } ImmutableMap.Builder<Symbol, Signature> functions = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); if (context.get().contains(symbol)) { FunctionCall call = entry.getValue(); expectedInputs.addAll(DependencyExtractor.extractUnique(call)); if (node.getMasks().containsKey(symbol)) { expectedInputs.add(node.getMasks().get(symbol)); masks.put(symbol, node.getMasks().get(symbol)); } functionCalls.put(symbol, call); functions.put(symbol, node.getFunctions().get(symbol)); } } if (node.getSampleWeight().isPresent()) { expectedInputs.add(node.getSampleWeight().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new AggregationNode( node.getId(), source, node.getGroupBy(), functionCalls.build(), functions.build(), masks.build(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
@Override public PlanNode visitProject(ProjectNode node, RewriteContext<Set<Symbol>> context) { ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.builder(); ImmutableMap.Builder<Symbol, Expression> builder = ImmutableMap.builder(); for (int i = 0; i < node.getOutputSymbols().size(); i++) { Symbol output = node.getOutputSymbols().get(i); Expression expression = node.getAssignments().get(output); if (context.get().contains(output)) { expectedInputs.addAll(DependencyExtractor.extractUnique(expression)); builder.put(output, expression); } } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new ProjectNode(node.getId(), source, builder.build()); }
public static ExpressionAnalysis analyzeExpressionsWithSymbols( Session session, Metadata metadata, SqlParser sqlParser, Map<Symbol, Type> types, Iterable<? extends Expression> expressions) { List<Field> fields = DependencyExtractor.extractUnique(expressions) .stream() .map( symbol -> { Type type = types.get(symbol); checkArgument(type != null, "No type for symbol %s", symbol); return Field.newUnqualified(symbol.getName(), type); }) .collect(toImmutableList()); return analyzeExpressions( session, metadata, sqlParser, new TupleDescriptor(fields), expressions); }
private InnerJoinPushDownResult processInnerJoin( Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection<Symbol> leftSymbols) { checkArgument( Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); checkArgument( Iterables.all( DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder(); // Strip out non-deterministic conjuncts joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic()))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(deterministic()))); joinPredicate = stripNonDeterministicConjuncts(joinPredicate); leftEffectivePredicate = stripNonDeterministicConjuncts(leftEffectivePredicate); rightEffectivePredicate = stripNonDeterministicConjuncts(rightEffectivePredicate); // Generate equality inferences EqualityInference allInference = createEqualityInference( inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate); EqualityInference allInferenceWithoutLeftInferred = createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate); EqualityInference allInferenceWithoutRightInferred = createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate); // Sort through conjuncts in inheritedPredicate that were not used for inference for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (leftRewrittenConjunct != null) { leftPushDownConjuncts.add(leftRewrittenConjunct); } Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rightRewrittenConjunct != null) { rightPushDownConjuncts.add(rightRewrittenConjunct); } // Drop predicate after join only if unable to push down to either side if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { joinConjuncts.add(conjunct); } } // See if we can push the right effective predicate to the left side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) { Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (rewritten != null) { leftPushDownConjuncts.add(rewritten); } } // See if we can push the left effective predicate to the right side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) { Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rewritten != null) { rightPushDownConjuncts.add(rewritten); } } // See if we can push any parts of the join predicates to either side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) { Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (leftRewritten != null) { leftPushDownConjuncts.add(leftRewritten); } Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rightRewritten != null) { rightPushDownConjuncts.add(rightRewritten); } if (leftRewritten == null && rightRewritten == null) { joinConjuncts.add(conjunct); } } // Add equalities from the inference back in leftPushDownConjuncts.addAll( allInferenceWithoutLeftInferred .generateEqualitiesPartitionedBy(in(leftSymbols)) .getScopeEqualities()); rightPushDownConjuncts.addAll( allInferenceWithoutRightInferred .generateEqualitiesPartitionedBy(not(in(leftSymbols))) .getScopeEqualities()); joinConjuncts.addAll( allInference .generateEqualitiesPartitionedBy(in(leftSymbols)) .getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as // part of the join predicate // Since we only currently support equality in join conjuncts, factor out the non-equality // conjuncts to a post-join filter List<Expression> joinConjunctsList = joinConjuncts.build(); List<Expression> postJoinConjuncts = ImmutableList.copyOf(filter(joinConjunctsList, not(joinEqualityExpression(leftSymbols)))); joinConjunctsList = ImmutableList.copyOf(filter(joinConjunctsList, joinEqualityExpression(leftSymbols))); return new InnerJoinPushDownResult( combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjunctsList), combineConjuncts(postJoinConjuncts)); }
private OuterJoinPushDownResult processOuterJoin( Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection<Symbol> outerSymbols) { checkArgument( Iterables.all( DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); checkArgument( Iterables.all( DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); ImmutableList.Builder<Expression> outerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> innerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> postJoinConjuncts = ImmutableList.builder(); // Strip out non-deterministic conjuncts postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic()))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); outerEffectivePredicate = stripNonDeterministicConjuncts(outerEffectivePredicate); innerEffectivePredicate = stripNonDeterministicConjuncts(innerEffectivePredicate); joinPredicate = stripNonDeterministicConjuncts(joinPredicate); // Generate equality inferences EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); EqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate); EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols)); Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); EqualityInference potentialNullSymbolInference = createEqualityInference( outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = createEqualityInference( outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); // Sort through conjuncts in inheritedPredicate that were not used for inference for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerSymbols)); if (outerRewritten != null) { outerPushdownConjuncts.add(outerRewritten); // A conjunct can only be pushed down into an inner side if it can be rewritten in terms // of the outer side Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerSymbols))); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } } else { postJoinConjuncts.add(conjunct); } } // See if we can push down any outer or join predicates to the inner side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(and(outerEffectivePredicate, joinPredicate))) { Expression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols))); if (rewritten != null) { innerPushdownConjuncts.add(rewritten); } } // TODO: consider adding join predicate optimizations to outer joins // Add the equalities from the inferences back in outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); innerPushdownConjuncts.addAll( potentialNullSymbolInferenceWithoutInnerInferred .generateEqualitiesPartitionedBy(not(in(outerSymbols))) .getScopeEqualities()); return new OuterJoinPushDownResult( combineConjuncts(outerPushdownConjuncts.build()), combineConjuncts(innerPushdownConjuncts.build()), combineConjuncts(postJoinConjuncts.build())); }
@Override public PlanNode rewriteJoin( JoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { boolean isCrossJoin = (node.getType() == JoinNode.Type.CROSS); // See if we can rewrite outer joins in terms of a plain inner join node = tryNormalizeToInnerJoin(node, inheritedPredicate); Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft()); Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight()); Expression joinPredicate = extractJoinPredicate(node); Expression leftPredicate; Expression rightPredicate; Expression postJoinPredicate; Expression newJoinPredicate; switch (node.getType()) { case INNER: InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin( inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, node.getLeft().getOutputSymbols()); leftPredicate = innerJoinPushDownResult.getLeftPredicate(); rightPredicate = innerJoinPushDownResult.getRightPredicate(); postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = innerJoinPushDownResult.getJoinPredicate(); break; case LEFT: OuterJoinPushDownResult leftOuterJoinPushDownResult = processOuterJoin( inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, node.getLeft().getOutputSymbols()); leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate(); rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate(); postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = joinPredicate; // Use the same as the original break; case RIGHT: OuterJoinPushDownResult rightOuterJoinPushDownResult = processOuterJoin( inheritedPredicate, rightEffectivePredicate, leftEffectivePredicate, joinPredicate, node.getRight().getOutputSymbols()); leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate(); rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate(); postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = joinPredicate; // Use the same as the original break; default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } PlanNode leftSource = planRewriter.rewrite(node.getLeft(), leftPredicate); PlanNode rightSource = planRewriter.rewrite(node.getRight(), rightPredicate); PlanNode output = node; if (leftSource != node.getLeft() || rightSource != node.getRight() || !newJoinPredicate.equals(joinPredicate)) { List<JoinNode.EquiJoinClause> criteria = node.getCriteria(); // Rewrite criteria and add projections if there is a new join predicate if (!newJoinPredicate.equals(joinPredicate) || isCrossJoin) { // Create identity projections for all existing symbols ImmutableMap.Builder<Symbol, Expression> leftProjections = ImmutableMap.builder(); leftProjections.putAll( IterableTransformer.<Symbol>on(node.getLeft().getOutputSymbols()) .toMap(symbolToQualifiedNameReference()) .map()); ImmutableMap.Builder<Symbol, Expression> rightProjections = ImmutableMap.builder(); rightProjections.putAll( IterableTransformer.<Symbol>on(node.getRight().getOutputSymbols()) .toMap(symbolToQualifiedNameReference()) .map()); // HACK! we don't support cross joins right now, so put in a simple fake join predicate // instead if all of the join clauses got simplified out // TODO: remove this code when cross join support is added Iterable<Expression> simplifiedJoinConjuncts = transform(extractConjuncts(newJoinPredicate), simplifyExpressions()); simplifiedJoinConjuncts = filter( simplifiedJoinConjuncts, not(Predicates.<Expression>equalTo(BooleanLiteral.TRUE_LITERAL))); if (Iterables.isEmpty(simplifiedJoinConjuncts)) { simplifiedJoinConjuncts = ImmutableList.<Expression>of( new ComparisonExpression( ComparisonExpression.Type.EQUAL, new LongLiteral("0"), new LongLiteral("0"))); } // Create new projections for the new join clauses ImmutableList.Builder<JoinNode.EquiJoinClause> builder = ImmutableList.builder(); for (Expression conjunct : simplifiedJoinConjuncts) { checkState( joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct), "Expected join predicate to be a valid join equality"); ComparisonExpression equality = (ComparisonExpression) conjunct; boolean alignedComparison = Iterables.all( DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight(); Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft(); Symbol leftSymbol = symbolAllocator.newSymbol(leftExpression, extractType(leftExpression)); leftProjections.put(leftSymbol, leftExpression); Symbol rightSymbol = symbolAllocator.newSymbol(rightExpression, extractType(rightExpression)); rightProjections.put(rightSymbol, rightExpression); builder.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol)); } leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); criteria = builder.build(); } output = new JoinNode(node.getId(), node.getType(), leftSource, rightSource, criteria); } if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate); } return output; }