/**
   * This method builds a naive bayes model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildNaiveBayesModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
      NaiveBayesModel naiveBayesModel =
          naiveBayesClassifier.train(
              trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          naiveBayesClassifier.test(naiveBayesModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(naiveBayesModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building naive bayes model: " + e.getMessage(), e);
    }
  }
Esempio n. 2
0
  /**
   * Main method reads command-line flags and outputs either the classifications of the test file or
   * uses cross-validation to compute a mean accuracy of the classifier.
   *
   * @param args
   * @throws IOException
   */
  public static void main(String[] args) throws IOException {
    if (args.length < 2) {
      System.out.println("usage: java HW3 <trainingFilename> <testFilename>");
    }

    // Output classifications on test data
    File trainingFile = new File(args[0]);
    File testFile = new File(args[1]);

    Instance[] trainingData = createInstances(trainingFile);
    Instance[] testData = createInstances(testFile);

    NaiveBayesClassifier nbc = getNewClassifier();
    nbc.train(trainingData, vocabularySize(trainingData, testData));

    for (Instance i : testData) {
      ClassifyResult cr = nbc.classify(i.words);
      System.out.println(String.format("%s %s", cr.label, i.label));
      System.out.println(String.format("Log probability of spam: %f", cr.log_prob_spam));
      System.out.println(String.format("Log probability of ham: %f", cr.log_prob_ham));
    }
  }