コード例 #1
0
  @Override
  public void prepare() throws Exception {
    final TaskConfig config = this.taskContext.getTaskConfig();
    if (config.getDriverStrategy() != DriverStrategy.ALL_REDUCE) {
      throw new Exception(
          "Unrecognized driver strategy for AllReduce driver: "
              + config.getDriverStrategy().name());
    }

    TypeSerializerFactory<T> serializerFactory = this.taskContext.getInputSerializer(0);
    this.serializer = serializerFactory.getSerializer();
    this.input = this.taskContext.getInput(0);
  }
コード例 #2
0
  @Override
  public void prepare() throws Exception {
    final TaskConfig config = this.taskContext.getTaskConfig();
    if (config.getDriverStrategy() != DriverStrategy.CO_GROUP) {
      throw new Exception(
          "Unrecognized driver strategy for CoGoup driver: " + config.getDriverStrategy().name());
    }

    final MutableObjectIterator<IT1> in1 = this.taskContext.getInput(0);
    final MutableObjectIterator<IT2> in2 = this.taskContext.getInput(1);

    // get the key positions and types
    final TypeSerializer<IT1> serializer1 =
        this.taskContext.<IT1>getInputSerializer(0).getSerializer();
    final TypeSerializer<IT2> serializer2 =
        this.taskContext.<IT2>getInputSerializer(1).getSerializer();
    final TypeComparator<IT1> groupComparator1 = this.taskContext.getInputComparator(0);
    final TypeComparator<IT2> groupComparator2 = this.taskContext.getInputComparator(1);

    final TypePairComparatorFactory<IT1, IT2> pairComparatorFactory =
        config.getPairComparatorFactory(this.taskContext.getUserCodeClassLoader());
    if (pairComparatorFactory == null) {
      throw new Exception("Missing pair comparator factory for CoGroup driver");
    }

    // create CoGropuTaskIterator according to provided local strategy.
    this.coGroupIterator =
        new SortMergeCoGroupIterator<IT1, IT2>(
            in1,
            in2,
            serializer1,
            groupComparator1,
            serializer2,
            groupComparator2,
            pairComparatorFactory.createComparator12(groupComparator1, groupComparator2));

    // open CoGroupTaskIterator - this triggers the sorting and blocks until the iterator is ready
    this.coGroupIterator.open();

    if (LOG.isDebugEnabled()) {
      LOG.debug(this.taskContext.formatLogString("CoGroup task iterator ready."));
    }
  }
コード例 #3
0
ファイル: DataSinkTask.java プロジェクト: HuangWHWHW/flink
  /**
   * Initializes the OutputFormat implementation and configuration.
   *
   * @throws RuntimeException Throws if instance of OutputFormat implementation can not be obtained.
   */
  private void initOutputFormat() {
    ClassLoader userCodeClassLoader = getUserCodeClassLoader();
    // obtain task configuration (including stub parameters)
    Configuration taskConf = getTaskConfiguration();
    this.config = new TaskConfig(taskConf);

    try {
      this.format =
          config
              .<OutputFormat<IT>>getStubWrapper(userCodeClassLoader)
              .getUserCodeObject(OutputFormat.class, userCodeClassLoader);

      // check if the class is a subclass, if the check is required
      if (!OutputFormat.class.isAssignableFrom(this.format.getClass())) {
        throw new RuntimeException(
            "The class '"
                + this.format.getClass().getName()
                + "' is not a subclass of '"
                + OutputFormat.class.getName()
                + "' as is required.");
      }
    } catch (ClassCastException ccex) {
      throw new RuntimeException(
          "The stub class is not a proper subclass of " + OutputFormat.class.getName(), ccex);
    }

    Thread thread = Thread.currentThread();
    ClassLoader original = thread.getContextClassLoader();
    // configure the stub. catch exceptions here extra, to report them as originating from the user
    // code
    try {
      thread.setContextClassLoader(userCodeClassLoader);
      this.format.configure(this.config.getStubParameters());
    } catch (Throwable t) {
      throw new RuntimeException(
          "The user defined 'configure()' method in the Output Format caused an error: "
              + t.getMessage(),
          t);
    } finally {
      thread.setContextClassLoader(original);
    }
  }
  public static JobGraph getJobGraph(String[] args) throws Exception {

    int degreeOfParallelism = 2;
    String pageWithRankInputPath =
        ""; // "file://" + PlayConstants.PLAY_DIR + "test-inputs/danglingpagerank/pageWithRank";
    String adjacencyListInputPath = ""; // "file://" + PlayConstants.PLAY_DIR +
    //			"test-inputs/danglingpagerank/adjacencylists";
    String outputPath =
        Path.constructTestURI(
            CustomCompensatableDanglingPageRankWithCombiner.class, "flink_iterations");
    int minorConsumer = 2;
    int matchMemory = 5;
    int coGroupSortMemory = 5;
    int numIterations = 25;
    long numVertices = 5;
    long numDanglingVertices = 1;

    String failingWorkers = "1";
    int failingIteration = 2;
    double messageLoss = 0.75;

    if (args.length >= 14) {
      degreeOfParallelism = Integer.parseInt(args[0]);
      pageWithRankInputPath = args[1];
      adjacencyListInputPath = args[2];
      outputPath = args[3];
      // [4] is config path
      minorConsumer = Integer.parseInt(args[5]);
      matchMemory = Integer.parseInt(args[6]);
      coGroupSortMemory = Integer.parseInt(args[7]);
      numIterations = Integer.parseInt(args[8]);
      numVertices = Long.parseLong(args[9]);
      numDanglingVertices = Long.parseLong(args[10]);
      failingWorkers = args[11];
      failingIteration = Integer.parseInt(args[12]);
      messageLoss = Double.parseDouble(args[13]);
    }

    int totalMemoryConsumption = 3 * minorConsumer + 2 * coGroupSortMemory + matchMemory;

    JobGraph jobGraph = new JobGraph("CompensatableDanglingPageRank");

    // --------------- the inputs ---------------------

    // page rank input
    JobInputVertex pageWithRankInput =
        JobGraphUtils.createInput(
            new CustomImprovedDanglingPageRankInputFormat(),
            pageWithRankInputPath,
            "DanglingPageWithRankInput",
            jobGraph,
            degreeOfParallelism);
    TaskConfig pageWithRankInputConfig = new TaskConfig(pageWithRankInput.getConfiguration());
    pageWithRankInputConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH);
    pageWithRankInputConfig.setOutputComparator(vertexWithRankAndDanglingComparator, 0);
    pageWithRankInputConfig.setOutputSerializer(vertexWithRankAndDanglingSerializer);
    pageWithRankInputConfig.setStubParameter("pageRank.numVertices", String.valueOf(numVertices));

    // edges as adjacency list
    JobInputVertex adjacencyListInput =
        JobGraphUtils.createInput(
            new CustomImprovedAdjacencyListInputFormat(),
            adjacencyListInputPath,
            "AdjancencyListInput",
            jobGraph,
            degreeOfParallelism);
    TaskConfig adjacencyListInputConfig = new TaskConfig(adjacencyListInput.getConfiguration());
    adjacencyListInputConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH);
    adjacencyListInputConfig.setOutputSerializer(vertexWithAdjacencyListSerializer);
    adjacencyListInputConfig.setOutputComparator(vertexWithAdjacencyListComparator, 0);

    // --------------- the head ---------------------
    JobTaskVertex head =
        JobGraphUtils.createTask(
            IterationHeadPactTask.class, "IterationHead", jobGraph, degreeOfParallelism);
    TaskConfig headConfig = new TaskConfig(head.getConfiguration());
    headConfig.setIterationId(ITERATION_ID);

    // initial input / partial solution
    headConfig.addInputToGroup(0);
    headConfig.setIterationHeadPartialSolutionOrWorksetInputIndex(0);
    headConfig.setInputSerializer(vertexWithRankAndDanglingSerializer, 0);
    headConfig.setInputComparator(vertexWithRankAndDanglingComparator, 0);
    headConfig.setInputLocalStrategy(0, LocalStrategy.SORT);
    headConfig.setRelativeMemoryInput(0, (double) minorConsumer / totalMemoryConsumption);
    headConfig.setFilehandlesInput(0, NUM_FILE_HANDLES_PER_SORT);
    headConfig.setSpillingThresholdInput(0, SORT_SPILL_THRESHOLD);

    // back channel / iterations
    headConfig.setRelativeBackChannelMemory((double) minorConsumer / totalMemoryConsumption);

    // output into iteration
    headConfig.setOutputSerializer(vertexWithRankAndDanglingSerializer);
    headConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    headConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);

    // final output
    TaskConfig headFinalOutConfig = new TaskConfig(new Configuration());
    headFinalOutConfig.setOutputSerializer(vertexWithRankAndDanglingSerializer);
    headFinalOutConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    headConfig.setIterationHeadFinalOutputConfig(headFinalOutConfig);

    // the sync
    headConfig.setIterationHeadIndexOfSyncOutput(3);
    headConfig.setNumberOfIterations(numIterations);

    // the driver
    headConfig.setDriver(CollectorMapDriver.class);
    headConfig.setDriverStrategy(DriverStrategy.COLLECTOR_MAP);
    headConfig.setStubWrapper(
        new UserCodeClassWrapper<CustomCompensatingMap>(CustomCompensatingMap.class));
    headConfig.setStubParameter("pageRank.numVertices", String.valueOf(numVertices));
    headConfig.setStubParameter("compensation.failingWorker", failingWorkers);
    headConfig.setStubParameter("compensation.failingIteration", String.valueOf(failingIteration));
    headConfig.setStubParameter("compensation.messageLoss", String.valueOf(messageLoss));
    headConfig.addIterationAggregator(
        CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());

    // --------------- the join ---------------------

    JobTaskVertex intermediate =
        JobGraphUtils.createTask(
            IterationIntermediatePactTask.class,
            "IterationIntermediate",
            jobGraph,
            degreeOfParallelism);
    TaskConfig intermediateConfig = new TaskConfig(intermediate.getConfiguration());
    intermediateConfig.setIterationId(ITERATION_ID);
    //		intermediateConfig.setDriver(RepeatableHashjoinMatchDriverWithCachedBuildside.class);
    intermediateConfig.setDriver(BuildSecondCachedMatchDriver.class);
    intermediateConfig.setDriverStrategy(DriverStrategy.HYBRIDHASH_BUILD_SECOND);
    intermediateConfig.setRelativeMemoryDriver((double) matchMemory / totalMemoryConsumption);
    intermediateConfig.addInputToGroup(0);
    intermediateConfig.addInputToGroup(1);
    intermediateConfig.setInputSerializer(vertexWithRankAndDanglingSerializer, 0);
    intermediateConfig.setInputSerializer(vertexWithAdjacencyListSerializer, 1);
    intermediateConfig.setDriverComparator(vertexWithRankAndDanglingComparator, 0);
    intermediateConfig.setDriverComparator(vertexWithAdjacencyListComparator, 1);
    intermediateConfig.setDriverPairComparator(matchComparator);

    intermediateConfig.setOutputSerializer(vertexWithRankSerializer);
    intermediateConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);

    intermediateConfig.setStubWrapper(
        new UserCodeClassWrapper<CustomCompensatableDotProductMatch>(
            CustomCompensatableDotProductMatch.class));
    intermediateConfig.setStubParameter("pageRank.numVertices", String.valueOf(numVertices));
    intermediateConfig.setStubParameter("compensation.failingWorker", failingWorkers);
    intermediateConfig.setStubParameter(
        "compensation.failingIteration", String.valueOf(failingIteration));
    intermediateConfig.setStubParameter("compensation.messageLoss", String.valueOf(messageLoss));

    // the combiner and the output
    TaskConfig combinerConfig = new TaskConfig(new Configuration());
    combinerConfig.addInputToGroup(0);
    combinerConfig.setInputSerializer(vertexWithRankSerializer, 0);
    combinerConfig.setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
    combinerConfig.setDriverComparator(vertexWithRankComparator, 0);
    combinerConfig.setRelativeMemoryDriver((double) coGroupSortMemory / totalMemoryConsumption);
    combinerConfig.setOutputSerializer(vertexWithRankSerializer);
    combinerConfig.addOutputShipStrategy(ShipStrategyType.PARTITION_HASH);
    combinerConfig.setOutputComparator(vertexWithRankComparator, 0);
    combinerConfig.setStubWrapper(
        new UserCodeClassWrapper<CustomRankCombiner>(CustomRankCombiner.class));
    intermediateConfig.addChainedTask(
        SynchronousChainedCombineDriver.class, combinerConfig, "Combiner");

    // ---------------- the tail (co group) --------------------

    JobTaskVertex tail =
        JobGraphUtils.createTask(
            IterationTailPactTask.class, "IterationTail", jobGraph, degreeOfParallelism);
    TaskConfig tailConfig = new TaskConfig(tail.getConfiguration());
    tailConfig.setIterationId(ITERATION_ID);
    tailConfig.setIsWorksetUpdate();

    // inputs and driver
    tailConfig.setDriver(CoGroupDriver.class);
    tailConfig.setDriverStrategy(DriverStrategy.CO_GROUP);
    tailConfig.addInputToGroup(0);
    tailConfig.addInputToGroup(1);
    tailConfig.setInputSerializer(vertexWithRankAndDanglingSerializer, 0);
    tailConfig.setInputSerializer(vertexWithRankSerializer, 1);
    tailConfig.setDriverComparator(vertexWithRankAndDanglingComparator, 0);
    tailConfig.setDriverComparator(vertexWithRankComparator, 1);
    tailConfig.setDriverPairComparator(coGroupComparator);
    tailConfig.setInputAsynchronouslyMaterialized(0, true);
    tailConfig.setRelativeInputMaterializationMemory(
        0, (double) minorConsumer / totalMemoryConsumption);
    tailConfig.setInputLocalStrategy(1, LocalStrategy.SORT);
    tailConfig.setInputComparator(vertexWithRankComparator, 1);
    tailConfig.setRelativeMemoryInput(1, (double) coGroupSortMemory / totalMemoryConsumption);
    tailConfig.setFilehandlesInput(1, NUM_FILE_HANDLES_PER_SORT);
    tailConfig.setSpillingThresholdInput(1, SORT_SPILL_THRESHOLD);
    tailConfig.addIterationAggregator(
        CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());

    // output
    tailConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
    tailConfig.setOutputSerializer(vertexWithRankAndDanglingSerializer);

    // the stub
    tailConfig.setStubWrapper(
        new UserCodeClassWrapper<CustomCompensatableDotProductCoGroup>(
            CustomCompensatableDotProductCoGroup.class));
    tailConfig.setStubParameter("pageRank.numVertices", String.valueOf(numVertices));
    tailConfig.setStubParameter(
        "pageRank.numDanglingVertices", String.valueOf(numDanglingVertices));
    tailConfig.setStubParameter("compensation.failingWorker", failingWorkers);
    tailConfig.setStubParameter("compensation.failingIteration", String.valueOf(failingIteration));
    tailConfig.setStubParameter("compensation.messageLoss", String.valueOf(messageLoss));

    // --------------- the output ---------------------

    JobOutputVertex output =
        JobGraphUtils.createFileOutput(jobGraph, "FinalOutput", degreeOfParallelism);
    TaskConfig outputConfig = new TaskConfig(output.getConfiguration());
    outputConfig.addInputToGroup(0);
    outputConfig.setInputSerializer(vertexWithRankAndDanglingSerializer, 0);
    outputConfig.setStubWrapper(
        new UserCodeClassWrapper<CustomPageWithRankOutFormat>(CustomPageWithRankOutFormat.class));
    outputConfig.setStubParameter(FileOutputFormat.FILE_PARAMETER_KEY, outputPath);

    // --------------- the auxiliaries ---------------------

    JobOutputVertex fakeTailOutput =
        JobGraphUtils.createFakeOutput(jobGraph, "FakeTailOutput", degreeOfParallelism);

    JobOutputVertex sync = JobGraphUtils.createSync(jobGraph, degreeOfParallelism);
    TaskConfig syncConfig = new TaskConfig(sync.getConfiguration());
    syncConfig.setNumberOfIterations(numIterations);
    syncConfig.addIterationAggregator(
        CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new PageRankStatsAggregator());
    syncConfig.setConvergenceCriterion(
        CustomCompensatableDotProductCoGroup.AGGREGATOR_NAME, new DiffL1NormConvergenceCriterion());
    syncConfig.setIterationId(ITERATION_ID);

    // --------------- the wiring ---------------------

    JobGraphUtils.connect(
        pageWithRankInput, head, ChannelType.NETWORK, DistributionPattern.BIPARTITE);

    JobGraphUtils.connect(head, intermediate, ChannelType.IN_MEMORY, DistributionPattern.POINTWISE);
    intermediateConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(0, 1);

    JobGraphUtils.connect(
        adjacencyListInput, intermediate, ChannelType.NETWORK, DistributionPattern.BIPARTITE);

    JobGraphUtils.connect(head, tail, ChannelType.NETWORK, DistributionPattern.POINTWISE);
    JobGraphUtils.connect(intermediate, tail, ChannelType.NETWORK, DistributionPattern.BIPARTITE);
    tailConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(0, 1);
    tailConfig.setGateIterativeWithNumberOfEventsUntilInterrupt(1, degreeOfParallelism);

    JobGraphUtils.connect(head, output, ChannelType.IN_MEMORY, DistributionPattern.POINTWISE);
    JobGraphUtils.connect(
        tail, fakeTailOutput, ChannelType.IN_MEMORY, DistributionPattern.POINTWISE);

    JobGraphUtils.connect(head, sync, ChannelType.NETWORK, DistributionPattern.POINTWISE);

    fakeTailOutput.setVertexToShareInstancesWith(tail);
    tail.setVertexToShareInstancesWith(head);
    pageWithRankInput.setVertexToShareInstancesWith(head);
    adjacencyListInput.setVertexToShareInstancesWith(head);
    intermediate.setVertexToShareInstancesWith(head);
    output.setVertexToShareInstancesWith(head);
    sync.setVertexToShareInstancesWith(head);

    return jobGraph;
  }
コード例 #5
0
  @Override
  public void invoke() throws Exception {
    userCodeClassLoader = LibraryCacheManager.getClassLoader(getEnvironment().getJobID());
    TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());

    // store all aggregators
    this.aggregators = new HashMap<String, Aggregator<?>>();
    for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators()) {
      aggregators.put(aggWithName.getName(), aggWithName.getAggregator());
    }

    // store the aggregator convergence criterion
    if (taskConfig.usesConvergenceCriterion()) {
      convergenceCriterion = taskConfig.getConvergenceCriterion();
      convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
      Preconditions.checkNotNull(convergenceAggregatorName);
    }

    maxNumberOfIterations = taskConfig.getNumberOfIterations();

    // set up the event handler
    int numEventsTillEndOfSuperstep = taskConfig.getNumberOfEventsUntilInterruptInIterativeGate(0);
    eventHandler =
        new SyncEventHandler(numEventsTillEndOfSuperstep, aggregators, userCodeClassLoader);
    headEventReader.subscribeToEvent(eventHandler, WorkerDoneEvent.class);

    IntegerRecord dummy = new IntegerRecord();

    while (!terminationRequested()) {

      //			notifyMonitor(IterationMonitoring.Event.SYNC_STARTING, currentIteration);
      if (log.isInfoEnabled()) {
        log.info(formatLogString("starting iteration [" + currentIteration + "]"));
      }

      // this call listens for events until the end-of-superstep is reached
      readHeadEventChannel(dummy);

      if (log.isInfoEnabled()) {
        log.info(formatLogString("finishing iteration [" + currentIteration + "]"));
      }

      if (checkForConvergence()) {
        if (log.isInfoEnabled()) {
          log.info(
              formatLogString(
                  "signaling that all workers are to terminate in iteration ["
                      + currentIteration
                      + "]"));
        }

        requestTermination();
        sendToAllWorkers(new TerminationEvent());
        //				notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
      } else {
        if (log.isInfoEnabled()) {
          log.info(
              formatLogString(
                  "signaling that all workers are done in iteration [" + currentIteration + "]"));
        }

        AllWorkersDoneEvent allWorkersDoneEvent = new AllWorkersDoneEvent(aggregators);
        sendToAllWorkers(allWorkersDoneEvent);

        // reset all aggregators
        for (Aggregator<?> agg : aggregators.values()) {
          agg.reset();
        }

        //				notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
        currentIteration++;
      }
    }
  }