@Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals("hello", ds.first()); Dataset<String> filtered = ds.filter( new FilterFunction<String>() { @Override public boolean call(String v) throws Exception { return v.startsWith("h"); } }); Assert.assertEquals(Arrays.asList("hello"), filtered.collectAsList()); Dataset<Integer> mapped = ds.map( new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); } }, Encoders.INT()); Assert.assertEquals(Arrays.asList(5, 5), mapped.collectAsList()); Dataset<String> parMapped = ds.mapPartitions( new MapPartitionsFunction<String, String>() { @Override public Iterable<String> call(Iterator<String> it) throws Exception { List<String> ls = new LinkedList<String>(); while (it.hasNext()) { ls.add(it.next().toUpperCase()); } return ls; } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("HELLO", "WORLD"), parMapped.collectAsList()); Dataset<String> flatMapped = ds.flatMap( new FlatMapFunction<String, String>() { @Override public Iterable<String> call(String s) throws Exception { List<String> ls = new LinkedList<String>(); for (char c : s.toCharArray()) { ls.add(String.valueOf(c)); } return ls; } }, Encoders.STRING()); Assert.assertEquals( Arrays.asList("h", "e", "l", "l", "o", "w", "o", "r", "l", "d"), flatMapped.collectAsList()); }
@Test public void testCollect() { List<String> data = Arrays.asList("hello", "world"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); List<String> collected = ds.collectAsList(); Assert.assertEquals(Arrays.asList("hello", "world"), collected); }
@Test public void testJavaEncoder() { Encoder<JavaSerializable> encoder = Encoders.javaSerialization(JavaSerializable.class); List<JavaSerializable> data = Arrays.asList(new JavaSerializable("hello"), new JavaSerializable("world")); Dataset<JavaSerializable> ds = context.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); }
@Test public void testKryoEncoder() { Encoder<KryoSerializable> encoder = Encoders.kryo(KryoSerializable.class); List<KryoSerializable> data = Arrays.asList(new KryoSerializable("hello"), new KryoSerializable("world")); Dataset<KryoSerializable> ds = context.createDataset(data, encoder); Assert.assertEquals(data, ds.collectAsList()); }
@Test public void testSetOperation() { List<String> data = Arrays.asList("abc", "abc", "xyz"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); Assert.assertEquals( Arrays.asList("abc", "xyz"), sort(ds.distinct().collectAsList().toArray(new String[0]))); List<String> data2 = Arrays.asList("xyz", "foo", "foo"); Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING()); Dataset<String> intersected = ds.intersect(ds2); Assert.assertEquals(Arrays.asList("xyz"), intersected.collectAsList()); Dataset<String> unioned = ds.union(ds2); Assert.assertEquals( Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"), sort(unioned.collectAsList().toArray(new String[0]))); Dataset<String> subtracted = ds.subtract(ds2); Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList()); }
@Test public void testGroupByColumn() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).keyAs(Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @Override public String call(Integer key, Iterator<String> data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (data.hasNext()) { sb.append(data.next()); } return sb.toString(); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); }
@Test public void testJavaBeanEncoder() { OuterScopes.addOuterScope(this); SimpleJavaBean obj1 = new SimpleJavaBean(); obj1.setA(true); obj1.setB(3); obj1.setC(new byte[] {1, 2}); obj1.setD(new String[] {"hello", null}); obj1.setE(Arrays.asList("a", "b")); obj1.setF(Arrays.asList(100L, null, 200L)); SimpleJavaBean obj2 = new SimpleJavaBean(); obj2.setA(false); obj2.setB(30); obj2.setC(new byte[] {3, 4}); obj2.setD(new String[] {null, "world"}); obj2.setE(Arrays.asList("x", "y")); obj2.setF(Arrays.asList(300L, null, 400L)); List<SimpleJavaBean> data = Arrays.asList(obj1, obj2); Dataset<SimpleJavaBean> ds = context.createDataset(data, Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds.collectAsList()); NestedJavaBean obj3 = new NestedJavaBean(); obj3.setA(obj1); List<NestedJavaBean> data2 = Arrays.asList(obj3); Dataset<NestedJavaBean> ds2 = context.createDataset(data2, Encoders.bean(NestedJavaBean.class)); Assert.assertEquals(data2, ds2.collectAsList()); Row row1 = new GenericRow( new Object[] { true, 3, new byte[] {1, 2}, new String[] {"hello", null}, Arrays.asList("a", "b"), Arrays.asList(100L, null, 200L) }); Row row2 = new GenericRow( new Object[] { false, 30, new byte[] {3, 4}, new String[] {null, "world"}, Arrays.asList("x", "y"), Arrays.asList(300L, null, 400L) }); StructType schema = new StructType() .add("a", BooleanType, false) .add("b", IntegerType, false) .add("c", BinaryType) .add("d", createArrayType(StringType)) .add("e", createArrayType(StringType)) .add("f", createArrayType(LongType)); Dataset<SimpleJavaBean> ds3 = context .createDataFrame(Arrays.asList(row1, row2), schema) .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); }
@Test public void testGroupBy() { List<String> data = Arrays.asList("a", "foo", "bar"); Dataset<String> ds = context.createDataset(data, Encoders.STRING()); GroupedDataset<Integer, String> grouped = ds.groupBy( new MapFunction<String, Integer>() { @Override public Integer call(String v) throws Exception { return v.length(); } }, Encoders.INT()); Dataset<String> mapped = grouped.mapGroups( new MapGroupsFunction<Integer, String, String>() { @Override public String call(Integer key, Iterator<String> values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return sb.toString(); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); Dataset<String> flatMapped = grouped.flatMapGroups( new FlatMapGroupsFunction<Integer, String, String>() { @Override public Iterable<String> call(Integer key, Iterator<String> values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); } return Collections.singletonList(sb.toString()); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); Dataset<Tuple2<Integer, String>> reduced = grouped.reduce( new ReduceFunction<String>() { @Override public String call(String v1, String v2) throws Exception { return v1 + v2; } }); Assert.assertEquals( Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")), reduced.collectAsList()); List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT()); GroupedDataset<Integer, Integer> grouped2 = ds2.groupBy( new MapFunction<Integer, Integer>() { @Override public Integer call(Integer v) throws Exception { return v / 2; } }, Encoders.INT()); Dataset<String> cogrouped = grouped.cogroup( grouped2, new CoGroupFunction<Integer, String, Integer, String>() { @Override public Iterable<String> call( Integer key, Iterator<String> left, Iterator<Integer> right) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (left.hasNext()) { sb.append(left.next()); } sb.append("#"); while (right.hasNext()) { sb.append(right.next()); } return Collections.singletonList(sb.toString()); } }, Encoders.STRING()); Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList()); }