/** * A utility method to generate class classification model summary * * @param predictionsAndLabels Predictions and actual labels * @return Class classification model summary */ public static ClassClassificationAndRegressionModelSummary getClassClassificationModelSummary( JavaSparkContext sparkContext, JavaRDD<LabeledPoint> testingData, JavaPairRDD<Double, Double> predictionsAndLabels) { ClassClassificationAndRegressionModelSummary classClassificationModelSummary = new ClassClassificationAndRegressionModelSummary(); // store predictions and actuals List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>(); for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) { PredictedVsActual predictedVsActual = new PredictedVsActual(); predictedVsActual.setPredicted(scoreAndLabel._1()); predictedVsActual.setActual(scoreAndLabel._2()); predictedVsActuals.add(predictedVsActual); } // create a list of feature values List<double[]> features = new ArrayList<double[]>(); for (LabeledPoint labeledPoint : testingData.collect()) { if (labeledPoint != null && labeledPoint.features() != null) { double[] rowFeatures = labeledPoint.features().toArray(); features.add(rowFeatures); } } // create a list of feature values with predicted vs. actuals List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>(); for (int i = 0; i < features.size(); i++) { TestResultDataPoint testResultDataPoint = new TestResultDataPoint(); testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i)); testResultDataPoint.setFeatureValues(features.get(i)); testResultDataPoints.add(testResultDataPoint); } // covert List to JavaRDD JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD = sparkContext.parallelize(testResultDataPoints); // collect RDD as a sampled list List<TestResultDataPoint> testResultDataPointsSample; if (testResultDataPointsJavaRDD.count() > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) { testResultDataPointsSample = testResultDataPointsJavaRDD.takeSample( true, MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()); } else { testResultDataPointsSample = testResultDataPointsJavaRDD.collect(); } classClassificationModelSummary.setTestResultDataPointsSample(testResultDataPointsSample); classClassificationModelSummary.setPredictedVsActuals(predictedVsActuals); // calculate test error double error = 1.0 * predictionsAndLabels .filter( new Function<Tuple2<Double, Double>, Boolean>() { private static final long serialVersionUID = -3063364114286182333L; @Override public Boolean call(Tuple2<Double, Double> pl) { return !pl._1().equals(pl._2()); } }) .count() / predictionsAndLabels.count(); classClassificationModelSummary.setError(error); return classClassificationModelSummary; }
/** * A utility method to generate regression model summary * * @param predictionsAndLabels Tuple2 containing predicted and actual values * @return Regression model summary */ public static ClassClassificationAndRegressionModelSummary generateRegressionModelSummary( JavaSparkContext sparkContext, JavaRDD<LabeledPoint> testingData, JavaRDD<Tuple2<Double, Double>> predictionsAndLabels) { ClassClassificationAndRegressionModelSummary regressionModelSummary = new ClassClassificationAndRegressionModelSummary(); // store predictions and actuals List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>(); DecimalFormat decimalFormat = new DecimalFormat(MLConstants.DECIMAL_FORMAT); for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) { PredictedVsActual predictedVsActual = new PredictedVsActual(); predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(scoreAndLabel._1()))); predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(scoreAndLabel._2()))); predictedVsActuals.add(predictedVsActual); } // create a list of feature values List<double[]> features = new ArrayList<double[]>(); for (LabeledPoint labeledPoint : testingData.collect()) { if (labeledPoint != null && labeledPoint.features() != null) { double[] rowFeatures = labeledPoint.features().toArray(); features.add(rowFeatures); } } // create a list of feature values with predicted vs. actuals List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>(); for (int i = 0; i < features.size(); i++) { TestResultDataPoint testResultDataPoint = new TestResultDataPoint(); testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i)); testResultDataPoint.setFeatureValues(features.get(i)); testResultDataPoints.add(testResultDataPoint); } // covert List to JavaRDD JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD = sparkContext.parallelize(testResultDataPoints); // collect RDD as a sampled list List<TestResultDataPoint> testResultDataPointsSample; if (testResultDataPointsJavaRDD.count() > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) { testResultDataPointsSample = testResultDataPointsJavaRDD.takeSample( true, MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()); } else { testResultDataPointsSample = testResultDataPointsJavaRDD.collect(); } regressionModelSummary.setTestResultDataPointsSample(testResultDataPointsSample); regressionModelSummary.setPredictedVsActuals(predictedVsActuals); // calculate mean squared error (MSE) double meanSquaredError = new JavaDoubleRDD( predictionsAndLabels .map( new Function<Tuple2<Double, Double>, Object>() { private static final long serialVersionUID = -162193633199074816L; public Object call(Tuple2<Double, Double> pair) { return Math.pow(pair._1() - pair._2(), 2.0); } }) .rdd()) .mean(); regressionModelSummary.setError(meanSquaredError); return regressionModelSummary; }