/**
   * This method builds a decision tree model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildDecisionTreeModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures,
      Map<Integer, Integer> categoricalFeatureInfo)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      DecisionTree decisionTree = new DecisionTree();
      DecisionTreeModel decisionTreeModel =
          decisionTree.train(
              trainingData,
              getNoOfClasses(mlModel),
              categoricalFeatureInfo,
              hyperParameters.get(MLConstants.IMPURITY),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          decisionTree.test(decisionTreeModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.DECISION_TREE.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building decision tree model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a naive bayes model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildNaiveBayesModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
      NaiveBayesModel naiveBayesModel =
          naiveBayesClassifier.train(
              trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          naiveBayesClassifier.test(naiveBayesModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(naiveBayesModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building naive bayes model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a lasso regression model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildLassoRegressionModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      LassoRegression lassoRegression = new LassoRegression();
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      LassoModel lassoModel =
          lassoRegression.train(
              trainingData,
              Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
              Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
              Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
              Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      Vector weights = lassoModel.weights();
      if (!isValidWeights(weights)) {
        throw new MLModelBuilderException(
            "Weights of the model generated are null or infinity. [Weights] "
                + vectorToString(weights));
      }
      JavaRDD<Tuple2<Double, Double>> predictionsAndLabels =
          lassoRegression.test(lassoModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary regressionModelSummary =
          SparkModelUtils.generateRegressionModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLGeneralizedLinearModel(lassoModel));

      List<FeatureImportance> featureWeights =
          getFeatureWeights(includedFeatures, lassoModel.weights().toArray());
      regressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
      regressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString());
      regressionModelSummary.setFeatureImportance(featureWeights);

      RegressionMetrics regressionMetrics =
          getRegressionMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      Double meanSquaredError = regressionMetrics.meanSquaredError();
      regressionModelSummary.setMeanSquaredError(meanSquaredError);
      regressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return regressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building lasso regression model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a support vector machine (SVM) model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildSVMModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {

    if (getNoOfClasses(mlModel) > 2) {
      throw new MLModelBuilderException(
          "A binary classification algorithm cannot have more than "
              + "two distinct values in response variable.");
    }

    try {
      SVM svm = new SVM();
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      SVMModel svmModel =
          svm.train(
              trainingData,
              Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
              hyperParameters.get(MLConstants.REGULARIZATION_TYPE),
              Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
              Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
              Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      Vector weights = svmModel.weights();
      if (!isValidWeights(weights)) {
        throw new MLModelBuilderException(
            "Weights of the model generated are null or infinity. [Weights] "
                + vectorToString(weights));
      }

      // getting scores and labels without clearing threshold to get confusion matrix
      JavaRDD<Tuple2<Object, Object>> scoresAndLabelsThresholded = svm.test(svmModel, testingData);
      MulticlassMetrics multiclassMetrics =
          new MulticlassMetrics(JavaRDD.toRDD(scoresAndLabelsThresholded));
      MulticlassConfusionMatrix multiclassConfusionMatrix =
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel);

      svmModel.clearThreshold();
      JavaRDD<Tuple2<Object, Object>> scoresAndLabels = svm.test(svmModel, testingData);
      ProbabilisticClassificationModelSummary probabilisticClassificationModelSummary =
          SparkModelUtils.generateProbabilisticClassificationModelSummary(
              sparkContext, testingData, scoresAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(svmModel));

      List<FeatureImportance> featureWeights =
          getFeatureWeights(includedFeatures, svmModel.weights().toArray());
      probabilisticClassificationModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      probabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
      probabilisticClassificationModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.SVM.toString());

      probabilisticClassificationModelSummary.setMulticlassConfusionMatrix(
          multiclassConfusionMatrix);
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      probabilisticClassificationModelSummary.setModelAccuracy(modelAccuracy);
      probabilisticClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return probabilisticClassificationModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building SVM model: " + e.getMessage(), e);
    }
  }