/** * Infers predicates for an Aggregate. * * <p>Pulls up predicates that only contains references to columns in the GroupSet. For e.g. * * <pre> * childPullUpExprs : { a > 7, b + c < 10, a + e = 9} * groupSet : { a, b} * pulledUpExprs : { a > 7} * </pre> */ public RelOptPredicateList getPredicates(Aggregate agg) { RelNode child = agg.getInput(); RelOptPredicateList childInfo = RelMetadataQuery.getPulledUpPredicates(child); List<RexNode> aggPullUpPredicates = new ArrayList<RexNode>(); ImmutableBitSet groupKeys = agg.getGroupSet(); Mapping m = Mappings.create( MappingType.PARTIAL_FUNCTION, child.getRowType().getFieldCount(), agg.getRowType().getFieldCount()); int i = 0; for (int j : groupKeys) { m.set(j, i++); } for (RexNode r : childInfo.pulledUpPredicates) { ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r); if (groupKeys.contains(rCols)) { r = r.accept(new RexPermuteInputsShuttle(m, child)); aggPullUpPredicates.add(r); } } return RelOptPredicateList.of(aggPullUpPredicates); }
public static List<ExprNodeDesc> getExprNodes( List<Integer> inputRefs, RelNode inputRel, String inputTabAlias) { List<ExprNodeDesc> exprNodes = new ArrayList<ExprNodeDesc>(); List<RexNode> rexInputRefs = getInputRef(inputRefs, inputRel); List<RexNode> exprs = inputRel.getChildExps(); // TODO: Change ExprNodeConverter to be independent of Partition Expr ExprNodeConverter exprConv = new ExprNodeConverter( inputTabAlias, inputRel.getRowType(), new HashSet<Integer>(), inputRel.getCluster().getTypeFactory()); for (int index = 0; index < rexInputRefs.size(); index++) { // The following check is only a guard against failures. // TODO: Knowing which expr is constant in GBY's aggregation function // arguments could be better done using Metadata provider of Calcite. if (exprs != null && index < exprs.size() && exprs.get(index) instanceof RexLiteral) { ExprNodeDesc exprNodeDesc = exprConv.visitLiteral((RexLiteral) exprs.get(index)); exprNodes.add(exprNodeDesc); } else { RexNode iRef = rexInputRefs.get(index); exprNodes.add(iRef.accept(exprConv)); } } return exprNodes; }
public static ExprNodeDesc getExprNode( Integer inputRefIndx, RelNode inputRel, ExprNodeConverter exprConv) { ExprNodeDesc exprNode = null; RexNode rexInputRef = new RexInputRef( inputRefIndx, inputRel.getRowType().getFieldList().get(inputRefIndx).getType()); exprNode = rexInputRef.accept(exprConv); return exprNode; }
@Override public String translate(ExprCompiler compiler, RexCall call) { String val = compiler.reserveName(); PrintWriter pw = compiler.pw; RexNode op = call.getOperands().get(0); String lhs = op.accept(compiler); pw.print( String.format( "final %1$s %2$s = (%1$s) %3$s;\n", compiler.javaTypeName(call), val, lhs)); return val; }
public static boolean isDeterministicFuncOnLiterals(RexNode expr) { boolean deterministicFuncOnLiterals = true; RexVisitor<Void> visitor = new RexVisitorImpl<Void>(true) { @Override public Void visitCall(org.apache.calcite.rex.RexCall call) { if (!call.getOperator().isDeterministic()) { throw new Util.FoundOne(call); } return super.visitCall(call); } @Override public Void visitInputRef(RexInputRef inputRef) { throw new Util.FoundOne(inputRef); } @Override public Void visitLocalRef(RexLocalRef localRef) { throw new Util.FoundOne(localRef); } @Override public Void visitOver(RexOver over) { throw new Util.FoundOne(over); } @Override public Void visitDynamicParam(RexDynamicParam dynamicParam) { throw new Util.FoundOne(dynamicParam); } @Override public Void visitRangeRef(RexRangeRef rangeRef) { throw new Util.FoundOne(rangeRef); } @Override public Void visitFieldAccess(RexFieldAccess fieldAccess) { throw new Util.FoundOne(fieldAccess); } }; try { expr.accept(visitor); } catch (Util.FoundOne e) { deterministicFuncOnLiterals = false; } return deterministicFuncOnLiterals; }
/** * Infers predicates for a project. * * <ol> * <li>create a mapping from input to projection. Map only positions that directly reference an * input column. * <li>Expressions that only contain above columns are retained in the Project's pullExpressions * list. * <li>For e.g. expression 'a + e = 9' below will not be pulled up because 'e' is not in the * projection list. * <pre> * childPullUpExprs: {a > 7, b + c < 10, a + e = 9} * projectionExprs: {a, b, c, e / 2} * projectionPullupExprs: {a > 7, b + c < 10} * </pre> * </ol> */ public RelOptPredicateList getPredicates(Project project) { RelNode child = project.getInput(); final RexBuilder rexBuilder = project.getCluster().getRexBuilder(); RelOptPredicateList childInfo = RelMetadataQuery.getPulledUpPredicates(child); List<RexNode> projectPullUpPredicates = new ArrayList<RexNode>(); ImmutableBitSet.Builder columnsMappedBuilder = ImmutableBitSet.builder(); Mapping m = Mappings.create( MappingType.PARTIAL_FUNCTION, child.getRowType().getFieldCount(), project.getRowType().getFieldCount()); for (Ord<RexNode> o : Ord.zip(project.getProjects())) { if (o.e instanceof RexInputRef) { int sIdx = ((RexInputRef) o.e).getIndex(); m.set(sIdx, o.i); columnsMappedBuilder.set(sIdx); } } // Go over childPullUpPredicates. If a predicate only contains columns in // 'columnsMapped' construct a new predicate based on mapping. final ImmutableBitSet columnsMapped = columnsMappedBuilder.build(); for (RexNode r : childInfo.pulledUpPredicates) { ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r); if (columnsMapped.contains(rCols)) { r = r.accept(new RexPermuteInputsShuttle(m, child)); projectPullUpPredicates.add(r); } } // Project can also generate constants. We need to include them. for (Ord<RexNode> expr : Ord.zip(project.getProjects())) { if (RexLiteral.isNullLiteral(expr.e)) { projectPullUpPredicates.add( rexBuilder.makeCall( SqlStdOperatorTable.IS_NULL, rexBuilder.makeInputRef(project, expr.i))); } else if (RexUtil.isConstant(expr.e)) { final List<RexNode> args = ImmutableList.of(rexBuilder.makeInputRef(project, expr.i), expr.e); final SqlOperator op = args.get(0).getType().isNullable() || args.get(1).getType().isNullable() ? SqlStdOperatorTable.IS_NOT_DISTINCT_FROM : SqlStdOperatorTable.EQUALS; projectPullUpPredicates.add(rexBuilder.makeCall(op, args)); } } return RelOptPredicateList.of(projectPullUpPredicates); }
@Override public String translate(ExprCompiler compiler, RexCall call) { String val = compiler.reserveName(); PrintWriter pw = compiler.pw; pw.print(String.format("final %s %s;\n", compiler.javaTypeName(call), val)); RexNode op0 = call.getOperands().get(0); RexNode op1 = call.getOperands().get(1); boolean lhsNullable = op0.getType().isNullable(); boolean rhsNullable = op1.getType().isNullable(); String lhs = op0.accept(compiler); if (!lhsNullable) { pw.print(String.format("if (%2$s) { %1$s = true; }\n", val, lhs)); pw.print("else {\n"); String rhs = op1.accept(compiler); pw.print(String.format(" %1$s = %2$s;\n}\n", val, rhs)); } else { String foldedLHS = foldNullExpr(String.format("%1$s == null || !(%1$s)", lhs), "true", op0); pw.print(String.format("if (%s) {\n", foldedLHS)); String rhs = op1.accept(compiler); String s; if (rhsNullable) { s = foldNullExpr( String.format( "(%2$s != null && %2$s) ? Boolean.TRUE : ((%1$s == null || %2$s == null) ? null : Boolean.FALSE)", lhs, rhs), "null", op1); } else { s = String.format("%2$s ? Boolean.valueOf(%2$s) : %1$s", lhs, rhs); } pw.print(String.format(" %1$s = %2$s;\n", val, s)); pw.print(String.format("} else { %1$s = true; }\n", val)); } return val; }
@Override public String translate(ExprCompiler compiler, RexCall call) { String val = compiler.reserveName(); PrintWriter pw = compiler.pw; RexNode op = call.getOperands().get(0); String lhs = op.accept(compiler); boolean nullable = call.getType().isNullable(); pw.print(String.format("final %s %s;\n", compiler.javaTypeName(call), val)); if (!nullable) { pw.print(String.format("%1$s = !(%2$s);\n", val, lhs)); } else { String s = foldNullExpr(String.format("%1$s == null ? null : !(%1$s)", lhs), "null", op); pw.print(String.format("%1$s = %2$s;\n", val, s)); } return val; }
public static boolean isDeterministic(RexNode expr) { boolean deterministic = true; RexVisitor<Void> visitor = new RexVisitorImpl<Void>(true) { @Override public Void visitCall(org.apache.calcite.rex.RexCall call) { if (!call.getOperator().isDeterministic()) { throw new Util.FoundOne(call); } return super.visitCall(call); } }; try { expr.accept(visitor); } catch (Util.FoundOne e) { deterministic = false; } return deterministic; }
private void infer( RexNode predicates, Set<String> allExprsDigests, List<RexNode> inferedPredicates, boolean includeEqualityInference, ImmutableBitSet inferringFields) { for (RexNode r : RelOptUtil.conjunctions(predicates)) { if (!includeEqualityInference && equalityPredicates.contains(r.toString())) { continue; } for (Mapping m : mappings(r)) { RexNode tr = r.accept(new RexPermuteInputsShuttle(m, joinRel.getInput(0), joinRel.getInput(1))); if (inferringFields.contains(RelOptUtil.InputFinder.bits(tr)) && !allExprsDigests.contains(tr.toString()) && !isAlwaysTrue(tr)) { inferedPredicates.add(tr); allExprsDigests.add(tr.toString()); } } } }
public Double getDistinctRowCount(Union rel, ImmutableBitSet groupKey, RexNode predicate) { Double rowCount = 0.0; int[] adjustments = new int[rel.getRowType().getFieldCount()]; RexBuilder rexBuilder = rel.getCluster().getRexBuilder(); for (RelNode input : rel.getInputs()) { // convert the predicate to reference the types of the union child RexNode modifiedPred; if (predicate == null) { modifiedPred = null; } else { modifiedPred = predicate.accept( new RelOptUtil.RexInputConverter( rexBuilder, null, input.getRowType().getFieldList(), adjustments)); } Double partialRowCount = RelMetadataQuery.getDistinctRowCount(input, groupKey, modifiedPred); if (partialRowCount == null) { return null; } rowCount += partialRowCount; } return rowCount; }
@Override public Prel visitProject(ProjectPrel project, Object unused) throws RelConversionException { // Apply the rule to the child RelNode originalInput = ((Prel) project.getInput(0)).accept(this, null); project = (ProjectPrel) project.copy(project.getTraitSet(), Lists.newArrayList(originalInput)); List<RexNode> exprList = new ArrayList<>(); List<RelDataTypeField> relDataTypes = new ArrayList(); List<RelDataTypeField> origRelDataTypes = new ArrayList(); int i = 0; final int lastColumnReferenced = PrelUtil.getLastUsedColumnReference(project.getProjects()); if (lastColumnReferenced == -1) { return project; } final int lastRexInput = lastColumnReferenced + 1; RexVisitorComplexExprSplitter exprSplitter = new RexVisitorComplexExprSplitter(factory, funcReg, lastRexInput); for (RexNode rex : project.getChildExps()) { origRelDataTypes.add(project.getRowType().getFieldList().get(i)); i++; exprList.add(rex.accept(exprSplitter)); } List<RexNode> complexExprs = exprSplitter.getComplexExprs(); if (complexExprs.size() == 1 && findTopComplexFunc(project.getChildExps()).size() == 1) { return project; } ProjectPrel childProject; List<RexNode> allExprs = new ArrayList(); int exprIndex = 0; List<String> fieldNames = originalInput.getRowType().getFieldNames(); for (int index = 0; index < lastRexInput; index++) { RexBuilder builder = new RexBuilder(factory); allExprs.add( builder.makeInputRef(new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory), index)); if (fieldNames.get(index).contains(StarColumnHelper.STAR_COLUMN)) { relDataTypes.add( new RelDataTypeFieldImpl( fieldNames.get(index), allExprs.size(), factory.createSqlType(SqlTypeName.ANY))); } else { relDataTypes.add( new RelDataTypeFieldImpl( "EXPR$" + exprIndex, allExprs.size(), factory.createSqlType(SqlTypeName.ANY))); exprIndex++; } } RexNode currRexNode; int index = lastRexInput - 1; // if the projection expressions contained complex outputs, split them into their own individual // projects if (complexExprs.size() > 0) { while (complexExprs.size() > 0) { if (index >= lastRexInput) { allExprs.remove(allExprs.size() - 1); RexBuilder builder = new RexBuilder(factory); allExprs.add( builder.makeInputRef( new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory), index)); } index++; exprIndex++; currRexNode = complexExprs.remove(0); allExprs.add(currRexNode); relDataTypes.add( new RelDataTypeFieldImpl( "EXPR$" + exprIndex, allExprs.size(), factory.createSqlType(SqlTypeName.ANY))); childProject = new ProjectPrel( project.getCluster(), project.getTraitSet(), originalInput, ImmutableList.copyOf(allExprs), new RelRecordType(relDataTypes)); originalInput = childProject; } // copied from above, find a better way to do this allExprs.remove(allExprs.size() - 1); RexBuilder builder = new RexBuilder(factory); allExprs.add( builder.makeInputRef(new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory), index)); relDataTypes.add( new RelDataTypeFieldImpl( "EXPR$" + index, allExprs.size(), factory.createSqlType(SqlTypeName.ANY))); } return (Prel) project.copy( project.getTraitSet(), originalInput, exprList, new RelRecordType(origRelDataTypes)); }
/** * The PullUp Strategy is sound but not complete. * * <ol> * <li>We only pullUp inferred predicates for now. Pulling up existing predicates causes an * explosion of duplicates. The existing predicates are pushed back down as new * predicates. Once we have rules to eliminate duplicate Filter conditions, we should * pullUp all predicates. * <li>For Left Outer: we infer new predicates from the left and set them as applicable on the * Right side. No predicates are pulledUp. * <li>Right Outer Joins are handled in an analogous manner. * <li>For Full Outer Joins no predicates are pulledUp or inferred. * </ol> */ public RelOptPredicateList inferPredicates(boolean includeEqualityInference) { List<RexNode> inferredPredicates = new ArrayList<RexNode>(); Set<String> allExprsDigests = new HashSet<String>(this.allExprsDigests); final JoinRelType joinType = joinRel.getJoinType(); switch (joinType) { case INNER: case LEFT: infer( leftChildPredicates, allExprsDigests, inferredPredicates, includeEqualityInference, joinType == JoinRelType.LEFT ? rightFieldsBitSet : allFieldsBitSet); break; } switch (joinType) { case INNER: case RIGHT: infer( rightChildPredicates, allExprsDigests, inferredPredicates, includeEqualityInference, joinType == JoinRelType.RIGHT ? leftFieldsBitSet : allFieldsBitSet); break; } Mappings.TargetMapping rightMapping = Mappings.createShiftMapping( nSysFields + nFieldsLeft + nFieldsRight, 0, nSysFields + nFieldsLeft, nFieldsRight); final RexPermuteInputsShuttle rightPermute = new RexPermuteInputsShuttle(rightMapping, joinRel); Mappings.TargetMapping leftMapping = Mappings.createShiftMapping(nSysFields + nFieldsLeft, 0, nSysFields, nFieldsLeft); final RexPermuteInputsShuttle leftPermute = new RexPermuteInputsShuttle(leftMapping, joinRel); List<RexNode> leftInferredPredicates = new ArrayList<RexNode>(); List<RexNode> rightInferredPredicates = new ArrayList<RexNode>(); for (RexNode iP : inferredPredicates) { ImmutableBitSet iPBitSet = RelOptUtil.InputFinder.bits(iP); if (leftFieldsBitSet.contains(iPBitSet)) { leftInferredPredicates.add(iP.accept(leftPermute)); } else if (rightFieldsBitSet.contains(iPBitSet)) { rightInferredPredicates.add(iP.accept(rightPermute)); } } switch (joinType) { case INNER: Iterable<RexNode> pulledUpPredicates; if (isSemiJoin) { pulledUpPredicates = Iterables.concat( RelOptUtil.conjunctions(leftChildPredicates), leftInferredPredicates); } else { pulledUpPredicates = Iterables.concat( RelOptUtil.conjunctions(leftChildPredicates), RelOptUtil.conjunctions(rightChildPredicates), RelOptUtil.conjunctions(joinRel.getCondition()), inferredPredicates); } return RelOptPredicateList.of( pulledUpPredicates, leftInferredPredicates, rightInferredPredicates); case LEFT: return RelOptPredicateList.of( RelOptUtil.conjunctions(leftChildPredicates), leftInferredPredicates, rightInferredPredicates); case RIGHT: return RelOptPredicateList.of( RelOptUtil.conjunctions(rightChildPredicates), inferredPredicates, EMPTY_LIST); default: assert inferredPredicates.size() == 0; return RelOptPredicateList.EMPTY; } }
private JoinConditionBasedPredicateInference( Join joinRel, boolean isSemiJoin, RexNode lPreds, RexNode rPreds) { super(); this.joinRel = joinRel; this.isSemiJoin = isSemiJoin; nFieldsLeft = joinRel.getLeft().getRowType().getFieldList().size(); nFieldsRight = joinRel.getRight().getRowType().getFieldList().size(); nSysFields = joinRel.getSystemFieldList().size(); leftFieldsBitSet = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft); rightFieldsBitSet = ImmutableBitSet.range(nSysFields + nFieldsLeft, nSysFields + nFieldsLeft + nFieldsRight); allFieldsBitSet = ImmutableBitSet.range(0, nSysFields + nFieldsLeft + nFieldsRight); exprFields = Maps.newHashMap(); allExprsDigests = new HashSet<String>(); if (lPreds == null) { leftChildPredicates = null; } else { Mappings.TargetMapping leftMapping = Mappings.createShiftMapping(nSysFields + nFieldsLeft, nSysFields, 0, nFieldsLeft); leftChildPredicates = lPreds.accept(new RexPermuteInputsShuttle(leftMapping, joinRel.getInput(0))); for (RexNode r : RelOptUtil.conjunctions(leftChildPredicates)) { exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r)); allExprsDigests.add(r.toString()); } } if (rPreds == null) { rightChildPredicates = null; } else { Mappings.TargetMapping rightMapping = Mappings.createShiftMapping( nSysFields + nFieldsLeft + nFieldsRight, nSysFields + nFieldsLeft, 0, nFieldsRight); rightChildPredicates = rPreds.accept(new RexPermuteInputsShuttle(rightMapping, joinRel.getInput(1))); for (RexNode r : RelOptUtil.conjunctions(rightChildPredicates)) { exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r)); allExprsDigests.add(r.toString()); } } equivalence = Maps.newTreeMap(); equalityPredicates = new HashSet<String>(); for (int i = 0; i < nSysFields + nFieldsLeft + nFieldsRight; i++) { equivalence.put(i, BitSets.of(i)); } // Only process equivalences found in the join conditions. Processing // Equivalences from the left or right side infer predicates that are // already present in the Tree below the join. RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder(); List<RexNode> exprs = RelOptUtil.conjunctions(compose(rexBuilder, ImmutableList.of(joinRel.getCondition()))); final EquivalenceFinder eF = new EquivalenceFinder(); new ArrayList<Void>( Lists.transform( exprs, new Function<RexNode, Void>() { public Void apply(RexNode input) { return input.accept(eF); } })); equivalence = BitSets.closure(equivalence); }
public static Set<Integer> getInputRefs(RexNode expr) { InputRefsCollector irefColl = new InputRefsCollector(true); expr.accept(irefColl); return irefColl.getInputRefSet(); }
private static void splitJoinCondition( List<RelDataTypeField> sysFieldList, List<RelNode> inputs, RexNode condition, List<List<RexNode>> joinKeys, List<Integer> filterNulls, List<SqlOperator> rangeOp, List<RexNode> nonEquiList) throws CalciteSemanticException { final int sysFieldCount = sysFieldList.size(); final RelOptCluster cluster = inputs.get(0).getCluster(); final RexBuilder rexBuilder = cluster.getRexBuilder(); if (condition instanceof RexCall) { RexCall call = (RexCall) condition; if (call.getOperator() == SqlStdOperatorTable.AND) { for (RexNode operand : call.getOperands()) { splitJoinCondition( sysFieldList, inputs, operand, joinKeys, filterNulls, rangeOp, nonEquiList); } return; } RexNode leftKey = null; RexNode rightKey = null; int leftInput = 0; int rightInput = 0; List<RelDataTypeField> leftFields = null; List<RelDataTypeField> rightFields = null; boolean reverse = false; SqlKind kind = call.getKind(); // Only consider range operators if we haven't already seen one if ((kind == SqlKind.EQUALS) || (filterNulls != null && kind == SqlKind.IS_NOT_DISTINCT_FROM) || (rangeOp != null && rangeOp.isEmpty() && (kind == SqlKind.GREATER_THAN || kind == SqlKind.GREATER_THAN_OR_EQUAL || kind == SqlKind.LESS_THAN || kind == SqlKind.LESS_THAN_OR_EQUAL))) { final List<RexNode> operands = call.getOperands(); RexNode op0 = operands.get(0); RexNode op1 = operands.get(1); final ImmutableBitSet projRefs0 = InputFinder.bits(op0); final ImmutableBitSet projRefs1 = InputFinder.bits(op1); final ImmutableBitSet[] inputsRange = new ImmutableBitSet[inputs.size()]; int totalFieldCount = 0; for (int i = 0; i < inputs.size(); i++) { final int firstField = totalFieldCount + sysFieldCount; totalFieldCount = firstField + inputs.get(i).getRowType().getFieldCount(); inputsRange[i] = ImmutableBitSet.range(firstField, totalFieldCount); } boolean foundBothInputs = false; for (int i = 0; i < inputs.size() && !foundBothInputs; i++) { if (projRefs0.intersects(inputsRange[i]) && projRefs0.union(inputsRange[i]).equals(inputsRange[i])) { if (leftKey == null) { leftKey = op0; leftInput = i; leftFields = inputs.get(leftInput).getRowType().getFieldList(); } else { rightKey = op0; rightInput = i; rightFields = inputs.get(rightInput).getRowType().getFieldList(); reverse = true; foundBothInputs = true; } } else if (projRefs1.intersects(inputsRange[i]) && projRefs1.union(inputsRange[i]).equals(inputsRange[i])) { if (leftKey == null) { leftKey = op1; leftInput = i; leftFields = inputs.get(leftInput).getRowType().getFieldList(); } else { rightKey = op1; rightInput = i; rightFields = inputs.get(rightInput).getRowType().getFieldList(); foundBothInputs = true; } } } if ((leftKey != null) && (rightKey != null)) { // adjustment array int[] adjustments = new int[totalFieldCount]; for (int i = 0; i < inputs.size(); i++) { final int adjustment = inputsRange[i].nextSetBit(0); for (int j = adjustment; j < inputsRange[i].length(); j++) { adjustments[j] = -adjustment; } } // replace right Key input ref rightKey = rightKey.accept( new RelOptUtil.RexInputConverter( rexBuilder, rightFields, rightFields, adjustments)); // left key only needs to be adjusted if there are system // fields, but do it for uniformity leftKey = leftKey.accept( new RelOptUtil.RexInputConverter( rexBuilder, leftFields, leftFields, adjustments)); RelDataType leftKeyType = leftKey.getType(); RelDataType rightKeyType = rightKey.getType(); if (leftKeyType != rightKeyType) { // perform casting using Hive rules TypeInfo rType = TypeConverter.convert(rightKeyType); TypeInfo lType = TypeConverter.convert(leftKeyType); TypeInfo tgtType = FunctionRegistry.getCommonClassForComparison(lType, rType); if (tgtType == null) { throw new CalciteSemanticException( "Cannot find common type for join keys " + leftKey + " (type " + leftKeyType + ") and " + rightKey + " (type " + rightKeyType + ")"); } RelDataType targetKeyType = TypeConverter.convert(tgtType, rexBuilder.getTypeFactory()); if (leftKeyType != targetKeyType && TypeInfoUtils.isConversionRequiredForComparison(tgtType, lType)) { leftKey = rexBuilder.makeCast(targetKeyType, leftKey); } if (rightKeyType != targetKeyType && TypeInfoUtils.isConversionRequiredForComparison(tgtType, rType)) { rightKey = rexBuilder.makeCast(targetKeyType, rightKey); } } } } if ((leftKey != null) && (rightKey != null)) { // found suitable join keys // add them to key list, ensuring that if there is a // non-equi join predicate, it appears at the end of the // key list; also mark the null filtering property addJoinKey(joinKeys.get(leftInput), leftKey, (rangeOp != null) && !rangeOp.isEmpty()); addJoinKey(joinKeys.get(rightInput), rightKey, (rangeOp != null) && !rangeOp.isEmpty()); if (filterNulls != null && kind == SqlKind.EQUALS) { // nulls are considered not matching for equality comparison // add the position of the most recently inserted key filterNulls.add(joinKeys.get(leftInput).size() - 1); } if (rangeOp != null && kind != SqlKind.EQUALS && kind != SqlKind.IS_DISTINCT_FROM) { if (reverse) { kind = reverse(kind); } rangeOp.add(op(kind, call.getOperator())); } return; } // else fall through and add this condition as nonEqui condition } // The operator is not of RexCall type // So we fail. Fall through. // Add this condition to the list of non-equi-join conditions. nonEquiList.add(condition); }
/** * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for {@link * org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin}. */ public TrimResult trimFields( HiveMultiJoin join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) { final int fieldCount = join.getRowType().getFieldCount(); final RexNode conditionExpr = join.getCondition(); // Add in fields used in the condition. final Set<RelDataTypeField> combinedInputExtraFields = new LinkedHashSet<RelDataTypeField>(extraFields); RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields); inputFinder.inputBitSet.addAll(fieldsUsed); conditionExpr.accept(inputFinder); final ImmutableBitSet fieldsUsedPlus = inputFinder.inputBitSet.build(); int inputStartPos = 0; int changeCount = 0; int newFieldCount = 0; List<RelNode> newInputs = new ArrayList<RelNode>(); List<Mapping> inputMappings = new ArrayList<Mapping>(); for (RelNode input : join.getInputs()) { final RelDataType inputRowType = input.getRowType(); final int inputFieldCount = inputRowType.getFieldCount(); // Compute required mapping. ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder(); for (int bit : fieldsUsedPlus) { if (bit >= inputStartPos && bit < inputStartPos + inputFieldCount) { inputFieldsUsed.set(bit - inputStartPos); } } Set<RelDataTypeField> inputExtraFields = Collections.<RelDataTypeField>emptySet(); TrimResult trimResult = trimChild(join, input, inputFieldsUsed.build(), inputExtraFields); newInputs.add(trimResult.left); if (trimResult.left != input) { ++changeCount; } final Mapping inputMapping = trimResult.right; inputMappings.add(inputMapping); // Move offset to point to start of next input. inputStartPos += inputFieldCount; newFieldCount += inputMapping.getTargetCount(); } Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount); int offset = 0; int newOffset = 0; for (int i = 0; i < inputMappings.size(); i++) { Mapping inputMapping = inputMappings.get(i); for (IntPair pair : inputMapping) { mapping.set(pair.source + offset, pair.target + newOffset); } offset += inputMapping.getSourceCount(); newOffset += inputMapping.getTargetCount(); } if (changeCount == 0 && mapping.isIdentity()) { return new TrimResult(join, Mappings.createIdentity(fieldCount)); } // Build new join. final RexVisitor<RexNode> shuttle = new RexPermuteInputsShuttle(mapping, newInputs.toArray(new RelNode[newInputs.size()])); RexNode newConditionExpr = conditionExpr.accept(shuttle); final RelDataType newRowType = RelOptUtil.permute(join.getCluster().getTypeFactory(), join.getRowType(), mapping); final RelNode newJoin = new HiveMultiJoin( join.getCluster(), newInputs, newConditionExpr, newRowType, join.getJoinInputs(), join.getJoinTypes(), join.getJoinFilters()); return new TrimResult(newJoin, mapping); }