@Test(timeout = 5000)
  public void testDagNumber() {
    String[] localDirs = new String[] {"dummyLocalDir"};
    int appAttemptNumber = 1;
    TezUmbilical tezUmbilical = mock(TezUmbilical.class);
    String dagName = "DAG_NAME";
    String vertexName = "VERTEX_NAME";
    int vertexParallelism = 20;
    int dagNumber = 52;
    ApplicationId appId = ApplicationId.newInstance(10000, 13);
    TezDAGID dagId = TezDAGID.getInstance(appId, dagNumber);
    TezVertexID vertexId = TezVertexID.getInstance(dagId, 6);
    TezTaskID taskId = TezTaskID.getInstance(vertexId, 4);
    TezTaskAttemptID taskAttemptId = TezTaskAttemptID.getInstance(taskId, 2);

    LogicalIOProcessorRuntimeTask runtimeTask = mock(LogicalIOProcessorRuntimeTask.class);
    doReturn(new TezCounters()).when(runtimeTask).addAndGetTezCounter(any(String.class));
    Map<String, ByteBuffer> serviceConsumerMetadata = Maps.newHashMap();
    Map<String, String> auxServiceEnv = Maps.newHashMap();
    MemoryDistributor memDist = mock(MemoryDistributor.class);
    ProcessorDescriptor processorDesc = mock(ProcessorDescriptor.class);
    InputReadyTracker inputReadyTracker = mock(InputReadyTracker.class);
    ObjectRegistry objectRegistry = new ObjectRegistryImpl();
    ExecutionContext execContext = new ExecutionContextImpl("localhost");
    long memAvailable = 10000l;

    TezProcessorContextImpl procContext =
        new TezProcessorContextImpl(
            new Configuration(),
            localDirs,
            appAttemptNumber,
            tezUmbilical,
            dagName,
            vertexName,
            vertexParallelism,
            taskAttemptId,
            null,
            runtimeTask,
            serviceConsumerMetadata,
            auxServiceEnv,
            memDist,
            processorDesc,
            inputReadyTracker,
            objectRegistry,
            execContext,
            memAvailable);

    assertEquals(dagNumber, procContext.getDagIdentifier());
    assertEquals(appAttemptNumber, procContext.getDAGAttemptNumber());
    assertEquals(appId, procContext.getApplicationId());
    assertEquals(dagName, procContext.getDAGName());
    assertEquals(vertexName, procContext.getTaskVertexName());
    assertEquals(vertexId.getId(), procContext.getTaskVertexIndex());
    assertTrue(Arrays.equals(localDirs, procContext.getWorkDirs()));
  }
 public VertexParallelismUpdatedProto toProto() {
   VertexParallelismUpdatedProto.Builder builder = VertexParallelismUpdatedProto.newBuilder();
   builder.setVertexId(vertexID.toString()).setNumTasks(numTasks);
   if (vertexLocationHint != null) {
     builder.setVertexLocationHint(
         DagTypeConverters.convertVertexLocationHintToProto(this.vertexLocationHint));
   }
   if (sourceEdgeManagers != null) {
     for (Entry<String, EdgeManagerPluginDescriptor> entry : sourceEdgeManagers.entrySet()) {
       EdgeManagerDescriptorProto.Builder edgeMgrBuilder = EdgeManagerDescriptorProto.newBuilder();
       edgeMgrBuilder.setEdgeName(entry.getKey());
       edgeMgrBuilder.setEntityDescriptor(DagTypeConverters.convertToDAGPlan(entry.getValue()));
       builder.addEdgeManagerDescriptors(edgeMgrBuilder.build());
     }
   }
   if (rootInputSpecUpdates != null) {
     for (Entry<String, InputSpecUpdate> entry : rootInputSpecUpdates.entrySet()) {
       RootInputSpecUpdateProto.Builder rootInputSpecUpdateBuilder =
           RootInputSpecUpdateProto.newBuilder();
       rootInputSpecUpdateBuilder.setInputName(entry.getKey());
       rootInputSpecUpdateBuilder.setForAllWorkUnits(entry.getValue().isForAllWorkUnits());
       rootInputSpecUpdateBuilder.addAllNumPhysicalInputs(
           entry.getValue().getAllNumPhysicalInputs());
       builder.addRootInputSpecUpdates(rootInputSpecUpdateBuilder.build());
     }
   }
   return builder.build();
 }
 public void fromProto(VertexParallelismUpdatedProto proto) {
   this.vertexID = TezVertexID.fromString(proto.getVertexId());
   this.numTasks = proto.getNumTasks();
   if (proto.hasVertexLocationHint()) {
     this.vertexLocationHint =
         DagTypeConverters.convertVertexLocationHintFromProto(proto.getVertexLocationHint());
   }
   if (proto.getEdgeManagerDescriptorsCount() > 0) {
     this.sourceEdgeManagers =
         new HashMap<String, EdgeManagerPluginDescriptor>(proto.getEdgeManagerDescriptorsCount());
     for (EdgeManagerDescriptorProto edgeManagerProto : proto.getEdgeManagerDescriptorsList()) {
       EdgeManagerPluginDescriptor edgeManagerDescriptor =
           DagTypeConverters.convertEdgeManagerPluginDescriptorFromDAGPlan(
               edgeManagerProto.getEntityDescriptor());
       sourceEdgeManagers.put(edgeManagerProto.getEdgeName(), edgeManagerDescriptor);
     }
   }
   if (proto.getRootInputSpecUpdatesCount() > 0) {
     this.rootInputSpecUpdates = Maps.newHashMap();
     for (RootInputSpecUpdateProto rootInputSpecUpdateProto :
         proto.getRootInputSpecUpdatesList()) {
       InputSpecUpdate specUpdate;
       if (rootInputSpecUpdateProto.getForAllWorkUnits()) {
         specUpdate =
             InputSpecUpdate.createAllTaskInputSpecUpdate(
                 rootInputSpecUpdateProto.getNumPhysicalInputs(0));
       } else {
         specUpdate =
             InputSpecUpdate.createPerTaskInputSpecUpdate(
                 rootInputSpecUpdateProto.getNumPhysicalInputsList());
       }
       this.rootInputSpecUpdates.put(rootInputSpecUpdateProto.getInputName(), specUpdate);
     }
   }
 }
 @Before
 public void setup() {
   applicationId = ApplicationId.newInstance(9999l, 1);
   applicationAttemptId = ApplicationAttemptId.newInstance(applicationId, 1);
   tezDAGID = TezDAGID.getInstance(applicationId, random.nextInt());
   tezVertexID = TezVertexID.getInstance(tezDAGID, random.nextInt());
   tezTaskID = TezTaskID.getInstance(tezVertexID, random.nextInt());
   tezTaskAttemptID = TezTaskAttemptID.getInstance(tezTaskID, random.nextInt());
   dagPlan = DAGPlan.newBuilder().setName("DAGPlanMock").build();
   containerId = ContainerId.newInstance(applicationAttemptId, 111);
   nodeId = NodeId.newInstance("node", 13435);
 }
  public void testBasicSpeculation(boolean withProgress) throws Exception {
    DAG dag = DAG.create("test");
    Vertex vA = Vertex.create("A", ProcessorDescriptor.create("Proc.class"), 5);
    dag.addVertex(vA);

    MockTezClient tezClient = createTezSession();

    DAGClient dagClient = tezClient.submitDAG(dag);
    DAGImpl dagImpl = (DAGImpl) mockApp.getContext().getCurrentDAG();
    TezVertexID vertexId = TezVertexID.getInstance(dagImpl.getID(), 0);
    // original attempt is killed and speculative one is successful
    TezTaskAttemptID killedTaId =
        TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 0);
    TezTaskAttemptID successTaId =
        TezTaskAttemptID.getInstance(TezTaskID.getInstance(vertexId, 0), 1);

    mockLauncher.updateProgress(withProgress);
    // cause speculation trigger
    mockLauncher.setStatusUpdatesForTask(killedTaId, 100);

    mockLauncher.startScheduling(true);
    dagClient.waitForCompletion();
    Assert.assertEquals(DAGStatus.State.SUCCEEDED, dagClient.getDAGStatus(null).getState());
    Task task = dagImpl.getTask(killedTaId.getTaskID());
    Assert.assertEquals(2, task.getAttempts().size());
    Assert.assertEquals(successTaId, task.getSuccessfulAttempt().getID());
    TaskAttempt killedAttempt = task.getAttempt(killedTaId);
    Joiner.on(",").join(killedAttempt.getDiagnostics()).contains("Killed as speculative attempt");
    Assert.assertEquals(
        TaskAttemptTerminationCause.TERMINATED_EFFECTIVE_SPECULATION,
        killedAttempt.getTerminationCause());
    if (withProgress) {
      // without progress updates occasionally more than 1 task speculates
      Assert.assertEquals(
          1, task.getCounters().findCounter(TaskCounter.NUM_SPECULATIONS).getValue());
      Assert.assertEquals(
          1, dagImpl.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS).getValue());
      org.apache.tez.dag.app.dag.Vertex v = dagImpl.getVertex(killedTaId.getTaskID().getVertexID());
      Assert.assertEquals(
          1, v.getAllCounters().findCounter(TaskCounter.NUM_SPECULATIONS).getValue());
    }
    tezClient.stop();
  }
Beispiel #6
0
    public WrappedContainer(boolean shouldProfile, String profileString) {
      applicationID = ApplicationId.newInstance(rmIdentifier, 1);
      appAttemptID = ApplicationAttemptId.newInstance(applicationID, 1);
      containerID = ContainerId.newInstance(appAttemptID, 1);
      nodeID = NodeId.newInstance("host", 12500);
      nodeHttpAddress = "host:12501";
      resource = Resource.newInstance(1024, 1);
      priority = Priority.newInstance(1);
      container =
          Container.newInstance(containerID, nodeID, nodeHttpAddress, resource, priority, null);

      chh = mock(ContainerHeartbeatHandler.class);

      InetSocketAddress addr = new InetSocketAddress("localhost", 0);
      tal = mock(TaskAttemptListener.class);
      doReturn(addr).when(tal).getAddress();

      dagID = TezDAGID.getInstance(applicationID, 1);
      vertexID = TezVertexID.getInstance(dagID, 1);
      taskID = TezTaskID.getInstance(vertexID, 1);
      taskAttemptID = TezTaskAttemptID.getInstance(taskID, 1);

      eventHandler = mock(EventHandler.class);
      historyEventHandler = mock(HistoryEventHandler.class);

      Configuration conf = new Configuration(false);
      appContext = mock(AppContext.class);
      doReturn(new HashMap<ApplicationAccessType, String>()).when(appContext).getApplicationACLs();
      doReturn(eventHandler).when(appContext).getEventHandler();
      doReturn(appAttemptID).when(appContext).getApplicationAttemptId();
      doReturn(applicationID).when(appContext).getApplicationID();
      doReturn(new SystemClock()).when(appContext).getClock();
      doReturn(historyEventHandler).when(appContext).getHistoryHandler();
      doReturn(conf).when(appContext).getAMConf();
      mockDAGID();

      taskSpec = mock(TaskSpec.class);
      doReturn(taskAttemptID).when(taskSpec).getTaskAttemptID();

      amContainer =
          new AMContainerImpl(container, chh, tal, new ContainerContextMatcher(), appContext);
    }
Beispiel #7
0
  @SuppressWarnings("unchecked")
  @Test
  public void testCredentialsTransfer() {
    WrappedContainerMultipleDAGs wc = new WrappedContainerMultipleDAGs();

    TezDAGID dagID2 = TezDAGID.getInstance("800", 500, 2);
    TezDAGID dagID3 = TezDAGID.getInstance("800", 500, 3);
    TezVertexID vertexID2 = TezVertexID.getInstance(dagID2, 1);
    TezVertexID vertexID3 = TezVertexID.getInstance(dagID3, 1);
    TezTaskID taskID2 = TezTaskID.getInstance(vertexID2, 1);
    TezTaskID taskID3 = TezTaskID.getInstance(vertexID3, 1);

    TezTaskAttemptID attempt11 = TezTaskAttemptID.getInstance(wc.taskID, 200);
    TezTaskAttemptID attempt12 = TezTaskAttemptID.getInstance(wc.taskID, 300);
    TezTaskAttemptID attempt21 = TezTaskAttemptID.getInstance(taskID2, 200);
    TezTaskAttemptID attempt22 = TezTaskAttemptID.getInstance(taskID2, 300);
    TezTaskAttemptID attempt31 = TezTaskAttemptID.getInstance(taskID3, 200);
    TezTaskAttemptID attempt32 = TezTaskAttemptID.getInstance(taskID3, 300);

    Map<String, LocalResource> LRs = new HashMap<String, LocalResource>();
    AMContainerTask fetchedTask = null;

    Token<TokenIdentifier> amGenToken = mock(Token.class);
    Token<TokenIdentifier> token1 = mock(Token.class);
    Token<TokenIdentifier> token3 = mock(Token.class);

    Credentials containerCredentials = new Credentials();
    TokenCache.setSessionToken(amGenToken, containerCredentials);

    Text token1Name = new Text("tokenDag1");
    Text token3Name = new Text("tokenDag3");

    Credentials dag1Credentials = new Credentials();
    dag1Credentials.addToken(new Text(token1Name), token1);
    Credentials dag3Credentials = new Credentials();
    dag3Credentials.addToken(new Text(token3Name), token3);

    wc.launchContainer(new HashMap<String, LocalResource>(), containerCredentials);
    wc.containerLaunched();
    wc.assignTaskAttempt(attempt11, LRs, dag1Credentials);
    fetchedTask = wc.pullTaskToRun();
    assertTrue(fetchedTask.haveCredentialsChanged());
    assertNotNull(fetchedTask.getCredentials());
    assertNotNull(fetchedTask.getCredentials().getToken(token1Name));
    wc.taskAttemptSucceeded(attempt11);

    wc.assignTaskAttempt(attempt12, LRs, dag1Credentials);
    fetchedTask = wc.pullTaskToRun();
    assertFalse(fetchedTask.haveCredentialsChanged());
    assertNull(fetchedTask.getCredentials());
    wc.taskAttemptSucceeded(attempt12);

    // Move to running a second DAG, with no credentials.
    wc.setNewDAGID(dagID2);
    wc.assignTaskAttempt(attempt21, LRs, null);
    fetchedTask = wc.pullTaskToRun();
    assertTrue(fetchedTask.haveCredentialsChanged());
    assertNull(fetchedTask.getCredentials());
    wc.taskAttemptSucceeded(attempt21);

    wc.assignTaskAttempt(attempt22, LRs, null);
    fetchedTask = wc.pullTaskToRun();
    assertFalse(fetchedTask.haveCredentialsChanged());
    assertNull(fetchedTask.getCredentials());
    wc.taskAttemptSucceeded(attempt22);

    // Move to running a third DAG, with Credentials this time
    wc.setNewDAGID(dagID3);
    wc.assignTaskAttempt(attempt31, LRs, dag3Credentials);
    fetchedTask = wc.pullTaskToRun();
    assertTrue(fetchedTask.haveCredentialsChanged());
    assertNotNull(fetchedTask.getCredentials());
    assertNotNull(fetchedTask.getCredentials().getToken(token3Name));
    assertNull(fetchedTask.getCredentials().getToken(token1Name));
    wc.taskAttemptSucceeded(attempt31);

    wc.assignTaskAttempt(attempt32, LRs, dag1Credentials);
    fetchedTask = wc.pullTaskToRun();
    assertFalse(fetchedTask.haveCredentialsChanged());
    assertNull(fetchedTask.getCredentials());
    wc.taskAttemptSucceeded(attempt32);
  }
Beispiel #8
0
public class TestTaskRecovery {

  private TaskImpl task;
  private DrainDispatcher dispatcher;

  private int taskAttemptCounter = 0;

  private Configuration conf = new Configuration();
  private AppContext mockAppContext;
  private MockHistoryEventHandler mockHistoryEventHandler;
  private ApplicationId appId = ApplicationId.newInstance(System.currentTimeMillis(), 1);
  private TezDAGID dagId = TezDAGID.getInstance(appId, 1);
  private TezVertexID vertexId = TezVertexID.getInstance(dagId, 1);
  private Vertex vertex;
  private String vertexName = "v1";
  private long taskScheduledTime = 100L;
  private long taskStartTime = taskScheduledTime + 100L;
  private long taskFinishTime = taskStartTime + 100L;
  private TaskAttemptEventHandler taEventHandler = new TaskAttemptEventHandler();

  private class TaskEventHandler implements EventHandler<TaskEvent> {
    @Override
    public void handle(TaskEvent event) {
      task.handle(event);
    }
  }

  private class TaskAttemptEventHandler implements EventHandler<TaskAttemptEvent> {

    private List<TaskAttemptEvent> events = Lists.newArrayList();

    @Override
    public void handle(TaskAttemptEvent event) {
      events.add(event);
      ((TaskAttemptImpl) task.getAttempt(event.getTaskAttemptID())).handle(event);
    }

    public List<TaskAttemptEvent> getEvents() {
      return events;
    }
  }

  private class TestOutputCommitter extends OutputCommitter {

    boolean recoverySupported = false;
    boolean throwExceptionWhenRecovery = false;

    public TestOutputCommitter(
        OutputCommitterContext committerContext,
        boolean recoverySupported,
        boolean throwExceptionWhenRecovery) {
      super(committerContext);
      this.recoverySupported = recoverySupported;
      this.throwExceptionWhenRecovery = throwExceptionWhenRecovery;
    }

    @Override
    public void recoverTask(int taskIndex, int previousDAGAttempt) throws Exception {
      if (throwExceptionWhenRecovery) {
        throw new Exception("fail recovery Task");
      }
    }

    @Override
    public boolean isTaskRecoverySupported() {
      return recoverySupported;
    }

    @Override
    public void initialize() throws Exception {}

    @Override
    public void setupOutput() throws Exception {}

    @Override
    public void commitOutput() throws Exception {}

    @Override
    public void abortOutput(State finalState) throws Exception {}
  }

  @Before
  public void setUp() {
    dispatcher = new DrainDispatcher();
    dispatcher.register(DAGEventType.class, mock(EventHandler.class));
    dispatcher.register(VertexEventType.class, mock(EventHandler.class));
    dispatcher.register(TaskEventType.class, new TaskEventHandler());
    dispatcher.register(TaskAttemptEventType.class, taEventHandler);
    dispatcher.init(new Configuration());
    dispatcher.start();

    vertex = mock(Vertex.class, RETURNS_DEEP_STUBS);
    when(vertex.getProcessorDescriptor().getClassName()).thenReturn("");

    mockAppContext = mock(AppContext.class, RETURNS_DEEP_STUBS);
    when(mockAppContext.getCurrentDAG().getVertex(any(TezVertexID.class))).thenReturn(vertex);
    mockHistoryEventHandler = new MockHistoryEventHandler(mockAppContext);
    when(mockAppContext.getHistoryHandler()).thenReturn(mockHistoryEventHandler);
    task =
        new TaskImpl(
            vertexId,
            0,
            dispatcher.getEventHandler(),
            new Configuration(),
            mock(TaskCommunicatorManagerInterface.class),
            new SystemClock(),
            mock(TaskHeartbeatHandler.class),
            mockAppContext,
            false,
            Resource.newInstance(1, 1),
            mock(ContainerContext.class),
            mock(StateChangeNotifier.class),
            vertex);

    Map<String, OutputCommitter> committers = new HashMap<String, OutputCommitter>();
    committers.put(
        "out1", new TestOutputCommitter(mock(OutputCommitterContext.class), true, false));
    when(task.getVertex().getOutputCommitters()).thenReturn(committers);
  }

  private void restoreFromTaskStartEvent() {
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskStartedEvent(task.getTaskId(), vertexName, taskScheduledTime, taskStartTime));
    assertEquals(TaskState.SCHEDULED, recoveredState);
    assertEquals(0, task.getFinishedAttemptsCount());
    assertEquals(taskScheduledTime, task.scheduledTime);
    assertEquals(0, task.getAttempts().size());
  }

  private void restoreFromFirstTaskAttemptStartEvent(TezTaskAttemptID taId) {
    long taStartTime = taskStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptStartedEvent(
                taId,
                vertexName,
                taStartTime,
                mock(ContainerId.class),
                mock(NodeId.class),
                "",
                "",
                "",
                0,
                null,
                0));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(0, task.getFinishedAttemptsCount());
    assertEquals(taskScheduledTime, task.scheduledTime);
    assertEquals(1, task.getAttempts().size());
    assertEquals(
        TaskAttemptStateInternal.NEW, ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
    assertEquals(1, task.getUncompletedAttemptsCount());
  }

  /** New -> RecoverTransition */
  @Test(timeout = 5000)
  public void testRecovery_New() {
    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.NEW, task.getInternalState());
  }

  /** -> restoreFromTaskFinishEvent ( no TaskStartEvent ) */
  @Test(timeout = 5000)
  public void testRecovery_NoStartEvent() {
    try {
      task.restoreFromEvent(
          new TaskFinishedEvent(
              task.getTaskId(),
              vertexName,
              taskStartTime,
              taskFinishTime,
              null,
              TaskState.SUCCEEDED,
              "",
              new TezCounters(),
              0));
      fail("Should fail due to no TaskStartEvent before TaskFinishEvent");
    } catch (Throwable e) {
      assertTrue(
          e.getMessage()
              .contains("Finished Event seen but" + " no Started Event was encountered earlier"));
    }
  }

  /** -> restoreFromTaskFinishEvent ( no TaskStartEvent ) */
  @Test(timeout = 5000)
  public void testRecoveryNewToKilled_NoStartEvent() {
    task.restoreFromEvent(
        new TaskFinishedEvent(
            task.getTaskId(),
            vertexName,
            taskStartTime,
            taskFinishTime,
            null,
            TaskState.KILLED,
            "",
            new TezCounters(),
            0));
  }

  /** restoreFromTaskStartedEvent -> RecoverTransition */
  @Test(timeout = 5000)
  public void testRecovery_Started() {
    restoreFromTaskStartEvent();

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // new task attempt is scheduled
    assertEquals(1, task.getAttempts().size());
    assertEquals(0, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptFinishedEvent (KILLED) -> RecoverTranstion
   */
  @Test(timeout = 5000)
  public void testRecovery_OnlyTAFinishedEvent_KILLED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    task.restoreFromEvent(
        new TaskAttemptFinishedEvent(
            taId,
            vertexName,
            0L,
            0L,
            TaskAttemptState.KILLED,
            TaskAttemptTerminationCause.TERMINATED_BY_CLIENT,
            "",
            new TezCounters(),
            0,
            null));
    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // wait for the second task attempt is scheduled
    dispatcher.await();
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // taskAttempt_1 is recovered to KILLED, and new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptFinishedEvent (FAILED) -> RecoverTranstion
   */
  @Test(timeout = 5000)
  public void testRecovery_OnlyTAFinishedEvent_FAILED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    task.restoreFromEvent(
        new TaskAttemptFinishedEvent(
            taId,
            vertexName,
            0L,
            0L,
            TaskAttemptState.FAILED,
            TaskAttemptTerminationCause.CONTAINER_LAUNCH_FAILED,
            "",
            new TezCounters(),
            0,
            null));
    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // wait for the second task attempt is scheduled
    dispatcher.await();
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // taskAttempt_1 is recovered to FAILED, and new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(1, task.failedAttempts);
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptFinishedEvent (SUCCEEDED) ->
   * RecoverTranstion
   */
  @Test(timeout = 5000)
  public void testRecovery_OnlyTAFinishedEvent_SUCCEEDED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    try {
      task.restoreFromEvent(
          new TaskAttemptFinishedEvent(
              taId,
              vertexName,
              0L,
              0L,
              TaskAttemptState.SUCCEEDED,
              null,
              "",
              new TezCounters(),
              0,
              null));
      fail(
          "Should fail due to no TaskAttemptStartedEvent but with TaskAttemptFinishedEvent(Succeeded)");
    } catch (TezUncheckedException e) {
      assertTrue(e.getMessage().contains("Could not find task attempt when trying to recover"));
    }
  }

  /** restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent -> RecoverTranstion */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // wait for the second task attempt is scheduled
    dispatcher.await();
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // taskAttempt_1 is recovered to KILLED, and new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_SUCCEEDED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.SUCCEEDED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.SUCCEEDED, task.getInternalState());
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);
    mockHistoryEventHandler.verifyTaskFinishedEvent(task.getTaskId(), TaskState.SUCCEEDED, 1);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (FAILED) -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_FAILED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.FAILED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(1, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(1, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (KILLED) -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_KILLED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.KILLED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> restoreFromTaskFinishedEvent ->
   * RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_SUCCEEDED_Finished() {

    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.SUCCEEDED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);

    recoveredState =
        task.restoreFromEvent(
            new TaskFinishedEvent(
                task.getTaskId(),
                vertexName,
                taskStartTime,
                taskFinishTime,
                taId,
                TaskState.SUCCEEDED,
                "",
                new TezCounters(),
                0));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(taId, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.SUCCEEDED, task.getInternalState());
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);
    mockHistoryEventHandler.verifyTaskFinishedEvent(task.getTaskId(), TaskState.SUCCEEDED, 1);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> restoreFromTaskAttemptFinishedEvent (Failed
   * due to output_failure) restoreFromTaskFinishedEvent -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_SUCCEEDED_FAILED() {

    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.SUCCEEDED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);

    // it is possible for TaskAttempt transit from SUCCEEDED to FAILURE due to output failure.
    recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.FAILED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(1, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(1, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> restoreFromTaskAttemptFinishedEvent (KILLED
   * due to node failed ) restoreFromTaskFinishedEvent -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_SUCCEEDED_KILLED() {

    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.SUCCEEDED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);

    // it is possible for TaskAttempt transit from SUCCEEDED to KILLED due to node failure.
    recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.KILLED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_Commit_Failed_Recovery_Not_Supported() {
    Map<String, OutputCommitter> committers = new HashMap<String, OutputCommitter>();
    committers.put(
        "out1", new TestOutputCommitter(mock(OutputCommitterContext.class), false, false));
    when(task.getVertex().getOutputCommitters()).thenReturn(committers);

    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    // restoreFromTaskAttemptFinishedEvent (SUCCEEDED)
    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.SUCCEEDED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (SUCCEEDED) -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_Commit_Failed_recover_fail() {
    Map<String, OutputCommitter> committers = new HashMap<String, OutputCommitter>();
    committers.put("out1", new TestOutputCommitter(mock(OutputCommitterContext.class), true, true));
    when(task.getVertex().getOutputCommitters()).thenReturn(committers);

    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);

    // restoreFromTaskAttemptFinishedEvent (SUCCEEDED)
    long taStartTime = taskStartTime + 100L;
    long taFinishTime = taStartTime + 100L;
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.SUCCEEDED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.SUCCEEDED, recoveredState);
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(taId, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  @Test(timeout = 5000)
  public void testRecovery_WithDesired_SUCCEEDED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);
    task.handle(new TaskEventRecoverTask(task.getTaskId(), TaskState.SUCCEEDED, false));
    assertEquals(TaskStateInternal.SUCCEEDED, task.getInternalState());
    // no TA_Recovery event sent
    assertEquals(0, taEventHandler.getEvents().size());
  }

  @Test(timeout = 5000)
  public void testRecovery_WithDesired_FAILED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);
    task.handle(new TaskEventRecoverTask(task.getTaskId(), TaskState.FAILED, false));
    assertEquals(TaskStateInternal.FAILED, task.getInternalState());
    // no TA_Recovery event sent
    assertEquals(0, taEventHandler.getEvents().size());
  }

  @Test(timeout = 5000)
  public void testRecovery_WithDesired_KILLED() {
    restoreFromTaskStartEvent();
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    restoreFromFirstTaskAttemptStartEvent(taId);
    task.handle(new TaskEventRecoverTask(task.getTaskId(), TaskState.KILLED, false));
    assertEquals(TaskStateInternal.KILLED, task.getInternalState());
    // no TA_Recovery event sent
    assertEquals(0, taEventHandler.getEvents().size());
  }

  /**
   * restoreFromTaskStartedEvent -> restoreFromTaskAttemptStartedEvent ->
   * restoreFromTaskAttemptFinishedEvent (KILLED) -> RecoverTransition
   */
  @Test(timeout = 5000)
  public void testRecovery_OneTAStarted_Killed() {
    restoreFromTaskStartEvent();

    long taStartTime = taskStartTime + 100L;
    TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptStartedEvent(
                taId,
                vertexName,
                taStartTime,
                mock(ContainerId.class),
                mock(NodeId.class),
                "",
                "",
                "",
                0,
                null,
                0));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(
        TaskAttemptStateInternal.NEW, ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
    assertEquals(1, task.getAttempts().size());
    assertEquals(0, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);

    long taFinishTime = taStartTime + 100L;
    recoveredState =
        task.restoreFromEvent(
            new TaskAttemptFinishedEvent(
                taId,
                vertexName,
                taStartTime,
                taFinishTime,
                TaskAttemptState.KILLED,
                null,
                "",
                new TezCounters(),
                0,
                null));
    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(
        TaskAttemptStateInternal.NEW, ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
    assertEquals(1, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(0, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // wait for Task send TA_RECOVER to TA and TA complete the RecoverTransition
    dispatcher.await();
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    assertEquals(
        TaskAttemptStateInternal.KILLED,
        ((TaskAttemptImpl) task.getAttempt(taId)).getInternalState());
    // new task attempt is scheduled
    assertEquals(2, task.getAttempts().size());
    assertEquals(1, task.getFinishedAttemptsCount());
    assertEquals(0, task.failedAttempts);
    assertEquals(1, task.getUncompletedAttemptsCount());
    assertEquals(null, task.successfulAttempt);
  }

  /**
   * n = maxFailedAttempts, in the previous AM attempt, n task attempts are killed. When recovering,
   * it should continue to be in running state and schedule a new task attempt.
   */
  @Test(timeout = 5000)
  public void testTaskRecovery_MultipleAttempts1() {
    int maxFailedAttempts =
        conf.getInt(
            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
    restoreFromTaskStartEvent();

    for (int i = 0; i < maxFailedAttempts; ++i) {
      TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
      task.restoreFromEvent(
          new TaskAttemptStartedEvent(
              taId,
              vertexName,
              0L,
              mock(ContainerId.class),
              mock(NodeId.class),
              "",
              "",
              "",
              0,
              null,
              0));
      task.restoreFromEvent(
          new TaskAttemptFinishedEvent(
              taId, vertexName, 0, 0, TaskAttemptState.KILLED, null, "", null, 0, null));
    }
    assertEquals(maxFailedAttempts, task.getAttempts().size());
    assertEquals(0, task.failedAttempts);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // if the previous task attempt is killed, it should not been take into
    // account when checking whether exceed the max attempts
    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    // schedule a new task attempt
    assertEquals(maxFailedAttempts + 1, task.getAttempts().size());
  }

  /**
   * n = maxFailedAttempts, in the previous AM attempt, n task attempts are failed. When recovering,
   * it should transit to failed because # of failed_attempt is exceeded.
   */
  @Test(timeout = 5000)
  public void testTaskRecovery_MultipleAttempts2() {
    int maxFailedAttempts =
        conf.getInt(
            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
    restoreFromTaskStartEvent();

    for (int i = 0; i < maxFailedAttempts; ++i) {
      TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
      task.restoreFromEvent(
          new TaskAttemptStartedEvent(
              taId,
              vertexName,
              0L,
              mock(ContainerId.class),
              mock(NodeId.class),
              "",
              "",
              "",
              0,
              null,
              0));
      task.restoreFromEvent(
          new TaskAttemptFinishedEvent(
              taId, vertexName, 0, 0, TaskAttemptState.FAILED, null, "", null, 0, null));
    }
    assertEquals(maxFailedAttempts, task.getAttempts().size());
    assertEquals(maxFailedAttempts, task.failedAttempts);

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // it should transit to failed because of the failed task attempt in the
    // last application attempt.
    assertEquals(TaskStateInternal.FAILED, task.getInternalState());
    assertEquals(maxFailedAttempts, task.getAttempts().size());
  }

  /**
   * n = maxFailedAttempts, in the previous AM attempt, n-1 task attempts are killed. And last task
   * attempt is still in running state. When recovering, the last attempt should transit to killed
   * and task is still in running state and new task attempt is scheduled.
   */
  @Test(timeout = 5000)
  public void testTaskRecovery_MultipleAttempts3() throws InterruptedException {
    int maxFailedAttempts =
        conf.getInt(
            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS,
            TezConfiguration.TEZ_AM_TASK_MAX_FAILED_ATTEMPTS_DEFAULT);
    restoreFromTaskStartEvent();

    for (int i = 0; i < maxFailedAttempts - 1; ++i) {
      TezTaskAttemptID taId = getNewTaskAttemptID(task.getTaskId());
      task.restoreFromEvent(
          new TaskAttemptStartedEvent(
              taId,
              vertexName,
              0L,
              mock(ContainerId.class),
              mock(NodeId.class),
              "",
              "",
              "",
              0,
              null,
              0));
      task.restoreFromEvent(
          new TaskAttemptFinishedEvent(
              taId, vertexName, 0, 0, TaskAttemptState.FAILED, null, "", null, 0, null));
    }
    assertEquals(maxFailedAttempts - 1, task.getAttempts().size());
    assertEquals(maxFailedAttempts - 1, task.failedAttempts);

    TezTaskAttemptID newTaskAttemptId = getNewTaskAttemptID(task.getTaskId());
    TaskState recoveredState =
        task.restoreFromEvent(
            new TaskAttemptStartedEvent(
                newTaskAttemptId,
                vertexName,
                0,
                mock(ContainerId.class),
                mock(NodeId.class),
                "",
                "",
                "",
                0,
                null,
                0));

    assertEquals(TaskState.RUNNING, recoveredState);
    assertEquals(
        TaskAttemptStateInternal.NEW,
        ((TaskAttemptImpl) task.getAttempt(newTaskAttemptId)).getInternalState());
    assertEquals(maxFailedAttempts, task.getAttempts().size());

    task.handle(new TaskEventRecoverTask(task.getTaskId()));
    // wait until task attempt receive the Recover event from task
    dispatcher.await();

    assertEquals(TaskStateInternal.RUNNING, task.getInternalState());
    assertEquals(
        TaskAttemptStateInternal.KILLED,
        ((TaskAttemptImpl) (task.getAttempt(newTaskAttemptId))).getInternalState());
    assertEquals(maxFailedAttempts - 1, task.failedAttempts);

    // new task attempt is added
    assertEquals(maxFailedAttempts + 1, task.getAttempts().size());
  }

  private TezTaskAttemptID getNewTaskAttemptID(TezTaskID taskId) {
    return TezTaskAttemptID.getInstance(taskId, taskAttemptCounter++);
  }
}
  @Ignore
  @Test(timeout = 10000)
  public void testDAGSchedulerMRR() {
    DAG mockDag = mock(DAG.class);
    TezDAGID dagId = TezDAGID.getInstance("1", 1, 1);

    TaskSchedulerEventHandler mockTaskScheduler = mock(TaskSchedulerEventHandler.class);

    Vertex mockVertex1 = mock(Vertex.class);
    TezVertexID mockVertexId1 = TezVertexID.getInstance(dagId, 1);
    when(mockVertex1.getVertexId()).thenReturn(mockVertexId1);
    when(mockVertex1.getDistanceFromRoot()).thenReturn(0);
    TaskAttempt mockAttempt1 = mock(TaskAttempt.class);
    when(mockAttempt1.getVertexID()).thenReturn(mockVertexId1);
    when(mockAttempt1.getIsRescheduled()).thenReturn(false);
    when(mockDag.getVertex(mockVertexId1)).thenReturn(mockVertex1);

    Vertex mockVertex2 = mock(Vertex.class);
    TezVertexID mockVertexId2 = TezVertexID.getInstance(dagId, 2);
    when(mockVertex2.getVertexId()).thenReturn(mockVertexId2);
    when(mockVertex2.getDistanceFromRoot()).thenReturn(1);
    TaskAttempt mockAttempt2 = mock(TaskAttempt.class);
    when(mockAttempt2.getVertexID()).thenReturn(mockVertexId2);
    when(mockAttempt2.getIsRescheduled()).thenReturn(false);
    when(mockDag.getVertex(mockVertexId2)).thenReturn(mockVertex2);
    TaskAttempt mockAttempt2f = mock(TaskAttempt.class);
    when(mockAttempt2f.getVertexID()).thenReturn(mockVertexId2);
    when(mockAttempt2f.getIsRescheduled()).thenReturn(true);

    Vertex mockVertex3 = mock(Vertex.class);
    TezVertexID mockVertexId3 = TezVertexID.getInstance(dagId, 3);
    when(mockVertex3.getVertexId()).thenReturn(mockVertexId3);
    when(mockVertex3.getDistanceFromRoot()).thenReturn(2);
    TaskAttempt mockAttempt3 = mock(TaskAttempt.class);
    when(mockAttempt3.getVertexID()).thenReturn(mockVertexId3);
    when(mockAttempt3.getIsRescheduled()).thenReturn(false);
    when(mockDag.getVertex(mockVertexId3)).thenReturn(mockVertex3);

    DAGEventSchedulerUpdate mockEvent1 = mock(DAGEventSchedulerUpdate.class);
    when(mockEvent1.getAttempt()).thenReturn(mockAttempt1);
    DAGEventSchedulerUpdate mockEvent2 = mock(DAGEventSchedulerUpdate.class);
    when(mockEvent2.getAttempt()).thenReturn(mockAttempt2);
    DAGEventSchedulerUpdate mockEvent2f = mock(DAGEventSchedulerUpdate.class);
    when(mockEvent2f.getAttempt()).thenReturn(mockAttempt2f);
    DAGEventSchedulerUpdate mockEvent3 = mock(DAGEventSchedulerUpdate.class);
    when(mockEvent3.getAttempt()).thenReturn(mockAttempt3);
    DAGScheduler scheduler =
        new DAGSchedulerMRR(mockDag, mockEventHandler, mockTaskScheduler, 0.5f);

    // M starts. M completes. R1 starts. R1 completes. R2 starts. R2 completes
    scheduler.scheduleTask(mockEvent1); // M starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 3);
    scheduler.scheduleTask(mockEvent1); // M runs another
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 3);
    scheduler.vertexCompleted(mockVertex1); // M completes
    scheduler.scheduleTask(mockEvent2); // R1 starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 6);
    scheduler.scheduleTask(mockEvent2); // R1 runs another
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 6);
    scheduler.scheduleTask(mockEvent2f); // R1 runs retry. Retry priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 4);
    scheduler.vertexCompleted(mockVertex2); // R1 completes
    scheduler.scheduleTask(mockEvent3); // R2 starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 9);
    scheduler.scheduleTask(mockEvent3); // R2 runs another
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 9);
    scheduler.vertexCompleted(mockVertex3); // R2 completes

    // M starts. R1 starts. M completes. R2 starts. R1 completes. R2 completes
    scheduler.scheduleTask(mockEvent1); // M starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 3);
    scheduler.scheduleTask(mockEvent2); // R1 starts. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 2);
    scheduler.scheduleTask(mockEvent1); // M runs another
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 3);
    scheduler.scheduleTask(mockEvent2); // R1 runs another. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 2);
    scheduler.scheduleTask(mockEvent2f); // R1 runs retry. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 2);
    scheduler.vertexCompleted(mockVertex1); // M completes
    scheduler.scheduleTask(mockEvent3); // R2 starts. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 5);
    scheduler.scheduleTask(mockEvent2); // R1 runs another. Normal priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 6);
    scheduler.scheduleTask(mockEvent2f); // R1 runs retry. Retry priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 4);
    scheduler.scheduleTask(mockEvent3); // R2 runs another. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 5);
    scheduler.vertexCompleted(mockVertex2); // R1 completes
    scheduler.vertexCompleted(mockVertex3); // R2 completes

    // M starts. M completes. R1 starts. R2 starts. R1 completes. R2 completes
    scheduler.scheduleTask(mockEvent1); // M starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 3);
    scheduler.vertexCompleted(mockVertex1); // M completes
    scheduler.scheduleTask(mockEvent2); // R1 starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 6);
    scheduler.scheduleTask(mockEvent3); // R2 starts. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 5);
    scheduler.scheduleTask(mockEvent2); // R1 runs another
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 6);
    scheduler.vertexCompleted(mockVertex2); // R1 completes
    scheduler.vertexCompleted(mockVertex3); // R2 completes

    // M starts. R1 starts. M completes. R1 completes. R2 starts. R2 completes
    scheduler.scheduleTask(mockEvent1); // M starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 3);
    scheduler.scheduleTask(mockEvent2); // R1 starts. Reordered priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 2);
    scheduler.vertexCompleted(mockVertex1); // M completes
    scheduler.scheduleTask(mockEvent2); // R1 starts. Normal priority
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 6);
    scheduler.vertexCompleted(mockVertex2); // R1 completes
    scheduler.scheduleTask(mockEvent3); // R2 starts
    Assert.assertTrue(mockEventHandler.event.getPriority().getPriority() == 9);
    scheduler.vertexCompleted(mockVertex3); // R2 completes
  }