@Override
 public void initialize() {
   // Nothing to do. This class isn't currently designed to be used at the DAG API level.
   UserPayload userPayload = getContext().getUserPayload();
   if (userPayload == null
       || userPayload.getPayload() == null
       || userPayload.getPayload().limit() == 0) {
     throw new RuntimeException(
         "Could not initialize CustomShuffleEdgeManager" + " from provided user payload");
   }
   CustomShuffleEdgeManagerConfig config;
   try {
     config = CustomShuffleEdgeManagerConfig.fromUserPayload(userPayload);
   } catch (InvalidProtocolBufferException e) {
     throw new RuntimeException(
         "Could not initialize CustomShuffleEdgeManager" + " from provided user payload", e);
   }
   this.numSourceTaskOutputs = config.numSourceTaskOutputs;
   this.numDestinationTasks = config.numDestinationTasks;
   this.basePartitionRange = config.basePartitionRange;
   this.remainderRangeForLastShuffler = config.remainderRangeForLastShuffler;
   this.numSourceTasks = getContext().getSourceVertexNumTasks();
   Preconditions.checkState(
       this.numDestinationTasks == getContext().getDestinationVertexNumTasks());
 }
 public static CustomShuffleEdgeManagerConfig fromUserPayload(UserPayload payload)
     throws InvalidProtocolBufferException {
   ShuffleEdgeManagerConfigPayloadProto proto =
       ShuffleEdgeManagerConfigPayloadProto.parseFrom(ByteString.copyFrom(payload.getPayload()));
   return new CustomShuffleEdgeManagerConfig(
       proto.getNumSourceTaskOutputs(),
       proto.getNumDestinationTasks(),
       proto.getBasePartitionRange(),
       proto.getRemainderRangeForLastShuffler());
 }
 public UserPayload toUserPayload() {
   return UserPayload.create(
       ByteBuffer.wrap(
           ShuffleEdgeManagerConfigPayloadProto.newBuilder()
               .setNumSourceTaskOutputs(numSourceTaskOutputs)
               .setNumDestinationTasks(numDestinationTasks)
               .setBasePartitionRange(basePartitionRange)
               .setRemainderRangeForLastShuffler(remainderRangeForLastShuffler)
               .build()
               .toByteArray()));
 }
Beispiel #4
0
 private InputContext createTezInputContext() {
   TezCounters counters = new TezCounters();
   InputContext inputContext = mock(InputContext.class);
   doReturn(1024 * 1024 * 100l).when(inputContext).getTotalMemoryAvailableToTask();
   doReturn(counters).when(inputContext).getCounters();
   doReturn(1).when(inputContext).getInputIndex();
   doReturn("srcVertex").when(inputContext).getSourceVertexName();
   doReturn(1).when(inputContext).getTaskVertexIndex();
   doReturn(UserPayload.create(ByteBuffer.wrap(new byte[1024])))
       .when(inputContext)
       .getUserPayload();
   return inputContext;
 }
  @Test(timeout = 5000)
  public void testGetBytePayload() throws IOException {
    int numBuckets = 10;
    VertexManagerPluginContext context = mock(VertexManagerPluginContext.class);
    CustomVertexConfiguration vertexConf =
        new CustomVertexConfiguration(numBuckets, TezWork.VertexType.INITIALIZED_EDGES);
    DataOutputBuffer dob = new DataOutputBuffer();
    vertexConf.write(dob);
    UserPayload payload = UserPayload.create(ByteBuffer.wrap(dob.getData()));
    when(context.getUserPayload()).thenReturn(payload);

    CustomPartitionVertex vm = new CustomPartitionVertex(context);
    vm.initialize();

    // prepare empty routing table
    Multimap<Integer, Integer> routingTable = HashMultimap.<Integer, Integer>create();
    payload = vm.getBytePayload(routingTable);
    // get conf from user payload
    CustomEdgeConfiguration edgeConf = new CustomEdgeConfiguration();
    DataInputByteBuffer dibb = new DataInputByteBuffer();
    dibb.reset(payload.getPayload());
    edgeConf.readFields(dibb);
    assertEquals(numBuckets, edgeConf.getNumBuckets());
  }