@Test
  public void testCreateNormalParDoFn() throws Exception {
    String stringState = "some state";
    long longState = 42L;

    TestDoFn fn = new TestDoFn(stringState, longState);

    String serializedFn =
        StringUtils.byteArrayToJsonString(
            SerializableUtils.serializeToByteArray(
                new DoFnInfo(fn, WindowingStrategy.globalDefault())));

    CloudObject cloudUserFn = CloudObject.forClassName("DoFn");
    addString(cloudUserFn, "serialized_fn", serializedFn);

    String tag = "output";
    MultiOutputInfo multiOutputInfo = new MultiOutputInfo();
    multiOutputInfo.setTag(tag);
    List<MultiOutputInfo> multiOutputInfos = Arrays.asList(multiOutputInfo);

    PipelineOptions options = PipelineOptionsFactory.create();
    DataflowExecutionContext context = BatchModeExecutionContext.fromOptions(options);
    CounterSet counters = new CounterSet();
    StateSampler stateSampler = new StateSampler("test", counters.getAddCounterMutator());
    ParDoFn parDoFn =
        factory.create(
            options,
            cloudUserFn,
            "name",
            "transformName",
            null,
            multiOutputInfos,
            1,
            context,
            counters.getAddCounterMutator(),
            stateSampler);

    // Test that the factory created the correct class
    assertThat(parDoFn, instanceOf(NormalParDoFn.class));

    // Test that the DoFnInfo reflects the one passed in
    NormalParDoFn normalParDoFn = (NormalParDoFn) parDoFn;
    DoFnInfo doFnInfo = normalParDoFn.getDoFnInfo();
    DoFn actualDoFn = doFnInfo.getDoFn();
    assertThat(actualDoFn, instanceOf(TestDoFn.class));
    assertThat(doFnInfo.getWindowingStrategy().getWindowFn(), instanceOf(GlobalWindows.class));
    assertThat(
        doFnInfo.getWindowingStrategy().getTrigger().getSpec(), instanceOf(DefaultTrigger.class));

    // Test that the deserialized user DoFn is as expected
    TestDoFn actualTestDoFn = (TestDoFn) actualDoFn;
    assertEquals(stringState, actualTestDoFn.stringState);
    assertEquals(longState, actualTestDoFn.longState);
    assertEquals(context, normalParDoFn.getExecutionContext());
  }
  public static CombineValuesFn create(
      PipelineOptions options,
      CloudObject cloudUserFn,
      String stepName,
      @Nullable List<SideInputInfo> sideInputInfos,
      @Nullable List<MultiOutputInfo> multiOutputInfos,
      Integer numOutputs,
      ExecutionContext executionContext,
      CounterSet.AddCounterMutator addCounterMutator,
      StateSampler stateSampler /* unused */)
      throws Exception {
    Object deserializedFn =
        SerializableUtils.deserializeFromByteArray(
            getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), "serialized user fn");
    Preconditions.checkArgument(deserializedFn instanceof Combine.KeyedCombineFn);
    final Combine.KeyedCombineFn combineFn = (Combine.KeyedCombineFn) deserializedFn;

    // Get the combine phase, default to ALL. (The implementation
    // doesn't have to split the combiner).
    final String phase = getString(cloudUserFn, PropertyNames.PHASE, CombinePhase.ALL);

    Preconditions.checkArgument(
        sideInputInfos == null || sideInputInfos.size() == 0,
        "unexpected side inputs for CombineValuesFn");
    Preconditions.checkArgument(numOutputs == 1, "expected exactly one output for CombineValuesFn");

    DoFnInfoFactory fnFactory =
        new DoFnInfoFactory() {
          @Override
          public DoFnInfo createDoFnInfo() {
            DoFn doFn = null;
            switch (phase) {
              case CombinePhase.ALL:
                doFn = new CombineValuesDoFn(combineFn);
                break;
              case CombinePhase.ADD:
                doFn = new AddInputsDoFn(combineFn);
                break;
              case CombinePhase.MERGE:
                doFn = new MergeAccumulatorsDoFn(combineFn);
                break;
              case CombinePhase.EXTRACT:
                doFn = new ExtractOutputDoFn(combineFn);
                break;
              default:
                throw new IllegalArgumentException(
                    "phase must be one of 'all', 'add', 'merge', 'extract'");
            }
            return new DoFnInfo(doFn, null);
          }
        };
    return new CombineValuesFn(options, fnFactory, stepName, executionContext, addCounterMutator);
  }
示例#3
0
 @Override
 public void setConf(Configuration conf) {
   this.conf = conf;
   if (conf != null) {
     String headerString = conf.get(SAM_HEADER_PROPERTY_NAME);
     if (headerString == null) {
       throw new IllegalStateException("SAM file header has not been set");
     }
     byte[] headerBytes = Base64.getDecoder().decode(headerString);
     setSAMHeader(
         (SAMFileHeader)
             SerializableUtils.deserializeFromByteArray(headerBytes, "SAMFileHeader"));
   }
 }
示例#4
0
  private static void writeToHadoop(
      Pipeline pipeline,
      PCollection<GATKRead> reads,
      final SAMFileHeader header,
      final String destPath,
      final boolean parquet) {
    if (destPath.equals("/dev/null")) {
      return;
    }

    String headerString =
        Base64.getEncoder().encodeToString(SerializableUtils.serializeToByteArray(header));

    @SuppressWarnings("unchecked")
    Class<? extends FileOutputFormat<NullWritable, SAMRecordWritable>> outputFormatClass =
        (Class<? extends FileOutputFormat<NullWritable, SAMRecordWritable>>)
            (Class<?>) TemplatedKeyIgnoringBAMOutputFormat.class;
    @SuppressWarnings("unchecked")
    HadoopIO.Write.Bound<NullWritable, SAMRecordWritable> write =
        HadoopIO.Write.to(destPath, outputFormatClass, NullWritable.class, SAMRecordWritable.class)
            .withConfigurationProperty(
                TemplatedKeyIgnoringBAMOutputFormat.SAM_HEADER_PROPERTY_NAME, headerString);

    PCollection<KV<NullWritable, SAMRecordWritable>> samReads =
        reads
            .apply(
                ParDo.of(
                    new DoFn<GATKRead, KV<NullWritable, SAMRecordWritable>>() {
                      private static final long serialVersionUID = 1L;

                      @Override
                      public void processElement(ProcessContext c) throws Exception {
                        SAMRecord samRecord = c.element().convertToSAMRecord(header);
                        SAMRecordWritable samRecordWritable = new SAMRecordWritable();
                        samRecordWritable.set(samRecord);
                        c.output(KV.of(NullWritable.get(), samRecordWritable));
                      }
                    }))
            .setCoder(
                KvCoder.of(
                    WritableCoder.of(NullWritable.class),
                    WritableCoder.of(SAMRecordWritable.class)));

    // write as a single (unsharded) file
    samReads.apply(write.withoutSharding());
  }
  @SuppressWarnings({"rawtypes", "unchecked"})
  public static GroupAlsoByWindowsParDoFn create(
      PipelineOptions options,
      CloudObject cloudUserFn,
      String stepName,
      @Nullable List<SideInputInfo> sideInputInfos,
      @Nullable List<MultiOutputInfo> multiOutputInfos,
      Integer numOutputs,
      ExecutionContext executionContext,
      CounterSet.AddCounterMutator addCounterMutator,
      StateSampler sampler /* unused */)
      throws Exception {
    Object windowingStrategyObj;
    byte[] encodedWindowingStrategy = getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN);
    if (encodedWindowingStrategy.length == 0) {
      windowingStrategyObj = WindowingStrategy.globalDefault();
    } else {
      windowingStrategyObj =
          SerializableUtils.deserializeFromByteArray(
              encodedWindowingStrategy, "serialized windowing strategy");
      if (!(windowingStrategyObj instanceof WindowingStrategy)) {
        throw new Exception(
            "unexpected kind of WindowingStrategy: " + windowingStrategyObj.getClass().getName());
      }
    }
    WindowingStrategy windowingStrategy = (WindowingStrategy) windowingStrategyObj;

    byte[] serializedCombineFn = getBytes(cloudUserFn, PropertyNames.COMBINE_FN, null);
    KeyedCombineFn combineFn;
    if (serializedCombineFn != null) {
      Object combineFnObj =
          SerializableUtils.deserializeFromByteArray(serializedCombineFn, "serialized combine fn");
      if (!(combineFnObj instanceof KeyedCombineFn)) {
        throw new Exception(
            "unexpected kind of KeyedCombineFn: " + combineFnObj.getClass().getName());
      }
      combineFn = (KeyedCombineFn) combineFnObj;
    } else {
      combineFn = null;
    }

    Map<String, Object> inputCoderObject = getObject(cloudUserFn, PropertyNames.INPUT_CODER);

    Coder inputCoder = Serializer.deserialize(inputCoderObject, Coder.class);
    if (!(inputCoder instanceof WindowedValueCoder)) {
      throw new Exception(
          "Expected WindowedValueCoder for inputCoder, got: " + inputCoder.getClass().getName());
    }
    Coder elemCoder = ((WindowedValueCoder) inputCoder).getValueCoder();
    if (!(elemCoder instanceof KvCoder)) {
      throw new Exception(
          "Expected KvCoder for inputCoder, got: " + elemCoder.getClass().getName());
    }
    KvCoder kvCoder = (KvCoder) elemCoder;

    boolean isStreamingPipeline = false;
    if (options instanceof StreamingOptions) {
      isStreamingPipeline = ((StreamingOptions) options).isStreaming();
    }

    KeyedCombineFn maybeMergingCombineFn = null;
    if (combineFn != null) {
      String phase = getString(cloudUserFn, PropertyNames.PHASE, CombinePhase.ALL);
      Preconditions.checkArgument(
          phase.equals(CombinePhase.ALL) || phase.equals(CombinePhase.MERGE),
          "Unexpected phase: " + phase);
      if (phase.equals(CombinePhase.MERGE)) {
        maybeMergingCombineFn = new MergingKeyedCombineFn(combineFn);
      } else {
        maybeMergingCombineFn = combineFn;
      }
    }

    DoFnInfoFactory fnFactory;
    final DoFn groupAlsoByWindowsDoFn =
        getGroupAlsoByWindowsDoFn(
            isStreamingPipeline, windowingStrategy, kvCoder, maybeMergingCombineFn);

    fnFactory =
        new DoFnInfoFactory() {
          @Override
          public DoFnInfo createDoFnInfo() {
            return new DoFnInfo(groupAlsoByWindowsDoFn, null);
          }
        };
    return new GroupAlsoByWindowsParDoFn(
        options, fnFactory, stepName, executionContext, addCounterMutator);
  }