diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 241fa1f8ee54..749baecd2eb0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -46,7 +46,6 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress; @@ -151,7 +150,8 @@ public final PCollection expand(PBegin input) { .apply(ParDo.of(new OutputSingleSource<>(source))) .setCoder(SerializableCoder.of(new TypeDescriptor>() {})) .apply(ParDo.of(new BoundedSourceAsSDFWrapperFn<>())) - .setCoder(source.getOutputCoder()); + .setCoder(source.getOutputCoder()) + .setTypeDescriptor(source.getOutputCoder().getEncodedTypeDescriptor()); } /** Returns the {@code BoundedSource} used to create this {@code Read} {@code PTransform}. */ @@ -259,26 +259,26 @@ public void populateDisplayData(DisplayData.Builder builder) { *

We model the element as the original source and the restriction as the sub-source. This * allows us to split the sub-source over and over yet still receive "source" objects as inputs. */ - static class BoundedSourceAsSDFWrapperFn extends DoFn, T> { + static class BoundedSourceAsSDFWrapperFn> + extends DoFn { private static final Logger LOG = LoggerFactory.getLogger(BoundedSourceAsSDFWrapperFn.class); private static final long DEFAULT_DESIRED_BUNDLE_SIZE_BYTES = 64 * (1 << 20); @GetInitialRestriction - public BoundedSource initialRestriction(@Element BoundedSource element) { + public BoundedSourceT initialRestriction(@Element BoundedSourceT element) { return element; } @GetSize - public double getSize( - @Restriction BoundedSource restriction, PipelineOptions pipelineOptions) + public double getSize(@Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) throws Exception { return restriction.getEstimatedSizeBytes(pipelineOptions); } @SplitRestriction public void splitRestriction( - @Restriction BoundedSource restriction, - OutputReceiver> receiver, + @Restriction BoundedSourceT restriction, + OutputReceiver receiver, PipelineOptions pipelineOptions) throws Exception { long estimatedSize = restriction.getEstimatedSizeBytes(pipelineOptions); @@ -288,20 +288,22 @@ public void splitRestriction( Math.min( DEFAULT_DESIRED_BUNDLE_SIZE_BYTES, Math.max(1L, estimatedSize / DEFAULT_DESIRED_NUM_SPLITS)); - for (BoundedSource split : restriction.split(splitBundleSize, pipelineOptions)) { + List splits = + (List) restriction.split(splitBundleSize, pipelineOptions); + for (BoundedSourceT split : splits) { receiver.output(split); } } @NewTracker - public RestrictionTracker, TimestampedValue[]> restrictionTracker( - @Restriction BoundedSource restriction, PipelineOptions pipelineOptions) { + public RestrictionTracker[]> restrictionTracker( + @Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) { return new BoundedSourceAsSDFRestrictionTracker<>(restriction, pipelineOptions); } @ProcessElement public void processElement( - RestrictionTracker, TimestampedValue[]> tracker, + RestrictionTracker[]> tracker, OutputReceiver receiver) throws IOException { TimestampedValue[] out = new TimestampedValue[1]; @@ -311,23 +313,24 @@ public void processElement( } @GetRestrictionCoder - public Coder> restrictionCoder() { - return SerializableCoder.of(new TypeDescriptor>() {}); + public Coder restrictionCoder() { + return SerializableCoder.of(new TypeDescriptor() {}); } /** * A fake restriction tracker which adapts to the {@link BoundedSource} API. The restriction * object is used to advance the underlying source and to "return" the current element. */ - private static class BoundedSourceAsSDFRestrictionTracker - extends RestrictionTracker, TimestampedValue[]> { - private final BoundedSource initialRestriction; + private static class BoundedSourceAsSDFRestrictionTracker< + BoundedSourceT extends BoundedSource, T> + extends RestrictionTracker[]> { + private final BoundedSourceT initialRestriction; private final PipelineOptions pipelineOptions; private BoundedSource.BoundedReader currentReader; private boolean claimedAll; BoundedSourceAsSDFRestrictionTracker( - BoundedSource initialRestriction, PipelineOptions pipelineOptions) { + BoundedSourceT initialRestriction, PipelineOptions pipelineOptions) { this.initialRestriction = initialRestriction; this.pipelineOptions = pipelineOptions; } @@ -393,15 +396,15 @@ protected void finalize() throws Throwable { /** The value is invalid if {@link #tryClaim} has ever thrown an exception. */ @Override - public BoundedSource currentRestriction() { + public BoundedSourceT currentRestriction() { if (currentReader == null) { return initialRestriction; } - return currentReader.getCurrentSource(); + return (BoundedSourceT) currentReader.getCurrentSource(); } @Override - public SplitResult> trySplit(double fractionOfRemainder) { + public SplitResult trySplit(double fractionOfRemainder) { if (currentReader == null) { return null; } @@ -416,7 +419,7 @@ public SplitResult> trySplit(double fractionOfRemainder) { return null; } BoundedSource primary = currentReader.getCurrentSource(); - return SplitResult.of(primary, residual); + return (SplitResult) SplitResult.of(primary, residual); } @Override @@ -990,15 +993,16 @@ public Progress getProgress() { } } - private static class OutputSingleSource extends DoFn { - private final T source; + private static class OutputSingleSource> + extends DoFn { + private final SourceT source; - private OutputSingleSource(T source) { + private OutputSingleSource(SourceT source) { this.source = source; } @ProcessElement - public void processElement(OutputReceiver outputReceiver) { + public void processElement(OutputReceiver outputReceiver) { outputReceiver.output(source); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java index a637951ff1d8..5e85eb8c4adb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java @@ -124,7 +124,13 @@ public void finishSpecifying(PInput input, PTransform transform) { * override this to enable better {@code Coder} inference. */ public @Nullable TypeDescriptor getTypeDescriptor() { - return typeDescriptor; + if (typeDescriptor != null) { + return typeDescriptor; + } + if (coderOrFailure.coder != null) { + return coderOrFailure.coder.getEncodedTypeDescriptor(); + } + return null; } /** diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java index 1f9031f0408e..8dd4a5f29c23 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java @@ -21,6 +21,7 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import java.io.IOException; import java.io.InputStream; @@ -41,6 +42,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.CountingSource.CounterMark; @@ -63,6 +65,7 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; @@ -149,6 +152,17 @@ public void populateDisplayData(DisplayData.Builder builder) { assertThat(unboundedDisplayData, hasDisplayItem("maxReadTime", maxReadTime)); } + @Test + public void testReadBoundedPreservesTypeDescriptor() { + PCollection input = pipeline.apply(Read.from(new SerializableBoundedSource())); + TypeDescriptor typeDescriptor = input.getTypeDescriptor(); + assertEquals(String.class, typeDescriptor.getType()); + + ListBoundedSource longs = new ListBoundedSource<>(VarLongCoder.of()); + PCollection> numbers = pipeline.apply(Read.from(longs)); + assertEquals(new TypeDescriptor>() {}, numbers.getTypeDescriptor()); + } + @Test @Category({ NeedsRunner.class, @@ -261,6 +275,35 @@ public Coder getOutputCoder() { } } + private static class ListBoundedSource extends BoundedSource> { + private Coder coder; + + ListBoundedSource(Coder coder) { + this.coder = coder; + } + + @Override + public List>> split( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + return null; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 0; + } + + @Override + public BoundedReader> createReader(PipelineOptions options) throws IOException { + return null; + } + + @Override + public Coder> getOutputCoder() { + return ListCoder.of(coder); + } + } + private static class NotSerializableBoundedSource extends CustomBoundedSource { @SuppressWarnings("unused") private final NotSerializableClass notSerializableClass = new NotSerializableClass();