/** * This method builds a decision tree model * * @param sparkContext JavaSparkContext initialized with the application * @param modelID Model ID * @param trainingData Training data as a JavaRDD of LabeledPoints * @param testingData Testing data as a JavaRDD of LabeledPoints * @param workflow Machine learning workflow * @param mlModel Deployable machine learning model * @throws MLModelBuilderException */ private ModelSummary buildDecisionTreeModel( JavaSparkContext sparkContext, long modelID, JavaRDD<LabeledPoint> trainingData, JavaRDD<LabeledPoint> testingData, Workflow workflow, MLModel mlModel, SortedMap<Integer, String> includedFeatures, Map<Integer, Integer> categoricalFeatureInfo) throws MLModelBuilderException { try { Map<String, String> hyperParameters = workflow.getHyperParameters(); DecisionTree decisionTree = new DecisionTree(); DecisionTreeModel decisionTreeModel = decisionTree.train( trainingData, getNoOfClasses(mlModel), categoricalFeatureInfo, hyperParameters.get(MLConstants.IMPURITY), Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)), Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS))); // remove from cache trainingData.unpersist(); // add test data to cache testingData.cache(); JavaPairRDD<Double, Double> predictionsAndLabels = decisionTree.test(decisionTreeModel, testingData).cache(); ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary = SparkModelUtils.getClassClassificationModelSummary( sparkContext, testingData, predictionsAndLabels); // remove from cache testingData.unpersist(); mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel)); classClassificationAndRegressionModelSummary.setFeatures( includedFeatures.values().toArray(new String[0])); classClassificationAndRegressionModelSummary.setAlgorithm( SUPERVISED_ALGORITHM.DECISION_TREE.toString()); MulticlassMetrics multiclassMetrics = getMulticlassMetrics(sparkContext, predictionsAndLabels); predictionsAndLabels.unpersist(); classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix( getMulticlassConfusionMatrix(multiclassMetrics, mlModel)); Double modelAccuracy = getModelAccuracy(multiclassMetrics); classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy); classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion()); return classClassificationAndRegressionModelSummary; } catch (Exception e) { throw new MLModelBuilderException( "An error occurred while building decision tree model: " + e.getMessage(), e); } }
/** * Displays the tree, starting with the given root node. * * @param root the Node that is the root of the tree to be displayed * @param offset the String */ public static void displayTree(Node root, String offset) { if (root.children.size() == 0) { DecisionTree.appendText("\n" + offset + " THEN (" + root.label + ") (Leaf node)"); return; } else { Enumeration enum1 = root.children.elements(); Enumeration enum2 = root.linkLabels.elements(); DecisionTree.appendText("\n" + offset + " " + root.label + " (Interior node)"); while (enum1.hasMoreElements()) { DecisionTree.appendText("\n" + offset + " IF (" + (String) enum2.nextElement() + ")"); displayTree((Node) enum1.nextElement(), offset + " "); } } }
public static void main(String args[]) { DecisionTree decisionTree = makeOutlookTree(); Map<String, String> case1 = new HashMap<String, String>(); case1.put("Outlook", "Overcast"); case1.put("Temperature", "Hot"); case1.put("Humidity", "High"); case1.put("Wind", "Strong"); try { System.out.println(decisionTree.classify(case1)); } catch (UnknownDecisionException e) { System.out.println("?"); } }