Example #1
0
  public void onMatch(RelOptRuleCall call) {
    assert matches(call);
    final JoinRel join = (JoinRel) call.rels[0];
    final List<Integer> leftKeys = new ArrayList<Integer>();
    final List<Integer> rightKeys = new ArrayList<Integer>();
    RelNode right = join.getRight();
    final RelNode left = join.getLeft();
    RexNode remainingCondition =
        RelOptUtil.splitJoinCondition(left, right, join.getCondition(), leftKeys, rightKeys);
    assert leftKeys.size() == rightKeys.size();
    final List<CorrelatorRel.Correlation> correlationList =
        new ArrayList<CorrelatorRel.Correlation>();
    if (leftKeys.size() > 0) {
      final RelOptCluster cluster = join.getCluster();
      final RexBuilder rexBuilder = cluster.getRexBuilder();
      int k = 0;
      RexNode condition = null;
      for (Integer leftKey : leftKeys) {
        Integer rightKey = rightKeys.get(k++);
        final String dyn_inIdStr = cluster.getQuery().createCorrel();
        final int dyn_inId = RelOptQuery.getCorrelOrdinal(dyn_inIdStr);

        // Create correlation to say 'each row, set variable #id
        // to the value of column #leftKey'.
        correlationList.add(new CorrelatorRel.Correlation(dyn_inId, leftKey));
        condition =
            RelOptUtil.andJoinFilters(
                rexBuilder,
                condition,
                rexBuilder.makeCall(
                    SqlStdOperatorTable.equalsOperator,
                    rexBuilder.makeInputRef(
                        right.getRowType().getFieldList().get(rightKey).getType(), rightKey),
                    rexBuilder.makeCorrel(
                        left.getRowType().getFieldList().get(leftKey).getType(), dyn_inIdStr)));
      }
      right = CalcRel.createFilter(right, condition);
    }
    RelNode newRel =
        new CorrelatorRel(
            join.getCluster(),
            left,
            right,
            remainingCondition,
            correlationList,
            join.getJoinType());
    call.transformTo(newRel);
  }
  private void onMatchRight(RelOptRuleCall call) {
    final JoinRelBase topJoin = call.rel(0);
    final JoinRelBase bottomJoin = call.rel(1);
    final RelNode relC = call.rel(2);
    final RelNode relA = bottomJoin.getLeft();
    final RelNode relB = bottomJoin.getRight();
    final RelOptCluster cluster = topJoin.getCluster();

    //        topJoin
    //        /     \
    //   bottomJoin  C
    //    /    \
    //   A      B

    final int aCount = relA.getRowType().getFieldCount();
    final int bCount = relB.getRowType().getFieldCount();
    final int cCount = relC.getRowType().getFieldCount();
    final BitSet bBitSet = BitSets.range(aCount, aCount + bCount);

    // becomes
    //
    //        newTopJoin
    //        /        \
    //   newBottomJoin  B
    //    /    \
    //   A      C

    // If either join is not inner, we cannot proceed.
    // (Is this too strict?)
    if (topJoin.getJoinType() != JoinRelType.INNER
        || bottomJoin.getJoinType() != JoinRelType.INNER) {
      return;
    }

    // Split the condition of topJoin into a conjunction. Each of the
    // parts that does not use columns from B can be pushed down.
    final List<RexNode> intersecting = new ArrayList<RexNode>();
    final List<RexNode> nonIntersecting = new ArrayList<RexNode>();
    split(topJoin.getCondition(), bBitSet, intersecting, nonIntersecting);

    // If there's nothing to push down, it's not worth proceeding.
    if (nonIntersecting.isEmpty()) {
      return;
    }

    // Split the condition of bottomJoin into a conjunction. Each of the
    // parts that use columns from B will need to be pulled up.
    final List<RexNode> bottomIntersecting = new ArrayList<RexNode>();
    final List<RexNode> bottomNonIntersecting = new ArrayList<RexNode>();
    split(bottomJoin.getCondition(), bBitSet, bottomIntersecting, bottomNonIntersecting);

    // target: | A       | C      |
    // source: | A       | B | C      |
    final Mappings.TargetMapping bottomMapping =
        Mappings.createShiftMapping(
            aCount + bCount + cCount, 0, 0, aCount, aCount, aCount + bCount, cCount);
    List<RexNode> newBottomList = new ArrayList<RexNode>();
    new RexPermuteInputsShuttle(bottomMapping, relA, relC)
        .visitList(nonIntersecting, newBottomList);
    final Mappings.TargetMapping bottomBottomMapping =
        Mappings.createShiftMapping(aCount + bCount, 0, 0, aCount);
    new RexPermuteInputsShuttle(bottomBottomMapping, relA, relC)
        .visitList(bottomNonIntersecting, newBottomList);
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    RexNode newBottomCondition = RexUtil.composeConjunction(rexBuilder, newBottomList, false);
    final JoinRelBase newBottomJoin =
        bottomJoin.copy(
            bottomJoin.getTraitSet(), newBottomCondition, relA, relC, bottomJoin.getJoinType());

    // target: | A       | C      | B |
    // source: | A       | B | C      |
    final Mappings.TargetMapping topMapping =
        Mappings.createShiftMapping(
            aCount + bCount + cCount,
            0,
            0,
            aCount,
            aCount + cCount,
            aCount,
            bCount,
            aCount,
            aCount + bCount,
            cCount);
    List<RexNode> newTopList = new ArrayList<RexNode>();
    new RexPermuteInputsShuttle(topMapping, newBottomJoin, relB)
        .visitList(intersecting, newTopList);
    new RexPermuteInputsShuttle(topMapping, newBottomJoin, relB)
        .visitList(bottomIntersecting, newTopList);
    RexNode newTopCondition = RexUtil.composeConjunction(rexBuilder, newTopList, false);
    @SuppressWarnings("SuspiciousNameCombination")
    final JoinRelBase newTopJoin =
        topJoin.copy(
            topJoin.getTraitSet(), newTopCondition, newBottomJoin, relB, topJoin.getJoinType());

    assert !Mappings.isIdentity(topMapping);
    final RelNode newProject =
        RelFactories.createProject(projectFactory, newTopJoin, Mappings.asList(topMapping));

    call.transformTo(newProject);
  }