/** * @param str * @return * @throws DMLRuntimeException */ public static MapmmChainSPInstruction parseInstruction(String str) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); InstructionUtils.checkNumFields(parts, 4, 5); String opcode = parts[0]; // check supported opcode if (!opcode.equalsIgnoreCase(MapMultChain.OPCODE)) { throw new DMLRuntimeException( "MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + opcode); } // parse instruction parts (without exec type) CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); if (parts.length == 5) { CPOperand out = new CPOperand(parts[3]); ChainType type = ChainType.valueOf(parts[4]); return new MapmmChainSPInstruction(null, in1, in2, out, type, opcode, str); } else // parts.length==6 { CPOperand in3 = new CPOperand(parts[3]); CPOperand out = new CPOperand(parts[4]); ChainType type = ChainType.valueOf(parts[5]); return new MapmmChainSPInstruction(null, in1, in2, in3, out, type, opcode, str); } }
public static Instruction parseInstruction(String str) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionParts(str); InstructionUtils.checkNumFields(parts, 4); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte out = Byte.parseByte(parts[3]); boolean cbind = Boolean.parseBoolean(parts[4]); return new AppendRInstruction(null, in1, in2, out, cbind, str); }
public static CumulativeSplitInstruction parseInstruction(String str) throws DMLRuntimeException { InstructionUtils.checkNumFields(str, 3); String[] parts = InstructionUtils.getInstructionParts(str); byte in = Byte.parseByte(parts[1]); byte out = Byte.parseByte(parts[2]); double init = Double.parseDouble(parts[3]); return new CumulativeSplitInstruction(in, out, init, str); }
/** * @param str * @return * @throws DMLRuntimeException */ public static CumulativeAggregateSPInstruction parseInstruction(String str) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); InstructionUtils.checkNumFields(parts, 2); String opcode = parts[0]; CPOperand in1 = new CPOperand(parts[1]); CPOperand out = new CPOperand(parts[2]); AggregateUnaryOperator aggun = InstructionUtils.parseCumulativeAggregateUnaryOperator(opcode); return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str); }
public static CumulativeOffsetInstruction parseInstruction(String str) throws DMLRuntimeException { InstructionUtils.checkNumFields(str, 3); String[] parts = InstructionUtils.getInstructionParts(str); String opcode = parts[0]; byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte out = Byte.parseByte(parts[3]); return new CumulativeOffsetInstruction(in1, in2, out, opcode, str); }
/** * @param str * @return * @throws DMLRuntimeException */ public static AggregateUnaryInstruction parseInstruction(String str) throws DMLRuntimeException { InstructionUtils.checkNumFields(str, 3); String[] parts = InstructionUtils.getInstructionParts(str); String opcode = parts[0]; byte in = Byte.parseByte(parts[1]); byte out = Byte.parseByte(parts[2]); boolean drop = Boolean.parseBoolean(parts[3]); AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode); return new AggregateUnaryInstruction(aggun, in, out, drop, str); }
public static BinaryInstruction parseInstruction(String str) throws DMLRuntimeException { InstructionUtils.checkNumFields(str, 3); String[] parts = InstructionUtils.getInstructionParts(str); byte in1, in2, out; String opcode = parts[0]; in1 = Byte.parseByte(parts[1]); in2 = Byte.parseByte(parts[2]); out = Byte.parseByte(parts[3]); BinaryOperator bop = InstructionUtils.parseBinaryOperator(opcode); if (bop != null) return new BinaryInstruction(bop, in1, in2, out, str); else return null; }
/** * @param str * @return * @throws DMLRuntimeException */ public static PmmSPInstruction parseInstruction(String str) throws DMLRuntimeException { String parts[] = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = InstructionUtils.getOpCode(str); if (opcode.equalsIgnoreCase(PMMJ.OPCODE)) { CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand nrow = new CPOperand(parts[3]); CPOperand out = new CPOperand(parts[4]); CacheType type = CacheType.valueOf(parts[5]); AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); return new PmmSPInstruction(aggbin, in1, in2, out, nrow, type, opcode, str); } else { throw new DMLRuntimeException( "PmmSPInstruction.parseInstruction():: Unknown opcode " + opcode); } }
public static MultiReturnParameterizedBuiltinCPInstruction parseInstruction(String str) throws DMLRuntimeException { String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); ArrayList<CPOperand> outputs = new ArrayList<CPOperand>(); String opcode = parts[0]; if (opcode.equalsIgnoreCase("transformencode")) { // one input and two outputs CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); outputs.add(new CPOperand(parts[3], ValueType.DOUBLE, DataType.MATRIX)); outputs.add(new CPOperand(parts[4], ValueType.STRING, DataType.FRAME)); return new MultiReturnParameterizedBuiltinCPInstruction(null, in1, in2, outputs, opcode, str); } else { throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode); } }
/** * @param str * @return * @throws DMLRuntimeException */ public static MapmmSPInstruction parseInstruction(String str) throws DMLRuntimeException { String parts[] = InstructionUtils.getInstructionPartsWithValueType(str); String opcode = parts[0]; if (opcode.equalsIgnoreCase(MapMult.OPCODE)) { CPOperand in1 = new CPOperand(parts[1]); CPOperand in2 = new CPOperand(parts[2]); CPOperand out = new CPOperand(parts[3]); CacheType type = CacheType.valueOf(parts[4]); boolean outputEmpty = Boolean.parseBoolean(parts[5]); SparkAggType aggtype = SparkAggType.valueOf(parts[6]); AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator aggbin = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str); } else { throw new DMLRuntimeException( "MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode); } }
@Override public void processInstruction( Class<? extends MatrixValue> valueClass, CachedValueMap cachedValues, IndexedMatrixValue tempValue, IndexedMatrixValue zeroInput, int blockRowFactor, int blockColFactor) throws DMLRuntimeException { QuaternaryOperator qop = (QuaternaryOperator) optr; ArrayList<IndexedMatrixValue> blkList = cachedValues.get(_input1); if (blkList != null) for (IndexedMatrixValue imv : blkList) { // Step 1: prepare inputs and output if (imv == null) continue; MatrixIndexes inIx = imv.getIndexes(); MatrixValue inVal = imv.getValue(); // allocate space for the output value IndexedMatrixValue iout = null; if (output == _input1) iout = tempValue; else iout = cachedValues.holdPlace(output, valueClass); MatrixIndexes outIx = iout.getIndexes(); MatrixValue outVal = iout.getValue(); // Step 2: get remaining inputs: Wij, Ui, Vj MatrixValue Xij = inVal; // get Wij if existing (null of WeightsType.NONE or WSigmoid any type) IndexedMatrixValue iWij = (_input4 != -1) ? cachedValues.getFirst(_input4) : null; MatrixValue Wij = (iWij != null) ? iWij.getValue() : null; if (null == Wij && qop.hasFourInputs()) { MatrixBlock mb = new MatrixBlock(1, 1, false); String[] parts = InstructionUtils.getInstructionParts(instString); mb.quickSetValue(0, 0, Double.valueOf(parts[4])); Wij = mb; } // get Ui and Vj, potentially through distributed cache MatrixValue Ui = (!_cacheU) ? cachedValues.getFirst(_input2).getValue() // U : MRBaseForCommonInstructions.dcValues .get(_input2) .getDataBlock((int) inIx.getRowIndex(), 1) .getValue(); MatrixValue Vj = (!_cacheV) ? cachedValues.getFirst(_input3).getValue() // t(V) : MRBaseForCommonInstructions.dcValues .get(_input3) .getDataBlock((int) inIx.getColumnIndex(), 1) .getValue(); // handle special input case: //V through shuffle -> t(V) if (Ui.getNumColumns() != Vj.getNumColumns()) { Vj = LibMatrixReorg.reorg( (MatrixBlock) Vj, new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()), new ReorgOperator(SwapIndex.getSwapIndexFnObject())); } // Step 3: process instruction Xij.quaternaryOperations(qop, Ui, Vj, Wij, outVal); // set output indexes if (qop.wtype1 != null || qop.wtype4 != null) outIx.setIndexes(1, 1); // wsloss else if (qop.wtype2 != null || qop.wtype5 != null || qop.wtype3 != null && qop.wtype3.isBasic()) outIx.setIndexes(inIx); // wsigmoid/wdivmm-basic else { // wdivmm boolean left = qop.wtype3.isLeft(); outIx.setIndexes(left ? inIx.getColumnIndex() : inIx.getRowIndex(), 1); } // put the output value in the cache if (iout == tempValue) cachedValues.add(output, iout); } }
public static QuaternaryInstruction parseInstruction(String str) throws DMLRuntimeException { String opcode = InstructionUtils.getOpCode(str); // validity check if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) { throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str); } // instruction parsing if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) // wsloss || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) { boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode); // check number of fields (4 inputs, output, type) if (isRed) InstructionUtils.checkNumFields(str, 8); else InstructionUtils.checkNumFields(str, 6); // parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte in3 = Byte.parseByte(parts[3]); byte in4 = Byte.parseByte(parts[4]); byte out = Byte.parseByte(parts[5]); WeightsType wtype = WeightsType.valueOf(parts[6]); // in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; return new QuaternaryInstruction( new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str); } else if (WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) // wumm || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) { boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode); // check number of fields (4 inputs, output, type) if (isRed) InstructionUtils.checkNumFields(str, 8); else InstructionUtils.checkNumFields(str, 6); // parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); String uopcode = parts[1]; byte in1 = Byte.parseByte(parts[2]); byte in2 = Byte.parseByte(parts[3]); byte in3 = Byte.parseByte(parts[4]); byte out = Byte.parseByte(parts[5]); WUMMType wtype = WUMMType.valueOf(parts[6]); // in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; return new QuaternaryInstruction( new QuaternaryOperator(wtype, uopcode), in1, in2, in3, (byte) -1, out, cacheU, cacheV, str); } else if (WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) // wdivmm || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) { boolean isRed = opcode.startsWith("red"); // check number of fields (4 inputs, output, type) if (isRed) InstructionUtils.checkNumFields(str, 8); else InstructionUtils.checkNumFields(str, 6); // parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); final WDivMMType wtype = WDivMMType.valueOf(parts[6]); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte in3 = Byte.parseByte(parts[3]); byte in4 = wtype.hasScalar() ? -1 : Byte.parseByte(parts[4]); byte out = Byte.parseByte(parts[5]); // in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true; return new QuaternaryInstruction( new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str); } else // wsigmoid / wcemm { boolean isRed = opcode.startsWith("red"); int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0; // check number of fields (3 or 4 inputs, output, type) if (isRed) InstructionUtils.checkNumFields(str, 7 + addInput4); else InstructionUtils.checkNumFields(str, 5 + addInput4); // parse instruction parts (without exec type) String[] parts = InstructionUtils.getInstructionParts(str); byte in1 = Byte.parseByte(parts[1]); byte in2 = Byte.parseByte(parts[2]); byte in3 = Byte.parseByte(parts[3]); byte out = Byte.parseByte(parts[4 + addInput4]); // in mappers always through distcache, in reducers through distcache/shuffle boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true; boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true; if (opcode.endsWith("wsigmoid")) return new QuaternaryInstruction( new QuaternaryOperator(WSigmoidType.valueOf(parts[5])), in1, in2, in3, (byte) -1, out, cacheU, cacheV, str); else if (opcode.endsWith("wcemm")) return new QuaternaryInstruction( new QuaternaryOperator(WCeMMType.valueOf(parts[6])), in1, in2, in3, (byte) -1, out, cacheU, cacheV, str); } return null; }