@Override public Lop constructLops() throws HopsException, LopsException { // return already created lops if (getLops() != null) return getLops(); try { ExecType et = optFindExecType(); Hop input = getInput().get(0); if (et == ExecType.CP) { Lop agg1 = null; if (isTernaryAggregateRewriteApplicable()) { agg1 = constructLopsTernaryAggregateRewrite(et); } else if (isUnaryAggregateOuterCPRewriteApplicable()) { OperationTypes op = HopsAgg2Lops.get(_op); DirectionTypes dir = HopsDirection2Lops.get(_direction); BinaryOp binput = (BinaryOp) getInput().get(0); agg1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.CP); PartialAggregate.setDimensionsBasedOnDirection( agg1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP( agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } else { // general case int k = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); if (DMLScript.USE_ACCELERATOR && (DMLScript.FORCE_ACCELERATOR || getMemEstimate() < OptimizerUtils.GPU_MEMORY_BUDGET) && (_op == AggOp.SUM)) { et = ExecType.GPU; k = 1; } agg1 = new PartialAggregate( input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), getDataType(), getValueType(), et, k); } setOutputDimensions(agg1); setLineNumbers(agg1); setLops(agg1); if (getDataType() == DataType.SCALAR) { agg1.getOutputParameters() .setDimensions(1, 1, getRowsInBlock(), getColsInBlock(), getNnz()); } } else if (et == ExecType.MR) { OperationTypes op = HopsAgg2Lops.get(_op); DirectionTypes dir = HopsDirection2Lops.get(_direction); // unary aggregate operation Lop transform1 = null; if (isUnaryAggregateOuterRewriteApplicable()) { BinaryOp binput = (BinaryOp) getInput().get(0); transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.MR); PartialAggregate.setDimensionsBasedOnDirection( transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); } else // default { transform1 = new PartialAggregate(input.constructLops(), op, dir, DataType.MATRIX, getValueType()); ((PartialAggregate) transform1) .setDimensionsBasedOnDirection( getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock()); } setLineNumbers(transform1); // aggregation if required Lop aggregate = null; Group group1 = null; Aggregate agg1 = null; if (requiresAggregation(input, _direction) || transform1 instanceof UAggOuterChain) { group1 = new Group(transform1, Group.OperationTypes.Sort, DataType.MATRIX, getValueType()); group1 .getOutputParameters() .setDimensions( getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz()); setLineNumbers(group1); agg1 = new Aggregate(group1, HopsAgg2Lops.get(_op), DataType.MATRIX, getValueType(), et); agg1.getOutputParameters() .setDimensions( getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz()); agg1.setupCorrectionLocation(PartialAggregate.getCorrectionLocation(op, dir)); setLineNumbers(agg1); aggregate = agg1; } else { ((PartialAggregate) transform1).setDropCorrection(); aggregate = transform1; } setLops(aggregate); // cast if required if (getDataType() == DataType.SCALAR) { // Set the dimensions of PartialAggregate LOP based on the // direction in which aggregation is performed PartialAggregate.setDimensionsBasedOnDirection( transform1, input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); if (group1 != null && agg1 != null) { // if aggregation required group1 .getOutputParameters() .setDimensions( input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), getNnz()); agg1.getOutputParameters() .setDimensions(1, 1, input.getRowsInBlock(), input.getColsInBlock(), getNnz()); } UnaryCP unary1 = new UnaryCP( aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } else if (et == ExecType.SPARK) { OperationTypes op = HopsAgg2Lops.get(_op); DirectionTypes dir = HopsDirection2Lops.get(_direction); // unary aggregate if (isTernaryAggregateRewriteApplicable()) { Lop aggregate = constructLopsTernaryAggregateRewrite(et); setOutputDimensions(aggregate); // 0x0 (scalar) setLineNumbers(aggregate); setLops(aggregate); } else if (isUnaryAggregateOuterSPRewriteApplicable()) { BinaryOp binput = (BinaryOp) getInput().get(0); Lop transform1 = new UAggOuterChain( binput.getInput().get(0).constructLops(), binput.getInput().get(1).constructLops(), op, dir, HopsOpOp2LopsB.get(binput.getOp()), DataType.MATRIX, getValueType(), ExecType.SPARK); PartialAggregate.setDimensionsBasedOnDirection( transform1, getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock(), dir); setLineNumbers(transform1); setLops(transform1); if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP( transform1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } else // default { boolean needAgg = requiresAggregation(input, _direction); SparkAggType aggtype = getSparkUnaryAggregationType(needAgg); PartialAggregate aggregate = new PartialAggregate( input.constructLops(), HopsAgg2Lops.get(_op), HopsDirection2Lops.get(_direction), DataType.MATRIX, getValueType(), aggtype, et); aggregate.setDimensionsBasedOnDirection( getDim1(), getDim2(), input.getRowsInBlock(), input.getColsInBlock()); setLineNumbers(aggregate); setLops(aggregate); if (getDataType() == DataType.SCALAR) { UnaryCP unary1 = new UnaryCP( aggregate, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); setLops(unary1); } } } } catch (Exception e) { throw new HopsException( this.printErrorLocation() + "In AggUnary Hop, error constructing Lops ", e); } // add reblock/checkpoint lops if necessary constructAndSetLopsDataFlowProperties(); // return created lops return getLops(); }
private Lop constructLopsIQM() throws HopsException, LopsException { ExecType et = optFindExecType(); Hop input = getInput().get(0); if (et == ExecType.MR) { CombineUnary combine = CombineUnary.constructCombineLop(input.constructLops(), DataType.MATRIX, getValueType()); combine .getOutputParameters() .setDimensions( input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz()); SortKeys sort = SortKeys.constructSortByValueLop( combine, SortKeys.OperationTypes.WithoutWeights, DataType.MATRIX, ValueType.DOUBLE, ExecType.MR); // Sort dimensions are same as the first input sort.getOutputParameters() .setDimensions( input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz()); Data lit = Data.createLiteralLop(ValueType.DOUBLE, Double.toString(0.25)); lit.setAllPositions( this.getBeginLine(), this.getBeginColumn(), this.getEndLine(), this.getEndColumn()); PickByCount pick = new PickByCount( sort, lit, DataType.MATRIX, getValueType(), PickByCount.OperationTypes.RANGEPICK); pick.getOutputParameters().setDimensions(-1, -1, getRowsInBlock(), getColsInBlock(), -1); setLineNumbers(pick); PartialAggregate pagg = new PartialAggregate( pick, HopsAgg2Lops.get(Hop.AggOp.SUM), HopsDirection2Lops.get(Hop.Direction.RowCol), DataType.MATRIX, getValueType()); setLineNumbers(pagg); // Set the dimensions of PartialAggregate LOP based on the // direction in which aggregation is performed pagg.setDimensionsBasedOnDirection(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock()); Group group1 = new Group(pagg, Group.OperationTypes.Sort, DataType.MATRIX, getValueType()); group1 .getOutputParameters() .setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz()); setLineNumbers(group1); Aggregate agg1 = new Aggregate( group1, HopsAgg2Lops.get(Hop.AggOp.SUM), DataType.MATRIX, getValueType(), ExecType.MR); agg1.getOutputParameters() .setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz()); agg1.setupCorrectionLocation(pagg.getCorrectionLocation()); setLineNumbers(agg1); UnaryCP unary1 = new UnaryCP( agg1, HopsOpOp1LopsUS.get(OpOp1.CAST_AS_SCALAR), getDataType(), getValueType()); unary1.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(unary1); Unary iqm = new Unary( sort, unary1, Unary.OperationTypes.MR_IQM, DataType.SCALAR, ValueType.DOUBLE, ExecType.CP); iqm.getOutputParameters().setDimensions(0, 0, 0, 0, -1); setLineNumbers(iqm); return iqm; } else { SortKeys sort = SortKeys.constructSortByValueLop( input.constructLops(), SortKeys.OperationTypes.WithoutWeights, DataType.MATRIX, ValueType.DOUBLE, et); sort.getOutputParameters() .setDimensions( input.getDim1(), input.getDim2(), input.getRowsInBlock(), input.getColsInBlock(), input.getNnz()); PickByCount pick = new PickByCount( sort, null, getDataType(), getValueType(), PickByCount.OperationTypes.IQM, et, true); pick.getOutputParameters() .setDimensions(getDim1(), getDim2(), getRowsInBlock(), getColsInBlock(), getNnz()); setLineNumbers(pick); return pick; } }