/**
   * This method builds a decision tree model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildDecisionTreeModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures,
      Map<Integer, Integer> categoricalFeatureInfo)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      DecisionTree decisionTree = new DecisionTree();
      DecisionTreeModel decisionTreeModel =
          decisionTree.train(
              trainingData,
              getNoOfClasses(mlModel),
              categoricalFeatureInfo,
              hyperParameters.get(MLConstants.IMPURITY),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_DEPTH)),
              Integer.parseInt(hyperParameters.get(MLConstants.MAX_BINS)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          decisionTree.test(decisionTreeModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLDecisionTreeModel(decisionTreeModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.DECISION_TREE.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building decision tree model: " + e.getMessage(), e);
    }
  }
Esempio n. 2
0
  /**
   * Displays the tree, starting with the given root node.
   *
   * @param root the Node that is the root of the tree to be displayed
   * @param offset the String
   */
  public static void displayTree(Node root, String offset) {
    if (root.children.size() == 0) {
      DecisionTree.appendText("\n" + offset + "    THEN (" + root.label + ")  (Leaf node)");
      return;
    } else {
      Enumeration enum1 = root.children.elements();
      Enumeration enum2 = root.linkLabels.elements();

      DecisionTree.appendText("\n" + offset + "   " + root.label + " (Interior node)");
      while (enum1.hasMoreElements()) {
        DecisionTree.appendText("\n" + offset + "   IF (" + (String) enum2.nextElement() + ")");
        displayTree((Node) enum1.nextElement(), offset + "   ");
      }
    }
  }
  public static void main(String args[]) {
    DecisionTree decisionTree = makeOutlookTree();

    Map<String, String> case1 = new HashMap<String, String>();
    case1.put("Outlook", "Overcast");
    case1.put("Temperature", "Hot");
    case1.put("Humidity", "High");
    case1.put("Wind", "Strong");

    try {
      System.out.println(decisionTree.classify(case1));
    } catch (UnknownDecisionException e) {
      System.out.println("?");
    }
  }