@Test public void testCreateNormalParDoFn() throws Exception { String stringState = "some state"; long longState = 42L; TestDoFn fn = new TestDoFn(stringState, longState); String serializedFn = StringUtils.byteArrayToJsonString( SerializableUtils.serializeToByteArray( new DoFnInfo(fn, WindowingStrategy.globalDefault()))); CloudObject cloudUserFn = CloudObject.forClassName("DoFn"); addString(cloudUserFn, "serialized_fn", serializedFn); String tag = "output"; MultiOutputInfo multiOutputInfo = new MultiOutputInfo(); multiOutputInfo.setTag(tag); List<MultiOutputInfo> multiOutputInfos = Arrays.asList(multiOutputInfo); PipelineOptions options = PipelineOptionsFactory.create(); DataflowExecutionContext context = BatchModeExecutionContext.fromOptions(options); CounterSet counters = new CounterSet(); StateSampler stateSampler = new StateSampler("test", counters.getAddCounterMutator()); ParDoFn parDoFn = factory.create( options, cloudUserFn, "name", "transformName", null, multiOutputInfos, 1, context, counters.getAddCounterMutator(), stateSampler); // Test that the factory created the correct class assertThat(parDoFn, instanceOf(NormalParDoFn.class)); // Test that the DoFnInfo reflects the one passed in NormalParDoFn normalParDoFn = (NormalParDoFn) parDoFn; DoFnInfo doFnInfo = normalParDoFn.getDoFnInfo(); DoFn actualDoFn = doFnInfo.getDoFn(); assertThat(actualDoFn, instanceOf(TestDoFn.class)); assertThat(doFnInfo.getWindowingStrategy().getWindowFn(), instanceOf(GlobalWindows.class)); assertThat( doFnInfo.getWindowingStrategy().getTrigger().getSpec(), instanceOf(DefaultTrigger.class)); // Test that the deserialized user DoFn is as expected TestDoFn actualTestDoFn = (TestDoFn) actualDoFn; assertEquals(stringState, actualTestDoFn.stringState); assertEquals(longState, actualTestDoFn.longState); assertEquals(context, normalParDoFn.getExecutionContext()); }
private static void writeToHadoop( Pipeline pipeline, PCollection<GATKRead> reads, final SAMFileHeader header, final String destPath, final boolean parquet) { if (destPath.equals("/dev/null")) { return; } String headerString = Base64.getEncoder().encodeToString(SerializableUtils.serializeToByteArray(header)); @SuppressWarnings("unchecked") Class<? extends FileOutputFormat<NullWritable, SAMRecordWritable>> outputFormatClass = (Class<? extends FileOutputFormat<NullWritable, SAMRecordWritable>>) (Class<?>) TemplatedKeyIgnoringBAMOutputFormat.class; @SuppressWarnings("unchecked") HadoopIO.Write.Bound<NullWritable, SAMRecordWritable> write = HadoopIO.Write.to(destPath, outputFormatClass, NullWritable.class, SAMRecordWritable.class) .withConfigurationProperty( TemplatedKeyIgnoringBAMOutputFormat.SAM_HEADER_PROPERTY_NAME, headerString); PCollection<KV<NullWritable, SAMRecordWritable>> samReads = reads .apply( ParDo.of( new DoFn<GATKRead, KV<NullWritable, SAMRecordWritable>>() { private static final long serialVersionUID = 1L; @Override public void processElement(ProcessContext c) throws Exception { SAMRecord samRecord = c.element().convertToSAMRecord(header); SAMRecordWritable samRecordWritable = new SAMRecordWritable(); samRecordWritable.set(samRecord); c.output(KV.of(NullWritable.get(), samRecordWritable)); } })) .setCoder( KvCoder.of( WritableCoder.of(NullWritable.class), WritableCoder.of(SAMRecordWritable.class))); // write as a single (unsharded) file samReads.apply(write.withoutSharding()); }