private static String train(
      SQLContext sqlContext, DataFrame positives, DataFrame negatives, String modelFileName) {
    // combine data sets
    DataFrame all = positives.unionAll(negatives);

    // split into training and test sets
    DataFrame[] split = all.randomSplit(new double[] {.80, .20}, 1);
    DataFrame training = split[0].cache();
    DataFrame test = split[1].cache();

    // fit logistic regression model
    PipelineModel model = fitLogisticRegressionModel(training);

    try {
      ObjectSerializer.serialize(model, modelFileName);
    } catch (IOException e1) {
      // TODO Auto-generated catch block
      e1.printStackTrace();
    }

    // predict on training data to evaluate goodness of fit
    DataFrame trainingResults = model.transform(training).cache();

    // predict on test set to evaluate goodness of fit
    DataFrame testResults = model.transform(test).cache();

    // predict on unassigned data mentions

    StringBuilder sb = new StringBuilder();
    sb.append(getMetrics(trainingResults, "Training\n"));

    sb.append(getMetrics(testResults, "Testing\n"));

    return sb.toString();
  }
  public static void main(String[] args) {
    SparkConf conf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext sqlContext = new SQLContext(jsc);
    // $example on$
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("data/mllib/sample_libsvm_data.txt");

    // Automatically identify categorical features, and index them.
    // Set maxCategories so features with > 4 distinct values are treated as continuous.
    VectorIndexerModel featureIndexer =
        new VectorIndexer()
            .setInputCol("features")
            .setOutputCol("indexedFeatures")
            .setMaxCategories(4)
            .fit(data);

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[] {0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressor dt = new DecisionTreeRegressor().setFeaturesCol("indexedFeatures");

    // Chain indexer and tree in a Pipeline
    Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, dt});

    // Train model.  This also runs the indexer.
    PipelineModel model = pipeline.fit(trainingData);

    // Make predictions.
    DataFrame predictions = model.transform(testData);

    // Select example rows to display.
    predictions.select("label", "features").show(5);

    // Select (prediction, true label) and compute test error
    RegressionEvaluator evaluator =
        new RegressionEvaluator()
            .setLabelCol("label")
            .setPredictionCol("prediction")
            .setMetricName("rmse");
    double rmse = evaluator.evaluate(predictions);
    System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

    DecisionTreeRegressionModel treeModel = (DecisionTreeRegressionModel) (model.stages()[1]);
    System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
    // $example off$
  }
Пример #3
0
  @Override
  public int run(SparkConf conf, CommandLine cli) throws Exception {

    long startMs = System.currentTimeMillis();

    conf.set("spark.ui.enabled", "false");

    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext sqlContext = new SQLContext(jsc);

    long diffMs = (System.currentTimeMillis() - startMs);
    System.out.println(">> took " + diffMs + " ms to create SQLContext");

    Map<String, String> options = new HashMap<>();
    options.put("zkhost", "localhost:9983");
    options.put("collection", "ml20news");
    options.put("query", "content_txt:[* TO *]");
    options.put("fields", "content_txt");

    DataFrame solrData = sqlContext.read().format("solr").options(options).load();
    DataFrame sample = solrData.sample(false, 0.1d, 5150).select("content_txt");
    List<Row> rows = sample.collectAsList();
    System.out.println(">> loaded " + rows.size() + " docs to classify");

    StructType schema = sample.schema();

    CrossValidatorModel cvModel = CrossValidatorModel.load("ml-pipeline-model");
    PipelineModel bestModel = (PipelineModel) cvModel.bestModel();

    int r = 0;
    startMs = System.currentTimeMillis();
    for (Row next : rows) {
      Row oneRow = RowFactory.create(next.getString(0));
      DataFrame oneRowDF =
          sqlContext.createDataFrame(Collections.<Row>singletonList(oneRow), schema);
      DataFrame scored = bestModel.transform(oneRowDF);
      Row scoredRow = scored.collect()[0];
      String predictedLabel = scoredRow.getString(scoredRow.fieldIndex("predictedLabel"));

      // an acutal app would save the predictedLabel
      // System.out.println(">> for row["+r+"], model returned "+scoredRows.length+" rows,
      // "+scoredRows[0]);

      r++;
    }
    diffMs = (System.currentTimeMillis() - startMs);
    System.out.println(">> took " + diffMs + " ms to score " + rows.size() + " docs");

    return 0;
  }