/**
   * A utility method to generate class classification model summary
   *
   * @param predictionsAndLabels Predictions and actual labels
   * @return Class classification model summary
   */
  public static ClassClassificationAndRegressionModelSummary getClassClassificationModelSummary(
      JavaSparkContext sparkContext,
      JavaRDD<LabeledPoint> testingData,
      JavaPairRDD<Double, Double> predictionsAndLabels) {
    ClassClassificationAndRegressionModelSummary classClassificationModelSummary =
        new ClassClassificationAndRegressionModelSummary();
    // store predictions and actuals
    List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>();
    for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) {
      PredictedVsActual predictedVsActual = new PredictedVsActual();
      predictedVsActual.setPredicted(scoreAndLabel._1());
      predictedVsActual.setActual(scoreAndLabel._2());
      predictedVsActuals.add(predictedVsActual);
    }
    // create a list of feature values
    List<double[]> features = new ArrayList<double[]>();
    for (LabeledPoint labeledPoint : testingData.collect()) {
      if (labeledPoint != null && labeledPoint.features() != null) {
        double[] rowFeatures = labeledPoint.features().toArray();
        features.add(rowFeatures);
      }
    }
    // create a list of feature values with predicted vs. actuals
    List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>();
    for (int i = 0; i < features.size(); i++) {
      TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
      testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i));
      testResultDataPoint.setFeatureValues(features.get(i));
      testResultDataPoints.add(testResultDataPoint);
    }
    // covert List to JavaRDD
    JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD =
        sparkContext.parallelize(testResultDataPoints);
    // collect RDD as a sampled list
    List<TestResultDataPoint> testResultDataPointsSample;
    if (testResultDataPointsJavaRDD.count()
        > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) {
      testResultDataPointsSample =
          testResultDataPointsJavaRDD.takeSample(
              true,
              MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize());
    } else {
      testResultDataPointsSample = testResultDataPointsJavaRDD.collect();
    }
    classClassificationModelSummary.setTestResultDataPointsSample(testResultDataPointsSample);
    classClassificationModelSummary.setPredictedVsActuals(predictedVsActuals);
    // calculate test error
    double error =
        1.0
            * predictionsAndLabels
                .filter(
                    new Function<Tuple2<Double, Double>, Boolean>() {
                      private static final long serialVersionUID = -3063364114286182333L;

                      @Override
                      public Boolean call(Tuple2<Double, Double> pl) {
                        return !pl._1().equals(pl._2());
                      }
                    })
                .count()
            / predictionsAndLabels.count();
    classClassificationModelSummary.setError(error);
    return classClassificationModelSummary;
  }
  /**
   * A utility method to generate regression model summary
   *
   * @param predictionsAndLabels Tuple2 containing predicted and actual values
   * @return Regression model summary
   */
  public static ClassClassificationAndRegressionModelSummary generateRegressionModelSummary(
      JavaSparkContext sparkContext,
      JavaRDD<LabeledPoint> testingData,
      JavaRDD<Tuple2<Double, Double>> predictionsAndLabels) {
    ClassClassificationAndRegressionModelSummary regressionModelSummary =
        new ClassClassificationAndRegressionModelSummary();
    // store predictions and actuals
    List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>();
    DecimalFormat decimalFormat = new DecimalFormat(MLConstants.DECIMAL_FORMAT);
    for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) {
      PredictedVsActual predictedVsActual = new PredictedVsActual();
      predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(scoreAndLabel._1())));
      predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(scoreAndLabel._2())));
      predictedVsActuals.add(predictedVsActual);
    }
    // create a list of feature values
    List<double[]> features = new ArrayList<double[]>();
    for (LabeledPoint labeledPoint : testingData.collect()) {
      if (labeledPoint != null && labeledPoint.features() != null) {
        double[] rowFeatures = labeledPoint.features().toArray();
        features.add(rowFeatures);
      }
    }
    // create a list of feature values with predicted vs. actuals
    List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>();
    for (int i = 0; i < features.size(); i++) {
      TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
      testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i));
      testResultDataPoint.setFeatureValues(features.get(i));
      testResultDataPoints.add(testResultDataPoint);
    }
    // covert List to JavaRDD
    JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD =
        sparkContext.parallelize(testResultDataPoints);
    // collect RDD as a sampled list
    List<TestResultDataPoint> testResultDataPointsSample;
    if (testResultDataPointsJavaRDD.count()
        > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) {
      testResultDataPointsSample =
          testResultDataPointsJavaRDD.takeSample(
              true,
              MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize());
    } else {
      testResultDataPointsSample = testResultDataPointsJavaRDD.collect();
    }
    regressionModelSummary.setTestResultDataPointsSample(testResultDataPointsSample);
    regressionModelSummary.setPredictedVsActuals(predictedVsActuals);
    // calculate mean squared error (MSE)
    double meanSquaredError =
        new JavaDoubleRDD(
                predictionsAndLabels
                    .map(
                        new Function<Tuple2<Double, Double>, Object>() {
                          private static final long serialVersionUID = -162193633199074816L;

                          public Object call(Tuple2<Double, Double> pair) {
                            return Math.pow(pair._1() - pair._2(), 2.0);
                          }
                        })
                    .rdd())
            .mean();
    regressionModelSummary.setError(meanSquaredError);
    return regressionModelSummary;
  }