/** * This method gets model accuracy from given multi-class metrics * * @param multiclassMetrics multi-class metrics object */ private Double getModelAccuracy(MulticlassMetrics multiclassMetrics) { DecimalFormat decimalFormat = new DecimalFormat(MLConstants.DECIMAL_FORMAT); Double modelAccuracy = 0.0; int confusionMatrixSize = multiclassMetrics.confusionMatrix().numCols(); int confusionMatrixDiagonal = 0; long totalPopulation = arraySum(multiclassMetrics.confusionMatrix().toArray()); for (int i = 0; i < confusionMatrixSize; i++) { int diagonalValueIndex = multiclassMetrics.confusionMatrix().index(i, i); confusionMatrixDiagonal += multiclassMetrics.confusionMatrix().toArray()[diagonalValueIndex]; } if (totalPopulation > 0) { modelAccuracy = (double) confusionMatrixDiagonal / totalPopulation; } return Double.parseDouble(decimalFormat.format(modelAccuracy * 100)); }
/** * This method returns multiclass confusion matrix for a given multiclass metric object * * @param multiclassMetrics Multiclass metric object */ private MulticlassConfusionMatrix getMulticlassConfusionMatrix( MulticlassMetrics multiclassMetrics, MLModel mlModel) { MulticlassConfusionMatrix multiclassConfusionMatrix = new MulticlassConfusionMatrix(); if (multiclassMetrics != null) { int size = multiclassMetrics.confusionMatrix().numCols(); double[] matrixArray = multiclassMetrics.confusionMatrix().toArray(); double[][] matrix = new double[size][size]; // set values of matrix into a 2D array for (int i = 0; i < size; i++) { for (int j = 0; j < size; j++) { matrix[i][j] = matrixArray[(j * size) + i]; } } multiclassConfusionMatrix.setMatrix(matrix); List<Map<String, Integer>> encodings = mlModel.getEncodings(); // decode only if encodings are available if (encodings != null) { // last index is response variable encoding Map<String, Integer> encodingMap = encodings.get(encodings.size() - 1); List<String> decodedLabels = new ArrayList<String>(); for (double label : multiclassMetrics.labels()) { Integer labelInt = (int) label; String decodedLabel = MLUtils.getKeyByValue(encodingMap, labelInt); if (decodedLabel != null) { decodedLabels.add(decodedLabel); } else { continue; } } multiclassConfusionMatrix.setLabels(decodedLabels); } else { List<String> labelList = toStringList(multiclassMetrics.labels()); multiclassConfusionMatrix.setLabels(labelList); } multiclassConfusionMatrix.setSize(size); } return multiclassConfusionMatrix; }
public static void main(String[] args) { // parse the arguments Params params = parse(args); SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample"); JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext jsql = new SQLContext(jsc); // configure the base classifier LogisticRegression classifier = new LogisticRegression() .setMaxIter(params.maxIter) .setTol(params.tol) .setFitIntercept(params.fitIntercept); if (params.regParam != null) { classifier.setRegParam(params.regParam); } if (params.elasticNetParam != null) { classifier.setElasticNetParam(params.elasticNetParam); } // instantiate the One Vs Rest Classifier OneVsRest ovr = new OneVsRest().setClassifier(classifier); String input = params.input; RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input); RDD<LabeledPoint> train; RDD<LabeledPoint> test; // compute the train/ test split: if testInput is not provided use part of input String testInput = params.testInput; if (testInput != null) { train = inputData; // compute the number of features in the training set. int numFeatures = inputData.first().features().size(); test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures); } else { double f = params.fracTest; RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[] {1 - f, f}, 12345); train = tmp[0]; test = tmp[1]; } // train the multiclass model DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class); OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache()); // score the model on test data DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class); DataFrame predictions = ovrModel.transform(testDataFrame.cache()).select("prediction", "label"); // obtain metrics MulticlassMetrics metrics = new MulticlassMetrics(predictions); StructField predictionColSchema = predictions.schema().apply("prediction"); Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get(); // compute the false positive rate per label StringBuilder results = new StringBuilder(); results.append("label\tfpr\n"); for (int label = 0; label < numClasses; label++) { results.append(label); results.append("\t"); results.append(metrics.falsePositiveRate((double) label)); results.append("\n"); } Matrix confusionMatrix = metrics.confusionMatrix(); // output the Confusion Matrix System.out.println("Confusion Matrix"); System.out.println(confusionMatrix); System.out.println(); System.out.println(results); jsc.stop(); }