示例#1
0
 void validateDataFrameWithBeans(Bean bean, DataFrame df) {
   StructType schema = df.schema();
   Assert.assertEquals(
       new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a"));
   Assert.assertEquals(
       new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
       schema.apply("b"));
   ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
   MapType mapType = new MapType(DataTypes.StringType, valueType, true);
   Assert.assertEquals(new StructField("c", mapType, true, Metadata.empty()), schema.apply("c"));
   Assert.assertEquals(
       new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
       schema.apply("d"));
   Row first = df.select("a", "b", "c", "d").first();
   Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
   // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below,
   // verify that it has the expected length, and contains expected elements.
   Seq<Integer> result = first.getAs(1);
   Assert.assertEquals(bean.getB().length, result.length());
   for (int i = 0; i < result.length(); i++) {
     Assert.assertEquals(bean.getB()[i], result.apply(i));
   }
   @SuppressWarnings("unchecked")
   Seq<Integer> outputBuffer = (Seq<Integer>) first.getJavaMap(2).get("hello");
   Assert.assertArrayEquals(
       bean.getC().get("hello"),
       Ints.toArray(JavaConverters.seqAsJavaListConverter(outputBuffer).asJava()));
   Seq<String> d = first.getAs(3);
   Assert.assertEquals(bean.getD().size(), d.length());
   for (int i = 0; i < d.length(); i++) {
     Assert.assertEquals(bean.getD().get(i), d.apply(i));
   }
 }
 private <T> void assertExpectedTokens(LuceneAnalyzer analyzer, List<T> testData) {
   JavaRDD<T> rdd = jsc.parallelize(testData);
   Row[] pairs =
       analyzer
           .transform(jsql.createDataFrame(rdd, testData.get(0).getClass()))
           .select("wantedTokens", "tokens")
           .collect();
   for (Row r : pairs) {
     Assert.assertEquals(r.get(0), r.get(1));
   }
 }
示例#3
0
  @Override
  public void execute(
      JavaSparkContext ctx, SQLContext sqlContext, WorkflowContext workflowContext, DataFrame df) {

    workflowContext.out("Executing NodePrintFirstNRows : " + id);

    Row[] rows = df.take(n);

    for (Row row : rows) {
      workflowContext.out(row.toString());
    }

    super.execute(ctx, sqlContext, workflowContext, df);
  }
示例#4
0
  @Override
  public int run(SparkConf conf, CommandLine cli) throws Exception {

    long startMs = System.currentTimeMillis();

    conf.set("spark.ui.enabled", "false");

    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext sqlContext = new SQLContext(jsc);

    long diffMs = (System.currentTimeMillis() - startMs);
    System.out.println(">> took " + diffMs + " ms to create SQLContext");

    Map<String, String> options = new HashMap<>();
    options.put("zkhost", "localhost:9983");
    options.put("collection", "ml20news");
    options.put("query", "content_txt:[* TO *]");
    options.put("fields", "content_txt");

    DataFrame solrData = sqlContext.read().format("solr").options(options).load();
    DataFrame sample = solrData.sample(false, 0.1d, 5150).select("content_txt");
    List<Row> rows = sample.collectAsList();
    System.out.println(">> loaded " + rows.size() + " docs to classify");

    StructType schema = sample.schema();

    CrossValidatorModel cvModel = CrossValidatorModel.load("ml-pipeline-model");
    PipelineModel bestModel = (PipelineModel) cvModel.bestModel();

    int r = 0;
    startMs = System.currentTimeMillis();
    for (Row next : rows) {
      Row oneRow = RowFactory.create(next.getString(0));
      DataFrame oneRowDF =
          sqlContext.createDataFrame(Collections.<Row>singletonList(oneRow), schema);
      DataFrame scored = bestModel.transform(oneRowDF);
      Row scoredRow = scored.collect()[0];
      String predictedLabel = scoredRow.getString(scoredRow.fieldIndex("predictedLabel"));

      // an acutal app would save the predictedLabel
      // System.out.println(">> for row["+r+"], model returned "+scoredRows.length+" rows,
      // "+scoredRows[0]);

      r++;
    }
    diffMs = (System.currentTimeMillis() - startMs);
    System.out.println(">> took " + diffMs + " ms to score " + rows.size() + " docs");

    return 0;
  }
 @Override
 public Row next() {
   Row row;
   if (this.hasNext()) {
     row = this.rows.next();
   } else {
     row = null;
   }
   if (this.incEnable) {
     if (row.getLong(this.timestampIndex) > this.incMaxTS) {
       this.incMaxTS = row.getLong(this.timestampIndex);
     }
   }
   return row;
 }
示例#6
0
    @Override
    public ArrayList<String> call(JobContext jc) {
      String inputFile = "src/test/resources/testweet.json";
      SQLContext sqlctx = jc.sqlctx();
      DataFrame input = sqlctx.jsonFile(inputFile);
      input.registerTempTable("tweets");

      DataFrame topTweets =
          sqlctx.sql("SELECT text, retweetCount FROM tweets ORDER BY retweetCount LIMIT 10");
      ArrayList<String> tweetList = new ArrayList<>();
      for (Row r : topTweets.collect()) {
        tweetList.add(r.toString());
      }
      return tweetList;
    }
示例#7
0
 @Test
 public void testCrosstab() {
   DataFrame df = context.table("testData2");
   DataFrame crosstab = df.stat().crosstab("a", "b");
   String[] columnNames = crosstab.schema().fieldNames();
   Assert.assertEquals("a_b", columnNames[0]);
   Assert.assertEquals("2", columnNames[1]);
   Assert.assertEquals("1", columnNames[2]);
   Row[] rows = crosstab.collect();
   Arrays.sort(rows, crosstabRowComparator);
   Integer count = 1;
   for (Row row : rows) {
     Assert.assertEquals(row.get(0).toString(), count.toString());
     Assert.assertEquals(1L, row.getLong(1));
     Assert.assertEquals(1L, row.getLong(2));
     count++;
   }
 }
示例#8
0
 @Override
 public int compare(Row row1, Row row2) {
   String item1 = row1.getString(0);
   String item2 = row2.getString(0);
   return item1.compareTo(item2);
 }