예제 #1
0
  /**
   * Infers predicates for an Aggregate.
   *
   * <p>Pulls up predicates that only contains references to columns in the GroupSet. For e.g.
   *
   * <pre>
   * childPullUpExprs : { a &gt; 7, b + c &lt; 10, a + e = 9}
   * groupSet         : { a, b}
   * pulledUpExprs    : { a &gt; 7}
   * </pre>
   */
  public RelOptPredicateList getPredicates(Aggregate agg) {
    RelNode child = agg.getInput();
    RelOptPredicateList childInfo = RelMetadataQuery.getPulledUpPredicates(child);

    List<RexNode> aggPullUpPredicates = new ArrayList<RexNode>();

    ImmutableBitSet groupKeys = agg.getGroupSet();
    Mapping m =
        Mappings.create(
            MappingType.PARTIAL_FUNCTION,
            child.getRowType().getFieldCount(),
            agg.getRowType().getFieldCount());

    int i = 0;
    for (int j : groupKeys) {
      m.set(j, i++);
    }

    for (RexNode r : childInfo.pulledUpPredicates) {
      ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
      if (groupKeys.contains(rCols)) {
        r = r.accept(new RexPermuteInputsShuttle(m, child));
        aggPullUpPredicates.add(r);
      }
    }
    return RelOptPredicateList.of(aggPullUpPredicates);
  }
예제 #2
0
  /**
   * Infers predicates for a project.
   *
   * <ol>
   *   <li>create a mapping from input to projection. Map only positions that directly reference an
   *       input column.
   *   <li>Expressions that only contain above columns are retained in the Project's pullExpressions
   *       list.
   *   <li>For e.g. expression 'a + e = 9' below will not be pulled up because 'e' is not in the
   *       projection list.
   *       <pre>
   * childPullUpExprs:      {a &gt; 7, b + c &lt; 10, a + e = 9}
   * projectionExprs:       {a, b, c, e / 2}
   * projectionPullupExprs: {a &gt; 7, b + c &lt; 10}
   * </pre>
   * </ol>
   */
  public RelOptPredicateList getPredicates(Project project) {
    RelNode child = project.getInput();
    final RexBuilder rexBuilder = project.getCluster().getRexBuilder();
    RelOptPredicateList childInfo = RelMetadataQuery.getPulledUpPredicates(child);

    List<RexNode> projectPullUpPredicates = new ArrayList<RexNode>();

    ImmutableBitSet.Builder columnsMappedBuilder = ImmutableBitSet.builder();
    Mapping m =
        Mappings.create(
            MappingType.PARTIAL_FUNCTION,
            child.getRowType().getFieldCount(),
            project.getRowType().getFieldCount());

    for (Ord<RexNode> o : Ord.zip(project.getProjects())) {
      if (o.e instanceof RexInputRef) {
        int sIdx = ((RexInputRef) o.e).getIndex();
        m.set(sIdx, o.i);
        columnsMappedBuilder.set(sIdx);
      }
    }

    // Go over childPullUpPredicates. If a predicate only contains columns in
    // 'columnsMapped' construct a new predicate based on mapping.
    final ImmutableBitSet columnsMapped = columnsMappedBuilder.build();
    for (RexNode r : childInfo.pulledUpPredicates) {
      ImmutableBitSet rCols = RelOptUtil.InputFinder.bits(r);
      if (columnsMapped.contains(rCols)) {
        r = r.accept(new RexPermuteInputsShuttle(m, child));
        projectPullUpPredicates.add(r);
      }
    }

    // Project can also generate constants. We need to include them.
    for (Ord<RexNode> expr : Ord.zip(project.getProjects())) {
      if (RexLiteral.isNullLiteral(expr.e)) {
        projectPullUpPredicates.add(
            rexBuilder.makeCall(
                SqlStdOperatorTable.IS_NULL, rexBuilder.makeInputRef(project, expr.i)));
      } else if (RexUtil.isConstant(expr.e)) {
        final List<RexNode> args =
            ImmutableList.of(rexBuilder.makeInputRef(project, expr.i), expr.e);
        final SqlOperator op =
            args.get(0).getType().isNullable() || args.get(1).getType().isNullable()
                ? SqlStdOperatorTable.IS_NOT_DISTINCT_FROM
                : SqlStdOperatorTable.EQUALS;
        projectPullUpPredicates.add(rexBuilder.makeCall(op, args));
      }
    }
    return RelOptPredicateList.of(projectPullUpPredicates);
  }
예제 #3
0
 private void computeNextMapping(int level) {
   int t = columnSets[level].nextSetBit(iterationIdx[level]);
   if (t < 0) {
     if (level == 0) {
       nextMapping = null;
     } else {
       iterationIdx[level] = 0;
       computeNextMapping(level - 1);
     }
   } else {
     nextMapping.set(columns[level], t);
     iterationIdx[level] = t + 1;
   }
 }
예제 #4
0
 private void initializeMapping() {
   nextMapping =
       Mappings.create(
           MappingType.PARTIAL_FUNCTION,
           nSysFields + nFieldsLeft + nFieldsRight,
           nSysFields + nFieldsLeft + nFieldsRight);
   for (int i = 0; i < columnSets.length; i++) {
     BitSet c = columnSets[i];
     int t = c.nextSetBit(iterationIdx[i]);
     if (t < 0) {
       nextMapping = null;
       return;
     }
     nextMapping.set(columns[i], t);
     iterationIdx[i] = t + 1;
   }
 }
예제 #5
0
  /**
   * Variant of {@link #trimFields(RelNode, ImmutableBitSet, Set)} for {@link
   * org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin}.
   */
  public TrimResult trimFields(
      HiveMultiJoin join, ImmutableBitSet fieldsUsed, Set<RelDataTypeField> extraFields) {
    final int fieldCount = join.getRowType().getFieldCount();
    final RexNode conditionExpr = join.getCondition();

    // Add in fields used in the condition.
    final Set<RelDataTypeField> combinedInputExtraFields =
        new LinkedHashSet<RelDataTypeField>(extraFields);
    RelOptUtil.InputFinder inputFinder = new RelOptUtil.InputFinder(combinedInputExtraFields);
    inputFinder.inputBitSet.addAll(fieldsUsed);
    conditionExpr.accept(inputFinder);
    final ImmutableBitSet fieldsUsedPlus = inputFinder.inputBitSet.build();

    int inputStartPos = 0;
    int changeCount = 0;
    int newFieldCount = 0;
    List<RelNode> newInputs = new ArrayList<RelNode>();
    List<Mapping> inputMappings = new ArrayList<Mapping>();
    for (RelNode input : join.getInputs()) {
      final RelDataType inputRowType = input.getRowType();
      final int inputFieldCount = inputRowType.getFieldCount();

      // Compute required mapping.
      ImmutableBitSet.Builder inputFieldsUsed = ImmutableBitSet.builder();
      for (int bit : fieldsUsedPlus) {
        if (bit >= inputStartPos && bit < inputStartPos + inputFieldCount) {
          inputFieldsUsed.set(bit - inputStartPos);
        }
      }

      Set<RelDataTypeField> inputExtraFields = Collections.<RelDataTypeField>emptySet();
      TrimResult trimResult = trimChild(join, input, inputFieldsUsed.build(), inputExtraFields);
      newInputs.add(trimResult.left);
      if (trimResult.left != input) {
        ++changeCount;
      }

      final Mapping inputMapping = trimResult.right;
      inputMappings.add(inputMapping);

      // Move offset to point to start of next input.
      inputStartPos += inputFieldCount;
      newFieldCount += inputMapping.getTargetCount();
    }

    Mapping mapping = Mappings.create(MappingType.INVERSE_SURJECTION, fieldCount, newFieldCount);
    int offset = 0;
    int newOffset = 0;
    for (int i = 0; i < inputMappings.size(); i++) {
      Mapping inputMapping = inputMappings.get(i);
      for (IntPair pair : inputMapping) {
        mapping.set(pair.source + offset, pair.target + newOffset);
      }
      offset += inputMapping.getSourceCount();
      newOffset += inputMapping.getTargetCount();
    }

    if (changeCount == 0 && mapping.isIdentity()) {
      return new TrimResult(join, Mappings.createIdentity(fieldCount));
    }

    // Build new join.
    final RexVisitor<RexNode> shuttle =
        new RexPermuteInputsShuttle(mapping, newInputs.toArray(new RelNode[newInputs.size()]));
    RexNode newConditionExpr = conditionExpr.accept(shuttle);

    final RelDataType newRowType =
        RelOptUtil.permute(join.getCluster().getTypeFactory(), join.getRowType(), mapping);
    final RelNode newJoin =
        new HiveMultiJoin(
            join.getCluster(),
            newInputs,
            newConditionExpr,
            newRowType,
            join.getJoinInputs(),
            join.getJoinTypes(),
            join.getJoinFilters());

    return new TrimResult(newJoin, mapping);
  }