コード例 #1
0
  /**
   * This method builds a naive bayes model
   *
   * @param sparkContext JavaSparkContext initialized with the application
   * @param modelID Model ID
   * @param trainingData Training data as a JavaRDD of LabeledPoints
   * @param testingData Testing data as a JavaRDD of LabeledPoints
   * @param workflow Machine learning workflow
   * @param mlModel Deployable machine learning model
   * @throws MLModelBuilderException
   */
  private ModelSummary buildNaiveBayesModel(
      JavaSparkContext sparkContext,
      long modelID,
      JavaRDD<LabeledPoint> trainingData,
      JavaRDD<LabeledPoint> testingData,
      Workflow workflow,
      MLModel mlModel,
      SortedMap<Integer, String> includedFeatures)
      throws MLModelBuilderException {
    try {
      Map<String, String> hyperParameters = workflow.getHyperParameters();
      NaiveBayesClassifier naiveBayesClassifier = new NaiveBayesClassifier();
      NaiveBayesModel naiveBayesModel =
          naiveBayesClassifier.train(
              trainingData, Double.parseDouble(hyperParameters.get(MLConstants.LAMBDA)));

      // remove from cache
      trainingData.unpersist();
      // add test data to cache
      testingData.cache();

      JavaPairRDD<Double, Double> predictionsAndLabels =
          naiveBayesClassifier.test(naiveBayesModel, testingData).cache();
      ClassClassificationAndRegressionModelSummary classClassificationAndRegressionModelSummary =
          SparkModelUtils.getClassClassificationModelSummary(
              sparkContext, testingData, predictionsAndLabels);

      // remove from cache
      testingData.unpersist();

      mlModel.setModel(new MLClassificationModel(naiveBayesModel));

      classClassificationAndRegressionModelSummary.setFeatures(
          includedFeatures.values().toArray(new String[0]));
      classClassificationAndRegressionModelSummary.setAlgorithm(
          SUPERVISED_ALGORITHM.NAIVE_BAYES.toString());

      MulticlassMetrics multiclassMetrics =
          getMulticlassMetrics(sparkContext, predictionsAndLabels);

      predictionsAndLabels.unpersist();

      classClassificationAndRegressionModelSummary.setMulticlassConfusionMatrix(
          getMulticlassConfusionMatrix(multiclassMetrics, mlModel));
      Double modelAccuracy = getModelAccuracy(multiclassMetrics);
      classClassificationAndRegressionModelSummary.setModelAccuracy(modelAccuracy);
      classClassificationAndRegressionModelSummary.setDatasetVersion(workflow.getDatasetVersion());

      return classClassificationAndRegressionModelSummary;
    } catch (Exception e) {
      throw new MLModelBuilderException(
          "An error occurred while building naive bayes model: " + e.getMessage(), e);
    }
  }
コード例 #2
0
  public static void main(String[] args) {

    /*
        List examples = new ArrayList();
        String leftLight = "leftLight";
        String rightLight = "rightLight";
        String broken = "BROKEN";
        String ok = "OK";
        Counter c1 = new Counter();
        c1.incrementCount(leftLight, 0);
        c1.incrementCount(rightLight, 0);
        RVFDatum d1 = new RVFDatum(c1, broken);
        examples.add(d1);
        Counter c2 = new Counter();
        c2.incrementCount(leftLight, 1);
        c2.incrementCount(rightLight, 1);
        RVFDatum d2 = new RVFDatum(c2, ok);
        examples.add(d2);
        Counter c3 = new Counter();
        c3.incrementCount(leftLight, 0);
        c3.incrementCount(rightLight, 1);
        RVFDatum d3 = new RVFDatum(c3, ok);
        examples.add(d3);
        Counter c4 = new Counter();
        c4.incrementCount(leftLight, 1);
        c4.incrementCount(rightLight, 0);
        RVFDatum d4 = new RVFDatum(c4, ok);
        examples.add(d4);
        NaiveBayesClassifier classifier = (NaiveBayesClassifier) new NaiveBayesClassifierFactory(200, 200, 1.0, LogPrior.QUADRATIC.ordinal(), NaiveBayesClassifierFactory.CL).trainClassifier(examples);
        classifier.print();
        //now classifiy
        for (int i = 0; i < examples.size(); i++) {
            RVFDatum d = (RVFDatum) examples.get(i);
            Counter scores = classifier.scoresOf(d);
            System.out.println("for datum " + d + " scores are " + scores.toString());
            System.out.println(" class is " + scores.argmax());
        }

    }
    */
    String trainFile = args[0];
    String testFile = args[1];
    NominalDataReader nR = new NominalDataReader();
    Map<Integer, Index<String>> indices = Generics.newHashMap();
    List<RVFDatum<String, Integer>> train = nR.readData(trainFile, indices);
    List<RVFDatum<String, Integer>> test = nR.readData(testFile, indices);
    System.out.println("Constrained conditional likelihood no prior :");
    for (int j = 0; j < 100; j++) {
      NaiveBayesClassifier<String, Integer> classifier =
          new NaiveBayesClassifierFactory<String, Integer>(
                  0.1,
                  0.01,
                  0.6,
                  LogPrior.LogPriorType.NULL.ordinal(),
                  NaiveBayesClassifierFactory.CL)
              .trainClassifier(train);
      classifier.print();
      // now classifiy

      float accTrain = classifier.accuracy(train.iterator());
      System.err.println("training accuracy " + accTrain);
      float accTest = classifier.accuracy(test.iterator());
      System.err.println("test accuracy " + accTest);
    }
    System.out.println("Unconstrained conditional likelihood no prior :");
    for (int j = 0; j < 100; j++) {
      NaiveBayesClassifier<String, Integer> classifier =
          new NaiveBayesClassifierFactory<String, Integer>(
                  0.1,
                  0.01,
                  0.6,
                  LogPrior.LogPriorType.NULL.ordinal(),
                  NaiveBayesClassifierFactory.UCL)
              .trainClassifier(train);
      classifier.print();
      // now classify

      float accTrain = classifier.accuracy(train.iterator());
      System.err.println("training accuracy " + accTrain);
      float accTest = classifier.accuracy(test.iterator());
      System.err.println("test accuracy " + accTest);
    }
  }