void updatePendingTasks() {
   pendingTasks.clear();
   for (int i = 0; i < context.getVertexNumTasks(context.getVertexName()); ++i) {
     pendingTasks.add(new Integer(i));
   }
   totalTasksToSchedule = pendingTasks.size();
 }
  @Override
  public void onVertexStarted(Map<String, List<Integer>> completions) {
    pendingTasks =
        Lists.newArrayListWithCapacity(context.getVertexNumTasks(context.getVertexName()));
    // track the tasks in this vertex
    updatePendingTasks();
    updateSourceTaskCount();

    LOG.info(
        "OnVertexStarted vertex: "
            + context.getVertexName()
            + " with "
            + numSourceTasks
            + " source tasks and "
            + totalTasksToSchedule
            + " pending tasks");

    if (completions != null) {
      for (Map.Entry<String, List<Integer>> entry : completions.entrySet()) {
        for (Integer taskId : entry.getValue()) {
          onSourceTaskCompleted(entry.getKey(), taskId);
        }
      }
    }
    // for the special case when source has 0 tasks or min fraction == 0
    schedulePendingTasks();
  }
 void updateSourceTaskCount() {
   // track source vertices
   int numSrcTasks = 0;
   for (String vertex : bipartiteSources.keySet()) {
     numSrcTasks += context.getVertexNumTasks(vertex);
   }
   numSourceTasks = numSrcTasks;
 }
  @Test(timeout = 5000)
  public void testGetBytePayload() throws IOException {
    int numBuckets = 10;
    VertexManagerPluginContext context = mock(VertexManagerPluginContext.class);
    CustomVertexConfiguration vertexConf =
        new CustomVertexConfiguration(numBuckets, TezWork.VertexType.INITIALIZED_EDGES);
    DataOutputBuffer dob = new DataOutputBuffer();
    vertexConf.write(dob);
    UserPayload payload = UserPayload.create(ByteBuffer.wrap(dob.getData()));
    when(context.getUserPayload()).thenReturn(payload);

    CustomPartitionVertex vm = new CustomPartitionVertex(context);
    vm.initialize();

    // prepare empty routing table
    Multimap<Integer, Integer> routingTable = HashMultimap.<Integer, Integer>create();
    payload = vm.getBytePayload(routingTable);
    // get conf from user payload
    CustomEdgeConfiguration edgeConf = new CustomEdgeConfiguration();
    DataInputByteBuffer dibb = new DataInputByteBuffer();
    dibb.reset(payload.getPayload());
    edgeConf.readFields(dibb);
    assertEquals(numBuckets, edgeConf.getNumBuckets());
  }
 void schedulePendingTasks(int numTasksToSchedule) {
   // determine parallelism before scheduling the first time
   // this is the latest we can wait before determining parallelism.
   // currently this depends on task completion and so this is the best time
   // to do this. This is the max time we have until we have to launch tasks
   // as specified by the user. If/When we move to some other method of
   // calculating parallelism or change parallelism while tasks are already
   // running then we can create other parameters to trigger this calculation.
   if (enableAutoParallelism && !parallelismDetermined) {
     // do this once
     parallelismDetermined = true;
     determineParallelismAndApply();
   }
   List<TaskWithLocationHint> scheduledTasks = Lists.newArrayListWithCapacity(numTasksToSchedule);
   while (!pendingTasks.isEmpty() && numTasksToSchedule > 0) {
     numTasksToSchedule--;
     scheduledTasks.add(new TaskWithLocationHint(pendingTasks.get(0), null));
     pendingTasks.remove(0);
   }
   context.scheduleVertexTasks(scheduledTasks);
 }
  @Override
  public void initialize(VertexManagerPluginContext context) {
    Configuration conf;
    try {
      conf = TezUtils.createConfFromUserPayload(context.getUserPayload());
    } catch (IOException e) {
      throw new TezUncheckedException(e);
    }

    this.context = context;

    this.slowStartMinSrcCompletionFraction =
        conf.getFloat(
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION,
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_SRC_FRACTION_DEFAULT);
    this.slowStartMaxSrcCompletionFraction =
        conf.getFloat(
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION,
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MAX_SRC_FRACTION_DEFAULT);

    if (slowStartMinSrcCompletionFraction < 0
        || slowStartMaxSrcCompletionFraction < slowStartMinSrcCompletionFraction) {
      throw new IllegalArgumentException(
          "Invalid values for slowStartMinSrcCompletionFraction"
              + "/slowStartMaxSrcCompletionFraction. Min cannot be < 0 and "
              + "max cannot be < min.");
    }

    enableAutoParallelism =
        conf.getBoolean(
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL,
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_ENABLE_AUTO_PARALLEL_DEFAULT);
    desiredTaskInputDataSize =
        conf.getLong(
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE,
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_DESIRED_TASK_INPUT_SIZE_DEFAULT);
    minTaskParallelism =
        conf.getInt(
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM,
            ShuffleVertexManager.TEZ_AM_SHUFFLE_VERTEX_MANAGER_MIN_TASK_PARALLELISM_DEFAULT);
    LOG.info(
        "Shuffle Vertex Manager: settings"
            + " minFrac:"
            + slowStartMinSrcCompletionFraction
            + " maxFrac:"
            + slowStartMaxSrcCompletionFraction
            + " auto:"
            + enableAutoParallelism
            + " desiredTaskIput:"
            + desiredTaskInputDataSize
            + " minTasks:"
            + minTaskParallelism);

    Map<String, EdgeProperty> inputs = context.getInputVertexEdgeProperties();
    for (Map.Entry<String, EdgeProperty> entry : inputs.entrySet()) {
      if (entry.getValue().getDataMovementType() == DataMovementType.SCATTER_GATHER) {
        String vertex = entry.getKey();
        bipartiteSources.put(vertex, new HashSet<Integer>());
      }
    }
    if (bipartiteSources.isEmpty()) {
      throw new TezUncheckedException("Atleast 1 bipartite source should exist");
    }
    // dont track the source tasks here since those tasks may themselves be
    // dynamically changed as the DAG progresses.

  }
  void schedulePendingTasks() {
    int numPendingTasks = pendingTasks.size();
    if (numPendingTasks == 0) {
      return;
    }

    if (numSourceTasksCompleted == numSourceTasks && numPendingTasks > 0) {
      LOG.info(
          "All source tasks assigned. "
              + "Ramping up "
              + numPendingTasks
              + " remaining tasks for vertex: "
              + context.getVertexName());
      schedulePendingTasks(numPendingTasks);
      return;
    }

    float completedSourceTaskFraction = 0f;
    if (numSourceTasks != 0) { // support for 0 source tasks
      completedSourceTaskFraction = (float) numSourceTasksCompleted / numSourceTasks;
    } else {
      completedSourceTaskFraction = 1;
    }

    // start scheduling when source tasks completed fraction is more than min.
    // linearly increase the number of scheduled tasks such that all tasks are
    // scheduled when source tasks completed fraction reaches max
    float tasksFractionToSchedule = 1;
    float percentRange = slowStartMaxSrcCompletionFraction - slowStartMinSrcCompletionFraction;
    if (percentRange > 0) {
      tasksFractionToSchedule =
          (completedSourceTaskFraction - slowStartMinSrcCompletionFraction) / percentRange;
    } else {
      // min and max are equal. schedule 100% on reaching min
      if (completedSourceTaskFraction < slowStartMinSrcCompletionFraction) {
        tasksFractionToSchedule = 0;
      }
    }

    if (tasksFractionToSchedule > 1) {
      tasksFractionToSchedule = 1;
    } else if (tasksFractionToSchedule < 0) {
      tasksFractionToSchedule = 0;
    }

    int numTasksToSchedule =
        ((int) (tasksFractionToSchedule * totalTasksToSchedule)
            - (totalTasksToSchedule - numPendingTasks));

    if (numTasksToSchedule > 0) {
      // numTasksToSchedule can be -ve if numSourceTasksCompleted does not
      // does not increase monotonically
      LOG.info(
          "Scheduling "
              + numTasksToSchedule
              + " tasks for vertex: "
              + context.getVertexName()
              + " with totalTasks: "
              + totalTasksToSchedule
              + ". "
              + numSourceTasksCompleted
              + " source tasks completed out of "
              + numSourceTasks
              + ". SourceTaskCompletedFraction: "
              + completedSourceTaskFraction
              + " min: "
              + slowStartMinSrcCompletionFraction
              + " max: "
              + slowStartMaxSrcCompletionFraction);
      schedulePendingTasks(numTasksToSchedule);
    }
  }
  void determineParallelismAndApply() {
    if (numSourceTasksCompleted == 0) {
      return;
    }

    if (numVertexManagerEventsReceived == 0) {
      return;
    }

    int currentParallelism = pendingTasks.size();
    long expectedTotalSourceTasksOutputSize =
        (numSourceTasks * completedSourceTasksOutputSize) / numVertexManagerEventsReceived;
    int desiredTaskParallelism =
        (int)
            ((expectedTotalSourceTasksOutputSize + desiredTaskInputDataSize - 1)
                / desiredTaskInputDataSize);
    if (desiredTaskParallelism < minTaskParallelism) {
      desiredTaskParallelism = minTaskParallelism;
    }

    if (desiredTaskParallelism >= currentParallelism) {
      return;
    }

    // most shufflers will be assigned this range
    int basePartitionRange = currentParallelism / desiredTaskParallelism;

    if (basePartitionRange <= 1) {
      // nothing to do if range is equal 1 partition. shuffler does it by default
      return;
    }

    int numShufflersWithBaseRange = currentParallelism / basePartitionRange;
    int remainderRangeForLastShuffler = currentParallelism % basePartitionRange;

    int finalTaskParallelism =
        (remainderRangeForLastShuffler > 0)
            ? (numShufflersWithBaseRange + 1)
            : (numShufflersWithBaseRange);

    LOG.info(
        "Reduce auto parallelism for vertex: "
            + context.getVertexName()
            + " to "
            + finalTaskParallelism
            + " from "
            + pendingTasks.size()
            + " . Expected output: "
            + expectedTotalSourceTasksOutputSize
            + " based on actual output: "
            + completedSourceTasksOutputSize
            + " from "
            + numVertexManagerEventsReceived
            + " vertex manager events. "
            + " desiredTaskInputSize: "
            + desiredTaskInputDataSize);

    if (finalTaskParallelism < currentParallelism) {
      // final parallelism is less than actual parallelism
      Map<String, EdgeManagerDescriptor> edgeManagers =
          new HashMap<String, EdgeManagerDescriptor>(bipartiteSources.size());
      for (String vertex : bipartiteSources.keySet()) {
        // use currentParallelism for numSourceTasks to maintain original state
        // for the source tasks
        CustomShuffleEdgeManagerConfig edgeManagerConfig =
            new CustomShuffleEdgeManagerConfig(
                currentParallelism,
                finalTaskParallelism,
                numSourceTasks,
                basePartitionRange,
                ((remainderRangeForLastShuffler > 0)
                    ? remainderRangeForLastShuffler
                    : basePartitionRange));
        EdgeManagerDescriptor edgeManagerDescriptor =
            new EdgeManagerDescriptor(CustomShuffleEdgeManager.class.getName());
        edgeManagerDescriptor.setUserPayload(edgeManagerConfig.toUserPayload());
        edgeManagers.put(vertex, edgeManagerDescriptor);
      }

      context.setVertexParallelism(finalTaskParallelism, null, edgeManagers, null);
      updatePendingTasks();
    }
  }