public void onMatch(RelOptRuleCall call) { AggregateRel aggRel = call.rel(0); UnionRel unionRel = call.rel(1); if (!unionRel.all) { // This transformation is only valid for UNION ALL. // Consider t1(i) with rows (5), (5) and t2(i) with // rows (5), (10), and the query // select sum(i) from (select i from t1) union (select i from t2). // The correct answer is 15. If we apply the transformation, // we get // select sum(i) from // (select sum(i) as i from t1) union (select sum(i) as i from t2) // which yields 25 (incorrect). return; } RelOptCluster cluster = unionRel.getCluster(); List<AggregateCall> transformedAggCalls = transformAggCalls( aggRel.getCluster().getTypeFactory(), aggRel.getGroupSet().cardinality(), aggRel.getAggCallList()); if (transformedAggCalls == null) { // we've detected the presence of something like AVG, // which we can't handle return; } boolean anyTransformed = false; // create corresponding aggs on top of each union child List<RelNode> newUnionInputs = new ArrayList<RelNode>(); for (RelNode input : unionRel.getInputs()) { boolean alreadyUnique = RelMdUtil.areColumnsDefinitelyUnique(input, aggRel.getGroupSet()); if (alreadyUnique) { newUnionInputs.add(input); } else { anyTransformed = true; newUnionInputs.add( new AggregateRel(cluster, input, aggRel.getGroupSet(), aggRel.getAggCallList())); } } if (!anyTransformed) { // none of the children could benefit from the pushdown, // so bail out (preventing the infinite loop to which most // planners would succumb) return; } // create a new union whose children are the aggs created above UnionRel newUnionRel = new UnionRel(cluster, newUnionInputs, true); AggregateRel newTopAggRel = new AggregateRel(cluster, newUnionRel, aggRel.getGroupSet(), transformedAggCalls); // In case we transformed any COUNT (which is always NOT NULL) // to SUM (which is always NULLABLE), cast back to keep the // planner happy. RelNode castRel = RelOptUtil.createCastRel(newTopAggRel, aggRel.getRowType(), false); call.transformTo(castRel); }
/** Variant of {@link #trimFields(RelNode, BitSet, Set)} for {@link AggregateRel}. */ public TrimResult trimFields( AggregateRel aggregate, BitSet fieldsUsed, Set<RelDataTypeField> extraFields) { // Fields: // // | sys fields | group fields | agg functions | // // Two kinds of trimming: // // 1. If agg rel has system fields but none of these are used, create an // agg rel with no system fields. // // 2. If aggregate functions are not used, remove them. // // But grouping fields stay, even if they are not used. final RelDataType rowType = aggregate.getRowType(); // Compute which input fields are used. BitSet inputFieldsUsed = new BitSet(); // 1. group fields are always used for (int i : Util.toIter(aggregate.getGroupSet())) { inputFieldsUsed.set(i); } // 2. agg functions for (AggregateCall aggCall : aggregate.getAggCallList()) { for (int i : aggCall.getArgList()) { inputFieldsUsed.set(i); } } // Create input with trimmed columns. final RelNode input = aggregate.getInput(0); final Set<RelDataTypeField> inputExtraFields = Collections.emptySet(); final TrimResult trimResult = trimChild(aggregate, input, inputFieldsUsed, inputExtraFields); final RelNode newInput = trimResult.left; final Mapping inputMapping = trimResult.right; // If the input is unchanged, and we need to project all columns, // there's nothing to do. if (input == newInput && fieldsUsed.equals(Util.bitSetBetween(0, rowType.getFieldCount()))) { return new TrimResult(aggregate, Mappings.createIdentity(rowType.getFieldCount())); } // Which agg calls are used by our consumer? final int groupCount = aggregate.getGroupSet().cardinality(); int j = groupCount; int usedAggCallCount = 0; for (int i = 0; i < aggregate.getAggCallList().size(); i++) { if (fieldsUsed.get(j++)) { ++usedAggCallCount; } } // Offset due to the number of system fields having changed. Mapping mapping = Mappings.create( MappingType.InverseSurjection, rowType.getFieldCount(), groupCount + usedAggCallCount); final BitSet newGroupSet = Mappings.apply(inputMapping, aggregate.getGroupSet()); // Populate mapping of where to find the fields. System and grouping // fields first. for (IntPair pair : inputMapping) { if (pair.source < groupCount) { mapping.set(pair.source, pair.target); } } // Now create new agg calls, and populate mapping for them. final List<AggregateCall> newAggCallList = new ArrayList<AggregateCall>(); j = groupCount; for (AggregateCall aggCall : aggregate.getAggCallList()) { if (fieldsUsed.get(j)) { AggregateCall newAggCall = new AggregateCall( aggCall.getAggregation(), aggCall.isDistinct(), Mappings.apply2(inputMapping, aggCall.getArgList()), aggCall.getType(), aggCall.getName()); if (newAggCall.equals(aggCall)) { newAggCall = aggCall; // immutable -> canonize to save space } mapping.set(j, groupCount + newAggCallList.size()); newAggCallList.add(newAggCall); } ++j; } RelNode newAggregate = new AggregateRel(aggregate.getCluster(), newInput, newGroupSet, newAggCallList); assert newAggregate.getClass() == aggregate.getClass(); return new TrimResult(newAggregate, mapping); }