/**
   * This method builds a decision tree model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildDecisionTreeModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures,
      Map<Integer, Integer> categoricalFeatureInfo)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      DecisionTree decisionTree = new DecisionTree();
      DecisionTreeModel decisionTreeModel =
          decisionTree.train(
              trainingData,
              getNoOfClasses(mlModel),
              categoricalFeatureInfo,
              hyperParameters.get(MLConstants.IMPURITY),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          decisionTree.test(decisionTreeModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.DECISION_TREE.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building decision tree model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a naive bayes model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildNaiveBayesModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
      NaiveBayesModel naiveBayesModel =
          naiveBayesClassifier.train(
              trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          naiveBayesClassifier.test(naiveBayesModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(naiveBayesModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building naive bayes model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a lasso regression model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildLassoRegressionModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      LassoRegression lassoRegression = new LassoRegression();
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      LassoModel lassoModel =
          lassoRegression.train(
              trainingData,
              Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
              Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
              Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
              Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      Vector weights = lassoModel.weights();
      if (!isValidWeights(weights)) {
        throw new MLModelBuilderException(
            "Weights of the model generated are null or infinity. [Weights] "
                + vectorToString(weights));
      }
      JavaRDD<Tuple2<Double, Double>> predictionsAndLabels =
          lassoRegression.test(lassoModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary regressionModelSummary =
          SparkModelUtils.generateRegressionModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLGeneralizedLinearModel(lassoModel));

      List<FeatureImportance> featureWeights =
          getFeatureWeights(includedFeatures, lassoModel.weights().toArray());
      regressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0]));
      regressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString());
      regressionModelSummary.setFeatureImportance(featureWeights);

      RegressionMetrics regressionMetrics =
          getRegressionMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      Double meanSquaredError = regressionMetrics.meanSquaredError();
      regressionModelSummary.setMeanSquaredError(meanSquaredError);
      regressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return regressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building lasso regression model: " + e.getMessage(), e);
    }
  }
  /**
   * This method builds a support vector machine (SVM) model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildSVMModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {

    if (getNoOfClasses(mlModel) > 2) {
      throw new MLModelBuilderException(
          "A binary classification algorithm cannot have more than "
              + "two distinct values in response variable.");
    }

    try {
      SVM svm = new SVM();
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      SVMModel svmModel =
          svm.train(
              trainingData,
              Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)),
              hyperParameters.get(MLConstants.REGULARIZATION_TYPE),
              Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)),
              Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)),
              Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      Vector weights = svmModel.weights();
      if (!isValidWeights(weights)) {
        throw new MLModelBuilderException(
            "Weights of the model generated are null or infinity. [Weights] "
                + vectorToString(weights));
      }

      // getting scores and labels without clearing threshold to get confusion matrix
      JavaRDD<Tuple2<Object, Object>> scoresAndLabelsThresholded = svm.test(svmModel, testingData);
      MulticlassMetrics multiclassMetrics =
          new MulticlassMetrics(JavaRDD.toRDD(scoresAndLabelsThresholded));
      MulticlassConfusionMatrix multiclassConfusionMatrix =
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel);

      svmModel.clearThreshold();
      JavaRDD<Tuple2<Object, Object>> scoresAndLabels = svm.test(svmModel, testingData);
      ProbabilisticClassificationModelSummary probabilisticClassificationModelSummary =
          SparkModelUtils.generateProbabilisticClassificationModelSummary(
              sparkContext, testingData, scoresAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(svmModel));

      List<FeatureImportance> featureWeights =
          getFeatureWeights(includedFeatures, svmModel.weights().toArray());
      probabilisticClassificationModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      probabilisticClassificationModelSummary.setFeatureImportance(featureWeights);
      probabilisticClassificationModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.SVM.toString());

      probabilisticClassificationModelSummary.setMulticlassConfusionMatrix(
          multiclassConfusionMatrix);
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      probabilisticClassificationModelSummary.setModelAccuracy(modelAccuracy);
      probabilisticClassificationModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return probabilisticClassificationModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building SVM model: " + e.getMessage(), e);
    }
  }
  /** Build a supervised model. */
  public MLModel build() throws MLModelBuilderException {
    MLModelConfigurationContext context = getContext();
    JavaSparkContext sparkContext = null;
    DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
    MLModel mlModel = new MLModel();
    try {
      sparkContext = context.getSparkContext();
      Workflow workflow = context.getFacts();
      long modelId = context.getModelId();

      // Verify validity of response variable
      String typeOfResponseVariable =
          getTypeOfResponseVariable(workflow.getResponseVariable(), workflow.getFeatures());

      if (typeOfResponseVariable == null) {
        throw new MLModelBuilderException(
            "Type of response variable cannot be null for supervised learning " + "algorithms.");
      }

      // Stops model building if a categorical attribute is used with numerical prediction
      if (workflow.getAlgorithmClass().equals(AlgorithmType.NUMERICAL_PREDICTION.getValue())
          && typeOfResponseVariable.equals(FeatureType.CATEGORICAL)) {
        throw new MLModelBuilderException(
            "Categorical attribute "
                + workflow.getResponseVariable()
                + " cannot be used as the response variable of the Numerical Prediction algorithm: "
                + workflow.getAlgorithmName());
      }

      // generate train and test datasets by converting tokens to labeled points
      int responseIndex = context.getResponseIndex();
      SortedMap<Integer, String> includedFeatures =
          MLUtils.getIncludedFeaturesAfterReordering(
              workflow, context.getNewToOldIndicesList(), responseIndex);

      // gets the pre-processed dataset
      JavaRDD<LabeledPoint> labeledPoints = preProcess().cache();

      JavaRDD<LabeledPoint>[] dataSplit =
          labeledPoints.randomSplit(
              new double[] {workflow.getTrainDataFraction(), 1 - workflow.getTrainDataFraction()},
              MLConstants.RANDOM_SEED);

      // remove from cache
      labeledPoints.unpersist();

      JavaRDD<LabeledPoint> trainingData = dataSplit[0].cache();
      JavaRDD<LabeledPoint> testingData = dataSplit[1];
      // create a deployable MLModel object
      mlModel.setAlgorithmName(workflow.getAlgorithmName());
      mlModel.setAlgorithmClass(workflow.getAlgorithmClass());
      mlModel.setFeatures(workflow.getIncludedFeatures());
      mlModel.setResponseVariable(workflow.getResponseVariable());
      mlModel.setEncodings(context.getEncodings());
      mlModel.setNewToOldIndicesList(context.getNewToOldIndicesList());
      mlModel.setResponseIndex(responseIndex);

      ModelSummary summaryModel = null;
      Map<Integer, Integer> categoricalFeatureInfo;

      // build a machine learning model according to user selected algorithm
      SUPERVISED_ALGORITHM supervisedAlgorithm =
          SUPERVISED_ALGORITHM.valueOf(workflow.getAlgorithmName());
      switch (supervisedAlgorithm) {
        case LOGISTIC_REGRESSION:
          summaryModel =
              buildLogisticRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  true);
          break;
        case LOGISTIC_REGRESSION_LBFGS:
          summaryModel =
              buildLogisticRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  false);
          break;
        case DECISION_TREE:
          categoricalFeatureInfo = getCategoricalFeatureInfo(context.getEncodings());
          summaryModel =
              buildDecisionTreeModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  categoricalFeatureInfo);
          break;
        case RANDOM_FOREST:
          categoricalFeatureInfo = getCategoricalFeatureInfo(context.getEncodings());
          summaryModel =
              buildRandomForestTreeModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures,
                  categoricalFeatureInfo);
          break;
        case SVM:
          summaryModel =
              buildSVMModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case NAIVE_BAYES:
          summaryModel =
              buildNaiveBayesModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case LINEAR_REGRESSION:
          summaryModel =
              buildLinearRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case RIDGE_REGRESSION:
          summaryModel =
              buildRidgeRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        case LASSO_REGRESSION:
          summaryModel =
              buildLassoRegressionModel(
                  sparkContext,
                  modelId,
                  trainingData,
                  testingData,
                  workflow,
                  mlModel,
                  includedFeatures);
          break;
        default:
          throw new AlgorithmNameException("Incorrect algorithm name");
      }

      // persist model summary
      databaseService.updateModelSummary(modelId, summaryModel);
      return mlModel;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building supervised machine learning model: " + e.getMessage(),
          e);
    }
  }