/**
   * This method builds a decision tree model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildDecisionTreeModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures,
      Map<Integer, Integer> categoricalFeatureInfo)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      DecisionTree decisionTree = new DecisionTree();
      DecisionTreeModel decisionTreeModel =
          decisionTree.train(
              trainingData,
              getNoOfClasses(mlModel),
              categoricalFeatureInfo,
              hyperParameters.get(MLConstants.IMPURITY),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          decisionTree.test(decisionTreeModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.DECISION_TREE.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building decision tree model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a naive bayes model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildNaiveBayesModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
      NaiveBayesModel naiveBayesModel =
          naiveBayesClassifier.train(
              trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          naiveBayesClassifier.test(naiveBayesModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(naiveBayesModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building naive bayes model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a lasso regression model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildLassoRegressionModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      LassoRegression lassoRegression = new LassoRegression();
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      LassoModel lassoModel =
          lassoRegression.train(
              trainingData,
              Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
              Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
              Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
              Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      Vector weights = lassoModel.weights();
      if (!isValidWeights(weights)) {
        throw new MLModelBuilderException(
            "Weights of the model generated are null or infinity. [Weights] "
                + vectorToString(weights));
      }
      JavaRDD<Tuple2<Double, Double>> predictionsAndLabels =
          lassoRegression.test(lassoModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary regressionModelSummary =
          SparkModelUtils.generateRegressionModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLGeneralizedLinearModel(lassoModel));

      List<FeatureImportance> featureWeights =
          getFeatureWeights(includedFeatures, lassoModel.weights().toArray());
      regressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
      regressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString());
      regressionModelSummary.setFeatureImportance(featureWeights);

      RegressionMetrics regressionMetrics =
          getRegressionMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      Double meanSquaredError = regressionMetrics.meanSquaredError();
      regressionModelSummary.setMeanSquaredError(meanSquaredError);
      regressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return regressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building lasso regression model: " + e.getMessage(), e);
    }
  }
  /**
   * A utility method to generate class classification model summary
   *
   * @param predictionsAndLabels Predictions and actual labels
   * @return Class classification model summary
   */
  public static ClassClassificationAndRegressionModelSummary getClassClassificationModelSummary(
      JavaSparkContext sparkContext,
      JavaRDD<LabeledPoint> testingData,
      JavaPairRDD<Double, Double> predictionsAndLabels) {
    ClassClassificationAndRegressionModelSummary classClassificationModelSummary =
        new ClassClassificationAndRegressionModelSummary();
    // store predictions and actuals
    List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>();
    for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) {
      PredictedVsActual predictedVsActual = new PredictedVsActual();
      predictedVsActual.setPredicted(scoreAndLabel._1());
      predictedVsActual.setActual(scoreAndLabel._2());
      predictedVsActuals.add(predictedVsActual);
    }
    // create a list of feature values
    List<double[]> features = new ArrayList<double[]>();
    for (LabeledPoint labeledPoint : testingData.collect()) {
      if (labeledPoint != null && labeledPoint.features() != null) {
        double[] rowFeatures = labeledPoint.features().toArray();
        features.add(rowFeatures);
      }
    }
    // create a list of feature values with predicted vs. actuals
    List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>();
    for (int i = 0; i < features.size(); i++) {
      TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
      testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i));
      testResultDataPoint.setFeatureValues(features.get(i));
      testResultDataPoints.add(testResultDataPoint);
    }
    // covert List to JavaRDD
    JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD =
        sparkContext.parallelize(testResultDataPoints);
    // collect RDD as a sampled list
    List<TestResultDataPoint> testResultDataPointsSample;
    if (testResultDataPointsJavaRDD.count()
        > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) {
      testResultDataPointsSample =
          testResultDataPointsJavaRDD.takeSample(
              true,
              MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize());
    } else {
      testResultDataPointsSample = testResultDataPointsJavaRDD.collect();
    }
    classClassificationModelSummary.setTestResultDataPointsSample(testResultDataPointsSample);
    classClassificationModelSummary.setPredictedVsActuals(predictedVsActuals);
    // calculate test error
    double error =
        1.0
            * predictionsAndLabels
                .filter(
                    new Function<Tuple2<Double, Double>, Boolean>() {
                      private static final long serialVersionUID = -3063364114286182333L;

                      @Override
                      public Boolean call(Tuple2<Double, Double> pl) {
                        return !pl._1().equals(pl._2());
                      }
                    })
                .count()
            / predictionsAndLabels.count();
    classClassificationModelSummary.setError(error);
    return classClassificationModelSummary;
  }
  /**
   * A utility method to generate regression model summary
   *
   * @param predictionsAndLabels Tuple2 containing predicted and actual values
   * @return Regression model summary
   */
  public static ClassClassificationAndRegressionModelSummary generateRegressionModelSummary(
      JavaSparkContext sparkContext,
      JavaRDD<LabeledPoint> testingData,
      JavaRDD<Tuple2<Double, Double>> predictionsAndLabels) {
    ClassClassificationAndRegressionModelSummary regressionModelSummary =
        new ClassClassificationAndRegressionModelSummary();
    // store predictions and actuals
    List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>();
    DecimalFormat decimalFormat = new DecimalFormat(MLConstants.DECIMAL_FORMAT);
    for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) {
      PredictedVsActual predictedVsActual = new PredictedVsActual();
      predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(scoreAndLabel._1())));
      predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(scoreAndLabel._2())));
      predictedVsActuals.add(predictedVsActual);
    }
    // create a list of feature values
    List<double[]> features = new ArrayList<double[]>();
    for (LabeledPoint labeledPoint : testingData.collect()) {
      if (labeledPoint != null && labeledPoint.features() != null) {
        double[] rowFeatures = labeledPoint.features().toArray();
        features.add(rowFeatures);
      }
    }
    // create a list of feature values with predicted vs. actuals
    List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>();
    for (int i = 0; i < features.size(); i++) {
      TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
      testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i));
      testResultDataPoint.setFeatureValues(features.get(i));
      testResultDataPoints.add(testResultDataPoint);
    }
    // covert List to JavaRDD
    JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD =
        sparkContext.parallelize(testResultDataPoints);
    // collect RDD as a sampled list
    List<TestResultDataPoint> testResultDataPointsSample;
    if (testResultDataPointsJavaRDD.count()
        > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) {
      testResultDataPointsSample =
          testResultDataPointsJavaRDD.takeSample(
              true,
              MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize());
    } else {
      testResultDataPointsSample = testResultDataPointsJavaRDD.collect();
    }
    regressionModelSummary.setTestResultDataPointsSample(testResultDataPointsSample);
    regressionModelSummary.setPredictedVsActuals(predictedVsActuals);
    // calculate mean squared error (MSE)
    double meanSquaredError =
        new JavaDoubleRDD(
                predictionsAndLabels
                    .map(
                        new Function<Tuple2<Double, Double>, Object>() {
                          private static final long serialVersionUID = -162193633199074816L;

                          public Object call(Tuple2<Double, Double> pair) {
                            return Math.pow(pair._1() - pair._2(), 2.0);
                          }
                        })
                    .rdd())
            .mean();
    regressionModelSummary.setError(meanSquaredError);
    return regressionModelSummary;
  }