From 331c67c5f08a1cbfd73a1b9df5b268057c46a0bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Fri, 21 May 2021 09:35:49 +0200 Subject: [PATCH 1/3] [BEAM-12384] Refine generic types on Read.Bounded internals --- .../java/org/apache/beam/sdk/io/Read.java | 53 ++++++++++--------- 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 241fa1f8ee54..0e96fce277af 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -46,7 +46,6 @@ import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.transforms.display.DisplayData; -import org.apache.beam.sdk.transforms.display.HasDisplayData; import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress; @@ -259,26 +258,26 @@ public void populateDisplayData(DisplayData.Builder builder) { *

We model the element as the original source and the restriction as the sub-source. This * allows us to split the sub-source over and over yet still receive "source" objects as inputs. */ - static class BoundedSourceAsSDFWrapperFn extends DoFn, T> { + static class BoundedSourceAsSDFWrapperFn> + extends DoFn { private static final Logger LOG = LoggerFactory.getLogger(BoundedSourceAsSDFWrapperFn.class); private static final long DEFAULT_DESIRED_BUNDLE_SIZE_BYTES = 64 * (1 << 20); @GetInitialRestriction - public BoundedSource initialRestriction(@Element BoundedSource element) { + public BoundedSourceT initialRestriction(@Element BoundedSourceT element) { return element; } @GetSize - public double getSize( - @Restriction BoundedSource restriction, PipelineOptions pipelineOptions) + public double getSize(@Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) throws Exception { return restriction.getEstimatedSizeBytes(pipelineOptions); } @SplitRestriction public void splitRestriction( - @Restriction BoundedSource restriction, - OutputReceiver> receiver, + @Restriction BoundedSourceT restriction, + OutputReceiver receiver, PipelineOptions pipelineOptions) throws Exception { long estimatedSize = restriction.getEstimatedSizeBytes(pipelineOptions); @@ -288,20 +287,22 @@ public void splitRestriction( Math.min( DEFAULT_DESIRED_BUNDLE_SIZE_BYTES, Math.max(1L, estimatedSize / DEFAULT_DESIRED_NUM_SPLITS)); - for (BoundedSource split : restriction.split(splitBundleSize, pipelineOptions)) { + List splits = + (List) restriction.split(splitBundleSize, pipelineOptions); + for (BoundedSourceT split : splits) { receiver.output(split); } } @NewTracker - public RestrictionTracker, TimestampedValue[]> restrictionTracker( - @Restriction BoundedSource restriction, PipelineOptions pipelineOptions) { + public RestrictionTracker[]> restrictionTracker( + @Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) { return new BoundedSourceAsSDFRestrictionTracker<>(restriction, pipelineOptions); } @ProcessElement public void processElement( - RestrictionTracker, TimestampedValue[]> tracker, + RestrictionTracker[]> tracker, OutputReceiver receiver) throws IOException { TimestampedValue[] out = new TimestampedValue[1]; @@ -311,23 +312,24 @@ public void processElement( } @GetRestrictionCoder - public Coder> restrictionCoder() { - return SerializableCoder.of(new TypeDescriptor>() {}); + public Coder restrictionCoder() { + return SerializableCoder.of(new TypeDescriptor() {}); } /** * A fake restriction tracker which adapts to the {@link BoundedSource} API. The restriction * object is used to advance the underlying source and to "return" the current element. */ - private static class BoundedSourceAsSDFRestrictionTracker - extends RestrictionTracker, TimestampedValue[]> { - private final BoundedSource initialRestriction; + private static class BoundedSourceAsSDFRestrictionTracker< + BoundedSourceT extends BoundedSource, T> + extends RestrictionTracker[]> { + private final BoundedSourceT initialRestriction; private final PipelineOptions pipelineOptions; private BoundedSource.BoundedReader currentReader; private boolean claimedAll; BoundedSourceAsSDFRestrictionTracker( - BoundedSource initialRestriction, PipelineOptions pipelineOptions) { + BoundedSourceT initialRestriction, PipelineOptions pipelineOptions) { this.initialRestriction = initialRestriction; this.pipelineOptions = pipelineOptions; } @@ -393,15 +395,15 @@ protected void finalize() throws Throwable { /** The value is invalid if {@link #tryClaim} has ever thrown an exception. */ @Override - public BoundedSource currentRestriction() { + public BoundedSourceT currentRestriction() { if (currentReader == null) { return initialRestriction; } - return currentReader.getCurrentSource(); + return (BoundedSourceT) currentReader.getCurrentSource(); } @Override - public SplitResult> trySplit(double fractionOfRemainder) { + public SplitResult trySplit(double fractionOfRemainder) { if (currentReader == null) { return null; } @@ -416,7 +418,7 @@ public SplitResult> trySplit(double fractionOfRemainder) { return null; } BoundedSource primary = currentReader.getCurrentSource(); - return SplitResult.of(primary, residual); + return (SplitResult) SplitResult.of(primary, residual); } @Override @@ -990,15 +992,16 @@ public Progress getProgress() { } } - private static class OutputSingleSource extends DoFn { - private final T source; + private static class OutputSingleSource> + extends DoFn { + private final SourceT source; - private OutputSingleSource(T source) { + private OutputSingleSource(SourceT source) { this.source = source; } @ProcessElement - public void processElement(OutputReceiver outputReceiver) { + public void processElement(OutputReceiver outputReceiver) { outputReceiver.output(source); } From 3b705c1ee6c789611f12c0700f7a255f48256f44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Isma=C3=ABl=20Mej=C3=ADa?= Date: Fri, 21 May 2021 22:32:11 +0200 Subject: [PATCH 2/3] [BEAM-12384] Set output typeDescriptor explictly in Read.Bounded transform --- .../java/org/apache/beam/sdk/io/Read.java | 3 +- .../java/org/apache/beam/sdk/io/ReadTest.java | 43 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java index 0e96fce277af..749baecd2eb0 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java @@ -150,7 +150,8 @@ public final PCollection expand(PBegin input) { .apply(ParDo.of(new OutputSingleSource<>(source))) .setCoder(SerializableCoder.of(new TypeDescriptor>() {})) .apply(ParDo.of(new BoundedSourceAsSDFWrapperFn<>())) - .setCoder(source.getOutputCoder()); + .setCoder(source.getOutputCoder()) + .setTypeDescriptor(source.getOutputCoder().getEncodedTypeDescriptor()); } /** Returns the {@code BoundedSource} used to create this {@code Read} {@code PTransform}. */ diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java index 1f9031f0408e..8dd4a5f29c23 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java @@ -21,6 +21,7 @@ import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import java.io.IOException; import java.io.InputStream; @@ -41,6 +42,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.io.CountingSource.CounterMark; @@ -63,6 +65,7 @@ import org.apache.beam.sdk.transforms.windowing.GlobalWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; @@ -149,6 +152,17 @@ public void populateDisplayData(DisplayData.Builder builder) { assertThat(unboundedDisplayData, hasDisplayItem("maxReadTime", maxReadTime)); } + @Test + public void testReadBoundedPreservesTypeDescriptor() { + PCollection input = pipeline.apply(Read.from(new SerializableBoundedSource())); + TypeDescriptor typeDescriptor = input.getTypeDescriptor(); + assertEquals(String.class, typeDescriptor.getType()); + + ListBoundedSource longs = new ListBoundedSource<>(VarLongCoder.of()); + PCollection> numbers = pipeline.apply(Read.from(longs)); + assertEquals(new TypeDescriptor>() {}, numbers.getTypeDescriptor()); + } + @Test @Category({ NeedsRunner.class, @@ -261,6 +275,35 @@ public Coder getOutputCoder() { } } + private static class ListBoundedSource extends BoundedSource> { + private Coder coder; + + ListBoundedSource(Coder coder) { + this.coder = coder; + } + + @Override + public List>> split( + long desiredBundleSizeBytes, PipelineOptions options) throws Exception { + return null; + } + + @Override + public long getEstimatedSizeBytes(PipelineOptions options) throws Exception { + return 0; + } + + @Override + public BoundedReader> createReader(PipelineOptions options) throws IOException { + return null; + } + + @Override + public Coder> getOutputCoder() { + return ListCoder.of(coder); + } + } + private static class NotSerializableBoundedSource extends CustomBoundedSource { @SuppressWarnings("unused") private final NotSerializableClass notSerializableClass = new NotSerializableClass(); From 83bccf9c49056de775324e1c580391b59525397f Mon Sep 17 00:00:00 2001 From: Boyuan Zhang Date: Mon, 24 May 2021 11:01:59 -0700 Subject: [PATCH 3/3] [BEAM-12384] Infer typeDescriptor from coder if typeDescriptor is not set explicitly. --- .../main/java/org/apache/beam/sdk/values/PCollection.java | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java index a637951ff1d8..5e85eb8c4adb 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollection.java @@ -124,7 +124,13 @@ public void finishSpecifying(PInput input, PTransform transform) { * override this to enable better {@code Coder} inference. */ public @Nullable TypeDescriptor getTypeDescriptor() { - return typeDescriptor; + if (typeDescriptor != null) { + return typeDescriptor; + } + if (coderOrFailure.coder != null) { + return coderOrFailure.coder.getEncodedTypeDescriptor(); + } + return null; } /**