public static void main(String[] args) throws Exception { // Checking input parameters final ParameterTool params = ParameterTool.fromArgs(args); System.out.println( "Usage: KMeans --points <path> --centroids <path> --output <path> --iterations <n>"); // set up execution environment ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.getConfig() .setGlobalJobParameters(params); // make parameters available in the web interface // get input data: // read the points and centroids from the provided paths or fall back to default data DataSet<Point> points = getPointDataSet(params, env); DataSet<Centroid> centroids = getCentroidDataSet(params, env); // set number of bulk iterations for KMeans algorithm IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10)); DataSet<Centroid> newCentroids = points // compute closest centroid for each point .map(new SelectNearestCenter()) .withBroadcastSet(loop, "centroids") // count and sum point coordinates for each centroid .map(new CountAppender()) .groupBy(0) .reduce(new CentroidAccumulator()) // compute new centroids from point counts and coordinate sums .map(new CentroidAverager()); // feed new centroids back into next iteration DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids); DataSet<Tuple2<Integer, Point>> clusteredPoints = points // assign points to final clusters .map(new SelectNearestCenter()) .withBroadcastSet(finalCentroids, "centroids"); // emit result if (params.has("output")) { clusteredPoints.writeAsCsv(params.get("output"), "\n", " "); // since file sinks are lazy, we trigger the execution explicitly env.execute("KMeans Example"); } else { System.out.println("Printing result to stdout. Use --output to specify output path."); clusteredPoints.print(); } }
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); DataSet<Record> initialInput = env.readFile(new PointInFormat(), this.dataPath).setParallelism(1); IterativeDataSet<Record> iteration = initialInput.iterate(2); DataSet<Record> result = iteration.union(iteration).map(new IdentityMapper()); iteration.closeWith(result).write(new PointOutFormat(), this.resultPath); env.execute(); }
@Override protected void testProgram() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); env.setDegreeOfParallelism(4); DataSet<Integer> data = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8); IterativeDataSet<Integer> iteration = data.iterate(10); DataSet<Integer> result = data.reduceGroup(new PickOneAllReduce()).withBroadcastSet(iteration, "bc"); final List<Integer> resultList = new ArrayList<Integer>(); iteration.closeWith(result).output(new LocalCollectionOutputFormat<Integer>(resultList)); env.execute(); Assert.assertEquals(8, resultList.get(0).intValue()); }