Ejemplo n.º 1
0
  /**
   * Infers predicates for an Aggregate.
   *
   * <p>Pulls up predicates that only contains references to columns in the GroupSet. For e.g.
   *
   * <pre>
   * childPullUpExprs : { a &gt; 7, b + c &lt; 10, a + e = 9}
   * groupSet         : { a, b}
   * pulledUpExprs    : { a &gt; 7}
   * </pre>
   */
  public RelOptPredicateList getPredicates(Aggregate agg) {
    RelNode child = agg.getInput();
    RelOptPredicateList childInfo = RelMetadataQuery.getPulledUpPredicates(child);

    List<RexNode> aggPullUpPredicates = new ArrayList<RexNode>();

    ImmutableBitSet groupKeys = agg.getGroupSet();
    Mapping m =
        Mappings.create(
            MappingType.PARTIAL_FUNCTION,
            child.getRowType().getFieldCount(),
            agg.getRowType().getFieldCount());

    int i = 0;
    for (int j : groupKeys) {
      m.set(j, i++);
    }

    for (RexNode r : childInfo.pulledUpPredicates) {
      ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
      if (groupKeys.contains(rCols)) {
        r = r.accept(new RexPermuteInputsShuttle(m, child));
        aggPullUpPredicates.add(r);
      }
    }
    return RelOptPredicateList.of(aggPullUpPredicates);
  }
Ejemplo n.º 2
0
 public static List<ExprNodeDesc> getExprNodes(
     List<Integer> inputRefs, RelNode inputRel, String inputTabAlias) {
   List<ExprNodeDesc> exprNodes = new ArrayList<ExprNodeDesc>();
   List<RexNode> rexInputRefs = getInputRef(inputRefs, inputRel);
   List<RexNode> exprs = inputRel.getChildExps();
   // TODO: Change ExprNodeConverter to be independent of Partition Expr
   ExprNodeConverter exprConv =
       new ExprNodeConverter(
           inputTabAlias,
           inputRel.getRowType(),
           new HashSet<Integer>(),
           inputRel.getCluster().getTypeFactory());
   for (int index = 0; index < rexInputRefs.size(); index++) {
     // The following check is only a guard against failures.
     // TODO: Knowing which expr is constant in GBY's aggregation function
     // arguments could be better done using Metadata provider of Calcite.
     if (exprs != null && index < exprs.size() && exprs.get(index) instanceof RexLiteral) {
       ExprNodeDesc exprNodeDesc = exprConv.visitLiteral((RexLiteral) exprs.get(index));
       exprNodes.add(exprNodeDesc);
     } else {
       RexNode iRef = rexInputRefs.get(index);
       exprNodes.add(iRef.accept(exprConv));
     }
   }
   return exprNodes;
 }
Ejemplo n.º 3
0
  public static ExprNodeDesc getExprNode(
      Integer inputRefIndx, RelNode inputRel, ExprNodeConverter exprConv) {
    ExprNodeDesc exprNode = null;
    RexNode rexInputRef =
        new RexInputRef(
            inputRefIndx, inputRel.getRowType().getFieldList().get(inputRefIndx).getType());
    exprNode = rexInputRef.accept(exprConv);

    return exprNode;
  }
Ejemplo n.º 4
0
 @Override
 public String translate(ExprCompiler compiler, RexCall call) {
   String val = compiler.reserveName();
   PrintWriter pw = compiler.pw;
   RexNode op = call.getOperands().get(0);
   String lhs = op.accept(compiler);
   pw.print(
       String.format(
           "final %1$s %2$s = (%1$s) %3$s;\n", compiler.javaTypeName(call), val, lhs));
   return val;
 }
Ejemplo n.º 5
0
  public static boolean isDeterministicFuncOnLiterals(RexNode expr) {
    boolean deterministicFuncOnLiterals = true;

    RexVisitor<Void> visitor =
        new RexVisitorImpl<Void>(true) {
          @Override
          public Void visitCall(org.apache.calcite.rex.RexCall call) {
            if (!call.getOperator().isDeterministic()) {
              throw new Util.FoundOne(call);
            }
            return super.visitCall(call);
          }

          @Override
          public Void visitInputRef(RexInputRef inputRef) {
            throw new Util.FoundOne(inputRef);
          }

          @Override
          public Void visitLocalRef(RexLocalRef localRef) {
            throw new Util.FoundOne(localRef);
          }

          @Override
          public Void visitOver(RexOver over) {
            throw new Util.FoundOne(over);
          }

          @Override
          public Void visitDynamicParam(RexDynamicParam dynamicParam) {
            throw new Util.FoundOne(dynamicParam);
          }

          @Override
          public Void visitRangeRef(RexRangeRef rangeRef) {
            throw new Util.FoundOne(rangeRef);
          }

          @Override
          public Void visitFieldAccess(RexFieldAccess fieldAccess) {
            throw new Util.FoundOne(fieldAccess);
          }
        };

    try {
      expr.accept(visitor);
    } catch (Util.FoundOne e) {
      deterministicFuncOnLiterals = false;
    }

    return deterministicFuncOnLiterals;
  }
Ejemplo n.º 6
0
  /**
   * Infers predicates for a project.
   *
   * <ol>
   *   <li>create a mapping from input to projection. Map only positions that directly reference an
   *       input column.
   *   <li>Expressions that only contain above columns are retained in the Project's pullExpressions
   *       list.
   *   <li>For e.g. expression 'a + e = 9' below will not be pulled up because 'e' is not in the
   *       projection list.
   *       <pre>
   * childPullUpExprs:      {a &gt; 7, b + c &lt; 10, a + e = 9}
   * projectionExprs:       {a, b, c, e / 2}
   * projectionPullupExprs: {a &gt; 7, b + c &lt; 10}
   * </pre>
   * </ol>
   */
  public RelOptPredicateList getPredicates(Project project) {
    RelNode child = project.getInput();
    final RexBuilder rexBuilder = project.getCluster().getRexBuilder();
    RelOptPredicateList childInfo = RelMetadataQuery.getPulledUpPredicates(child);

    List<RexNode> projectPullUpPredicates = new ArrayList<RexNode>();

    ImmutableBitSet.Builder columnsMappedBuilder = ImmutableBitSet.builder();
    Mapping m =
        Mappings.create(
            MappingType.PARTIAL_FUNCTION,
            child.getRowType().getFieldCount(),
            project.getRowType().getFieldCount());

    for (Ord<RexNode> o : Ord.zip(project.getProjects())) {
      if (o.e instanceof RexInputRef) {
        int sIdx = ((RexInputRef) o.e).getIndex();
        m.set(sIdx, o.i);
        columnsMappedBuilder.set(sIdx);
      }
    }

    // Go over childPullUpPredicates. If a predicate only contains columns in
    // 'columnsMapped' construct a new predicate based on mapping.
    final ImmutableBitSet columnsMapped = columnsMappedBuilder.build();
    for (RexNode r : childInfo.pulledUpPredicates) {
      ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
      if (columnsMapped.contains(rCols)) {
        r = r.accept(new RexPermuteInputsShuttle(m, child));
        projectPullUpPredicates.add(r);
      }
    }

    // Project can also generate constants. We need to include them.
    for (Ord<RexNode> expr : Ord.zip(project.getProjects())) {
      if (RexLiteral.isNullLiteral(expr.e)) {
        projectPullUpPredicates.add(
            rexBuilder.makeCall(
                SqlStdOperatorTable.IS_NULL, rexBuilder.makeInputRef(project, expr.i)));
      } else if (RexUtil.isConstant(expr.e)) {
        final List<RexNode> args =
            ImmutableList.of(rexBuilder.makeInputRef(project, expr.i), expr.e);
        final SqlOperator op =
            args.get(0).getType().isNullable() || args.get(1).getType().isNullable()
                ? SqlStdOperatorTable.IS_NOT_DISTINCT_FROM
                : SqlStdOperatorTable.EQUALS;
        projectPullUpPredicates.add(rexBuilder.makeCall(op, args));
      }
    }
    return RelOptPredicateList.of(projectPullUpPredicates);
  }
Ejemplo n.º 7
0
 @Override
 public String translate(ExprCompiler compiler, RexCall call) {
   String val = compiler.reserveName();
   PrintWriter pw = compiler.pw;
   pw.print(String.format("final %s %s;\n", compiler.javaTypeName(call), val));
   RexNode op0 = call.getOperands().get(0);
   RexNode op1 = call.getOperands().get(1);
   boolean lhsNullable = op0.getType().isNullable();
   boolean rhsNullable = op1.getType().isNullable();
   String lhs = op0.accept(compiler);
   if (!lhsNullable) {
     pw.print(String.format("if (%2$s) { %1$s = true; }\n", val, lhs));
     pw.print("else {\n");
     String rhs = op1.accept(compiler);
     pw.print(String.format("  %1$s = %2$s;\n}\n", val, rhs));
   } else {
     String foldedLHS =
         foldNullExpr(String.format("%1$s == null || !(%1$s)", lhs), "true", op0);
     pw.print(String.format("if (%s) {\n", foldedLHS));
     String rhs = op1.accept(compiler);
     String s;
     if (rhsNullable) {
       s =
           foldNullExpr(
               String.format(
                   "(%2$s != null && %2$s) ? Boolean.TRUE : ((%1$s == null || %2$s == null) ? null : Boolean.FALSE)",
                   lhs, rhs),
               "null",
               op1);
     } else {
       s = String.format("%2$s ? Boolean.valueOf(%2$s) : %1$s", lhs, rhs);
     }
     pw.print(String.format("  %1$s = %2$s;\n", val, s));
     pw.print(String.format("} else { %1$s = true; }\n", val));
   }
   return val;
 }
Ejemplo n.º 8
0
 @Override
 public String translate(ExprCompiler compiler, RexCall call) {
   String val = compiler.reserveName();
   PrintWriter pw = compiler.pw;
   RexNode op = call.getOperands().get(0);
   String lhs = op.accept(compiler);
   boolean nullable = call.getType().isNullable();
   pw.print(String.format("final %s %s;\n", compiler.javaTypeName(call), val));
   if (!nullable) {
     pw.print(String.format("%1$s = !(%2$s);\n", val, lhs));
   } else {
     String s =
         foldNullExpr(String.format("%1$s == null ? null : !(%1$s)", lhs), "null", op);
     pw.print(String.format("%1$s = %2$s;\n", val, s));
   }
   return val;
 }
Ejemplo n.º 9
0
  public static boolean isDeterministic(RexNode expr) {
    boolean deterministic = true;

    RexVisitor<Void> visitor =
        new RexVisitorImpl<Void>(true) {
          @Override
          public Void visitCall(org.apache.calcite.rex.RexCall call) {
            if (!call.getOperator().isDeterministic()) {
              throw new Util.FoundOne(call);
            }
            return super.visitCall(call);
          }
        };

    try {
      expr.accept(visitor);
    } catch (Util.FoundOne e) {
      deterministic = false;
    }

    return deterministic;
  }
Ejemplo n.º 10
0
 private void infer(
     RexNode predicates,
     Set<String> allExprsDigests,
     List<RexNode> inferedPredicates,
     boolean includeEqualityInference,
     ImmutableBitSet inferringFields) {
   for (RexNode r : RelOptUtil.conjunctions(predicates)) {
     if (!includeEqualityInference && equalityPredicates.contains(r.toString())) {
       continue;
     }
     for (Mapping m : mappings(r)) {
       RexNode tr =
           r.accept(new RexPermuteInputsShuttle(m, joinRel.getInput(0), joinRel.getInput(1)));
       if (inferringFields.contains(RelOptUtil.InputFinder.bits(tr))
           && !allExprsDigests.contains(tr.toString())
           && !isAlwaysTrue(tr)) {
         inferedPredicates.add(tr);
         allExprsDigests.add(tr.toString());
       }
     }
   }
 }
 public Double getDistinctRowCount(Union rel, ImmutableBitSet groupKey, RexNode predicate) {
   Double rowCount = 0.0;
   int[] adjustments = new int[rel.getRowType().getFieldCount()];
   RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
   for (RelNode input : rel.getInputs()) {
     // convert the predicate to reference the types of the union child
     RexNode modifiedPred;
     if (predicate == null) {
       modifiedPred = null;
     } else {
       modifiedPred =
           predicate.accept(
               new RelOptUtil.RexInputConverter(
                   rexBuilder, null, input.getRowType().getFieldList(), adjustments));
     }
     Double partialRowCount = RelMetadataQuery.getDistinctRowCount(input, groupKey, modifiedPred);
     if (partialRowCount == null) {
       return null;
     }
     rowCount += partialRowCount;
   }
   return rowCount;
 }
  @Override
  public Prel visitProject(ProjectPrel project, Object unused) throws RelConversionException {

    // Apply the rule to the child
    RelNode originalInput = ((Prel) project.getInput(0)).accept(this, null);
    project = (ProjectPrel) project.copy(project.getTraitSet(), Lists.newArrayList(originalInput));

    List<RexNode> exprList = new ArrayList<>();

    List<RelDataTypeField> relDataTypes = new ArrayList();
    List<RelDataTypeField> origRelDataTypes = new ArrayList();
    int i = 0;
    final int lastColumnReferenced = PrelUtil.getLastUsedColumnReference(project.getProjects());

    if (lastColumnReferenced == -1) {
      return project;
    }

    final int lastRexInput = lastColumnReferenced + 1;
    RexVisitorComplexExprSplitter exprSplitter =
        new RexVisitorComplexExprSplitter(factory, funcReg, lastRexInput);

    for (RexNode rex : project.getChildExps()) {
      origRelDataTypes.add(project.getRowType().getFieldList().get(i));
      i++;
      exprList.add(rex.accept(exprSplitter));
    }
    List<RexNode> complexExprs = exprSplitter.getComplexExprs();

    if (complexExprs.size() == 1 && findTopComplexFunc(project.getChildExps()).size() == 1) {
      return project;
    }

    ProjectPrel childProject;

    List<RexNode> allExprs = new ArrayList();
    int exprIndex = 0;
    List<String> fieldNames = originalInput.getRowType().getFieldNames();
    for (int index = 0; index < lastRexInput; index++) {
      RexBuilder builder = new RexBuilder(factory);
      allExprs.add(
          builder.makeInputRef(new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory), index));

      if (fieldNames.get(index).contains(StarColumnHelper.STAR_COLUMN)) {
        relDataTypes.add(
            new RelDataTypeFieldImpl(
                fieldNames.get(index), allExprs.size(), factory.createSqlType(SqlTypeName.ANY)));
      } else {
        relDataTypes.add(
            new RelDataTypeFieldImpl(
                "EXPR$" + exprIndex, allExprs.size(), factory.createSqlType(SqlTypeName.ANY)));
        exprIndex++;
      }
    }
    RexNode currRexNode;
    int index = lastRexInput - 1;

    // if the projection expressions contained complex outputs, split them into their own individual
    // projects
    if (complexExprs.size() > 0) {
      while (complexExprs.size() > 0) {
        if (index >= lastRexInput) {
          allExprs.remove(allExprs.size() - 1);
          RexBuilder builder = new RexBuilder(factory);
          allExprs.add(
              builder.makeInputRef(
                  new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory), index));
        }
        index++;
        exprIndex++;

        currRexNode = complexExprs.remove(0);
        allExprs.add(currRexNode);
        relDataTypes.add(
            new RelDataTypeFieldImpl(
                "EXPR$" + exprIndex, allExprs.size(), factory.createSqlType(SqlTypeName.ANY)));
        childProject =
            new ProjectPrel(
                project.getCluster(),
                project.getTraitSet(),
                originalInput,
                ImmutableList.copyOf(allExprs),
                new RelRecordType(relDataTypes));
        originalInput = childProject;
      }
      // copied from above, find a better way to do this
      allExprs.remove(allExprs.size() - 1);
      RexBuilder builder = new RexBuilder(factory);
      allExprs.add(
          builder.makeInputRef(new RelDataTypeDrillImpl(new RelDataTypeHolder(), factory), index));
      relDataTypes.add(
          new RelDataTypeFieldImpl(
              "EXPR$" + index, allExprs.size(), factory.createSqlType(SqlTypeName.ANY)));
    }
    return (Prel)
        project.copy(
            project.getTraitSet(), originalInput, exprList, new RelRecordType(origRelDataTypes));
  }
Ejemplo n.º 13
0
    /**
     * The PullUp Strategy is sound but not complete.
     *
     * <ol>
     *   <li>We only pullUp inferred predicates for now. Pulling up existing predicates causes an
     *       explosion of duplicates. The existing predicates are pushed back down as new
     *       predicates. Once we have rules to eliminate duplicate Filter conditions, we should
     *       pullUp all predicates.
     *   <li>For Left Outer: we infer new predicates from the left and set them as applicable on the
     *       Right side. No predicates are pulledUp.
     *   <li>Right Outer Joins are handled in an analogous manner.
     *   <li>For Full Outer Joins no predicates are pulledUp or inferred.
     * </ol>
     */
    public RelOptPredicateList inferPredicates(boolean includeEqualityInference) {
      List<RexNode> inferredPredicates = new ArrayList<RexNode>();
      Set<String> allExprsDigests = new HashSet<String>(this.allExprsDigests);
      final JoinRelType joinType = joinRel.getJoinType();
      switch (joinType) {
        case INNER:
        case LEFT:
          infer(
              leftChildPredicates,
              allExprsDigests,
              inferredPredicates,
              includeEqualityInference,
              joinType == JoinRelType.LEFT ? rightFieldsBitSet : allFieldsBitSet);
          break;
      }
      switch (joinType) {
        case INNER:
        case RIGHT:
          infer(
              rightChildPredicates,
              allExprsDigests,
              inferredPredicates,
              includeEqualityInference,
              joinType == JoinRelType.RIGHT ? leftFieldsBitSet : allFieldsBitSet);
          break;
      }

      Mappings.TargetMapping rightMapping =
          Mappings.createShiftMapping(
              nSysFields + nFieldsLeft + nFieldsRight, 0, nSysFields + nFieldsLeft, nFieldsRight);
      final RexPermuteInputsShuttle rightPermute =
          new RexPermuteInputsShuttle(rightMapping, joinRel);
      Mappings.TargetMapping leftMapping =
          Mappings.createShiftMapping(nSysFields + nFieldsLeft, 0, nSysFields, nFieldsLeft);
      final RexPermuteInputsShuttle leftPermute = new RexPermuteInputsShuttle(leftMapping, joinRel);

      List<RexNode> leftInferredPredicates = new ArrayList<RexNode>();
      List<RexNode> rightInferredPredicates = new ArrayList<RexNode>();

      for (RexNode iP : inferredPredicates) {
        ImmutableBitSet iPBitSet = RelOptUtil.InputFinder.bits(iP);
        if (leftFieldsBitSet.contains(iPBitSet)) {
          leftInferredPredicates.add(iP.accept(leftPermute));
        } else if (rightFieldsBitSet.contains(iPBitSet)) {
          rightInferredPredicates.add(iP.accept(rightPermute));
        }
      }

      switch (joinType) {
        case INNER:
          Iterable<RexNode> pulledUpPredicates;
          if (isSemiJoin) {
            pulledUpPredicates =
                Iterables.concat(
                    RelOptUtil.conjunctions(leftChildPredicates), leftInferredPredicates);
          } else {
            pulledUpPredicates =
                Iterables.concat(
                    RelOptUtil.conjunctions(leftChildPredicates),
                    RelOptUtil.conjunctions(rightChildPredicates),
                    RelOptUtil.conjunctions(joinRel.getCondition()),
                    inferredPredicates);
          }
          return RelOptPredicateList.of(
              pulledUpPredicates, leftInferredPredicates, rightInferredPredicates);
        case LEFT:
          return RelOptPredicateList.of(
              RelOptUtil.conjunctions(leftChildPredicates),
              leftInferredPredicates,
              rightInferredPredicates);
        case RIGHT:
          return RelOptPredicateList.of(
              RelOptUtil.conjunctions(rightChildPredicates), inferredPredicates, EMPTY_LIST);
        default:
          assert inferredPredicates.size() == 0;
          return RelOptPredicateList.EMPTY;
      }
    }
Ejemplo n.º 14
0
    private JoinConditionBasedPredicateInference(
        Join joinRel, boolean isSemiJoin, RexNode lPreds, RexNode rPreds) {
      super();
      this.joinRel = joinRel;
      this.isSemiJoin = isSemiJoin;
      nFieldsLeft = joinRel.getLeft().getRowType().getFieldList().size();
      nFieldsRight = joinRel.getRight().getRowType().getFieldList().size();
      nSysFields = joinRel.getSystemFieldList().size();
      leftFieldsBitSet = ImmutableBitSet.range(nSysFields, nSysFields + nFieldsLeft);
      rightFieldsBitSet =
          ImmutableBitSet.range(nSysFields + nFieldsLeft, nSysFields + nFieldsLeft + nFieldsRight);
      allFieldsBitSet = ImmutableBitSet.range(0, nSysFields + nFieldsLeft + nFieldsRight);

      exprFields = Maps.newHashMap();
      allExprsDigests = new HashSet<String>();

      if (lPreds == null) {
        leftChildPredicates = null;
      } else {
        Mappings.TargetMapping leftMapping =
            Mappings.createShiftMapping(nSysFields + nFieldsLeft, nSysFields, 0, nFieldsLeft);
        leftChildPredicates =
            lPreds.accept(new RexPermuteInputsShuttle(leftMapping, joinRel.getInput(0)));

        for (RexNode r : RelOptUtil.conjunctions(leftChildPredicates)) {
          exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r));
          allExprsDigests.add(r.toString());
        }
      }
      if (rPreds == null) {
        rightChildPredicates = null;
      } else {
        Mappings.TargetMapping rightMapping =
            Mappings.createShiftMapping(
                nSysFields + nFieldsLeft + nFieldsRight, nSysFields + nFieldsLeft, 0, nFieldsRight);
        rightChildPredicates =
            rPreds.accept(new RexPermuteInputsShuttle(rightMapping, joinRel.getInput(1)));

        for (RexNode r : RelOptUtil.conjunctions(rightChildPredicates)) {
          exprFields.put(r.toString(), RelOptUtil.InputFinder.bits(r));
          allExprsDigests.add(r.toString());
        }
      }

      equivalence = Maps.newTreeMap();
      equalityPredicates = new HashSet<String>();
      for (int i = 0; i < nSysFields + nFieldsLeft + nFieldsRight; i++) {
        equivalence.put(i, BitSets.of(i));
      }

      // Only process equivalences found in the join conditions. Processing
      // Equivalences from the left or right side infer predicates that are
      // already present in the Tree below the join.
      RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
      List<RexNode> exprs =
          RelOptUtil.conjunctions(compose(rexBuilder, ImmutableList.of(joinRel.getCondition())));

      final EquivalenceFinder eF = new EquivalenceFinder();
      new ArrayList<Void>(
          Lists.transform(
              exprs,
              new Function<RexNode, Void>() {
                public Void apply(RexNode input) {
                  return input.accept(eF);
                }
              }));

      equivalence = BitSets.closure(equivalence);
    }
Ejemplo n.º 15
0
 public static Set<Integer> getInputRefs(RexNode expr) {
   InputRefsCollector irefColl = new InputRefsCollector(true);
   expr.accept(irefColl);
   return irefColl.getInputRefSet();
 }
Ejemplo n.º 16
0
  private static void splitJoinCondition(
      List<RelDataTypeField> sysFieldList,
      List<RelNode> inputs,
      RexNode condition,
      List<List<RexNode>> joinKeys,
      List<Integer> filterNulls,
      List<SqlOperator> rangeOp,
      List<RexNode> nonEquiList)
      throws CalciteSemanticException {
    final int sysFieldCount = sysFieldList.size();
    final RelOptCluster cluster = inputs.get(0).getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();

    if (condition instanceof RexCall) {
      RexCall call = (RexCall) condition;
      if (call.getOperator() == SqlStdOperatorTable.AND) {
        for (RexNode operand : call.getOperands()) {
          splitJoinCondition(
              sysFieldList, inputs, operand, joinKeys, filterNulls, rangeOp, nonEquiList);
        }
        return;
      }

      RexNode leftKey = null;
      RexNode rightKey = null;
      int leftInput = 0;
      int rightInput = 0;
      List<RelDataTypeField> leftFields = null;
      List<RelDataTypeField> rightFields = null;
      boolean reverse = false;

      SqlKind kind = call.getKind();

      // Only consider range operators if we haven't already seen one
      if ((kind == SqlKind.EQUALS)
          || (filterNulls != null && kind == SqlKind.IS_NOT_DISTINCT_FROM)
          || (rangeOp != null
              && rangeOp.isEmpty()
              && (kind == SqlKind.GREATER_THAN
                  || kind == SqlKind.GREATER_THAN_OR_EQUAL
                  || kind == SqlKind.LESS_THAN
                  || kind == SqlKind.LESS_THAN_OR_EQUAL))) {
        final List<RexNode> operands = call.getOperands();
        RexNode op0 = operands.get(0);
        RexNode op1 = operands.get(1);

        final ImmutableBitSet projRefs0 = InputFinder.bits(op0);
        final ImmutableBitSet projRefs1 = InputFinder.bits(op1);

        final ImmutableBitSet[] inputsRange = new ImmutableBitSet[inputs.size()];
        int totalFieldCount = 0;
        for (int i = 0; i < inputs.size(); i++) {
          final int firstField = totalFieldCount + sysFieldCount;
          totalFieldCount = firstField + inputs.get(i).getRowType().getFieldCount();
          inputsRange[i] = ImmutableBitSet.range(firstField, totalFieldCount);
        }

        boolean foundBothInputs = false;
        for (int i = 0; i < inputs.size() && !foundBothInputs; i++) {
          if (projRefs0.intersects(inputsRange[i])
              && projRefs0.union(inputsRange[i]).equals(inputsRange[i])) {
            if (leftKey == null) {
              leftKey = op0;
              leftInput = i;
              leftFields = inputs.get(leftInput).getRowType().getFieldList();
            } else {
              rightKey = op0;
              rightInput = i;
              rightFields = inputs.get(rightInput).getRowType().getFieldList();
              reverse = true;
              foundBothInputs = true;
            }
          } else if (projRefs1.intersects(inputsRange[i])
              && projRefs1.union(inputsRange[i]).equals(inputsRange[i])) {
            if (leftKey == null) {
              leftKey = op1;
              leftInput = i;
              leftFields = inputs.get(leftInput).getRowType().getFieldList();
            } else {
              rightKey = op1;
              rightInput = i;
              rightFields = inputs.get(rightInput).getRowType().getFieldList();
              foundBothInputs = true;
            }
          }
        }

        if ((leftKey != null) && (rightKey != null)) {
          // adjustment array
          int[] adjustments = new int[totalFieldCount];
          for (int i = 0; i < inputs.size(); i++) {
            final int adjustment = inputsRange[i].nextSetBit(0);
            for (int j = adjustment; j < inputsRange[i].length(); j++) {
              adjustments[j] = -adjustment;
            }
          }

          // replace right Key input ref
          rightKey =
              rightKey.accept(
                  new RelOptUtil.RexInputConverter(
                      rexBuilder, rightFields, rightFields, adjustments));

          // left key only needs to be adjusted if there are system
          // fields, but do it for uniformity
          leftKey =
              leftKey.accept(
                  new RelOptUtil.RexInputConverter(
                      rexBuilder, leftFields, leftFields, adjustments));

          RelDataType leftKeyType = leftKey.getType();
          RelDataType rightKeyType = rightKey.getType();

          if (leftKeyType != rightKeyType) {
            // perform casting using Hive rules
            TypeInfo rType = TypeConverter.convert(rightKeyType);
            TypeInfo lType = TypeConverter.convert(leftKeyType);
            TypeInfo tgtType = FunctionRegistry.getCommonClassForComparison(lType, rType);

            if (tgtType == null) {
              throw new CalciteSemanticException(
                  "Cannot find common type for join keys "
                      + leftKey
                      + " (type "
                      + leftKeyType
                      + ") and "
                      + rightKey
                      + " (type "
                      + rightKeyType
                      + ")");
            }
            RelDataType targetKeyType = TypeConverter.convert(tgtType, rexBuilder.getTypeFactory());

            if (leftKeyType != targetKeyType
                && TypeInfoUtils.isConversionRequiredForComparison(tgtType, lType)) {
              leftKey = rexBuilder.makeCast(targetKeyType, leftKey);
            }

            if (rightKeyType != targetKeyType
                && TypeInfoUtils.isConversionRequiredForComparison(tgtType, rType)) {
              rightKey = rexBuilder.makeCast(targetKeyType, rightKey);
            }
          }
        }
      }

      if ((leftKey != null) && (rightKey != null)) {
        // found suitable join keys
        // add them to key list, ensuring that if there is a
        // non-equi join predicate, it appears at the end of the
        // key list; also mark the null filtering property
        addJoinKey(joinKeys.get(leftInput), leftKey, (rangeOp != null) && !rangeOp.isEmpty());
        addJoinKey(joinKeys.get(rightInput), rightKey, (rangeOp != null) && !rangeOp.isEmpty());
        if (filterNulls != null && kind == SqlKind.EQUALS) {
          // nulls are considered not matching for equality comparison
          // add the position of the most recently inserted key
          filterNulls.add(joinKeys.get(leftInput).size() - 1);
        }
        if (rangeOp != null && kind != SqlKind.EQUALS && kind != SqlKind.IS_DISTINCT_FROM) {
          if (reverse) {
            kind = reverse(kind);
          }
          rangeOp.add(op(kind, call.getOperator()));
        }
        return;
      } // else fall through and add this condition as nonEqui condition
    }

    // The operator is not of RexCall type
    // So we fail. Fall through.
    // Add this condition to the list of non-equi-join conditions.
    nonEquiList.add(condition);
  }
Ejemplo n.º 17
0
  /**
   * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for {@link
   * org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin}.
   */
  public TrimResult trimFields(
      HiveMultiJoin join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
    final int fieldCount = join.getRowType().getFieldCount();
    final RexNode conditionExpr = join.getCondition();

    // Add in fields used in the condition.
    final Set<RelDataTypeField> combinedInputExtraFields =
        new LinkedHashSet<RelDataTypeField>(extraFields);
    RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields);
    inputFinder.inputBitSet.addAll(fieldsUsed);
    conditionExpr.accept(inputFinder);
    final ImmutableBitSet fieldsUsedPlus = inputFinder.inputBitSet.build();

    int inputStartPos = 0;
    int changeCount = 0;
    int newFieldCount = 0;
    List<RelNode> newInputs = new ArrayList<RelNode>();
    List<Mapping> inputMappings = new ArrayList<Mapping>();
    for (RelNode input : join.getInputs()) {
      final RelDataType inputRowType = input.getRowType();
      final int inputFieldCount = inputRowType.getFieldCount();

      // Compute required mapping.
      ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder();
      for (int bit : fieldsUsedPlus) {
        if (bit >= inputStartPos && bit < inputStartPos + inputFieldCount) {
          inputFieldsUsed.set(bit - inputStartPos);
        }
      }

      Set<RelDataTypeField> inputExtraFields = Collections.<RelDataTypeField>emptySet();
      TrimResult trimResult = trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
      newInputs.add(trimResult.left);
      if (trimResult.left != input) {
        ++changeCount;
      }

      final Mapping inputMapping = trimResult.right;
      inputMappings.add(inputMapping);

      // Move offset to point to start of next input.
      inputStartPos += inputFieldCount;
      newFieldCount += inputMapping.getTargetCount();
    }

    Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
    int offset = 0;
    int newOffset = 0;
    for (int i = 0; i < inputMappings.size(); i++) {
      Mapping inputMapping = inputMappings.get(i);
      for (IntPair pair : inputMapping) {
        mapping.set(pair.source + offset, pair.target + newOffset);
      }
      offset += inputMapping.getSourceCount();
      newOffset += inputMapping.getTargetCount();
    }

    if (changeCount == 0 && mapping.isIdentity()) {
      return new TrimResult(join, Mappings.createIdentity(fieldCount));
    }

    // Build new join.
    final RexVisitor<RexNode> shuttle =
        new RexPermuteInputsShuttle(mapping, newInputs.toArray(new RelNode[newInputs.size()]));
    RexNode newConditionExpr = conditionExpr.accept(shuttle);

    final RelDataType newRowType =
        RelOptUtil.permute(join.getCluster().getTypeFactory(), join.getRowType(), mapping);
    final RelNode newJoin =
        new HiveMultiJoin(
            join.getCluster(),
            newInputs,
            newConditionExpr,
            newRowType,
            join.getJoinInputs(),
            join.getJoinTypes(),
            join.getJoinFilters());

    return new TrimResult(newJoin, mapping);
  }