/**
   * This method gets model accuracy from given multi-class metrics
   *
   * @param multiclassMetrics multi-class metrics object
   */
  private Double getModelAccuracy(MulticlassMetrics multiclassMetrics) {
    DecimalFormat decimalFormat = new DecimalFormat(MLConstants.DECIMAL_FORMAT);

    Double modelAccuracy = 0.0;
    int confusionMatrixSize = multiclassMetrics.confusionMatrix().numCols();
    int confusionMatrixDiagonal = 0;
    long totalPopulation = arraySum(multiclassMetrics.confusionMatrix().toArray());
    for (int i = 0; i < confusionMatrixSize; i++) {
      int diagonalValueIndex = multiclassMetrics.confusionMatrix().index(i, i);
      confusionMatrixDiagonal += multiclassMetrics.confusionMatrix().toArray()[diagonalValueIndex];
    }
    if (totalPopulation > 0) {
      modelAccuracy = (double) confusionMatrixDiagonal / totalPopulation;
    }
    return Double.parseDouble(decimalFormat.format(modelAccuracy * 100));
  }
  /**
   * This method returns multiclass confusion matrix for a given multiclass metric object
   *
   * @param multiclassMetrics Multiclass metric object
   */
  private MulticlassConfusionMatrix getMulticlassConfusionMatrix(
      MulticlassMetrics multiclassMetrics, MLModel mlModel) {
    MulticlassConfusionMatrix multiclassConfusionMatrix = new MulticlassConfusionMatrix();
    if (multiclassMetrics != null) {
      int size = multiclassMetrics.confusionMatrix().numCols();
      double[] matrixArray = multiclassMetrics.confusionMatrix().toArray();
      double[][] matrix = new double[size][size];
      // set values of matrix into a 2D array
      for (int i = 0; i < size; i++) {
        for (int j = 0; j < size; j++) {
          matrix[i][j] = matrixArray[(j * size) + i];
        }
      }
      multiclassConfusionMatrix.setMatrix(matrix);

      List<Map<String, Integer>> encodings = mlModel.getEncodings();
      // decode only if encodings are available
      if (encodings != null) {
        // last index is response variable encoding
        Map<String, Integer> encodingMap = encodings.get(encodings.size() - 1);
        List<String> decodedLabels = new ArrayList<String>();
        for (double label : multiclassMetrics.labels()) {
          Integer labelInt = (int) label;
          String decodedLabel = MLUtils.getKeyByValue(encodingMap, labelInt);
          if (decodedLabel != null) {
            decodedLabels.add(decodedLabel);
          } else {
            continue;
          }
        }
        multiclassConfusionMatrix.setLabels(decodedLabels);
      } else {
        List<String> labelList = toStringList(multiclassMetrics.labels());
        multiclassConfusionMatrix.setLabels(labelList);
      }

      multiclassConfusionMatrix.setSize(size);
    }
    return multiclassConfusionMatrix;
  }
  public static void main(String[] args) {
    // parse the arguments
    Params params = parse(args);
    SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext jsql = new SQLContext(jsc);

    // configure the base classifier
    LogisticRegression classifier =
        new LogisticRegression()
            .setMaxIter(params.maxIter)
            .setTol(params.tol)
            .setFitIntercept(params.fitIntercept);

    if (params.regParam != null) {
      classifier.setRegParam(params.regParam);
    }
    if (params.elasticNetParam != null) {
      classifier.setElasticNetParam(params.elasticNetParam);
    }

    // instantiate the One Vs Rest Classifier
    OneVsRest ovr = new OneVsRest().setClassifier(classifier);

    String input = params.input;
    RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
    RDD<LabeledPoint> train;
    RDD<LabeledPoint> test;

    // compute the train/ test split: if testInput is not provided use part of input
    String testInput = params.testInput;
    if (testInput != null) {
      train = inputData;
      // compute the number of features in the training set.
      int numFeatures = inputData.first().features().size();
      test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
    } else {
      double f = params.fracTest;
      RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[] {1 - f, f}, 12345);
      train = tmp[0];
      test = tmp[1];
    }

    // train the multiclass model
    DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
    OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());

    // score the model on test data
    DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
    DataFrame predictions = ovrModel.transform(testDataFrame.cache()).select("prediction", "label");

    // obtain metrics
    MulticlassMetrics metrics = new MulticlassMetrics(predictions);
    StructField predictionColSchema = predictions.schema().apply("prediction");
    Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();

    // compute the false positive rate per label
    StringBuilder results = new StringBuilder();
    results.append("label\tfpr\n");
    for (int label = 0; label < numClasses; label++) {
      results.append(label);
      results.append("\t");
      results.append(metrics.falsePositiveRate((double) label));
      results.append("\n");
    }

    Matrix confusionMatrix = metrics.confusionMatrix();
    // output the Confusion Matrix
    System.out.println("Confusion Matrix");
    System.out.println(confusionMatrix);
    System.out.println();
    System.out.println(results);

    jsc.stop();
  }