@Test public void testCreateNormalParDoFn() throws Exception { String stringState = "some state"; long longState = 42L; TestDoFn fn = new TestDoFn(stringState, longState); String serializedFn = StringUtils.byteArrayToJsonString( SerializableUtils.serializeToByteArray( new DoFnInfo(fn, WindowingStrategy.globalDefault()))); CloudObject cloudUserFn = CloudObject.forClassName("DoFn"); addString(cloudUserFn, "serialized_fn", serializedFn); String tag = "output"; MultiOutputInfo multiOutputInfo = new MultiOutputInfo(); multiOutputInfo.setTag(tag); List<MultiOutputInfo> multiOutputInfos = Arrays.asList(multiOutputInfo); PipelineOptions options = PipelineOptionsFactory.create(); DataflowExecutionContext context = BatchModeExecutionContext.fromOptions(options); CounterSet counters = new CounterSet(); StateSampler stateSampler = new StateSampler("test", counters.getAddCounterMutator()); ParDoFn parDoFn = factory.create( options, cloudUserFn, "name", "transformName", null, multiOutputInfos, 1, context, counters.getAddCounterMutator(), stateSampler); // Test that the factory created the correct class assertThat(parDoFn, instanceOf(NormalParDoFn.class)); // Test that the DoFnInfo reflects the one passed in NormalParDoFn normalParDoFn = (NormalParDoFn) parDoFn; DoFnInfo doFnInfo = normalParDoFn.getDoFnInfo(); DoFn actualDoFn = doFnInfo.getDoFn(); assertThat(actualDoFn, instanceOf(TestDoFn.class)); assertThat(doFnInfo.getWindowingStrategy().getWindowFn(), instanceOf(GlobalWindows.class)); assertThat( doFnInfo.getWindowingStrategy().getTrigger().getSpec(), instanceOf(DefaultTrigger.class)); // Test that the deserialized user DoFn is as expected TestDoFn actualTestDoFn = (TestDoFn) actualDoFn; assertEquals(stringState, actualTestDoFn.stringState); assertEquals(longState, actualTestDoFn.longState); assertEquals(context, normalParDoFn.getExecutionContext()); }
public static CombineValuesFn create( PipelineOptions options, CloudObject cloudUserFn, String stepName, @Nullable List<SideInputInfo> sideInputInfos, @Nullable List<MultiOutputInfo> multiOutputInfos, Integer numOutputs, ExecutionContext executionContext, CounterSet.AddCounterMutator addCounterMutator, StateSampler stateSampler /* unused */) throws Exception { Object deserializedFn = SerializableUtils.deserializeFromByteArray( getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), "serialized user fn"); Preconditions.checkArgument(deserializedFn instanceof Combine.KeyedCombineFn); final Combine.KeyedCombineFn combineFn = (Combine.KeyedCombineFn) deserializedFn; // Get the combine phase, default to ALL. (The implementation // doesn't have to split the combiner). final String phase = getString(cloudUserFn, PropertyNames.PHASE, CombinePhase.ALL); Preconditions.checkArgument( sideInputInfos == null || sideInputInfos.size() == 0, "unexpected side inputs for CombineValuesFn"); Preconditions.checkArgument(numOutputs == 1, "expected exactly one output for CombineValuesFn"); DoFnInfoFactory fnFactory = new DoFnInfoFactory() { @Override public DoFnInfo createDoFnInfo() { DoFn doFn = null; switch (phase) { case CombinePhase.ALL: doFn = new CombineValuesDoFn(combineFn); break; case CombinePhase.ADD: doFn = new AddInputsDoFn(combineFn); break; case CombinePhase.MERGE: doFn = new MergeAccumulatorsDoFn(combineFn); break; case CombinePhase.EXTRACT: doFn = new ExtractOutputDoFn(combineFn); break; default: throw new IllegalArgumentException( "phase must be one of 'all', 'add', 'merge', 'extract'"); } return new DoFnInfo(doFn, null); } }; return new CombineValuesFn(options, fnFactory, stepName, executionContext, addCounterMutator); }
@Override public void setConf(Configuration conf) { this.conf = conf; if (conf != null) { String headerString = conf.get(SAM_HEADER_PROPERTY_NAME); if (headerString == null) { throw new IllegalStateException("SAM file header has not been set"); } byte[] headerBytes = Base64.getDecoder().decode(headerString); setSAMHeader( (SAMFileHeader) SerializableUtils.deserializeFromByteArray(headerBytes, "SAMFileHeader")); } }
private static void writeToHadoop( Pipeline pipeline, PCollection<GATKRead> reads, final SAMFileHeader header, final String destPath, final boolean parquet) { if (destPath.equals("/dev/null")) { return; } String headerString = Base64.getEncoder().encodeToString(SerializableUtils.serializeToByteArray(header)); @SuppressWarnings("unchecked") Class<? extends FileOutputFormat<NullWritable, SAMRecordWritable>> outputFormatClass = (Class<? extends FileOutputFormat<NullWritable, SAMRecordWritable>>) (Class<?>) TemplatedKeyIgnoringBAMOutputFormat.class; @SuppressWarnings("unchecked") HadoopIO.Write.Bound<NullWritable, SAMRecordWritable> write = HadoopIO.Write.to(destPath, outputFormatClass, NullWritable.class, SAMRecordWritable.class) .withConfigurationProperty( TemplatedKeyIgnoringBAMOutputFormat.SAM_HEADER_PROPERTY_NAME, headerString); PCollection<KV<NullWritable, SAMRecordWritable>> samReads = reads .apply( ParDo.of( new DoFn<GATKRead, KV<NullWritable, SAMRecordWritable>>() { private static final long serialVersionUID = 1L; @Override public void processElement(ProcessContext c) throws Exception { SAMRecord samRecord = c.element().convertToSAMRecord(header); SAMRecordWritable samRecordWritable = new SAMRecordWritable(); samRecordWritable.set(samRecord); c.output(KV.of(NullWritable.get(), samRecordWritable)); } })) .setCoder( KvCoder.of( WritableCoder.of(NullWritable.class), WritableCoder.of(SAMRecordWritable.class))); // write as a single (unsharded) file samReads.apply(write.withoutSharding()); }
@SuppressWarnings({"rawtypes", "unchecked"}) public static GroupAlsoByWindowsParDoFn create( PipelineOptions options, CloudObject cloudUserFn, String stepName, @Nullable List<SideInputInfo> sideInputInfos, @Nullable List<MultiOutputInfo> multiOutputInfos, Integer numOutputs, ExecutionContext executionContext, CounterSet.AddCounterMutator addCounterMutator, StateSampler sampler /* unused */) throws Exception { Object windowingStrategyObj; byte[] encodedWindowingStrategy = getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN); if (encodedWindowingStrategy.length == 0) { windowingStrategyObj = WindowingStrategy.globalDefault(); } else { windowingStrategyObj = SerializableUtils.deserializeFromByteArray( encodedWindowingStrategy, "serialized windowing strategy"); if (!(windowingStrategyObj instanceof WindowingStrategy)) { throw new Exception( "unexpected kind of WindowingStrategy: " + windowingStrategyObj.getClass().getName()); } } WindowingStrategy windowingStrategy = (WindowingStrategy) windowingStrategyObj; byte[] serializedCombineFn = getBytes(cloudUserFn, PropertyNames.COMBINE_FN, null); KeyedCombineFn combineFn; if (serializedCombineFn != null) { Object combineFnObj = SerializableUtils.deserializeFromByteArray(serializedCombineFn, "serialized combine fn"); if (!(combineFnObj instanceof KeyedCombineFn)) { throw new Exception( "unexpected kind of KeyedCombineFn: " + combineFnObj.getClass().getName()); } combineFn = (KeyedCombineFn) combineFnObj; } else { combineFn = null; } Map<String, Object> inputCoderObject = getObject(cloudUserFn, PropertyNames.INPUT_CODER); Coder inputCoder = Serializer.deserialize(inputCoderObject, Coder.class); if (!(inputCoder instanceof WindowedValueCoder)) { throw new Exception( "Expected WindowedValueCoder for inputCoder, got: " + inputCoder.getClass().getName()); } Coder elemCoder = ((WindowedValueCoder) inputCoder).getValueCoder(); if (!(elemCoder instanceof KvCoder)) { throw new Exception( "Expected KvCoder for inputCoder, got: " + elemCoder.getClass().getName()); } KvCoder kvCoder = (KvCoder) elemCoder; boolean isStreamingPipeline = false; if (options instanceof StreamingOptions) { isStreamingPipeline = ((StreamingOptions) options).isStreaming(); } KeyedCombineFn maybeMergingCombineFn = null; if (combineFn != null) { String phase = getString(cloudUserFn, PropertyNames.PHASE, CombinePhase.ALL); Preconditions.checkArgument( phase.equals(CombinePhase.ALL) || phase.equals(CombinePhase.MERGE), "Unexpected phase: " + phase); if (phase.equals(CombinePhase.MERGE)) { maybeMergingCombineFn = new MergingKeyedCombineFn(combineFn); } else { maybeMergingCombineFn = combineFn; } } DoFnInfoFactory fnFactory; final DoFn groupAlsoByWindowsDoFn = getGroupAlsoByWindowsDoFn( isStreamingPipeline, windowingStrategy, kvCoder, maybeMergingCombineFn); fnFactory = new DoFnInfoFactory() { @Override public DoFnInfo createDoFnInfo() { return new DoFnInfo(groupAlsoByWindowsDoFn, null); } }; return new GroupAlsoByWindowsParDoFn( options, fnFactory, stepName, executionContext, addCounterMutator); }