@Test public void testTupleEncoder() { Encoder<Tuple2<Integer, String>> encoder2 = Encoders.tuple(Encoders.INT(), Encoders.STRING()); List<Tuple2<Integer, String>> data2 = Arrays.asList(tuple2(1, "a"), tuple2(2, "b")); Dataset<Tuple2<Integer, String>> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); Encoder<Tuple3<Integer, Long, String>> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List<Tuple3<Integer, Long, String>> data3 = Arrays.asList(new Tuple3<Integer, Long, String>(1, 2L, "a")); Dataset<Tuple3<Integer, Long, String>> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder<Tuple4<Integer, String, Long, String>> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List<Tuple4<Integer, String, Long, String>> data4 = Arrays.asList(new Tuple4<Integer, String, Long, String>(1, "b", 2L, "a")); Dataset<Tuple4<Integer, String, Long, String>> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); Encoder<Tuple5<Integer, String, Long, String, Boolean>> encoder5 = Encoders.tuple( Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), Encoders.BOOLEAN()); List<Tuple5<Integer, String, Long, String, Boolean>> data5 = Arrays.asList(new Tuple5<Integer, String, Long, String, Boolean>(1, "b", 2L, "a", true)); Dataset<Tuple5<Integer, String, Long, String, Boolean>> ds5 = context.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); }
@Test public void testNestedTupleEncoder() { // test ((int, string), string) Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder = Encoders.tuple(Encoders.tuple(Encoders.INT(), Encoders.STRING()), Encoders.STRING()); List<Tuple2<Tuple2<Integer, String>, String>> data = Arrays.asList(tuple2(tuple2(1, "a"), "a"), tuple2(tuple2(2, "b"), "b")); Dataset<Tuple2<Tuple2<Integer, String>, String>> ds = context.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); // test (int, (string, string, long)) Encoder<Tuple2<Integer, Tuple3<String, String, Long>>> encoder2 = Encoders.tuple( Encoders.INT(), Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List<Tuple2<Integer, Tuple3<String, String, Long>>> data2 = Arrays.asList(tuple2(1, new Tuple3<String, String, Long>("a", "b", 3L))); Dataset<Tuple2<Integer, Tuple3<String, String, Long>>> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); // test (int, ((string, long), string)) Encoder<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> encoder3 = Encoders.tuple( Encoders.INT(), Encoders.tuple(Encoders.tuple(Encoders.STRING(), Encoders.LONG()), Encoders.STRING())); List<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> data3 = Arrays.asList(tuple2(1, tuple2(tuple2("a", 2L), "b"))); Dataset<Tuple2<Integer, Tuple2<Tuple2<String, Long>, String>>> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); }
@Test public void testTextLoad() { DataFrame df1 = context .read() .text( Thread.currentThread() .getContextClassLoader() .getResource("text-suite.txt") .toString()); Assert.assertEquals(4L, df1.count()); DataFrame df2 = context .read() .text( Thread.currentThread() .getContextClassLoader() .getResource("text-suite.txt") .toString(), Thread.currentThread() .getContextClassLoader() .getResource("text-suite2.txt") .toString()); Assert.assertEquals(5L, df2.count()); }
@Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); Dataset<Integer> ds = context.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); Assert.assertEquals(Arrays.asList(tuple2(2, 2), tuple2(3, 3)), joined.collectAsList()); }
@Test public void testBloomFilter() { DataFrame df = context.range(1000); BloomFilter filter1 = df.stat().bloomFilter("id", 1000, 0.03); Assert.assertTrue(filter1.expectedFpp() - 0.03 < 1e-3); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter1.mightContain(i)); } BloomFilter filter2 = df.stat().bloomFilter(col("id").multiply(3), 1000, 0.03); Assert.assertTrue(filter2.expectedFpp() - 0.03 < 1e-3); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter2.mightContain(i * 3)); } BloomFilter filter3 = df.stat().bloomFilter("id", 1000, 64 * 5); Assert.assertTrue(filter3.bitSize() == 64 * 5); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter3.mightContain(i)); } BloomFilter filter4 = df.stat().bloomFilter(col("id").multiply(3), 1000, 64 * 5); Assert.assertTrue(filter4.bitSize() == 64 * 5); for (int i = 0; i < 1000; i++) { Assert.assertTrue(filter4.mightContain(i * 3)); } }
@Test public void testTypedAggregation() { Encoder<Tuple2<String, Integer>> encoder = Encoders.tuple(Encoders.STRING(), Encoders.INT()); List<Tuple2<String, Integer>> data = Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3)); Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder); GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy( new MapFunction<Tuple2<String, Integer>, String>() { @Override public String call(Tuple2<String, Integer> value) throws Exception { return value._1(); } }, Encoders.STRING()); Dataset<Tuple2<String, Integer>> agged = grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())); Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList()); Dataset<Tuple2<String, Integer>> agged2 = grouped .agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT())) .as(Encoders.tuple(Encoders.STRING(), Encoders.INT())); Assert.assertEquals( Arrays.asList(new Tuple2<>("a", 3), new Tuple2<>("b", 3)), agged2.collectAsList()); }
@Test public void testFrequentItems() { DataFrame df = context.table("testData2"); String[] cols = {"a"}; DataFrame results = df.stat().freqItems(cols, 0.2); Assert.assertTrue(results.collect()[0].getSeq(0).contains(1)); }
@Test public void testTake() { List<String> data = Arrays.asList("hello", "world"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); List<String> collected = ds.takeAsList(1); Assert.assertEquals(Arrays.asList("hello"), collected); }
@Test public void testCreateDataFrameFromJavaBeans() { Bean bean = new Bean(); JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean)); DataFrame df = context.createDataFrame(rdd, Bean.class); validateDataFrameWithBeans(bean, df); }
@Test public void testCreateDataFrameFromLocalJavaBeans() { Bean bean = new Bean(); List<Bean> data = Arrays.asList(bean); DataFrame df = context.createDataFrame(data, Bean.class); validateDataFrameWithBeans(bean, df); }
@Ignore public void testShow() { // This test case is intended ignored, but to make sure it compiles correctly DataFrame df = context.table("testData"); df.show(); df.show(1000); }
@Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter( new FilterFunction<String>() { @Override public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); Dataset<Integer> mapped = ds.map( new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); } }, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset<String> parMapped = ds.mapPartitions( new MapPartitionsFunction<String, String>() { @Override public Iterable<String> call(Iterator<String> it) throws Exception { List<String> ls = new LinkedList<String>(); while (it.hasNext()) { ls.add(it.next().toUpperCase()); } return ls; } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset<String> flatMapped = ds.flatMap( new FlatMapFunction<String, String>() { @Override public Iterable<String> call(String s) throws Exception { List<String> ls = new LinkedList<String>(); for (char c : s.toCharArray()) { ls.add(String.valueOf(c)); } return ls; } }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), flatMapped.collectAsList()); }
@Test public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); List<Row> rows = Arrays.asList(RowFactory.create(0)); DataFrame df = context.createDataFrame(rows, schema); Row[] result = df.collect(); Assert.assertEquals(1, result.length); }
@Test public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Row[] expected = {RowFactory.create(0, 5), RowFactory.create(1, 8)}; Assert.assertArrayEquals(expected, actual); }
@Test public void testJavaEncoder() { Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class); List<JavaSerializable> data = Arrays.asList(new JavaSerializable("hello"), new JavaSerializable("world")); Dataset<JavaSerializable> ds = context.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); }
@Before public void setUp() { // Trigger static initializer of TestData SparkContext sc = new SparkContext("local[*]", "testing"); jsc = new JavaSparkContext(sc); context = new TestSQLContext(sc); context.loadTestData(); }
@Test public void testKryoEncoder() { Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class); List<KryoSerializable> data = Arrays.asList(new KryoSerializable("hello"), new KryoSerializable("world")); Dataset<KryoSerializable> ds = context.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); }
@Test public void testSampleBy() { DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); Row[] actual = sampled.groupBy("key").count().orderBy("key").collect(); Assert.assertEquals(0, actual[0].getLong(0)); Assert.assertTrue(0 <= actual[0].getLong(1) && actual[0].getLong(1) <= 8); Assert.assertEquals(1, actual[1].getLong(0)); Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13); }
/** See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. */ @Test public void testVarargMethods() { DataFrame df = context.table("testData"); df.toDF("key1", "value1"); df.select("key", "value"); df.select(col("key"), col("value")); df.selectExpr("key", "value + 1"); df.sort("key", "value"); df.sort(col("key"), col("value")); df.orderBy("key", "value"); df.orderBy(col("key"), col("value")); df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); df.agg(first("key"), sum("value")); df.groupBy().avg("key"); df.groupBy().mean("key"); df.groupBy().max("key"); df.groupBy().min("key"); df.groupBy().stddev("key"); df.groupBy().sum("key"); // Varargs in column expressions df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); // Varargs with mathfunctions DataFrame df2 = context.table("testData2"); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); df2.select(pow(col("a"), col("b")), exp("b")); df2.select(sin("a"), acos("b")); df2.select(rand(), acos("b")); df2.select(col("*"), randn(5L)); }
@Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select(expr("value + 1"), col("value").cast("string")) .as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals(Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), selected.collectAsList()); }
@Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals( Arrays.asList("abc", "xyz"), sort(ds.distinct().collectAsList().toArray(new String[0]))); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); Dataset<String> unioned = ds.union(ds2); Assert.assertEquals( Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), sort(unioned.collectAsList().toArray(new String[0]))); Dataset<String> subtracted = ds.subtract(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); }
@Test public void testForeach() { final Accumulator<Integer> accum = jsc.accumulator(0); List<String> data = Arrays.asList("a", "b", "c"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); ds.foreach( new ForeachFunction<String>() { @Override public void call(String s) throws Exception { accum.add(1); } }); Assert.assertEquals(3, accum.value().intValue()); }
@Test public void testReduce() { List<Integer> data = Arrays.asList(1, 2, 3); Dataset<Integer> ds = context.createDataset(data, Encoders.INT()); int reduced = ds.reduce( new ReduceFunction<Integer>() { @Override public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); Assert.assertEquals(6, reduced); }
@Test public void testCrosstab() { DataFrame df = context.table("testData2"); DataFrame crosstab = df.stat().crosstab("a", "b"); String[] columnNames = crosstab.schema().fieldNames(); Assert.assertEquals("a_b", columnNames[0]); Assert.assertEquals("2", columnNames[1]); Assert.assertEquals("1", columnNames[2]); Row[] rows = crosstab.collect(); Arrays.sort(rows, crosstabRowComparator); Integer count = 1; for (Row row : rows) { Assert.assertEquals(row.get(0).toString(), count.toString()); Assert.assertEquals(1L, row.getLong(1)); Assert.assertEquals(1L, row.getLong(2)); count++; } }
@Test public void pivot() { DataFrame df = context.table("courseSales"); Row[] actual = df.groupBy("year") .pivot("course", Arrays.<Object>asList("dotNET", "Java")) .agg(sum("earnings")) .orderBy("year") .collect(); Assert.assertEquals(2012, actual[0].getInt(0)); Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01); Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01); Assert.assertEquals(2013, actual[1].getInt(0)); Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01); Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01); }
@Test public void testPrimitiveEncoder() { Encoder<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> encoder = Encoders.tuple( Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), Encoders.FLOAT()); List<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> data = Arrays.asList( new Tuple5<Double, BigDecimal, Date, Timestamp, Float>( 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset<Tuple5<Double, BigDecimal, Date, Timestamp, Float>> ds = context.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); }
@Test public void testGroupByColumn() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @Override public String call(Integer key, Iterator<String> data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } return sb.toString(); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); }
@Test public void testCountMinSketch() { DataFrame df = context.range(1000); CountMinSketch sketch1 = df.stat().countMinSketch("id", 10, 20, 42); Assert.assertEquals(sketch1.totalCount(), 1000); Assert.assertEquals(sketch1.depth(), 10); Assert.assertEquals(sketch1.width(), 20); CountMinSketch sketch2 = df.stat().countMinSketch(col("id"), 10, 20, 42); Assert.assertEquals(sketch2.totalCount(), 1000); Assert.assertEquals(sketch2.depth(), 10); Assert.assertEquals(sketch2.width(), 20); CountMinSketch sketch3 = df.stat().countMinSketch("id", 0.001, 0.99, 42); Assert.assertEquals(sketch3.totalCount(), 1000); Assert.assertEquals(sketch3.relativeError(), 0.001, 1e-4); Assert.assertEquals(sketch3.confidence(), 0.99, 5e-3); CountMinSketch sketch4 = df.stat().countMinSketch(col("id"), 0.001, 0.99, 42); Assert.assertEquals(sketch4.totalCount(), 1000); Assert.assertEquals(sketch4.relativeError(), 0.001, 1e-4); Assert.assertEquals(sketch4.confidence(), 0.99, 5e-3); }
@Test public void testJavaBeanEncoder() { OuterScopes.addOuterScope(this); SimpleJavaBean obj1 = new SimpleJavaBean(); obj1.setA(true); obj1.setB(3); obj1.setC(new byte[] {1, 2}); obj1.setD(new String[] {"hello", null}); obj1.setE(Arrays.asList("a", "b")); obj1.setF(Arrays.asList(100L, null, 200L)); SimpleJavaBean obj2 = new SimpleJavaBean(); obj2.setA(false); obj2.setB(30); obj2.setC(new byte[] {3, 4}); obj2.setD(new String[] {null, "world"}); obj2.setE(Arrays.asList("x", "y")); obj2.setF(Arrays.asList(300L, null, 400L)); List<SimpleJavaBean> data = Arrays.asList(obj1, obj2); Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List<NestedJavaBean> data2 = Arrays.asList(obj3); Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow( new Object[] { true, 3, new byte[] {1, 2}, new String[] {"hello", null}, Arrays.asList("a", "b"), Arrays.asList(100L, null, 200L) }); Row row2 = new GenericRow( new Object[] { false, 30, new byte[] {3, 4}, new String[] {null, "world"}, Arrays.asList("x", "y"), Arrays.asList(300L, null, 400L) }); StructType schema = new StructType() .add("a", BooleanType, false) .add("b", IntegerType, false) .add("c", BinaryType) .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); Dataset<SimpleJavaBean> ds3 = context .createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); }
@After public void tearDown() { context.sparkContext().stop(); context = null; jsc = null; }