示例#1
0
  @Override
  protected long[] inferOutputCharacteristics(MemoTable memo) {
    long[] ret = null;

    Hop input = getInput().get(0);
    MatrixCharacteristics mc = memo.getAllInputStats(input);
    if (mc.dimsKnown()) {
      if (_op == OpOp1.ABS
          || _op == OpOp1.COS
          || _op == OpOp1.SIN
          || _op == OpOp1.TAN
          || _op == OpOp1.ACOS
          || _op == OpOp1.ASIN
          || _op == OpOp1.ATAN
          || _op == OpOp1.SQRT
          || _op == OpOp1.ROUND) // sparsity preserving
      {
        ret = new long[] {mc.getRows(), mc.getCols(), mc.getNonZeros()};
      } else ret = new long[] {mc.getRows(), mc.getCols(), -1};
    }

    return ret;
  }
  @Override
  public void processInstruction(ExecutionContext ec)
      throws DMLRuntimeException, DMLUnsupportedOperationException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;

    // get filename (literal or variable expression)
    String fname =
        ec.getScalarInput(input2.getName(), ValueType.STRING, input2.isLiteral()).getStringValue();

    try {
      // if the file already exists on HDFS, remove it.
      MapReduceTool.deleteFileIfExistOnHDFS(fname);

      // prepare output info according to meta data
      String outFmt = input3.getName();
      OutputInfo oi = OutputInfo.stringToOutputInfo(outFmt);

      // get input rdd
      JavaPairRDD<MatrixIndexes, MatrixBlock> in1 =
          sec.getBinaryBlockRDDHandleForVariable(input1.getName());
      MatrixCharacteristics mc = sec.getMatrixCharacteristics(input1.getName());

      if (oi == OutputInfo.MatrixMarketOutputInfo || oi == OutputInfo.TextCellOutputInfo) {
        // recompute nnz if necessary (required for header if matrix market)
        if (isInputMatrixBlock && !mc.nnzKnown())
          mc.setNonZeros(SparkUtils.computeNNZFromBlocks(in1));

        JavaRDD<String> header = null;
        if (outFmt.equalsIgnoreCase("matrixmarket")) {
          ArrayList<String> headerContainer = new ArrayList<String>(1);
          // First output MM header
          String headerStr =
              "%%MatrixMarket matrix coordinate real general\n"
                  +
                  // output number of rows, number of columns and number of nnz
                  mc.getRows()
                  + " "
                  + mc.getCols()
                  + " "
                  + mc.getNonZeros();
          headerContainer.add(headerStr);
          header = sec.getSparkContext().parallelize(headerContainer);
        }

        JavaRDD<String> ijv =
            in1.flatMap(
                new ConvertMatrixBlockToIJVLines(mc.getRowsPerBlock(), mc.getColsPerBlock()));
        if (header != null) customSaveTextFile(header.union(ijv), fname, true);
        else customSaveTextFile(ijv, fname, false);
      } else if (oi == OutputInfo.CSVOutputInfo) {
        JavaRDD<String> out = null;
        Accumulator<Double> aNnz = null;

        if (isInputMatrixBlock) {
          // piggyback nnz computation on actual write
          if (!mc.nnzKnown()) {
            aNnz = sec.getSparkContext().accumulator(0L);
            in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
          }

          out =
              RDDConverterUtils.binaryBlockToCsv(
                  in1, mc, (CSVFileFormatProperties) formatProperties, true);
        } else {
          // This case is applicable when the CSV output from transform() is written out
          @SuppressWarnings("unchecked")
          JavaPairRDD<Long, String> rdd =
              (JavaPairRDD<Long, String>)
                  ((MatrixObject) sec.getVariable(input1.getName())).getRDDHandle().getRDD();
          out = rdd.values();

          String sep = ",";
          boolean hasHeader = false;
          if (formatProperties != null) {
            sep = ((CSVFileFormatProperties) formatProperties).getDelim();
            hasHeader = ((CSVFileFormatProperties) formatProperties).hasHeader();
          }

          if (hasHeader) {
            StringBuffer buf = new StringBuffer();
            for (int j = 1; j < mc.getCols(); j++) {
              if (j != 1) {
                buf.append(sep);
              }
              buf.append("C" + j);
            }
            ArrayList<String> headerContainer = new ArrayList<String>(1);
            headerContainer.add(0, buf.toString());
            JavaRDD<String> header = sec.getSparkContext().parallelize(headerContainer);
            out = header.union(out);
          }
        }

        customSaveTextFile(out, fname, false);

        if (isInputMatrixBlock && !mc.nnzKnown()) mc.setNonZeros((long) aNnz.value().longValue());
      } else if (oi == OutputInfo.BinaryBlockOutputInfo) {
        // piggyback nnz computation on actual write
        Accumulator<Double> aNnz = null;
        if (!mc.nnzKnown()) {
          aNnz = sec.getSparkContext().accumulator(0L);
          in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
        }

        // save binary block rdd on hdfs
        in1.saveAsHadoopFile(
            fname, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class);

        if (!mc.nnzKnown()) mc.setNonZeros((long) aNnz.value().longValue());
      } else {
        // unsupported formats: binarycell (not externalized)
        throw new DMLRuntimeException("Unexpected data format: " + outFmt);
      }

      // write meta data file
      MapReduceTool.writeMetaDataFile(fname + ".mtd", ValueType.DOUBLE, mc, oi, formatProperties);
    } catch (IOException ex) {
      throw new DMLRuntimeException("Failed to process write instruction", ex);
    }
  }