Ejemplo n.º 1
0
  public static void main(String[] args) throws Exception {

    // Checking input parameters
    final ParameterTool params = ParameterTool.fromArgs(args);
    System.out.println(
        "Usage: KMeans --points <path> --centroids <path> --output <path> --iterations <n>");

    // set up execution environment
    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    env.getConfig()
        .setGlobalJobParameters(params); // make parameters available in the web interface

    // get input data:
    // read the points and centroids from the provided paths or fall back to default data
    DataSet<Point> points = getPointDataSet(params, env);
    DataSet<Centroid> centroids = getCentroidDataSet(params, env);

    // set number of bulk iterations for KMeans algorithm
    IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations", 10));

    DataSet<Centroid> newCentroids =
        points
            // compute closest centroid for each point
            .map(new SelectNearestCenter())
            .withBroadcastSet(loop, "centroids")
            // count and sum point coordinates for each centroid
            .map(new CountAppender())
            .groupBy(0)
            .reduce(new CentroidAccumulator())
            // compute new centroids from point counts and coordinate sums
            .map(new CentroidAverager());

    // feed new centroids back into next iteration
    DataSet<Centroid> finalCentroids = loop.closeWith(newCentroids);

    DataSet<Tuple2<Integer, Point>> clusteredPoints =
        points
            // assign points to final clusters
            .map(new SelectNearestCenter())
            .withBroadcastSet(finalCentroids, "centroids");

    // emit result
    if (params.has("output")) {
      clusteredPoints.writeAsCsv(params.get("output"), "\n", " ");

      // since file sinks are lazy, we trigger the execution explicitly
      env.execute("KMeans Example");
    } else {
      System.out.println("Printing result to stdout. Use --output to specify output path.");
      clusteredPoints.print();
    }
  }
  @Override
  protected void testProgram() throws Exception {
    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

    DataSet<Record> initialInput =
        env.readFile(new PointInFormat(), this.dataPath).setParallelism(1);

    IterativeDataSet<Record> iteration = initialInput.iterate(2);

    DataSet<Record> result = iteration.union(iteration).map(new IdentityMapper());

    iteration.closeWith(result).write(new PointOutFormat(), this.resultPath);

    env.execute();
  }
  @Override
  protected void testProgram() throws Exception {

    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    env.setDegreeOfParallelism(4);

    DataSet<Integer> data = env.fromElements(1, 2, 3, 4, 5, 6, 7, 8);

    IterativeDataSet<Integer> iteration = data.iterate(10);

    DataSet<Integer> result =
        data.reduceGroup(new PickOneAllReduce()).withBroadcastSet(iteration, "bc");

    final List<Integer> resultList = new ArrayList<Integer>();
    iteration.closeWith(result).output(new LocalCollectionOutputFormat<Integer>(resultList));

    env.execute();

    Assert.assertEquals(8, resultList.get(0).intValue());
  }