/** * @param c * @param vars * @return * @throws DMLRuntimeException */ private static LiteralOp replaceLiteralFullUnaryAggregate(Hop c, LocalVariableMap vars) throws DMLRuntimeException { LiteralOp ret = null; // full unary aggregate w/ matrix less than 10^6 cells if (c instanceof AggUnaryOp && isReplaceableUnaryAggregate((AggUnaryOp) c) && c.getInput().get(0) instanceof DataOp && vars.keySet().contains(c.getInput().get(0).getName())) { Hop data = c.getInput().get(0); MatrixObject mo = (MatrixObject) vars.get(data.getName()); // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile if (mo.getNumRows() * mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) { MatrixBlock mBlock = mo.acquireRead(); double value = replaceUnaryAggregate((AggUnaryOp) c, mBlock); mo.release(); // literal substitution (always double) ret = new LiteralOp(value); } } return ret; }
/** * @param hop * @param vars * @return * @throws DMLRuntimeException */ private static long getIntValueDataLiteral(Hop hop, LocalVariableMap vars) throws DMLRuntimeException { long value = -1; try { if (hop instanceof LiteralOp) { value = HopRewriteUtils.getIntValue((LiteralOp) hop); } else if (hop instanceof UnaryOp && ((UnaryOp) hop).getOp() == OpOp1.NROW) { // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile MatrixObject mo = (MatrixObject) vars.get(hop.getInput().get(0).getName()); value = mo.getNumRows(); } else if (hop instanceof UnaryOp && ((UnaryOp) hop).getOp() == OpOp1.NCOL) { // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile MatrixObject mo = (MatrixObject) vars.get(hop.getInput().get(0).getName()); value = mo.getNumColumns(); } else { ScalarObject sdat = (ScalarObject) vars.get(hop.getName()); value = sdat.getLongValue(); } } catch (HopsException ex) { throw new DMLRuntimeException("Failed to get int value for literal replacement", ex); } return value; }
/** * @param c * @param vars * @return * @throws DMLRuntimeException */ private static LiteralOp replaceLiteralDataTypeCastMatrixRead(Hop c, LocalVariableMap vars) throws DMLRuntimeException { LiteralOp ret = null; // as.scalar/matrix read - literal replacement if (c instanceof UnaryOp && ((UnaryOp) c).getOp() == OpOp1.CAST_AS_SCALAR && c.getInput().get(0) instanceof DataOp && c.getInput().get(0).getDataType() == DataType.MATRIX) { Data dat = vars.get(c.getInput().get(0).getName()); if (dat != null) // required for selective constant propagation { // cast as scalar (see VariableCPInstruction) MatrixObject mo = (MatrixObject) dat; MatrixBlock mBlock = mo.acquireRead(); if (mBlock.getNumRows() != 1 || mBlock.getNumColumns() != 1) throw new DMLRuntimeException( "Dimension mismatch - unable to cast matrix of dimension (" + mBlock.getNumRows() + " x " + mBlock.getNumColumns() + ") to scalar."); double value = mBlock.getValue(0, 0); mo.release(); // literal substitution (always double) ret = new LiteralOp(value); } } return ret; }
/** * @param c * @param vars * @return */ private static LiteralOp replaceLiteralValueTypeCastScalarRead(Hop c, LocalVariableMap vars) { LiteralOp ret = null; // as.double/as.integer/as.boolean over scalar read - literal replacement if (c instanceof UnaryOp && (((UnaryOp) c).getOp() == OpOp1.CAST_AS_DOUBLE || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_BOOLEAN) && c.getInput().get(0) instanceof DataOp && c.getDataType() == DataType.SCALAR) { Data dat = vars.get(c.getInput().get(0).getName()); if (dat != null) // required for selective constant propagation { ScalarObject sdat = (ScalarObject) dat; UnaryOp cast = (UnaryOp) c; switch (cast.getOp()) { case CAST_AS_INT: ret = new LiteralOp(sdat.getLongValue()); break; case CAST_AS_DOUBLE: ret = new LiteralOp(sdat.getDoubleValue()); break; case CAST_AS_BOOLEAN: ret = new LiteralOp(sdat.getBooleanValue()); break; default: // otherwise: do nothing } } } return ret; }
/** * @param c * @param vars * @return */ private static LiteralOp replaceLiteralScalarRead(Hop c, LocalVariableMap vars) { LiteralOp ret = null; // scalar read - literal replacement if (c instanceof DataOp && ((DataOp) c).getDataOpType() != DataOpTypes.PERSISTENTREAD && c.getDataType() == DataType.SCALAR) { Data dat = vars.get(c.getName()); if (dat != null) // required for selective constant propagation { ScalarObject sdat = (ScalarObject) dat; switch (sdat.getValueType()) { case INT: ret = new LiteralOp(sdat.getLongValue()); break; case DOUBLE: ret = new LiteralOp(sdat.getDoubleValue()); break; case BOOLEAN: ret = new LiteralOp(sdat.getBooleanValue()); break; default: // otherwise: do nothing } } } return ret; }
/** * @param h * @param vars * @return */ private static boolean isIntValueDataLiteral(Hop h, LocalVariableMap vars) { return ((h instanceof DataOp && vars.keySet().contains(h.getName())) || h instanceof LiteralOp || (h instanceof UnaryOp && (((UnaryOp) h).getOp() == OpOp1.NROW || ((UnaryOp) h).getOp() == OpOp1.NCOL) && h.getInput().get(0) instanceof DataOp && vars.keySet().contains(h.getInput().get(0).getName()))); }
/** * @param c * @param vars * @return * @throws DMLRuntimeException */ private static LiteralOp replaceLiteralFullUnaryAggregateRightIndexing( Hop c, LocalVariableMap vars) throws DMLRuntimeException { LiteralOp ret = null; // full unary aggregate w/ indexed matrix less than 10^6 cells if (c instanceof AggUnaryOp && isReplaceableUnaryAggregate((AggUnaryOp) c) && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getInput().get(0) instanceof DataOp) { IndexingOp rix = (IndexingOp) c.getInput().get(0); Hop data = rix.getInput().get(0); Hop rl = rix.getInput().get(1); Hop ru = rix.getInput().get(2); Hop cl = rix.getInput().get(3); Hop cu = rix.getInput().get(4); if (data instanceof DataOp && vars.keySet().contains(data.getName()) && isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) && isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars)) { long rlval = getIntValueDataLiteral(rl, vars); long ruval = getIntValueDataLiteral(ru, vars); long clval = getIntValueDataLiteral(cl, vars); long cuval = getIntValueDataLiteral(cu, vars); MatrixObject mo = (MatrixObject) vars.get(data.getName()); // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile if (mo.getNumRows() * mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) { MatrixBlock mBlock = mo.acquireRead(); MatrixBlock mBlock2 = mBlock.sliceOperations( (int) (rlval - 1), (int) (ruval - 1), (int) (clval - 1), (int) (cuval - 1), new MatrixBlock()); double value = replaceUnaryAggregate((AggUnaryOp) c, mBlock2); mo.release(); // literal substitution (always double) ret = new LiteralOp(value); } } } return ret; }
/** * @param c * @param vars * @return * @throws DMLRuntimeException */ private static LiteralOp replaceLiteralValueTypeCastRightIndexing(Hop c, LocalVariableMap vars) throws DMLRuntimeException { LiteralOp ret = null; // as.scalar/right indexing w/ literals/vars and matrix less than 10^6 cells if (c instanceof UnaryOp && ((UnaryOp) c).getOp() == OpOp1.CAST_AS_SCALAR && c.getInput().get(0) instanceof IndexingOp && c.getInput().get(0).getDataType() == DataType.MATRIX) { IndexingOp rix = (IndexingOp) c.getInput().get(0); Hop data = rix.getInput().get(0); Hop rl = rix.getInput().get(1); Hop ru = rix.getInput().get(2); Hop cl = rix.getInput().get(3); Hop cu = rix.getInput().get(4); if (rix.dimsKnown() && rix.getDim1() == 1 && rix.getDim2() == 1 && data instanceof DataOp && vars.keySet().contains(data.getName()) && isIntValueDataLiteral(rl, vars) && isIntValueDataLiteral(ru, vars) && isIntValueDataLiteral(cl, vars) && isIntValueDataLiteral(cu, vars)) { long rlval = getIntValueDataLiteral(rl, vars); long clval = getIntValueDataLiteral(cl, vars); MatrixObject mo = (MatrixObject) vars.get(data.getName()); // get the dimension information from the matrix object because the hop // dimensions might not have been updated during recompile if (mo.getNumRows() * mo.getNumColumns() < REPLACE_LITERALS_MAX_MATRIX_SIZE) { MatrixBlock mBlock = mo.acquireRead(); double value = mBlock.getValue((int) rlval - 1, (int) clval - 1); mo.release(); // literal substitution (always double) ret = new LiteralOp(value); } } } return ret; }
/** * @param c * @param vars * @return * @throws DMLRuntimeException */ private static LiteralOp replaceLiteralValueTypeCastLiteral(Hop c, LocalVariableMap vars) throws DMLRuntimeException { LiteralOp ret = null; // as.double/as.integer/as.boolean over scalar literal (potentially created by other replacement // rewrite in same dag) - literal replacement if (c instanceof UnaryOp && (((UnaryOp) c).getOp() == OpOp1.CAST_AS_DOUBLE || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_INT || ((UnaryOp) c).getOp() == OpOp1.CAST_AS_BOOLEAN) && c.getInput().get(0) instanceof LiteralOp) { LiteralOp sdat = (LiteralOp) c.getInput().get(0); UnaryOp cast = (UnaryOp) c; try { switch (cast.getOp()) { case CAST_AS_INT: long ival = HopRewriteUtils.getIntValue(sdat); ret = new LiteralOp(ival); break; case CAST_AS_DOUBLE: double dval = HopRewriteUtils.getDoubleValue(sdat); ret = new LiteralOp(dval); break; case CAST_AS_BOOLEAN: boolean bval = HopRewriteUtils.getBooleanValue(sdat); ret = new LiteralOp(bval); break; default: // otherwise: do nothing } } catch (HopsException ex) { throw new DMLRuntimeException(ex); } } return ret; }
/** * @param hop * @param vars * @throws DMLRuntimeException */ protected static void rReplaceLiterals(Hop hop, LocalVariableMap vars) throws DMLRuntimeException { if (hop.getVisited() == VisitStatus.DONE) return; if (hop.getInput() != null) { // indexed access to allow parent-child modifications for (int i = 0; i < hop.getInput().size(); i++) { Hop c = hop.getInput().get(i); Hop lit = null; // conditional apply of literal replacements lit = (lit == null) ? replaceLiteralScalarRead(c, vars) : lit; lit = (lit == null) ? replaceLiteralValueTypeCastScalarRead(c, vars) : lit; lit = (lit == null) ? replaceLiteralValueTypeCastLiteral(c, vars) : lit; lit = (lit == null) ? replaceLiteralDataTypeCastMatrixRead(c, vars) : lit; lit = (lit == null) ? replaceLiteralValueTypeCastRightIndexing(c, vars) : lit; lit = (lit == null) ? replaceLiteralFullUnaryAggregate(c, vars) : lit; lit = (lit == null) ? replaceLiteralFullUnaryAggregateRightIndexing(c, vars) : lit; // replace hop w/ literal on demand if (lit != null) { // replace hop c by literal, for all parents to prevent (1) missed opportunities // because hop c marked as visited, and (2) repeated evaluation of uagg ops if (c.getParent().size() > 1) { // multiple parents ArrayList<Hop> parents = new ArrayList<Hop>(c.getParent()); for (Hop p : parents) { int pos = HopRewriteUtils.getChildReferencePos(p, c); HopRewriteUtils.removeChildReferenceByPos(p, c, pos); HopRewriteUtils.addChildReference(p, lit, pos); } } else { // current hop is only parent HopRewriteUtils.removeChildReferenceByPos(hop, c, i); HopRewriteUtils.addChildReference(hop, lit, i); } } // recursively process children else { rReplaceLiterals(c, vars); } } } hop.setVisited(VisitStatus.DONE); }