public void addInput(MutableObjectIterator<PactRecord> input, int groupId) {
   this.mockEnv.addInput(input);
   TaskConfig conf = new TaskConfig(this.mockEnv.getTaskConfiguration());
   conf.addInputToGroup(groupId);
   conf.setInputSerializer(PactRecordSerializerFactory.get(), groupId);
 }
 public void addOutput(List<PactRecord> output) {
   this.mockEnv.addOutput(output);
   TaskConfig conf = new TaskConfig(this.mockEnv.getTaskConfiguration());
   conf.addOutputShipStrategy(ShipStrategyType.FORWARD);
   conf.setOutputSerializer(PactRecordSerializerFactory.get());
 }
public class ChainTaskTest extends TaskTestBase {

  private final List<PactRecord> outList = new ArrayList<PactRecord>();

  @SuppressWarnings("unchecked")
  private final PactRecordComparatorFactory compFact =
      new PactRecordComparatorFactory(
          new int[] {0}, new Class[] {PactInteger.class}, new boolean[] {true});

  private final PactRecordSerializerFactory serFact = PactRecordSerializerFactory.get();

  @Test
  public void testMapTask() {
    final int keyCnt = 100;
    final int valCnt = 20;

    try {

      // environment
      initEnvironment(3 * 1024 * 1024);
      addInput(new UniformPactRecordGenerator(keyCnt, valCnt, false), 0);
      addOutput(this.outList);

      // chained combine config
      {
        final TaskConfig combineConfig = new TaskConfig(new Configuration());

        // input
        combineConfig.addInputToGroup(0);
        combineConfig.setInputSerializer(serFact, 0);

        // output
        combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
        combineConfig.setOutputSerializer(serFact);

        // driver
        combineConfig.setDriverStrategy(DriverStrategy.PARTIAL_GROUP);
        combineConfig.setDriverComparator(compFact, 0);
        combineConfig.setMemoryDriver(3 * 1024 * 1024);

        // udf
        combineConfig.setStubClass(MockReduceStub.class);

        getTaskConfig().addChainedTask(ChainedCombineDriver.class, combineConfig, "combine");
      }

      // chained map+combine
      {
        RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord> testTask =
            new RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord>();
        registerTask(testTask, MapDriver.class, MockMapStub.class);

        try {
          testTask.invoke();
        } catch (Exception e) {
          e.printStackTrace();
          Assert.fail("Invoke method caused exception.");
        }
      }

      Assert.assertEquals(keyCnt, this.outList.size());
    } catch (Exception e) {
      e.printStackTrace();
      Assert.fail(e.getMessage());
    }
  }

  @Test
  public void testFailingMapTask() {
    int keyCnt = 100;
    int valCnt = 20;

    try {
      // environment
      initEnvironment(3 * 1024 * 1024);
      addInput(new UniformPactRecordGenerator(keyCnt, valCnt, false), 0);
      addOutput(this.outList);

      // chained combine config
      {
        final TaskConfig combineConfig = new TaskConfig(new Configuration());

        // input
        combineConfig.addInputToGroup(0);
        combineConfig.setInputSerializer(serFact, 0);

        // output
        combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD);
        combineConfig.setOutputSerializer(serFact);

        // driver
        combineConfig.setDriverStrategy(DriverStrategy.PARTIAL_GROUP);
        combineConfig.setDriverComparator(compFact, 0);
        combineConfig.setMemoryDriver(3 * 1024 * 1024);

        // udf
        combineConfig.setStubClass(MockFailingCombineStub.class);

        getTaskConfig().addChainedTask(ChainedCombineDriver.class, combineConfig, "combine");
      }

      // chained map+combine
      {
        final RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord> testTask =
            new RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord>();

        super.registerTask(testTask, MapDriver.class, MockMapStub.class);

        boolean stubFailed = false;

        try {
          testTask.invoke();
        } catch (Exception e) {
          stubFailed = true;
        }

        Assert.assertTrue("Stub exception was not forwarded.", stubFailed);
      }
    } catch (Exception e) {
      e.printStackTrace();
      Assert.fail(e.getMessage());
    }
  }

  public static final class MockFailingCombineStub extends ReduceStub {

    private int cnt = 0;

    @Override
    public void reduce(Iterator<PactRecord> records, Collector<PactRecord> out) throws Exception {
      if (++this.cnt >= 5) {
        throw new RuntimeException("Expected Test Exception");
      }
      while (records.hasNext()) out.collect(records.next());
    }
  }
}