/**
   * @param str
   * @return
   * @throws DMLRuntimeException
   */
  public static MapmmChainSPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    InstructionUtils.checkNumFields(parts, 4, 5);
    String opcode = parts[0];

    // check supported opcode
    if (!opcode.equalsIgnoreCase(MapMultChain.OPCODE)) {
      throw new DMLRuntimeException(
          "MapmmChainSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    // parse instruction parts (without exec type)
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand in2 = new CPOperand(parts[2]);

    if (parts.length == 5) {
      CPOperand out = new CPOperand(parts[3]);
      ChainType type = ChainType.valueOf(parts[4]);

      return new MapmmChainSPInstruction(null, in1, in2, out, type, opcode, str);
    } else // parts.length==6
    {
      CPOperand in3 = new CPOperand(parts[3]);
      CPOperand out = new CPOperand(parts[4]);
      ChainType type = ChainType.valueOf(parts[5]);

      return new MapmmChainSPInstruction(null, in1, in2, in3, out, type, opcode, str);
    }
  }
  public static Instruction parseInstruction(String str) throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionParts(str);
    InstructionUtils.checkNumFields(parts, 4);

    byte in1 = Byte.parseByte(parts[1]);
    byte in2 = Byte.parseByte(parts[2]);
    byte out = Byte.parseByte(parts[3]);
    boolean cbind = Boolean.parseBoolean(parts[4]);

    return new AppendRInstruction(null, in1, in2, out, cbind, str);
  }
  public static CumulativeSplitInstruction parseInstruction(String str) throws DMLRuntimeException {
    InstructionUtils.checkNumFields(str, 3);

    String[] parts = InstructionUtils.getInstructionParts(str);

    byte in = Byte.parseByte(parts[1]);
    byte out = Byte.parseByte(parts[2]);
    double init = Double.parseDouble(parts[3]);

    return new CumulativeSplitInstruction(in, out, init, str);
  }
  /**
   * @param str
   * @return
   * @throws DMLRuntimeException
   */
  public static CumulativeAggregateSPInstruction parseInstruction(String str)
      throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    InstructionUtils.checkNumFields(parts, 2);

    String opcode = parts[0];
    CPOperand in1 = new CPOperand(parts[1]);
    CPOperand out = new CPOperand(parts[2]);

    AggregateUnaryOperator aggun = InstructionUtils.parseCumulativeAggregateUnaryOperator(opcode);

    return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str);
  }
  public static CumulativeOffsetInstruction parseInstruction(String str)
      throws DMLRuntimeException {
    InstructionUtils.checkNumFields(str, 3);

    String[] parts = InstructionUtils.getInstructionParts(str);

    String opcode = parts[0];
    byte in1 = Byte.parseByte(parts[1]);
    byte in2 = Byte.parseByte(parts[2]);
    byte out = Byte.parseByte(parts[3]);

    return new CumulativeOffsetInstruction(in1, in2, out, opcode, str);
  }
  /**
   * @param str
   * @return
   * @throws DMLRuntimeException
   */
  public static AggregateUnaryInstruction parseInstruction(String str) throws DMLRuntimeException {

    InstructionUtils.checkNumFields(str, 3);

    String[] parts = InstructionUtils.getInstructionParts(str);

    String opcode = parts[0];
    byte in = Byte.parseByte(parts[1]);
    byte out = Byte.parseByte(parts[2]);
    boolean drop = Boolean.parseBoolean(parts[3]);

    AggregateUnaryOperator aggun = InstructionUtils.parseBasicAggregateUnaryOperator(opcode);
    return new AggregateUnaryInstruction(aggun, in, out, drop, str);
  }
  public static BinaryInstruction parseInstruction(String str) throws DMLRuntimeException {
    InstructionUtils.checkNumFields(str, 3);

    String[] parts = InstructionUtils.getInstructionParts(str);

    byte in1, in2, out;
    String opcode = parts[0];
    in1 = Byte.parseByte(parts[1]);
    in2 = Byte.parseByte(parts[2]);
    out = Byte.parseByte(parts[3]);

    BinaryOperator bop = InstructionUtils.parseBinaryOperator(opcode);
    if (bop != null) return new BinaryInstruction(bop, in1, in2, out, str);
    else return null;
  }
Пример #8
0
  /**
   * @param str
   * @return
   * @throws DMLRuntimeException
   */
  public static PmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String parts[] = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = InstructionUtils.getOpCode(str);

    if (opcode.equalsIgnoreCase(PMMJ.OPCODE)) {
      CPOperand in1 = new CPOperand(parts[1]);
      CPOperand in2 = new CPOperand(parts[2]);
      CPOperand nrow = new CPOperand(parts[3]);
      CPOperand out = new CPOperand(parts[4]);
      CacheType type = CacheType.valueOf(parts[5]);

      AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
      AggregateBinaryOperator aggbin =
          new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
      return new PmmSPInstruction(aggbin, in1, in2, out, nrow, type, opcode, str);
    } else {
      throw new DMLRuntimeException(
          "PmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
  }
  public static MultiReturnParameterizedBuiltinCPInstruction parseInstruction(String str)
      throws DMLRuntimeException {
    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
    ArrayList<CPOperand> outputs = new ArrayList<CPOperand>();
    String opcode = parts[0];

    if (opcode.equalsIgnoreCase("transformencode")) {
      // one input and two outputs
      CPOperand in1 = new CPOperand(parts[1]);
      CPOperand in2 = new CPOperand(parts[2]);
      outputs.add(new CPOperand(parts[3], ValueType.DOUBLE, DataType.MATRIX));
      outputs.add(new CPOperand(parts[4], ValueType.STRING, DataType.FRAME));
      return new MultiReturnParameterizedBuiltinCPInstruction(null, in1, in2, outputs, opcode, str);
    } else {
      throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode);
    }
  }
Пример #10
0
  /**
   * @param str
   * @return
   * @throws DMLRuntimeException
   */
  public static MapmmSPInstruction parseInstruction(String str) throws DMLRuntimeException {
    String parts[] = InstructionUtils.getInstructionPartsWithValueType(str);
    String opcode = parts[0];

    if (opcode.equalsIgnoreCase(MapMult.OPCODE)) {
      CPOperand in1 = new CPOperand(parts[1]);
      CPOperand in2 = new CPOperand(parts[2]);
      CPOperand out = new CPOperand(parts[3]);
      CacheType type = CacheType.valueOf(parts[4]);
      boolean outputEmpty = Boolean.parseBoolean(parts[5]);
      SparkAggType aggtype = SparkAggType.valueOf(parts[6]);

      AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
      AggregateBinaryOperator aggbin =
          new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);
      return new MapmmSPInstruction(aggbin, in1, in2, out, type, outputEmpty, aggtype, opcode, str);
    } else {
      throw new DMLRuntimeException(
          "MapmmSPInstruction.parseInstruction():: Unknown opcode " + opcode);
    }
  }
  @Override
  public void processInstruction(
      Class<? extends MatrixValue> valueClass,
      CachedValueMap cachedValues,
      IndexedMatrixValue tempValue,
      IndexedMatrixValue zeroInput,
      int blockRowFactor,
      int blockColFactor)
      throws DMLRuntimeException {
    QuaternaryOperator qop = (QuaternaryOperator) optr;

    ArrayList<IndexedMatrixValue> blkList = cachedValues.get(_input1);
    if (blkList != null)
      for (IndexedMatrixValue imv : blkList) {
        // Step 1: prepare inputs and output
        if (imv == null) continue;
        MatrixIndexes inIx = imv.getIndexes();
        MatrixValue inVal = imv.getValue();

        // allocate space for the output value
        IndexedMatrixValue iout = null;
        if (output == _input1) iout = tempValue;
        else iout = cachedValues.holdPlace(output, valueClass);

        MatrixIndexes outIx = iout.getIndexes();
        MatrixValue outVal = iout.getValue();

        // Step 2: get remaining inputs: Wij, Ui, Vj
        MatrixValue Xij = inVal;

        // get Wij if existing (null of WeightsType.NONE or WSigmoid any type)
        IndexedMatrixValue iWij = (_input4 != -1) ? cachedValues.getFirst(_input4) : null;
        MatrixValue Wij = (iWij != null) ? iWij.getValue() : null;
        if (null == Wij && qop.hasFourInputs()) {
          MatrixBlock mb = new MatrixBlock(1, 1, false);
          String[] parts = InstructionUtils.getInstructionParts(instString);
          mb.quickSetValue(0, 0, Double.valueOf(parts[4]));
          Wij = mb;
        }

        // get Ui and Vj, potentially through distributed cache
        MatrixValue Ui =
            (!_cacheU)
                ? cachedValues.getFirst(_input2).getValue() // U
                : MRBaseForCommonInstructions.dcValues
                    .get(_input2)
                    .getDataBlock((int) inIx.getRowIndex(), 1)
                    .getValue();
        MatrixValue Vj =
            (!_cacheV)
                ? cachedValues.getFirst(_input3).getValue() // t(V)
                : MRBaseForCommonInstructions.dcValues
                    .get(_input3)
                    .getDataBlock((int) inIx.getColumnIndex(), 1)
                    .getValue();
        // handle special input case: //V through shuffle -> t(V)
        if (Ui.getNumColumns() != Vj.getNumColumns()) {
          Vj =
              LibMatrixReorg.reorg(
                  (MatrixBlock) Vj,
                  new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()),
                  new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
        }

        // Step 3: process instruction
        Xij.quaternaryOperations(qop, Ui, Vj, Wij, outVal);

        // set output indexes

        if (qop.wtype1 != null || qop.wtype4 != null) outIx.setIndexes(1, 1); // wsloss
        else if (qop.wtype2 != null
            || qop.wtype5 != null
            || qop.wtype3 != null && qop.wtype3.isBasic())
          outIx.setIndexes(inIx); // wsigmoid/wdivmm-basic
        else { // wdivmm
          boolean left = qop.wtype3.isLeft();
          outIx.setIndexes(left ? inIx.getColumnIndex() : inIx.getRowIndex(), 1);
        }

        // put the output value in the cache
        if (iout == tempValue) cachedValues.add(output, iout);
      }
  }
  public static QuaternaryInstruction parseInstruction(String str) throws DMLRuntimeException {
    String opcode = InstructionUtils.getOpCode(str);

    // validity check
    if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
      throw new DMLRuntimeException("Unexpected opcode in QuaternaryInstruction: " + str);
    }

    // instruction parsing
    if (WeightedSquaredLoss.OPCODE.equalsIgnoreCase(opcode) // wsloss
        || WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode)) {
      boolean isRed = WeightedSquaredLossR.OPCODE.equalsIgnoreCase(opcode);

      // check number of fields (4 inputs, output, type)
      if (isRed) InstructionUtils.checkNumFields(str, 8);
      else InstructionUtils.checkNumFields(str, 6);

      // parse instruction parts (without exec type)
      String[] parts = InstructionUtils.getInstructionParts(str);

      byte in1 = Byte.parseByte(parts[1]);
      byte in2 = Byte.parseByte(parts[2]);
      byte in3 = Byte.parseByte(parts[3]);
      byte in4 = Byte.parseByte(parts[4]);
      byte out = Byte.parseByte(parts[5]);
      WeightsType wtype = WeightsType.valueOf(parts[6]);

      // in mappers always through distcache, in reducers through distcache/shuffle
      boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
      boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;

      return new QuaternaryInstruction(
          new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str);
    } else if (WeightedUnaryMM.OPCODE.equalsIgnoreCase(opcode) // wumm
        || WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode)) {
      boolean isRed = WeightedUnaryMMR.OPCODE.equalsIgnoreCase(opcode);

      // check number of fields (4 inputs, output, type)
      if (isRed) InstructionUtils.checkNumFields(str, 8);
      else InstructionUtils.checkNumFields(str, 6);

      // parse instruction parts (without exec type)
      String[] parts = InstructionUtils.getInstructionParts(str);

      String uopcode = parts[1];
      byte in1 = Byte.parseByte(parts[2]);
      byte in2 = Byte.parseByte(parts[3]);
      byte in3 = Byte.parseByte(parts[4]);
      byte out = Byte.parseByte(parts[5]);
      WUMMType wtype = WUMMType.valueOf(parts[6]);

      // in mappers always through distcache, in reducers through distcache/shuffle
      boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
      boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;

      return new QuaternaryInstruction(
          new QuaternaryOperator(wtype, uopcode),
          in1,
          in2,
          in3,
          (byte) -1,
          out,
          cacheU,
          cacheV,
          str);
    } else if (WeightedDivMM.OPCODE.equalsIgnoreCase(opcode) // wdivmm
        || WeightedDivMMR.OPCODE.equalsIgnoreCase(opcode)) {
      boolean isRed = opcode.startsWith("red");

      // check number of fields (4 inputs, output, type)
      if (isRed) InstructionUtils.checkNumFields(str, 8);
      else InstructionUtils.checkNumFields(str, 6);

      // parse instruction parts (without exec type)
      String[] parts = InstructionUtils.getInstructionParts(str);

      final WDivMMType wtype = WDivMMType.valueOf(parts[6]);

      byte in1 = Byte.parseByte(parts[1]);
      byte in2 = Byte.parseByte(parts[2]);
      byte in3 = Byte.parseByte(parts[3]);
      byte in4 = wtype.hasScalar() ? -1 : Byte.parseByte(parts[4]);
      byte out = Byte.parseByte(parts[5]);

      // in mappers always through distcache, in reducers through distcache/shuffle
      boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
      boolean cacheV = isRed ? Boolean.parseBoolean(parts[8]) : true;

      return new QuaternaryInstruction(
          new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV, str);
    } else // wsigmoid / wcemm
    {
      boolean isRed = opcode.startsWith("red");
      int addInput4 = (opcode.endsWith("wcemm")) ? 1 : 0;

      // check number of fields (3 or 4 inputs, output, type)
      if (isRed) InstructionUtils.checkNumFields(str, 7 + addInput4);
      else InstructionUtils.checkNumFields(str, 5 + addInput4);

      // parse instruction parts (without exec type)
      String[] parts = InstructionUtils.getInstructionParts(str);

      byte in1 = Byte.parseByte(parts[1]);
      byte in2 = Byte.parseByte(parts[2]);
      byte in3 = Byte.parseByte(parts[3]);
      byte out = Byte.parseByte(parts[4 + addInput4]);

      // in mappers always through distcache, in reducers through distcache/shuffle
      boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
      boolean cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;

      if (opcode.endsWith("wsigmoid"))
        return new QuaternaryInstruction(
            new QuaternaryOperator(WSigmoidType.valueOf(parts[5])),
            in1,
            in2,
            in3,
            (byte) -1,
            out,
            cacheU,
            cacheV,
            str);
      else if (opcode.endsWith("wcemm"))
        return new QuaternaryInstruction(
            new QuaternaryOperator(WCeMMType.valueOf(parts[6])),
            in1,
            in2,
            in3,
            (byte) -1,
            out,
            cacheU,
            cacheV,
            str);
    }

    return null;
  }