@SuppressWarnings("unchecked") @Test public void testGetAggregatorStepsWithParDoBoundExtractsSteps() { @SuppressWarnings("rawtypes") ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); AggregatorProvidingDoFn<ThreadGroup, StrictMath> fn = new AggregatorProvidingDoFn<>(); when(bound.getFn()).thenReturn(fn); Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); Aggregator<Integer, Integer> aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn()); TransformTreeNode transformNode = mock(TransformTreeNode.class); when(transformNode.getTransform()).thenReturn(bound); doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode))) .when(p) .traverseTopologically(Mockito.any(PipelineVisitor.class)); AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = extractor.getAggregatorSteps(); assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne)); assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo)); assertEquals(aggregatorSteps.size(), 2); }
@SuppressWarnings("unchecked") @Test public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() { @SuppressWarnings("rawtypes") ParDo.Bound bound = mock(ParDo.Bound.class, "Bound"); @SuppressWarnings("rawtypes") ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound"); AggregatorProvidingDoFn<String, Math> fn = new AggregatorProvidingDoFn<>(); when(bound.getFn()).thenReturn(fn); when(otherBound.getFn()).thenReturn(fn); Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn()); Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn()); TransformTreeNode transformNode = mock(TransformTreeNode.class); when(transformNode.getTransform()).thenReturn(bound); TransformTreeNode otherTransformNode = mock(TransformTreeNode.class); when(otherTransformNode.getTransform()).thenReturn(otherBound); doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode))) .when(p) .traverseTopologically(Mockito.any(PipelineVisitor.class)); AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p); Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps = extractor.getAggregatorSteps(); assertEquals( ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorOne)); assertEquals( ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo)); assertEquals(2, aggregatorSteps.size()); }