/** * This method builds a decision tree model * * @param sparkContext JavaSparkContext initialized with the application * @param modelID Model ID * @param trainingData Training data as a JavaRDD of LabeledPoints * @param testingData Testing data as a JavaRDD of LabeledPoints * @param workflow Machine learning workflow * @param mlModel Deployable machine learning model * @throws MLModelBuilderException */ private ModelSummary buildDecisionTreeModel( JavaSparkContext sparkContext, long modelID, JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel, SortedMap<Integer, String> includedFeatures, Map<Integer, Integer> categoricalFeatureInfo) throws MLModelBuilderException { try { Map<String, String> hyperParameters = workflow.getHyperParameters(); DecisionTree decisionTree = new DecisionTree(); DecisionTreeModel decisionTreeModel = decisionTree.train( trainingData, getNoOfClasses(mlModel), categoricalFeatureInfo, hyperParameters.get(MLConstants.IMPURITY), Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)), Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS))); // remove from cache trainingData.unpersist(); // add test data to cache testingData.cache(); JavaPairRDD<Double, Double> predictionsAndLabels = decisionTree.test(decisionTreeModel, testingData).cache(); ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = SparkModelUtils.getClassClassificationModelSummary( sparkContext, testingData, predictionsAndLabels); // remove from cache testingData.unpersist(); mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel)); classClassificationAndRegressionModelSummary.setFeatures( includedFeatures.values().toArray(new String[0])); classClassificationAndRegressionModelSummary.setAlgorithm( SUPERVISED_ALGORITHM.DECISION_TREE.toString()); MulticlassMetrics multiclassMetrics = getMulticlassMetrics(sparkContext, predictionsAndLabels); predictionsAndLabels.unpersist(); classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix( getMulticlassConfusionMatrix(multiclassMetrics, mlModel)); Double modelAccuracy = getModelAccuracy(multiclassMetrics); classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy); classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion()); return classClassificationAndRegressionModelSummary; } catch (Exception e) { throw new MLModelBuilderException( "An error occurred while building decision tree model: " + e.getMessage(), e); } }
/** * This method builds a naive bayes model * * @param sparkContext JavaSparkContext initialized with the application * @param modelID Model ID * @param trainingData Training data as a JavaRDD of LabeledPoints * @param testingData Testing data as a JavaRDD of LabeledPoints * @param workflow Machine learning workflow * @param mlModel Deployable machine learning model * @throws MLModelBuilderException */ private ModelSummary buildNaiveBayesModel( JavaSparkContext sparkContext, long modelID, JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel, SortedMap<Integer, String> includedFeatures) throws MLModelBuilderException { try { Map<String, String> hyperParameters = workflow.getHyperParameters(); NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier(); NaiveBayesModel naiveBayesModel = naiveBayesClassifier.train( trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA))); // remove from cache trainingData.unpersist(); // add test data to cache testingData.cache(); JavaPairRDD<Double, Double> predictionsAndLabels = naiveBayesClassifier.test(naiveBayesModel, testingData).cache(); ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = SparkModelUtils.getClassClassificationModelSummary( sparkContext, testingData, predictionsAndLabels); // remove from cache testingData.unpersist(); mlModel.setModel(new MLClassificationModel(naiveBayesModel)); classClassificationAndRegressionModelSummary.setFeatures( includedFeatures.values().toArray(new String[0])); classClassificationAndRegressionModelSummary.setAlgorithm( SUPERVISED_ALGORITHM.NAIVE_BAYES.toString()); MulticlassMetrics multiclassMetrics = getMulticlassMetrics(sparkContext, predictionsAndLabels); predictionsAndLabels.unpersist(); classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix( getMulticlassConfusionMatrix(multiclassMetrics, mlModel)); Double modelAccuracy = getModelAccuracy(multiclassMetrics); classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy); classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion()); return classClassificationAndRegressionModelSummary; } catch (Exception e) { throw new MLModelBuilderException( "An error occurred while building naive bayes model: " + e.getMessage(), e); } }
/** * This method builds a lasso regression model * * @param sparkContext JavaSparkContext initialized with the application * @param modelID Model ID * @param trainingData Training data as a JavaRDD of LabeledPoints * @param testingData Testing data as a JavaRDD of LabeledPoints * @param workflow Machine learning workflow * @param mlModel Deployable machine learning model * @throws MLModelBuilderException */ private ModelSummary buildLassoRegressionModel( JavaSparkContext sparkContext, long modelID, JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel, SortedMap<Integer, String> includedFeatures) throws MLModelBuilderException { try { LassoRegression lassoRegression = new LassoRegression(); Map<String, String> hyperParameters = workflow.getHyperParameters(); LassoModel lassoModel = lassoRegression.train( trainingData, Integer.parseInt(hyperParameters.get(MLConstants.ITERATIONS)), Double.parseDouble(hyperParameters.get(MLConstants.LEARNING_RATE)), Double.parseDouble(hyperParameters.get(MLConstants.REGULARIZATION_PARAMETER)), Double.parseDouble(hyperParameters.get(MLConstants.SGD_DATA_FRACTION))); // remove from cache trainingData.unpersist(); // add test data to cache testingData.cache(); Vector weights = lassoModel.weights(); if (!isValidWeights(weights)) { throw new MLModelBuilderException( "Weights of the model generated are null or infinity. [Weights] " + vectorToString(weights)); } JavaRDD<Tuple2<Double, Double>> predictionsAndLabels = lassoRegression.test(lassoModel, testingData).cache(); ClassClassificationAndRegressionModelSummary regressionModelSummary = SparkModelUtils.generateRegressionModelSummary( sparkContext, testingData, predictionsAndLabels); // remove from cache testingData.unpersist(); mlModel.setModel(new MLGeneralizedLinearModel(lassoModel)); List<FeatureImportance> featureWeights = getFeatureWeights(includedFeatures, lassoModel.weights().toArray()); regressionModelSummary.setFeatures(includedFeatures.values().toArray(new String[0])); regressionModelSummary.setAlgorithm(SUPERVISED_ALGORITHM.LASSO_REGRESSION.toString()); regressionModelSummary.setFeatureImportance(featureWeights); RegressionMetrics regressionMetrics = getRegressionMetrics(sparkContext, predictionsAndLabels); predictionsAndLabels.unpersist(); Double meanSquaredError = regressionMetrics.meanSquaredError(); regressionModelSummary.setMeanSquaredError(meanSquaredError); regressionModelSummary.setDatasetVersion(workflow.getDatasetVersion()); return regressionModelSummary; } catch (Exception e) { throw new MLModelBuilderException( "An error occurred while building lasso regression model: " + e.getMessage(), e); } }
/** * A utility method to generate class classification model summary * * @param predictionsAndLabels Predictions and actual labels * @return Class classification model summary */ public static ClassClassificationAndRegressionModelSummary getClassClassificationModelSummary( JavaSparkContext sparkContext, JavaRDD<LabeledPoint> testingData, JavaPairRDD<Double, Double> predictionsAndLabels) { ClassClassificationAndRegressionModelSummary classClassificationModelSummary = new ClassClassificationAndRegressionModelSummary(); // store predictions and actuals List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>(); for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) { PredictedVsActual predictedVsActual = new PredictedVsActual(); predictedVsActual.setPredicted(scoreAndLabel._1()); predictedVsActual.setActual(scoreAndLabel._2()); predictedVsActuals.add(predictedVsActual); } // create a list of feature values List<double[]> features = new ArrayList<double[]>(); for (LabeledPoint labeledPoint : testingData.collect()) { if (labeledPoint != null && labeledPoint.features() != null) { double[] rowFeatures = labeledPoint.features().toArray(); features.add(rowFeatures); } } // create a list of feature values with predicted vs. actuals List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>(); for (int i = 0; i < features.size(); i++) { TestResultDataPoint testResultDataPoint = new TestResultDataPoint(); testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i)); testResultDataPoint.setFeatureValues(features.get(i)); testResultDataPoints.add(testResultDataPoint); } // covert List to JavaRDD JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD = sparkContext.parallelize(testResultDataPoints); // collect RDD as a sampled list List<TestResultDataPoint> testResultDataPointsSample; if (testResultDataPointsJavaRDD.count() > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) { testResultDataPointsSample = testResultDataPointsJavaRDD.takeSample( true, MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()); } else { testResultDataPointsSample = testResultDataPointsJavaRDD.collect(); } classClassificationModelSummary.setTestResultDataPointsSample(testResultDataPointsSample); classClassificationModelSummary.setPredictedVsActuals(predictedVsActuals); // calculate test error double error = 1.0 * predictionsAndLabels .filter( new Function<Tuple2<Double, Double>, Boolean>() { private static final long serialVersionUID = -3063364114286182333L; @Override public Boolean call(Tuple2<Double, Double> pl) { return !pl._1().equals(pl._2()); } }) .count() / predictionsAndLabels.count(); classClassificationModelSummary.setError(error); return classClassificationModelSummary; }
/** * A utility method to generate regression model summary * * @param predictionsAndLabels Tuple2 containing predicted and actual values * @return Regression model summary */ public static ClassClassificationAndRegressionModelSummary generateRegressionModelSummary( JavaSparkContext sparkContext, JavaRDD<LabeledPoint> testingData, JavaRDD<Tuple2<Double, Double>> predictionsAndLabels) { ClassClassificationAndRegressionModelSummary regressionModelSummary = new ClassClassificationAndRegressionModelSummary(); // store predictions and actuals List<PredictedVsActual> predictedVsActuals = new ArrayList<PredictedVsActual>(); DecimalFormat decimalFormat = new DecimalFormat(MLConstants.DECIMAL_FORMAT); for (Tuple2<Double, Double> scoreAndLabel : predictionsAndLabels.collect()) { PredictedVsActual predictedVsActual = new PredictedVsActual(); predictedVsActual.setPredicted(Double.parseDouble(decimalFormat.format(scoreAndLabel._1()))); predictedVsActual.setActual(Double.parseDouble(decimalFormat.format(scoreAndLabel._2()))); predictedVsActuals.add(predictedVsActual); } // create a list of feature values List<double[]> features = new ArrayList<double[]>(); for (LabeledPoint labeledPoint : testingData.collect()) { if (labeledPoint != null && labeledPoint.features() != null) { double[] rowFeatures = labeledPoint.features().toArray(); features.add(rowFeatures); } } // create a list of feature values with predicted vs. actuals List<TestResultDataPoint> testResultDataPoints = new ArrayList<TestResultDataPoint>(); for (int i = 0; i < features.size(); i++) { TestResultDataPoint testResultDataPoint = new TestResultDataPoint(); testResultDataPoint.setPredictedVsActual(predictedVsActuals.get(i)); testResultDataPoint.setFeatureValues(features.get(i)); testResultDataPoints.add(testResultDataPoint); } // covert List to JavaRDD JavaRDD<TestResultDataPoint> testResultDataPointsJavaRDD = sparkContext.parallelize(testResultDataPoints); // collect RDD as a sampled list List<TestResultDataPoint> testResultDataPointsSample; if (testResultDataPointsJavaRDD.count() > MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) { testResultDataPointsSample = testResultDataPointsJavaRDD.takeSample( true, MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()); } else { testResultDataPointsSample = testResultDataPointsJavaRDD.collect(); } regressionModelSummary.setTestResultDataPointsSample(testResultDataPointsSample); regressionModelSummary.setPredictedVsActuals(predictedVsActuals); // calculate mean squared error (MSE) double meanSquaredError = new JavaDoubleRDD( predictionsAndLabels .map( new Function<Tuple2<Double, Double>, Object>() { private static final long serialVersionUID = -162193633199074816L; public Object call(Tuple2<Double, Double> pair) { return Math.pow(pair._1() - pair._2(), 2.0); } }) .rdd()) .mean(); regressionModelSummary.setError(meanSquaredError); return regressionModelSummary; }