Пример #1
0
  @Override
  public void processInstruction(ExecutionContext ec)
      throws DMLRuntimeException, DMLUnsupportedOperationException {
    SparkExecutionContext sec = (SparkExecutionContext) ec;

    // get filename (literal or variable expression)
    String fname =
        ec.getScalarInput(input2.getName(), ValueType.STRING, input2.isLiteral()).getStringValue();

    try {
      // if the file already exists on HDFS, remove it.
      MapReduceTool.deleteFileIfExistOnHDFS(fname);

      // prepare output info according to meta data
      String outFmt = input3.getName();
      OutputInfo oi = OutputInfo.stringToOutputInfo(outFmt);

      // get input rdd
      JavaPairRDD<MatrixIndexes, MatrixBlock> in1 =
          sec.getBinaryBlockRDDHandleForVariable(input1.getName());
      MatrixCharacteristics mc = sec.getMatrixCharacteristics(input1.getName());

      if (oi == OutputInfo.MatrixMarketOutputInfo || oi == OutputInfo.TextCellOutputInfo) {
        // recompute nnz if necessary (required for header if matrix market)
        if (isInputMatrixBlock && !mc.nnzKnown())
          mc.setNonZeros(SparkUtils.computeNNZFromBlocks(in1));

        JavaRDD<String> header = null;
        if (outFmt.equalsIgnoreCase("matrixmarket")) {
          ArrayList<String> headerContainer = new ArrayList<String>(1);
          // First output MM header
          String headerStr =
              "%%MatrixMarket matrix coordinate real general\n"
                  +
                  // output number of rows, number of columns and number of nnz
                  mc.getRows()
                  + " "
                  + mc.getCols()
                  + " "
                  + mc.getNonZeros();
          headerContainer.add(headerStr);
          header = sec.getSparkContext().parallelize(headerContainer);
        }

        JavaRDD<String> ijv =
            in1.flatMap(
                new ConvertMatrixBlockToIJVLines(mc.getRowsPerBlock(), mc.getColsPerBlock()));
        if (header != null) customSaveTextFile(header.union(ijv), fname, true);
        else customSaveTextFile(ijv, fname, false);
      } else if (oi == OutputInfo.CSVOutputInfo) {
        JavaRDD<String> out = null;
        Accumulator<Double> aNnz = null;

        if (isInputMatrixBlock) {
          // piggyback nnz computation on actual write
          if (!mc.nnzKnown()) {
            aNnz = sec.getSparkContext().accumulator(0L);
            in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
          }

          out =
              RDDConverterUtils.binaryBlockToCsv(
                  in1, mc, (CSVFileFormatProperties) formatProperties, true);
        } else {
          // This case is applicable when the CSV output from transform() is written out
          @SuppressWarnings("unchecked")
          JavaPairRDD<Long, String> rdd =
              (JavaPairRDD<Long, String>)
                  ((MatrixObject) sec.getVariable(input1.getName())).getRDDHandle().getRDD();
          out = rdd.values();

          String sep = ",";
          boolean hasHeader = false;
          if (formatProperties != null) {
            sep = ((CSVFileFormatProperties) formatProperties).getDelim();
            hasHeader = ((CSVFileFormatProperties) formatProperties).hasHeader();
          }

          if (hasHeader) {
            StringBuffer buf = new StringBuffer();
            for (int j = 1; j < mc.getCols(); j++) {
              if (j != 1) {
                buf.append(sep);
              }
              buf.append("C" + j);
            }
            ArrayList<String> headerContainer = new ArrayList<String>(1);
            headerContainer.add(0, buf.toString());
            JavaRDD<String> header = sec.getSparkContext().parallelize(headerContainer);
            out = header.union(out);
          }
        }

        customSaveTextFile(out, fname, false);

        if (isInputMatrixBlock && !mc.nnzKnown()) mc.setNonZeros((long) aNnz.value().longValue());
      } else if (oi == OutputInfo.BinaryBlockOutputInfo) {
        // piggyback nnz computation on actual write
        Accumulator<Double> aNnz = null;
        if (!mc.nnzKnown()) {
          aNnz = sec.getSparkContext().accumulator(0L);
          in1 = in1.mapValues(new ComputeBinaryBlockNnzFunction(aNnz));
        }

        // save binary block rdd on hdfs
        in1.saveAsHadoopFile(
            fname, MatrixIndexes.class, MatrixBlock.class, SequenceFileOutputFormat.class);

        if (!mc.nnzKnown()) mc.setNonZeros((long) aNnz.value().longValue());
      } else {
        // unsupported formats: binarycell (not externalized)
        throw new DMLRuntimeException("Unexpected data format: " + outFmt);
      }

      // write meta data file
      MapReduceTool.writeMetaDataFile(fname + ".mtd", ValueType.DOUBLE, mc, oi, formatProperties);
    } catch (IOException ex) {
      throw new DMLRuntimeException("Failed to process write instruction", ex);
    }
  }
  /**
   * @param sparseM1
   * @param sparseM2
   * @param instType
   */
  private void runMinMaxComparisonTest(
      OpType type,
      DataType dtM1,
      DataType dtM2,
      boolean sparseM1,
      boolean sparseM2,
      ExecType instType) {
    // rtplatform for MR
    RUNTIME_PLATFORM platformOld = rtplatform;
    rtplatform = (instType == ExecType.MR) ? RUNTIME_PLATFORM.HADOOP : RUNTIME_PLATFORM.HYBRID;

    // get the testname
    String TEST_NAME = null;
    int minFlag = (type == OpType.MIN) ? 1 : 0;
    boolean s1Flag = (dtM1 == DataType.SCALAR);
    boolean s2Flag = (dtM2 == DataType.SCALAR);

    if (s1Flag && s2Flag) TEST_NAME = TEST_NAME4;
    else if (s1Flag) TEST_NAME = TEST_NAME2;
    else if (s2Flag) TEST_NAME = TEST_NAME3;
    else TEST_NAME = TEST_NAME1;

    String TEST_CACHE_DIR = "";
    if (TEST_CACHE_ENABLED) {
      int mrows1 = (dtM1 == DataType.MATRIX) ? rows : 1;
      int mrows2 = (dtM2 == DataType.MATRIX) ? rows : 1;

      double sparsityLeft = sparseM1 ? sparsity2 : sparsity1;
      double sparsityRight = sparseM2 ? sparsity2 : sparsity1;

      TEST_CACHE_DIR =
          minFlag + "_" + mrows1 + "_" + mrows2 + "_" + sparsityLeft + "_" + sparsityRight + "/";
    }

    try {
      TestConfiguration config = getTestConfiguration(TEST_NAME);
      loadTestConfiguration(config, TEST_CACHE_DIR);

      // This is for running the junit test the new way, i.e., construct the arguments directly
      String HOME = SCRIPT_DIR + TEST_DIR;
      fullDMLScriptName = HOME + TEST_NAME + ".dml";
      programArgs =
          new String[] {
            "-explain", "-args", input("A"), input("B"), Integer.toString(minFlag), output("C")
          };

      fullRScriptName = HOME + TEST_NAME_R + ".R";
      rCmd =
          "Rscript"
              + " "
              + fullRScriptName
              + " "
              + inputDir()
              + " "
              + minFlag
              + " "
              + expectedDir();

      // generate actual dataset
      int mrows1 = (dtM1 == DataType.MATRIX) ? rows : 1;
      int mcols1 = (dtM1 == DataType.MATRIX) ? cols : 1;
      int mrows2 = (dtM2 == DataType.MATRIX) ? rows : 1;
      int mcols2 = (dtM2 == DataType.MATRIX) ? cols : 1;
      double[][] A = getRandomMatrix(mrows1, mcols1, -1, 1, sparseM1 ? sparsity2 : sparsity1, 7);
      writeInputMatrix("A", A, true);
      MatrixCharacteristics mc1 = new MatrixCharacteristics(mrows1, mcols1, 1000, 1000);
      MapReduceTool.writeMetaDataFile(
          input("A.mtd"), ValueType.DOUBLE, mc1, OutputInfo.TextCellOutputInfo);

      double[][] B = getRandomMatrix(mrows2, mcols2, -1, 1, sparseM2 ? sparsity2 : sparsity1, 3);
      writeInputMatrix("B", B, true);
      MatrixCharacteristics mc2 = new MatrixCharacteristics(mrows2, mcols2, 1000, 1000);
      MapReduceTool.writeMetaDataFile(
          input("B.mtd"), ValueType.DOUBLE, mc2, OutputInfo.TextCellOutputInfo);

      // run test
      runTest(true, false, null, -1);
      runRScript(true);

      // compare matrices
      HashMap<CellIndex, Double> dmlfile = readDMLMatrixFromHDFS("C");
      HashMap<CellIndex, Double> rfile = readRMatrixFromFS("C");
      TestUtils.compareMatrices(dmlfile, rfile, eps, "Stat-DML", "Stat-R");
    } catch (IOException e) {
      e.printStackTrace();
      throw new RuntimeException(e);
    } finally {
      rtplatform = platformOld;
    }
  }