예제 #1
0
  public static void main(String[] args) {
    if (!parseArgs(args)) throw new IllegalArgumentException("Parse arguments failed.");

    SparkConf conf = new SparkConf().setAppName("Logistic Regression with SGD");
    SparkContext sc = new SparkContext(conf);

    JavaRDD<String> data = sc.textFile(inputFile, 1).toJavaRDD();
    JavaRDD<LabeledPoint> training =
        data.map(
                new Function<String, LabeledPoint>() {
                  public LabeledPoint call(String line) {
                    String[] splits = line.split(",");
                    double[] features = new double[3];
                    try {
                      features[0] = Double.valueOf(splits[1]);
                      features[1] = Double.valueOf(splits[2]);
                      features[2] = Double.valueOf(splits[3]);
                      return new LabeledPoint(Double.valueOf(splits[3]), Vectors.dense(features));
                    } catch (NumberFormatException e) {
                      return null; // Nothing to do..
                    }
                  }
                })
            .filter(
                new Function<LabeledPoint, Boolean>() {
                  public Boolean call(LabeledPoint p) {
                    return p != null;
                  }
                })
            .cache();

    LogisticRegressionModel model = lrs.run(training.rdd());
    model.save(sc, outputFile);
    sc.stop();
  }
예제 #2
0
 public static void create(final Configuration configuration) {
   final SparkConf sparkConf = new SparkConf();
   configuration
       .getKeys()
       .forEachRemaining(key -> sparkConf.set(key, configuration.getProperty(key).toString()));
   sparkConf.setAppName("Apache TinkerPop's Spark-Gremlin");
   CONTEXT = SparkContext.getOrCreate(sparkConf);
 }
  public static void main(String[] args) throws FileNotFoundException {
    // Set up contexts.
    SparkContext sc = SparkUtils.getSparkContext();
    SQLContext sqlContext = SparkUtils.getSqlContext(sc);

    String workingDirectory = args[0];
    // read positive data set, cases where the PDB ID occurs in the sentence of the primary citation
    String positivesIFileName = workingDirectory + "/" + "PositivesI.parquet";
    DataFrame positivesI = sqlContext.read().parquet(positivesIFileName);

    sqlContext.registerDataFrameAsTable(positivesI, "positivesI");
    DataFrame positives =
        sqlContext.sql(
            "SELECT pdb_id, match_type, deposition_year, pmc_id, pm_id, publication_year, CAST(primary_citation AS double) AS label, sentence, blinded_sentence  FROM positivesI");
    positives.show(10);

    String positivesIIFileName = workingDirectory + "/" + "PositivesII.parquet";
    DataFrame positivesII = sqlContext.read().parquet(positivesIIFileName);
    //		System.out.println("Sampling 16% of positivesII");
    //		positivesII = positivesII.sample(false,  0.16, 1);
    sqlContext.registerDataFrameAsTable(positivesII, "positivesII");
    DataFrame negatives =
        sqlContext.sql(
            "SELECT pdb_id, match_type, deposition_year, pmc_id, pm_id, publication_year, CAST(primary_citation AS double) AS label, sentence, blinded_sentence FROM positivesII WHERE sentence NOT LIKE '%deposited%' AND sentence NOT LIKE '%submitted%'");
    negatives.show(10);

    long start = System.nanoTime();

    String metricFileName = workingDirectory + "/" + "PdbPrimaryCitationMetrics.txt";
    PrintWriter writer = new PrintWriter(metricFileName);

    writer.println("PDB Primary Citation Classification: Logistic Regression Results");

    String modelFileName = workingDirectory + "/" + "PdbPrimaryCitationModel.ser";
    writer.println(train(sqlContext, positives, negatives, modelFileName));
    writer.close();

    long end = System.nanoTime();
    System.out.println("Time: " + (end - start) / 1E9 + " sec.");

    sc.stop();
  }
예제 #4
0
 public static void refresh() {
   if (null == CONTEXT) throw new IllegalStateException("The Spark context has not been created.");
   final Set<String> keepNames = new HashSet<>();
   for (final RDD<?> rdd : JavaConversions.asJavaIterable(CONTEXT.persistentRdds().values())) {
     if (null != rdd.name()) {
       keepNames.add(rdd.name());
       NAME_TO_RDD.put(rdd.name(), rdd);
     }
   }
   // remove all stale names in the NAME_TO_RDD map
   NAME_TO_RDD
       .keySet()
       .stream()
       .filter(key -> !keepNames.contains(key))
       .collect(Collectors.toList())
       .forEach(NAME_TO_RDD::remove);
 }
예제 #5
0
  /** NewHadoopRDD */
  @Override
  public RDD<Tuple> convert(List<RDD<Tuple>> predecessorRdds, POLoad poLoad) throws IOException {
    if (predecessorRdds.size() != 0) {
      throw new RuntimeException(
          "Should not have predecessorRdds for Load. Got : " + predecessorRdds);
    }

    JobConf loadJobConf = SparkUtil.newJobConf(pigContext);
    configureLoader(physicalPlan, poLoad, loadJobConf);

    RDD<Tuple2<Text, Tuple>> hadoopRDD =
        sparkContext.newAPIHadoopFile(
            poLoad.getLFile().getFileName(),
            PigInputFormat.class, // InputFormat class
            Text.class, // K class
            Tuple.class, // V class
            loadJobConf);

    // map to get just RDD<Tuple>
    return hadoopRDD.map(TO_TUPLE_FUNCTION, ScalaUtil.getClassTag(Tuple.class));
  }
예제 #6
0
 public static void create(final String master) {
   final SparkConf sparkConf = new SparkConf();
   sparkConf.setAppName("Apache TinkerPop's Spark-Gremlin");
   sparkConf.setMaster(master);
   CONTEXT = SparkContext.getOrCreate(sparkConf);
 }
예제 #7
0
 public static void close() {
   NAME_TO_RDD.clear();
   if (null != CONTEXT) CONTEXT.stop();
   CONTEXT = null;
 }
예제 #8
0
 @BeforeClass
 public static void setUp() {
   SparkConf sparkConf = new SparkConf().setMaster("local[*]").setAppName("SparkIT");
   javaSparkContext = JavaSparkContext.fromSparkContext(SparkContext.getOrCreate(sparkConf));
 }