private static JobTaskVertex createMapper(
      JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) {
    JobTaskVertex pointsInput =
        JobGraphUtils.createTask(
            RegularPactTask.class, "Map[DotProducts]", jobGraph, numSubTasks, numSubTasks);

    {
      TaskConfig taskConfig = new TaskConfig(pointsInput.getConfiguration());

      taskConfig.setStubWrapper(new UserCodeClassWrapper<DotProducts>(DotProducts.class));
      taskConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
      taskConfig.setOutputSerializer(serializer);
      taskConfig.setDriver(CollectorMapDriver.class);
      taskConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);

      taskConfig.addInputToGroup(0);
      taskConfig.setInputLocalStrategy(0, LocalStrategy.NONE);
      taskConfig.setInputSerializer(serializer, 0);

      taskConfig.setBroadcastInputName("models", 0);
      taskConfig.addBroadcastInputToGroup(0);
      taskConfig.setBroadcastInputSerializer(serializer, 0);
    }

    return pointsInput;
  }
Ejemplo n.º 2
0
  public void registerTask(
      AbstractTask task,
      @SuppressWarnings("rawtypes") Class<? extends PactDriver> driver,
      Class<? extends Stub> stubClass) {
    final TaskConfig config = new TaskConfig(this.mockEnv.getTaskConfiguration());
    config.setDriver(driver);
    config.setStubWrapper(new UserCodeClassWrapper<Stub>(stubClass));

    task.setEnvironment(this.mockEnv);

    if (task instanceof RegularPactTask<?, ?>) {
      ((RegularPactTask<?, ?>) task).setUserCodeClassLoader(getClass().getClassLoader());
    }

    task.registerInputOutput();
  }
  private static JobTaskVertex createReducer(
      JobGraph jobGraph,
      int numSubTasks,
      TypeSerializerFactory<?> inputSerializer,
      TypeComparatorFactory<?> inputComparator,
      TypeSerializerFactory<?> outputSerializer) {
    // ---------------- the tail (co group) --------------------

    JobTaskVertex tail =
        JobGraphUtils.createTask(
            IterationTailPactTask.class,
            "Reduce / Iteration Tail",
            jobGraph,
            numSubTasks,
            numSubTasks);

    TaskConfig tailConfig = new TaskConfig(tail.getConfiguration());
    tailConfig.setIterationId(ITERATION_ID);
    tailConfig.setIsWorksetUpdate();

    // inputs and driver
    tailConfig.setDriver(ReduceDriver.class);
    tailConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP);
    tailConfig.addInputToGroup(0);
    tailConfig.setInputSerializer(inputSerializer, 0);
    tailConfig.setDriverComparator(inputComparator, 0);

    tailConfig.setInputLocalStrategy(0, LocalStrategy.SORT);
    tailConfig.setInputComparator(inputComparator, 0);
    tailConfig.setMemoryInput(0, MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE);
    tailConfig.setFilehandlesInput(0, 128);
    tailConfig.setSpillingThresholdInput(0, 0.9f);

    // output
    tailConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    tailConfig.setOutputSerializer(outputSerializer);

    // the udf
    tailConfig.setStubWrapper(
        new UserCodeObjectWrapper<RecomputeClusterCenter>(new RecomputeClusterCenter()));

    return tail;
  }
  private static JobTaskVertex createMapper(
      JobGraph jobGraph,
      int numSubTasks,
      TypeSerializerFactory<?> inputSerializer,
      TypeSerializerFactory<?> broadcastVarSerializer,
      TypeSerializerFactory<?> outputSerializer,
      TypeComparatorFactory<?> outputComparator) {
    JobTaskVertex mapper =
        JobGraphUtils.createTask(
            IterationIntermediatePactTask.class,
            "Map (Select nearest center)",
            jobGraph,
            numSubTasks,
            numSubTasks);

    TaskConfig intermediateConfig = new TaskConfig(mapper.getConfiguration());
    intermediateConfig.setIterationId(ITERATION_ID);

    intermediateConfig.setDriver(CollectorMapDriver.class);
    intermediateConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
    intermediateConfig.addInputToGroup(0);
    intermediateConfig.setInputSerializer(inputSerializer, 0);

    intermediateConfig.setOutputSerializer(outputSerializer);
    intermediateConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH);
    intermediateConfig.setOutputComparator(outputComparator, 0);

    intermediateConfig.setBroadcastInputName("centers", 0);
    intermediateConfig.addBroadcastInputToGroup(0);
    intermediateConfig.setBroadcastInputSerializer(broadcastVarSerializer, 0);

    // the udf
    intermediateConfig.setStubWrapper(
        new UserCodeObjectWrapper<SelectNearestCenter>(new SelectNearestCenter()));

    return mapper;
  }
  private static JobTaskVertex createIterationHead(
      JobGraph jobGraph, int numSubTasks, TypeSerializerFactory<?> serializer) {
    JobTaskVertex head =
        JobGraphUtils.createTask(
            IterationHeadPactTask.class, "Iteration Head", jobGraph, numSubTasks, numSubTasks);

    TaskConfig headConfig = new TaskConfig(head.getConfiguration());
    headConfig.setIterationId(ITERATION_ID);

    // initial input / partial solution
    headConfig.addInputToGroup(0);
    headConfig.setIterationHeadPartialSolutionOrWorksetInputIndex(0);
    headConfig.setInputSerializer(serializer, 0);

    // back channel / iterations
    headConfig.setBackChannelMemory(MEMORY_PER_CONSUMER * JobGraphUtils.MEGABYTE);

    // output into iteration. broadcasting the centers
    headConfig.setOutputSerializer(serializer);
    headConfig.addOutputShipStrategy(ShipStrategyType.BROADCAST);

    // final output
    TaskConfig headFinalOutConfig = new TaskConfig(new Configuration());
    headFinalOutConfig.setOutputSerializer(serializer);
    headFinalOutConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    headConfig.setIterationHeadFinalOutputConfig(headFinalOutConfig);

    // the sync
    headConfig.setIterationHeadIndexOfSyncOutput(2);

    // the driver
    headConfig.setDriver(NoOpDriver.class);
    headConfig.setDriverStrategy(DriverStrategy.UNARY_NO_OP);

    return head;
  }