示例#1
0
  public static void main(String[] args) {
    if (!parseArgs(args)) throw new IllegalArgumentException("Parse arguments failed.");

    SparkConf conf = new SparkConf().setAppName("Logistic Regression with SGD");
    SparkContext sc = new SparkContext(conf);

    JavaRDD<String> data = sc.textFile(inputFile, 1).toJavaRDD();
    JavaRDD<LabeledPoint> training =
        data.map(
                new Function<String, LabeledPoint>() {
                  public LabeledPoint call(String line) {
                    String[] splits = line.split(",");
                    double[] features = new double[3];
                    try {
                      features[0] = Double.valueOf(splits[1]);
                      features[1] = Double.valueOf(splits[2]);
                      features[2] = Double.valueOf(splits[3]);
                      return new LabeledPoint(Double.valueOf(splits[3]), Vectors.dense(features));
                    } catch (NumberFormatException e) {
                      return null; // Nothing to do..
                    }
                  }
                })
            .filter(
                new Function<LabeledPoint, Boolean>() {
                  public Boolean call(LabeledPoint p) {
                    return p != null;
                  }
                })
            .cache();

    LogisticRegressionModel model = lrs.run(training.rdd());
    model.save(sc, outputFile);
    sc.stop();
  }
  public static void main(String[] args) throws FileNotFoundException {
    // Set up contexts.
    SparkContext sc = SparkUtils.getSparkContext();
    SQLContext sqlContext = SparkUtils.getSqlContext(sc);

    String workingDirectory = args[0];
    // read positive data set, cases where the PDB ID occurs in the sentence of the primary citation
    String positivesIFileName = workingDirectory + "/" + "PositivesI.parquet";
    DataFrame positivesI = sqlContext.read().parquet(positivesIFileName);

    sqlContext.registerDataFrameAsTable(positivesI, "positivesI");
    DataFrame positives =
        sqlContext.sql(
            "SELECT pdb_id, match_type, deposition_year, pmc_id, pm_id, publication_year, CAST(primary_citation AS double) AS label, sentence, blinded_sentence  FROM positivesI");
    positives.show(10);

    String positivesIIFileName = workingDirectory + "/" + "PositivesII.parquet";
    DataFrame positivesII = sqlContext.read().parquet(positivesIIFileName);
    //		System.out.println("Sampling 16% of positivesII");
    //		positivesII = positivesII.sample(false,  0.16, 1);
    sqlContext.registerDataFrameAsTable(positivesII, "positivesII");
    DataFrame negatives =
        sqlContext.sql(
            "SELECT pdb_id, match_type, deposition_year, pmc_id, pm_id, publication_year, CAST(primary_citation AS double) AS label, sentence, blinded_sentence FROM positivesII WHERE sentence NOT LIKE '%deposited%' AND sentence NOT LIKE '%submitted%'");
    negatives.show(10);

    long start = System.nanoTime();

    String metricFileName = workingDirectory + "/" + "PdbPrimaryCitationMetrics.txt";
    PrintWriter writer = new PrintWriter(metricFileName);

    writer.println("PDB Primary Citation Classification: Logistic Regression Results");

    String modelFileName = workingDirectory + "/" + "PdbPrimaryCitationModel.ser";
    writer.println(train(sqlContext, positives, negatives, modelFileName));
    writer.close();

    long end = System.nanoTime();
    System.out.println("Time: " + (end - start) / 1E9 + " sec.");

    sc.stop();
  }
 public static void close() {
   NAME_TO_RDD.clear();
   if (null != CONTEXT) CONTEXT.stop();
   CONTEXT = null;
 }