public boolean computeDimension(MatrixCharacteristics in, MatrixCharacteristics out) throws DMLRuntimeException { out.set(1, 1, in.getRowsPerBlock(), in.getColsPerBlock()); return true; }
@Override public void processInstruction(ExecutionContext ec) throws DMLRuntimeException, DMLUnsupportedOperationException { SparkExecutionContext sec = (SparkExecutionContext) ec; // get filename (literal or variable expression) String fname = ec.getScalarInput(input2.getName(), ValueType.STRING, input2.isLiteral()).getStringValue(); try { // if the file already exists on HDFS, remove it. MapReduceTool.deleteFileIfExistOnHDFS(fname); // prepare output info according to meta data String outFmt = input3.getName(); OutputInfo oi = OutputInfo.stringToOutputInfo(outFmt); // get input rdd JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec.getBinaryBlockRDDHandleForVariable(input1.getName()); MatrixCharacteristics mc = sec.getMatrixCharacteristics(input1.getName()); if (oi == OutputInfo.MatrixMarketOutputInfo || oi == OutputInfo.TextCellOutputInfo) { // recompute nnz if necessary (required for header if matrix market) if (isInputMatrixBlock && !mc.nnzKnown()) mc.setNonZeros(SparkUtils.computeNNZFromBlocks(in1)); JavaRDD<String> header = null; if (outFmt.equalsIgnoreCase("matrixmarket")) { ArrayList<String> headerContainer = new ArrayList<String>(1); // First output MM header String headerStr = "%%MatrixMarket matrix coordinate real general\n" + // output number of rows, number of columns and number of nnz mc.getRows() + " " + mc.getCols() + " " + mc.getNonZeros(); headerContainer.add(headerStr); header = sec.getSparkContext().parallelize(headerContainer); } JavaRDD<String> ijv = in1.flatMap( new ConvertMatrixBlockToIJVLines(mc.getRowsPerBlock(), mc.getColsPerBlock())); if (header != null) customSaveTextFile(header.union(ijv), fname, true); else customSaveTextFile(ijv, fname, false); } else if (oi == OutputInfo.CSVOutputInfo) { JavaRDD<String> out = null; Accumulator<Double> aNnz = null; if (isInputMatrixBlock) { // piggyback nnz computation on actual write if (!mc.nnzKnown()) { aNnz = sec.getSparkContext().accumulator(0L); in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz)); } out = RDDConverterUtils.binaryBlockToCsv( in1, mc, (CSVFileFormatProperties) formatProperties, true); } else { // This case is applicable when the CSV output from transform() is written out @SuppressWarnings("unchecked") JavaPairRDD<Long, String> rdd = (JavaPairRDD<Long, String>) ((MatrixObject) sec.getVariable(input1.getName())).getRDDHandle().getRDD(); out = rdd.values(); String sep = ","; boolean hasHeader = false; if (formatProperties != null) { sep = ((CSVFileFormatProperties) formatProperties).getDelim(); hasHeader = ((CSVFileFormatProperties) formatProperties).hasHeader(); } if (hasHeader) { StringBuffer buf = new StringBuffer(); for (int j = 1; j < mc.getCols(); j++) { if (j != 1) { buf.append(sep); } buf.append("C" + j); } ArrayList<String> headerContainer = new ArrayList<String>(1); headerContainer.add(0, buf.toString()); JavaRDD<String> header = sec.getSparkContext().parallelize(headerContainer); out = header.union(out); } } customSaveTextFile(out, fname, false); if (isInputMatrixBlock && !mc.nnzKnown()) mc.setNonZeros((long) aNnz.value().longValue()); } else if (oi == OutputInfo.BinaryBlockOutputInfo) { // piggyback nnz computation on actual write Accumulator<Double> aNnz = null; if (!mc.nnzKnown()) { aNnz = sec.getSparkContext().accumulator(0L); in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz)); } // save binary block rdd on hdfs in1.saveAsHadoopFile( fname, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class); if (!mc.nnzKnown()) mc.setNonZeros((long) aNnz.value().longValue()); } else { // unsupported formats: binarycell (not externalized) throw new DMLRuntimeException("Unexpected data format: " + outFmt); } // write meta data file MapReduceTool.writeMetaDataFile(fname + ".mtd", ValueType.DOUBLE, mc, oi, formatProperties); } catch (IOException ex) { throw new DMLRuntimeException("Failed to process write instruction", ex); } }