public void addInput(MutableObjectIterator<PactRecord> input, int groupId) { this.mockEnv.addInput(input); TaskConfig conf = new TaskConfig(this.mockEnv.getTaskConfiguration()); conf.addInputToGroup(groupId); conf.setInputSerializer(PactRecordSerializerFactory.get(), groupId); }
public void addOutput(List<PactRecord> output) { this.mockEnv.addOutput(output); TaskConfig conf = new TaskConfig(this.mockEnv.getTaskConfiguration()); conf.addOutputShipStrategy(ShipStrategyType.FORWARD); conf.setOutputSerializer(PactRecordSerializerFactory.get()); }
public class ChainTaskTest extends TaskTestBase { private final List<PactRecord> outList = new ArrayList<PactRecord>(); @SuppressWarnings("unchecked") private final PactRecordComparatorFactory compFact = new PactRecordComparatorFactory( new int[] {0}, new Class[] {PactInteger.class}, new boolean[] {true}); private final PactRecordSerializerFactory serFact = PactRecordSerializerFactory.get(); @Test public void testMapTask() { final int keyCnt = 100; final int valCnt = 20; try { // environment initEnvironment(3 * 1024 * 1024); addInput(new UniformPactRecordGenerator(keyCnt, valCnt, false), 0); addOutput(this.outList); // chained combine config { final TaskConfig combineConfig = new TaskConfig(new Configuration()); // input combineConfig.addInputToGroup(0); combineConfig.setInputSerializer(serFact, 0); // output combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); combineConfig.setOutputSerializer(serFact); // driver combineConfig.setDriverStrategy(DriverStrategy.PARTIAL_GROUP); combineConfig.setDriverComparator(compFact, 0); combineConfig.setMemoryDriver(3 * 1024 * 1024); // udf combineConfig.setStubClass(MockReduceStub.class); getTaskConfig().addChainedTask(ChainedCombineDriver.class, combineConfig, "combine"); } // chained map+combine { RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord> testTask = new RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord>(); registerTask(testTask, MapDriver.class, MockMapStub.class); try { testTask.invoke(); } catch (Exception e) { e.printStackTrace(); Assert.fail("Invoke method caused exception."); } } Assert.assertEquals(keyCnt, this.outList.size()); } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } @Test public void testFailingMapTask() { int keyCnt = 100; int valCnt = 20; try { // environment initEnvironment(3 * 1024 * 1024); addInput(new UniformPactRecordGenerator(keyCnt, valCnt, false), 0); addOutput(this.outList); // chained combine config { final TaskConfig combineConfig = new TaskConfig(new Configuration()); // input combineConfig.addInputToGroup(0); combineConfig.setInputSerializer(serFact, 0); // output combineConfig.addOutputShipStrategy(ShipStrategyType.FORWARD); combineConfig.setOutputSerializer(serFact); // driver combineConfig.setDriverStrategy(DriverStrategy.PARTIAL_GROUP); combineConfig.setDriverComparator(compFact, 0); combineConfig.setMemoryDriver(3 * 1024 * 1024); // udf combineConfig.setStubClass(MockFailingCombineStub.class); getTaskConfig().addChainedTask(ChainedCombineDriver.class, combineConfig, "combine"); } // chained map+combine { final RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord> testTask = new RegularPactTask<GenericMapper<PactRecord, PactRecord>, PactRecord>(); super.registerTask(testTask, MapDriver.class, MockMapStub.class); boolean stubFailed = false; try { testTask.invoke(); } catch (Exception e) { stubFailed = true; } Assert.assertTrue("Stub exception was not forwarded.", stubFailed); } } catch (Exception e) { e.printStackTrace(); Assert.fail(e.getMessage()); } } public static final class MockFailingCombineStub extends ReduceStub { private int cnt = 0; @Override public void reduce(Iterator<PactRecord> records, Collector<PactRecord> out) throws Exception { if (++this.cnt >= 5) { throw new RuntimeException("Expected Test Exception"); } while (records.hasNext()) out.collect(records.next()); } } }