private static boolean isKeyPreserving(PTransform<?, ?> transform) { // This is a hacky check for what is considered key-preserving to the direct runner. // The most obvious alternative would be a package-private marker interface, but // better to make this obviously hacky so it is less likely to proliferate. Meanwhile // we intend to allow explicit expression of key-preserving DoFn in the model. if (transform instanceof ParDo.BoundMulti) { ParDo.BoundMulti<?, ?> parDo = (ParDo.BoundMulti<?, ?>) transform; return parDo.getFn() instanceof ParDoMultiOverrideFactory.ToKeyedWorkItem; } else { return false; } }
@Override public void translateNode( ParDo.BoundMulti<InputT, OutputT> transform, FlinkBatchTranslationContext context) { DoFn<InputT, OutputT> doFn = transform.getFn(); rejectStateAndTimers(doFn); DataSet<WindowedValue<InputT>> inputDataSet = context.getInputDataSet(context.getInput(transform)); List<TaggedPValue> outputs = context.getOutputs(transform); Map<TupleTag<?>, Integer> outputMap = Maps.newHashMap(); // put the main output at index 0, FlinkMultiOutputDoFnFunction expects this outputMap.put(transform.getMainOutputTag(), 0); int count = 1; for (TaggedPValue taggedValue : outputs) { if (!outputMap.containsKey(taggedValue.getTag())) { outputMap.put(taggedValue.getTag(), count++); } } // assume that the windowing strategy is the same for all outputs WindowingStrategy<?, ?> windowingStrategy = null; // collect all output Coders and create a UnionCoder for our tagged outputs List<Coder<?>> outputCoders = Lists.newArrayList(); for (TaggedPValue taggedValue : outputs) { checkState( taggedValue.getValue() instanceof PCollection, "Within ParDo, got a non-PCollection output %s of type %s", taggedValue.getValue(), taggedValue.getValue().getClass().getSimpleName()); PCollection<?> coll = (PCollection<?>) taggedValue.getValue(); outputCoders.add(coll.getCoder()); windowingStrategy = coll.getWindowingStrategy(); } if (windowingStrategy == null) { throw new IllegalStateException("No outputs defined."); } UnionCoder unionCoder = UnionCoder.of(outputCoders); TypeInformation<WindowedValue<RawUnionValue>> typeInformation = new CoderTypeInformation<>( WindowedValue.getFullCoder( unionCoder, windowingStrategy.getWindowFn().windowCoder())); List<PCollectionView<?>> sideInputs = transform.getSideInputs(); // construct a map from side input to WindowingStrategy so that // the OldDoFn runner can map main-input windows to side input windows Map<PCollectionView<?>, WindowingStrategy<?, ?>> sideInputStrategies = new HashMap<>(); for (PCollectionView<?> sideInput : sideInputs) { sideInputStrategies.put(sideInput, sideInput.getWindowingStrategyInternal()); } @SuppressWarnings("unchecked") FlinkMultiOutputDoFnFunction<InputT, OutputT> doFnWrapper = new FlinkMultiOutputDoFnFunction( doFn, windowingStrategy, sideInputStrategies, context.getPipelineOptions(), outputMap); MapPartitionOperator<WindowedValue<InputT>, WindowedValue<RawUnionValue>> taggedDataSet = new MapPartitionOperator<>( inputDataSet, typeInformation, doFnWrapper, transform.getName()); transformSideInputs(sideInputs, taggedDataSet, context); for (TaggedPValue output : outputs) { pruneOutput( taggedDataSet, context, outputMap.get(output.getTag()), (PCollection) output.getValue()); } }