Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BEAM-12384] Set output typeDescriptor explictly in Read.Bounded transform #14854

Merged
merged 3 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 30 additions & 26 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
Expand Down Expand Up @@ -151,7 +150,8 @@ public final PCollection<T> expand(PBegin input) {
.apply(ParDo.of(new OutputSingleSource<>(source)))
.setCoder(SerializableCoder.of(new TypeDescriptor<BoundedSource<T>>() {}))
.apply(ParDo.of(new BoundedSourceAsSDFWrapperFn<>()))
.setCoder(source.getOutputCoder());
.setCoder(source.getOutputCoder())
.setTypeDescriptor(source.getOutputCoder().getEncodedTypeDescriptor());
Copy link
Contributor

@boyuanzz boyuanzz May 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we maintain the TypeDescriptor information before for Read? I was under impression that for most of cases we only set Coder for a output PCollection.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right and I don't know why we don't pay more attention to this. Probably because coders seem to include the TypeDescriptor, any ideas @kennknowles ? is this redundant somehow?

In any case having this information seems important for the downstream transforms.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the typeDescriptor can be inferred from Coder.getEncodedTypeDescriptor(). If we really want to populate this information in a consistent way, probably we can consider changing PCollection.getTypeDescriptor() to infer the typeDescriptor from Coder if the typeDescriptor is set.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the expected use of this method is to set the type descriptor but not the coder. This way, the coder registry still can choose the coder.

Setting both is redundant, in theory. Setting just the coder should suffice. Maybe some plumbing needed? It was not really expected to look at either one in this way.

Another angle to consider is that type descriptor is Java-specific, while coder is the portable "type" of the data. I don't know if that matters here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking about changes like: #14870

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like @boyuanzz fix because even in the presence of different Coders the TypeDescriptor is commonly preserved inside of the Coders. WDYT @kennknowles can you spot some particular issues about it?
I can rebase this PR targetting a generic implementation like the one on #14870 but I did not do it like that because I was not really familiar with the reasoning behind not relying on the coder typeDescriptor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea makes lots of sense.

}

/** Returns the {@code BoundedSource} used to create this {@code Read} {@code PTransform}. */
Expand Down Expand Up @@ -259,26 +259,26 @@ public void populateDisplayData(DisplayData.Builder builder) {
* <p>We model the element as the original source and the restriction as the sub-source. This
* allows us to split the sub-source over and over yet still receive "source" objects as inputs.
*/
static class BoundedSourceAsSDFWrapperFn<T> extends DoFn<BoundedSource<T>, T> {
static class BoundedSourceAsSDFWrapperFn<T, BoundedSourceT extends BoundedSource<T>>
extends DoFn<BoundedSourceT, T> {
private static final Logger LOG = LoggerFactory.getLogger(BoundedSourceAsSDFWrapperFn.class);
private static final long DEFAULT_DESIRED_BUNDLE_SIZE_BYTES = 64 * (1 << 20);

@GetInitialRestriction
public BoundedSource<T> initialRestriction(@Element BoundedSource<T> element) {
public BoundedSourceT initialRestriction(@Element BoundedSourceT element) {
return element;
}

@GetSize
public double getSize(
@Restriction BoundedSource<T> restriction, PipelineOptions pipelineOptions)
public double getSize(@Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions)
throws Exception {
return restriction.getEstimatedSizeBytes(pipelineOptions);
}

@SplitRestriction
public void splitRestriction(
@Restriction BoundedSource<T> restriction,
OutputReceiver<BoundedSource<T>> receiver,
@Restriction BoundedSourceT restriction,
OutputReceiver<BoundedSourceT> receiver,
PipelineOptions pipelineOptions)
throws Exception {
long estimatedSize = restriction.getEstimatedSizeBytes(pipelineOptions);
Expand All @@ -288,20 +288,22 @@ public void splitRestriction(
Math.min(
DEFAULT_DESIRED_BUNDLE_SIZE_BYTES,
Math.max(1L, estimatedSize / DEFAULT_DESIRED_NUM_SPLITS));
for (BoundedSource<T> split : restriction.split(splitBundleSize, pipelineOptions)) {
List<BoundedSourceT> splits =
(List<BoundedSourceT>) restriction.split(splitBundleSize, pipelineOptions);
for (BoundedSourceT split : splits) {
receiver.output(split);
}
}

@NewTracker
public RestrictionTracker<BoundedSource<T>, TimestampedValue<T>[]> restrictionTracker(
@Restriction BoundedSource<T> restriction, PipelineOptions pipelineOptions) {
public RestrictionTracker<BoundedSourceT, TimestampedValue<T>[]> restrictionTracker(
@Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) {
return new BoundedSourceAsSDFRestrictionTracker<>(restriction, pipelineOptions);
}

@ProcessElement
public void processElement(
RestrictionTracker<BoundedSource<T>, TimestampedValue<T>[]> tracker,
RestrictionTracker<BoundedSourceT, TimestampedValue<T>[]> tracker,
OutputReceiver<T> receiver)
throws IOException {
TimestampedValue<T>[] out = new TimestampedValue[1];
Expand All @@ -311,23 +313,24 @@ public void processElement(
}

@GetRestrictionCoder
public Coder<BoundedSource<T>> restrictionCoder() {
return SerializableCoder.of(new TypeDescriptor<BoundedSource<T>>() {});
public Coder<BoundedSourceT> restrictionCoder() {
return SerializableCoder.of(new TypeDescriptor<BoundedSourceT>() {});
}

/**
* A fake restriction tracker which adapts to the {@link BoundedSource} API. The restriction
* object is used to advance the underlying source and to "return" the current element.
*/
private static class BoundedSourceAsSDFRestrictionTracker<T>
extends RestrictionTracker<BoundedSource<T>, TimestampedValue<T>[]> {
private final BoundedSource<T> initialRestriction;
private static class BoundedSourceAsSDFRestrictionTracker<
BoundedSourceT extends BoundedSource<T>, T>
extends RestrictionTracker<BoundedSourceT, TimestampedValue<T>[]> {
private final BoundedSourceT initialRestriction;
private final PipelineOptions pipelineOptions;
private BoundedSource.BoundedReader<T> currentReader;
private boolean claimedAll;

BoundedSourceAsSDFRestrictionTracker(
BoundedSource<T> initialRestriction, PipelineOptions pipelineOptions) {
BoundedSourceT initialRestriction, PipelineOptions pipelineOptions) {
this.initialRestriction = initialRestriction;
this.pipelineOptions = pipelineOptions;
}
Expand Down Expand Up @@ -393,15 +396,15 @@ protected void finalize() throws Throwable {

/** The value is invalid if {@link #tryClaim} has ever thrown an exception. */
@Override
public BoundedSource<T> currentRestriction() {
public BoundedSourceT currentRestriction() {
if (currentReader == null) {
return initialRestriction;
}
return currentReader.getCurrentSource();
return (BoundedSourceT) currentReader.getCurrentSource();
}

@Override
public SplitResult<BoundedSource<T>> trySplit(double fractionOfRemainder) {
public SplitResult<BoundedSourceT> trySplit(double fractionOfRemainder) {
if (currentReader == null) {
return null;
}
Expand All @@ -416,7 +419,7 @@ public SplitResult<BoundedSource<T>> trySplit(double fractionOfRemainder) {
return null;
}
BoundedSource<T> primary = currentReader.getCurrentSource();
return SplitResult.of(primary, residual);
return (SplitResult<BoundedSourceT>) SplitResult.of(primary, residual);
}

@Override
Expand Down Expand Up @@ -990,15 +993,16 @@ public Progress getProgress() {
}
}

private static class OutputSingleSource<T extends HasDisplayData> extends DoFn<byte[], T> {
private final T source;
private static class OutputSingleSource<T, SourceT extends Source<T>>
extends DoFn<byte[], SourceT> {
private final SourceT source;

private OutputSingleSource(T source) {
private OutputSingleSource(SourceT source) {
this.source = source;
}

@ProcessElement
public void processElement(OutputReceiver<T> outputReceiver) {
public void processElement(OutputReceiver<SourceT> outputReceiver) {
outputReceiver.output(source);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ public void finishSpecifying(PInput input, PTransform<?, ?> transform) {
* override this to enable better {@code Coder} inference.
*/
public @Nullable TypeDescriptor<T> getTypeDescriptor() {
return typeDescriptor;
if (typeDescriptor != null) {
return typeDescriptor;
}
if (coderOrFailure.coder != null) {
return coderOrFailure.coder.getEncodedTypeDescriptor();
}
return null;
}

/**
Expand Down
43 changes: 43 additions & 0 deletions sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.InputStream;
Expand All @@ -41,6 +42,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.io.CountingSource.CounterMark;
Expand All @@ -63,6 +65,7 @@
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -149,6 +152,17 @@ public void populateDisplayData(DisplayData.Builder builder) {
assertThat(unboundedDisplayData, hasDisplayItem("maxReadTime", maxReadTime));
}

@Test
public void testReadBoundedPreservesTypeDescriptor() {
PCollection<String> input = pipeline.apply(Read.from(new SerializableBoundedSource()));
TypeDescriptor<String> typeDescriptor = input.getTypeDescriptor();
assertEquals(String.class, typeDescriptor.getType());
iemejia marked this conversation as resolved.
Show resolved Hide resolved

ListBoundedSource<Long> longs = new ListBoundedSource<>(VarLongCoder.of());
PCollection<List<Long>> numbers = pipeline.apply(Read.from(longs));
assertEquals(new TypeDescriptor<List<Long>>() {}, numbers.getTypeDescriptor());
}

@Test
@Category({
NeedsRunner.class,
Expand Down Expand Up @@ -261,6 +275,35 @@ public Coder<String> getOutputCoder() {
}
}

private static class ListBoundedSource<T> extends BoundedSource<List<T>> {
private Coder<T> coder;

ListBoundedSource(Coder<T> coder) {
this.coder = coder;
}

@Override
public List<? extends BoundedSource<List<T>>> split(
long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
return null;
}

@Override
public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
return 0;
}

@Override
public BoundedReader<List<T>> createReader(PipelineOptions options) throws IOException {
return null;
}

@Override
public Coder<List<T>> getOutputCoder() {
return ListCoder.of(coder);
}
}

private static class NotSerializableBoundedSource extends CustomBoundedSource {
@SuppressWarnings("unused")
private final NotSerializableClass notSerializableClass = new NotSerializableClass();
Expand Down