/**
   * This method builds a naive bayes model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildNaiveBayesModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
      NaiveBayesModel naiveBayesModel =
          naiveBayesClassifier.train(
              trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          naiveBayesClassifier.test(naiveBayesModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(naiveBayesModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building naive bayes model: " + e.getMessage(), e);
    }
  }
Esempio n. 2
0
  /**
   * Main method reads command-line flags and outputs either the classifications of the test file or
   * uses cross-validation to compute a mean accuracy of the classifier.
   *
   * @param args
   * @throws IOException
   */
  public static void main(String[] args) throws IOException {
    if (args.length < 2) {
      System.out.println("usage: java HW3 <trainingFilename> <testFilename>");
    }

    // Output classifications on test data
    File trainingFile = new File(args[0]);
    File testFile = new File(args[1]);

    Instance[] trainingData = createInstances(trainingFile);
    Instance[] testData = createInstances(testFile);

    NaiveBayesClassifier nbc = getNewClassifier();
    nbc.train(trainingData, vocabularySize(trainingData, testData));

    for (Instance i : testData) {
      ClassifyResult cr = nbc.classify(i.words);
      System.out.println(String.format("%s %s", cr.label, i.label));
      System.out.println(String.format("Log probability of spam: %f", cr.log_prob_spam));
      System.out.println(String.format("Log probability of ham: %f", cr.log_prob_ham));
    }
  }
Esempio n. 3
0
  /**
   * @param args the command line arguments
   * @throws Exception
   */
  public static void main(String[] args) throws Exception {
    PreProcessor p = new PreProcessor("census-income.data", "census-income-preprocessed.arff");

    p.smote();

    PreProcessor p_test =
        new PreProcessor("census-income.test", "census-income-test-preprocessed.arff");

    p_test.run();

    BufferedReader traindata =
        new BufferedReader(new FileReader("census-income-preprocessed.arff"));
    BufferedReader testdata =
        new BufferedReader(new FileReader("census-income-test-preprocessed.arff"));
    Instances traininstance = new Instances(traindata);
    Instances testinstance = new Instances(testdata);

    traindata.close();
    testdata.close();
    traininstance.setClassIndex(traininstance.numAttributes() - 1);
    testinstance.setClassIndex(testinstance.numAttributes() - 1);
    int numOfAttributes = testinstance.numAttributes();
    int numOfInstances = testinstance.numInstances();

    NaiveBayesClassifier nb = new NaiveBayesClassifier("census-income-preprocessed.arff");
    Classifier cnaive = nb.NBClassify();

    DecisionTree dt = new DecisionTree("census-income-preprocessed.arff");
    Classifier cls = dt.DTClassify();

    AdaBoost ab = new AdaBoost("census-income-preprocessed.arff");
    AdaBoostM1 m1 = ab.AdaBoostDTClassify();

    BaggingMethod b = new BaggingMethod("census-income-preprocessed.arff");
    Bagging bag = b.BaggingDTClassify();

    SVM s = new SVM("census-income-preprocessed.arff");
    SMO svm = s.SMOClassifier();

    knn knnclass = new knn("census-income-preprocessed.arff");
    IBk knnc = knnclass.knnclassifier();

    Logistic log = new Logistic();
    log.buildClassifier(traininstance);

    int match = 0;
    int error = 0;
    int greater = 0;
    int less = 0;

    for (int i = 0; i < numOfInstances; i++) {
      String predicted = "";
      greater = 0;
      less = 0;
      double predictions[] = new double[8];

      double pred = cls.classifyInstance(testinstance.instance(i));
      predictions[0] = pred;

      double abpred = m1.classifyInstance(testinstance.instance(i));
      predictions[1] = abpred;

      double naivepred = cnaive.classifyInstance(testinstance.instance(i));
      predictions[2] = naivepred;

      double bagpred = bag.classifyInstance(testinstance.instance(i));
      predictions[3] = bagpred;

      double smopred = svm.classifyInstance(testinstance.instance(i));
      predictions[4] = smopred;

      double knnpred = knnc.classifyInstance(testinstance.instance(i));
      predictions[5] = knnpred;

      for (int j = 0; j < 6; j++) {
        if ((testinstance.instance(i).classAttribute().value((int) predictions[j]))
                .compareTo(">50K")
            == 0) greater++;
        else less++;
      }

      if (greater > less) predicted = ">50K";
      else predicted = "<=50K";

      if ((testinstance.instance(i).stringValue(numOfAttributes - 1)).compareTo(predicted) == 0)
        match++;
      else error++;
    }

    System.out.println("Correctly classified Instances: " + match);
    System.out.println("Misclassified Instances: " + error);

    double accuracy = (double) match / (double) numOfInstances * 100;
    double error_percent = 100 - accuracy;
    System.out.println("Accuracy: " + accuracy + "%");
    System.out.println("Error: " + error_percent + "%");
  }