Exemplo n.º 1
0
 public static void refresh() {
   if (null == CONTEXT) throw new IllegalStateException("The Spark context has not been created.");
   final Set<String> keepNames = new HashSet<>();
   for (final RDD<?> rdd : JavaConversions.asJavaIterable(CONTEXT.persistentRdds().values())) {
     if (null != rdd.name()) {
       keepNames.add(rdd.name());
       NAME_TO_RDD.put(rdd.name(), rdd);
     }
   }
   // remove all stale names in the NAME_TO_RDD map
   NAME_TO_RDD
       .keySet()
       .stream()
       .filter(key -> !keepNames.contains(key))
       .collect(Collectors.toList())
       .forEach(NAME_TO_RDD::remove);
 }
  @Test
  @SuppressWarnings("unchecked")
  public void testJavaRDDFunctions() throws Exception {
    JavaRDD<String> mockJavaRDD = mock(JavaRDD.class);
    RDD<String> mockRDD = mock(RDD.class);
    when(mockJavaRDD.rdd()).thenReturn(mockRDD);
    GemFireJavaRDDFunctions wrapper = javaFunctions(mockJavaRDD);
    assertTrue(mockRDD == wrapper.rddf.rdd());

    Tuple3<SparkContext, GemFireConnectionConf, GemFireConnection> tuple3 = createCommonMocks();
    when(mockRDD.sparkContext()).thenReturn(tuple3._1());
    PairFunction<String, String, Integer> mockPairFunc = mock(PairFunction.class);
    String regionPath = "testregion";
    wrapper.saveToGemfire(regionPath, mockPairFunc, tuple3._2());
    verify(mockRDD, times(1)).sparkContext();
    verify(tuple3._1(), times(1)).runJob(eq(mockRDD), any(Function2.class), any(ClassTag.class));
  }
Exemplo n.º 3
0
 public static void removeRDD(final String name) {
   if (null == name) return;
   Spark.refresh();
   final RDD<?> rdd = NAME_TO_RDD.remove(name);
   if (null != rdd) rdd.unpersist(true);
 }
Exemplo n.º 4
0
  public static void main(String[] args) {
    // parse the arguments
    Params params = parse(args);
    SparkConf conf = new SparkConf().setAppName("JavaOneVsRestExample");
    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext jsql = new SQLContext(jsc);

    // configure the base classifier
    LogisticRegression classifier =
        new LogisticRegression()
            .setMaxIter(params.maxIter)
            .setTol(params.tol)
            .setFitIntercept(params.fitIntercept);

    if (params.regParam != null) {
      classifier.setRegParam(params.regParam);
    }
    if (params.elasticNetParam != null) {
      classifier.setElasticNetParam(params.elasticNetParam);
    }

    // instantiate the One Vs Rest Classifier
    OneVsRest ovr = new OneVsRest().setClassifier(classifier);

    String input = params.input;
    RDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), input);
    RDD<LabeledPoint> train;
    RDD<LabeledPoint> test;

    // compute the train/ test split: if testInput is not provided use part of input
    String testInput = params.testInput;
    if (testInput != null) {
      train = inputData;
      // compute the number of features in the training set.
      int numFeatures = inputData.first().features().size();
      test = MLUtils.loadLibSVMFile(jsc.sc(), testInput, numFeatures);
    } else {
      double f = params.fracTest;
      RDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[] {1 - f, f}, 12345);
      train = tmp[0];
      test = tmp[1];
    }

    // train the multiclass model
    DataFrame trainingDataFrame = jsql.createDataFrame(train, LabeledPoint.class);
    OneVsRestModel ovrModel = ovr.fit(trainingDataFrame.cache());

    // score the model on test data
    DataFrame testDataFrame = jsql.createDataFrame(test, LabeledPoint.class);
    DataFrame predictions = ovrModel.transform(testDataFrame.cache()).select("prediction", "label");

    // obtain metrics
    MulticlassMetrics metrics = new MulticlassMetrics(predictions);
    StructField predictionColSchema = predictions.schema().apply("prediction");
    Integer numClasses = (Integer) MetadataUtils.getNumClasses(predictionColSchema).get();

    // compute the false positive rate per label
    StringBuilder results = new StringBuilder();
    results.append("label\tfpr\n");
    for (int label = 0; label < numClasses; label++) {
      results.append(label);
      results.append("\t");
      results.append(metrics.falsePositiveRate((double) label));
      results.append("\n");
    }

    Matrix confusionMatrix = metrics.confusionMatrix();
    // output the Confusion Matrix
    System.out.println("Confusion Matrix");
    System.out.println(confusionMatrix);
    System.out.println();
    System.out.println(results);

    jsc.stop();
  }