@SuppressWarnings("unchecked")
  @Test
  public void testGetAggregatorStepsWithParDoBoundExtractsSteps() {
    @SuppressWarnings("rawtypes")
    ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
    AggregatorProvidingDoFn<ThreadGroup, StrictMath> fn = new AggregatorProvidingDoFn<>();
    when(bound.getFn()).thenReturn(fn);

    Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
    Aggregator<Integer, Integer> aggregatorTwo = fn.addAggregator(new Min.MinIntegerFn());

    TransformTreeNode transformNode = mock(TransformTreeNode.class);
    when(transformNode.getTransform()).thenReturn(bound);

    doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode)))
        .when(p)
        .traverseTopologically(Mockito.any(PipelineVisitor.class));

    AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);

    Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
        extractor.getAggregatorSteps();

    assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorOne));
    assertEquals(ImmutableSet.<PTransform<?, ?>>of(bound), aggregatorSteps.get(aggregatorTwo));
    assertEquals(aggregatorSteps.size(), 2);
  }
  @SuppressWarnings("unchecked")
  @Test
  public void testGetAggregatorStepsWithOneAggregatorInMultipleStepsAddsSteps() {
    @SuppressWarnings("rawtypes")
    ParDo.Bound bound = mock(ParDo.Bound.class, "Bound");
    @SuppressWarnings("rawtypes")
    ParDo.BoundMulti otherBound = mock(ParDo.BoundMulti.class, "otherBound");
    AggregatorProvidingDoFn<String, Math> fn = new AggregatorProvidingDoFn<>();
    when(bound.getFn()).thenReturn(fn);
    when(otherBound.getFn()).thenReturn(fn);

    Aggregator<Long, Long> aggregatorOne = fn.addAggregator(new Sum.SumLongFn());
    Aggregator<Double, Double> aggregatorTwo = fn.addAggregator(new Min.MinDoubleFn());

    TransformTreeNode transformNode = mock(TransformTreeNode.class);
    when(transformNode.getTransform()).thenReturn(bound);
    TransformTreeNode otherTransformNode = mock(TransformTreeNode.class);
    when(otherTransformNode.getTransform()).thenReturn(otherBound);

    doAnswer(new VisitNodesAnswer(ImmutableList.of(transformNode, otherTransformNode)))
        .when(p)
        .traverseTopologically(Mockito.any(PipelineVisitor.class));

    AggregatorPipelineExtractor extractor = new AggregatorPipelineExtractor(p);

    Map<Aggregator<?, ?>, Collection<PTransform<?, ?>>> aggregatorSteps =
        extractor.getAggregatorSteps();

    assertEquals(
        ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorOne));
    assertEquals(
        ImmutableSet.<PTransform<?, ?>>of(bound, otherBound), aggregatorSteps.get(aggregatorTwo));
    assertEquals(2, aggregatorSteps.size());
  }