/** * @param hop * @param vars * @return * @throws DMLRuntimeException */ private static long getIntValueDataLiteral(Hop hop, LocalVariableMap vars) throws DMLRuntimeException { long value = -1; try { if (hop instanceof LiteralOp) { value = HopRewriteUtils.getIntValue((LiteralOp) hop); } else if (hop instanceof UnaryOp && ((UnaryOp) hop).getOp() == OpOp1.NROW) { // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile MatrixObject mo = (MatrixObject) vars.get(hop.getInput().get(0).getName()); value = mo.getNumRows(); } else if (hop instanceof UnaryOp && ((UnaryOp) hop).getOp() == OpOp1.NCOL) { // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile MatrixObject mo = (MatrixObject) vars.get(hop.getInput().get(0).getName()); value = mo.getNumColumns(); } else { ScalarObject sdat = (ScalarObject) vars.get(hop.getName()); value = sdat.getLongValue(); } } catch (HopsException ex) { throw new DMLRuntimeException("Failed to get int value for literal replacement", ex); } return value; }
/** * @param hop * @param vars * @throws DMLRuntimeException */ protected static void rReplaceLiterals(Hop hop, LocalVariableMap vars) throws DMLRuntimeException { if (hop.getVisited() == VisitStatus.DONE) return; if (hop.getInput() != null) { // indexed access to allow parent-child modifications for (int i = 0; i < hop.getInput().size(); i++) { Hop c = hop.getInput().get(i); Hop lit = null; // conditional apply of literal replacements lit = (lit == null) ? replaceLiteralScalarRead(c, vars) : lit; lit = (lit == null) ? replaceLiteralValueTypeCastScalarRead(c, vars) : lit; lit = (lit == null) ? replaceLiteralValueTypeCastLiteral(c, vars) : lit; lit = (lit == null) ? replaceLiteralDataTypeCastMatrixRead(c, vars) : lit; lit = (lit == null) ? replaceLiteralValueTypeCastRightIndexing(c, vars) : lit; lit = (lit == null) ? replaceLiteralFullUnaryAggregate(c, vars) : lit; lit = (lit == null) ? replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit; // replace hop w/ literal on demand if (lit != null) { // replace hop c by literal, for all parents to prevent (1) missed opportunities // because hop c marked as visited, and (2) repeated evaluation of uagg ops if (c.getParent().size() > 1) { // multiple parents ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent()); for (Hop p : parents) { int pos = HopRewriteUtils.getChildReferencePos(p, c); HopRewriteUtils.removeChildReferenceByPos(p, c, pos); HopRewriteUtils.addChildReference(p, lit, pos); } } else { // current hop is only parent HopRewriteUtils.removeChildReferenceByPos(hop, c, i); HopRewriteUtils.addChildReference(hop, lit, i); } } // recursively process children else { rReplaceLiterals(c, vars); } } } hop.setVisited(VisitStatus.DONE); }
/** * Indicates if the lbound:rbound expressions is of the form "(c * (i - 1) + 1) : (c * i)", where * we could use c as a tight size estimate. * * @param lbound * @param ubound * @return */ private boolean isBlockIndexingExpression(Hop lbound, Hop ubound) { boolean ret = false; LiteralOp constant = null; DataOp var = null; // handle lower bound if (lbound instanceof BinaryOp && ((BinaryOp) lbound).getOp() == OpOp2.PLUS && lbound.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) lbound.getInput().get(1)) == 1 && lbound.getInput().get(0) instanceof BinaryOp) { BinaryOp lmult = (BinaryOp) lbound.getInput().get(0); if (lmult.getOp() == OpOp2.MULT && lmult.getInput().get(0) instanceof LiteralOp && lmult.getInput().get(1) instanceof BinaryOp) { BinaryOp lminus = (BinaryOp) lmult.getInput().get(1); if (lminus.getOp() == OpOp2.MINUS && lminus.getInput().get(1) instanceof LiteralOp && HopRewriteUtils.getDoubleValueSafe((LiteralOp) lminus.getInput().get(1)) == 1 && lminus.getInput().get(0) instanceof DataOp) { constant = (LiteralOp) lmult.getInput().get(0); var = (DataOp) lminus.getInput().get(0); } } } // handle upper bound if (var != null && constant != null && ubound instanceof BinaryOp && ubound.getInput().get(0) instanceof LiteralOp && ubound.getInput().get(1) instanceof DataOp && ubound.getInput().get(1).getName().equals(var.getName())) { LiteralOp constant2 = (LiteralOp) ubound.getInput().get(0); ret = (HopRewriteUtils.getDoubleValueSafe(constant) == HopRewriteUtils.getDoubleValueSafe(constant2)); } return ret; }
/** * @param c * @param vars * @return * @throws DMLRuntimeException */ private static LiteralOp replaceLiteralValueTypeCastLiteral(Hop c, LocalVariableMap vars) throws DMLRuntimeException { LiteralOp ret = null; // as.double/as.integer/as.boolean over scalar literal (potentially created by other replacement // rewrite in same dag) - literal replacement if (c instanceof UnaryOp && (((UnaryOp) c).getOp() == OpOp1.CAST_AS_DOUBLE || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_BOOLEAN) && c.getInput().get(0) instanceof LiteralOp) { LiteralOp sdat = (LiteralOp) c.getInput().get(0); UnaryOp cast = (UnaryOp) c; try { switch (cast.getOp()) { case CAST_AS_INT: long ival = HopRewriteUtils.getIntValue(sdat); ret = new LiteralOp(ival); break; case CAST_AS_DOUBLE: double dval = HopRewriteUtils.getDoubleValue(sdat); ret = new LiteralOp(dval); break; case CAST_AS_BOOLEAN: boolean bval = HopRewriteUtils.getBooleanValue(sdat); ret = new LiteralOp(bval); break; default: // otherwise: do nothing } } catch (HopsException ex) { throw new DMLRuntimeException(ex); } } return ret; }
@Override public void refreshSizeInformation() { Hop input1 = getInput().get(0); // original matrix Hop input2 = getInput().get(1); // inpRowL Hop input3 = getInput().get(2); // inpRowU Hop input4 = getInput().get(3); // inpColL Hop input5 = getInput().get(4); // inpColU // parse input information boolean allRows = (input2 instanceof LiteralOp && HopRewriteUtils.getIntValueSafe((LiteralOp) input2) == 1 && input3 instanceof UnaryOp && ((UnaryOp) input3).getOp() == OpOp1.NROW); boolean allCols = (input4 instanceof LiteralOp && HopRewriteUtils.getIntValueSafe((LiteralOp) input4) == 1 && input5 instanceof UnaryOp && ((UnaryOp) input5).getOp() == OpOp1.NCOL); boolean constRowRange = (input2 instanceof LiteralOp && input3 instanceof LiteralOp); boolean constColRange = (input4 instanceof LiteralOp && input5 instanceof LiteralOp); // set dimension information if (_rowLowerEqualsUpper) // ROWS setDim1(1); else if (allRows) setDim1(input1.getDim1()); else if (constRowRange) { setDim1( HopRewriteUtils.getIntValueSafe((LiteralOp) input3) - HopRewriteUtils.getIntValueSafe((LiteralOp) input2) + 1); } else if (isBlockIndexingExpression(input2, input3)) { setDim1(getBlockIndexingExpressionSize(input2, input3)); } if (_colLowerEqualsUpper) // COLS setDim2(1); else if (allCols) setDim2(input1.getDim2()); else if (constColRange) { setDim2( HopRewriteUtils.getIntValueSafe((LiteralOp) input5) - HopRewriteUtils.getIntValueSafe((LiteralOp) input4) + 1); } else if (isBlockIndexingExpression(input4, input5)) { setDim2(getBlockIndexingExpressionSize(input4, input5)); } }
private boolean isTernaryAggregateRewriteApplicable() throws HopsException { boolean ret = false; // currently we support only sum over binary multiply but potentially // it can be generalized to any RC aggregate over two common binary operations if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && _direction == Direction.RowCol && _op == AggOp.SUM) { Hop input1 = getInput().get(0); if (input1.getParent().size() == 1 && // sum single consumer input1 instanceof BinaryOp && ((BinaryOp) input1).getOp() == OpOp2.MULT // As unary agg instruction is not implemented in MR and since MR is in maintenance mode, // postponed it. && input1.optFindExecType() != ExecType.MR) { Hop input11 = input1.getInput().get(0); Hop input12 = input1.getInput().get(1); if (input11 instanceof BinaryOp && ((BinaryOp) input11).getOp() == OpOp2.MULT) { // ternary, arbitrary matrices but no mv/outer operations. ret = HopRewriteUtils.isEqualSize(input11.getInput().get(0), input1) && HopRewriteUtils.isEqualSize(input11.getInput().get(1), input1) && HopRewriteUtils.isEqualSize(input12, input1); } else if (input12 instanceof BinaryOp && ((BinaryOp) input12).getOp() == OpOp2.MULT) { // ternary, arbitrary matrices but no mv/outer operations. ret = HopRewriteUtils.isEqualSize(input12.getInput().get(0), input1) && HopRewriteUtils.isEqualSize(input12.getInput().get(1), input1) && HopRewriteUtils.isEqualSize(input11, input1); } else { // binary, arbitrary matrices but no mv/outer operations. ret = HopRewriteUtils.isEqualSize(input11, input12); } } } return ret; }
/** * @param lbound * @param ubound * @return */ private long getBlockIndexingExpressionSize(Hop lbound, Hop ubound) { // NOTE: ensure consistency with isBlockIndexingExpression LiteralOp c = (LiteralOp) ubound.getInput().get(0); // (c*i) return HopRewriteUtils.getIntValueSafe(c); }