コード例 #1
0
  @Override
  public void processInstruction(
      Class<? extends MatrixValue> valueClass,
      CachedValueMap cachedValues,
      IndexedMatrixValue tempValue,
      IndexedMatrixValue zeroInput,
      int blockRowFactor,
      int blockColFactor)
      throws DMLUnsupportedOperationException, DMLRuntimeException {

    ArrayList<IndexedMatrixValue> blkList = cachedValues.get(input);
    if (blkList != null)
      for (IndexedMatrixValue in : blkList) {
        if (in == null) continue;

        // allocate space for the output value
        IndexedMatrixValue out;
        if (input == output) out = tempValue;
        else out = cachedValues.holdPlace(output, valueClass);

        MatrixIndexes inix = in.getIndexes();

        // prune unnecessary blocks for trace
        if ((((AggregateUnaryOperator) optr).indexFn instanceof ReduceDiag
            && inix.getColumnIndex() != inix.getRowIndex())) {
          // do nothing (block not on diagonal); but reset
          out.getValue().reset();
        } else // general case
        {
          // process instruction
          AggregateUnaryOperator auop = (AggregateUnaryOperator) optr;
          OperationsOnMatrixValues.performAggregateUnary(
              inix,
              in.getValue(),
              out.getIndexes(),
              out.getValue(),
              auop,
              blockRowFactor,
              blockColFactor);
          if (_dropCorr)
            ((MatrixBlock) out.getValue()).dropLastRowsOrColums(auop.aggOp.correctionLocation);
        }

        // put the output value in the cache
        if (out == tempValue) cachedValues.add(output, out);
      }
  }
コード例 #2
0
  @Override
  public void processInstruction(
      Class<? extends MatrixValue> valueClass,
      CachedValueMap cachedValues,
      IndexedMatrixValue tempValue,
      IndexedMatrixValue zeroInput,
      int blockRowFactor,
      int blockColFactor)
      throws DMLRuntimeException {
    ArrayList<IndexedMatrixValue> blkList = cachedValues.get(input);
    if (blkList == null) return;

    for (IndexedMatrixValue in1 : blkList) {
      if (in1 == null) continue;

      MatrixIndexes inix = in1.getIndexes();
      MatrixBlock blk = (MatrixBlock) in1.getValue();
      long rixOffset = (inix.getRowIndex() - 1) * blockRowFactor;
      boolean firstBlk = (inix.getRowIndex() == 1);
      boolean lastBlk = (inix.getRowIndex() == _lastRowBlockIndex);

      // introduce offsets w/ init value for first row
      if (firstBlk) {
        IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass);
        ((MatrixBlock) out.getValue()).reset(1, blk.getNumColumns());
        if (_initValue != 0) {
          for (int j = 0; j < blk.getNumColumns(); j++)
            ((MatrixBlock) out.getValue()).appendValue(0, j, _initValue);
        }
        out.getIndexes().setIndexes(1, inix.getColumnIndex());
      }

      // output splitting (shift by one), preaggregated offset used by subsequent block
      for (int i = 0; i < blk.getNumRows(); i++)
        if (!(lastBlk && i == (blk.getNumRows() - 1))) // ignore last row
        {
          IndexedMatrixValue out = cachedValues.holdPlace(output, valueClass);
          MatrixBlock tmpBlk = (MatrixBlock) out.getValue();
          tmpBlk.reset(1, blk.getNumColumns());
          blk.sliceOperations(i, i, 0, blk.getNumColumns() - 1, tmpBlk);
          out.getIndexes().setIndexes(rixOffset + i + 2, inix.getColumnIndex());
        }
    }
  }
コード例 #3
0
  @Override
  public void processInstruction(
      Class<? extends MatrixValue> valueClass,
      CachedValueMap cachedValues,
      IndexedMatrixValue tempValue,
      IndexedMatrixValue zeroInput,
      int blockRowFactor,
      int blockColFactor)
      throws DMLRuntimeException {
    QuaternaryOperator qop = (QuaternaryOperator) optr;

    ArrayList<IndexedMatrixValue> blkList = cachedValues.get(_input1);
    if (blkList != null)
      for (IndexedMatrixValue imv : blkList) {
        // Step 1: prepare inputs and output
        if (imv == null) continue;
        MatrixIndexes inIx = imv.getIndexes();
        MatrixValue inVal = imv.getValue();

        // allocate space for the output value
        IndexedMatrixValue iout = null;
        if (output == _input1) iout = tempValue;
        else iout = cachedValues.holdPlace(output, valueClass);

        MatrixIndexes outIx = iout.getIndexes();
        MatrixValue outVal = iout.getValue();

        // Step 2: get remaining inputs: Wij, Ui, Vj
        MatrixValue Xij = inVal;

        // get Wij if existing (null of WeightsType.NONE or WSigmoid any type)
        IndexedMatrixValue iWij = (_input4 != -1) ? cachedValues.getFirst(_input4) : null;
        MatrixValue Wij = (iWij != null) ? iWij.getValue() : null;
        if (null == Wij && qop.hasFourInputs()) {
          MatrixBlock mb = new MatrixBlock(1, 1, false);
          String[] parts = InstructionUtils.getInstructionParts(instString);
          mb.quickSetValue(0, 0, Double.valueOf(parts[4]));
          Wij = mb;
        }

        // get Ui and Vj, potentially through distributed cache
        MatrixValue Ui =
            (!_cacheU)
                ? cachedValues.getFirst(_input2).getValue() // U
                : MRBaseForCommonInstructions.dcValues
                    .get(_input2)
                    .getDataBlock((int) inIx.getRowIndex(), 1)
                    .getValue();
        MatrixValue Vj =
            (!_cacheV)
                ? cachedValues.getFirst(_input3).getValue() // t(V)
                : MRBaseForCommonInstructions.dcValues
                    .get(_input3)
                    .getDataBlock((int) inIx.getColumnIndex(), 1)
                    .getValue();
        // handle special input case: //V through shuffle -> t(V)
        if (Ui.getNumColumns() != Vj.getNumColumns()) {
          Vj =
              LibMatrixReorg.reorg(
                  (MatrixBlock) Vj,
                  new MatrixBlock(Vj.getNumColumns(), Vj.getNumRows(), Vj.isInSparseFormat()),
                  new ReorgOperator(SwapIndex.getSwapIndexFnObject()));
        }

        // Step 3: process instruction
        Xij.quaternaryOperations(qop, Ui, Vj, Wij, outVal);

        // set output indexes

        if (qop.wtype1 != null || qop.wtype4 != null) outIx.setIndexes(1, 1); // wsloss
        else if (qop.wtype2 != null
            || qop.wtype5 != null
            || qop.wtype3 != null && qop.wtype3.isBasic())
          outIx.setIndexes(inIx); // wsigmoid/wdivmm-basic
        else { // wdivmm
          boolean left = qop.wtype3.isLeft();
          outIx.setIndexes(left ? inIx.getColumnIndex() : inIx.getRowIndex(), 1);
        }

        // put the output value in the cache
        if (iout == tempValue) cachedValues.add(output, iout);
      }
  }