Skip to content

Commit

Permalink
Merge pull request #14854: [BEAM-12384] Set output typeDescriptor exp…
Browse files Browse the repository at this point in the history
…lictly in Read.Bounded transform
  • Loading branch information
iemejia authored May 28, 2021
2 parents 31988c8 + 83bccf9 commit b03e429
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 27 deletions.
56 changes: 30 additions & 26 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/Read.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.HasProgress;
Expand Down Expand Up @@ -151,7 +150,8 @@ public final PCollection<T> expand(PBegin input) {
.apply(ParDo.of(new OutputSingleSource<>(source)))
.setCoder(SerializableCoder.of(new TypeDescriptor<BoundedSource<T>>() {}))
.apply(ParDo.of(new BoundedSourceAsSDFWrapperFn<>()))
.setCoder(source.getOutputCoder());
.setCoder(source.getOutputCoder())
.setTypeDescriptor(source.getOutputCoder().getEncodedTypeDescriptor());
}

/** Returns the {@code BoundedSource} used to create this {@code Read} {@code PTransform}. */
Expand Down Expand Up @@ -259,26 +259,26 @@ public void populateDisplayData(DisplayData.Builder builder) {
* <p>We model the element as the original source and the restriction as the sub-source. This
* allows us to split the sub-source over and over yet still receive "source" objects as inputs.
*/
static class BoundedSourceAsSDFWrapperFn<T> extends DoFn<BoundedSource<T>, T> {
static class BoundedSourceAsSDFWrapperFn<T, BoundedSourceT extends BoundedSource<T>>
extends DoFn<BoundedSourceT, T> {
private static final Logger LOG = LoggerFactory.getLogger(BoundedSourceAsSDFWrapperFn.class);
private static final long DEFAULT_DESIRED_BUNDLE_SIZE_BYTES = 64 * (1 << 20);

@GetInitialRestriction
public BoundedSource<T> initialRestriction(@Element BoundedSource<T> element) {
public BoundedSourceT initialRestriction(@Element BoundedSourceT element) {
return element;
}

@GetSize
public double getSize(
@Restriction BoundedSource<T> restriction, PipelineOptions pipelineOptions)
public double getSize(@Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions)
throws Exception {
return restriction.getEstimatedSizeBytes(pipelineOptions);
}

@SplitRestriction
public void splitRestriction(
@Restriction BoundedSource<T> restriction,
OutputReceiver<BoundedSource<T>> receiver,
@Restriction BoundedSourceT restriction,
OutputReceiver<BoundedSourceT> receiver,
PipelineOptions pipelineOptions)
throws Exception {
long estimatedSize = restriction.getEstimatedSizeBytes(pipelineOptions);
Expand All @@ -288,20 +288,22 @@ public void splitRestriction(
Math.min(
DEFAULT_DESIRED_BUNDLE_SIZE_BYTES,
Math.max(1L, estimatedSize / DEFAULT_DESIRED_NUM_SPLITS));
for (BoundedSource<T> split : restriction.split(splitBundleSize, pipelineOptions)) {
List<BoundedSourceT> splits =
(List<BoundedSourceT>) restriction.split(splitBundleSize, pipelineOptions);
for (BoundedSourceT split : splits) {
receiver.output(split);
}
}

@NewTracker
public RestrictionTracker<BoundedSource<T>, TimestampedValue<T>[]> restrictionTracker(
@Restriction BoundedSource<T> restriction, PipelineOptions pipelineOptions) {
public RestrictionTracker<BoundedSourceT, TimestampedValue<T>[]> restrictionTracker(
@Restriction BoundedSourceT restriction, PipelineOptions pipelineOptions) {
return new BoundedSourceAsSDFRestrictionTracker<>(restriction, pipelineOptions);
}

@ProcessElement
public void processElement(
RestrictionTracker<BoundedSource<T>, TimestampedValue<T>[]> tracker,
RestrictionTracker<BoundedSourceT, TimestampedValue<T>[]> tracker,
OutputReceiver<T> receiver)
throws IOException {
TimestampedValue<T>[] out = new TimestampedValue[1];
Expand All @@ -311,23 +313,24 @@ public void processElement(
}

@GetRestrictionCoder
public Coder<BoundedSource<T>> restrictionCoder() {
return SerializableCoder.of(new TypeDescriptor<BoundedSource<T>>() {});
public Coder<BoundedSourceT> restrictionCoder() {
return SerializableCoder.of(new TypeDescriptor<BoundedSourceT>() {});
}

/**
* A fake restriction tracker which adapts to the {@link BoundedSource} API. The restriction
* object is used to advance the underlying source and to "return" the current element.
*/
private static class BoundedSourceAsSDFRestrictionTracker<T>
extends RestrictionTracker<BoundedSource<T>, TimestampedValue<T>[]> {
private final BoundedSource<T> initialRestriction;
private static class BoundedSourceAsSDFRestrictionTracker<
BoundedSourceT extends BoundedSource<T>, T>
extends RestrictionTracker<BoundedSourceT, TimestampedValue<T>[]> {
private final BoundedSourceT initialRestriction;
private final PipelineOptions pipelineOptions;
private BoundedSource.BoundedReader<T> currentReader;
private boolean claimedAll;

BoundedSourceAsSDFRestrictionTracker(
BoundedSource<T> initialRestriction, PipelineOptions pipelineOptions) {
BoundedSourceT initialRestriction, PipelineOptions pipelineOptions) {
this.initialRestriction = initialRestriction;
this.pipelineOptions = pipelineOptions;
}
Expand Down Expand Up @@ -393,15 +396,15 @@ protected void finalize() throws Throwable {

/** The value is invalid if {@link #tryClaim} has ever thrown an exception. */
@Override
public BoundedSource<T> currentRestriction() {
public BoundedSourceT currentRestriction() {
if (currentReader == null) {
return initialRestriction;
}
return currentReader.getCurrentSource();
return (BoundedSourceT) currentReader.getCurrentSource();
}

@Override
public SplitResult<BoundedSource<T>> trySplit(double fractionOfRemainder) {
public SplitResult<BoundedSourceT> trySplit(double fractionOfRemainder) {
if (currentReader == null) {
return null;
}
Expand All @@ -416,7 +419,7 @@ public SplitResult<BoundedSource<T>> trySplit(double fractionOfRemainder) {
return null;
}
BoundedSource<T> primary = currentReader.getCurrentSource();
return SplitResult.of(primary, residual);
return (SplitResult<BoundedSourceT>) SplitResult.of(primary, residual);
}

@Override
Expand Down Expand Up @@ -990,15 +993,16 @@ public Progress getProgress() {
}
}

private static class OutputSingleSource<T extends HasDisplayData> extends DoFn<byte[], T> {
private final T source;
private static class OutputSingleSource<T, SourceT extends Source<T>>
extends DoFn<byte[], SourceT> {
private final SourceT source;

private OutputSingleSource(T source) {
private OutputSingleSource(SourceT source) {
this.source = source;
}

@ProcessElement
public void processElement(OutputReceiver<T> outputReceiver) {
public void processElement(OutputReceiver<SourceT> outputReceiver) {
outputReceiver.output(source);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ public void finishSpecifying(PInput input, PTransform<?, ?> transform) {
* override this to enable better {@code Coder} inference.
*/
public @Nullable TypeDescriptor<T> getTypeDescriptor() {
return typeDescriptor;
if (typeDescriptor != null) {
return typeDescriptor;
}
if (coderOrFailure.coder != null) {
return coderOrFailure.coder.getEncodedTypeDescriptor();
}
return null;
}

/**
Expand Down
43 changes: 43 additions & 0 deletions sdks/java/core/src/test/java/org/apache/beam/sdk/io/ReadTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.includesDisplayDataFor;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;

import java.io.IOException;
import java.io.InputStream;
Expand All @@ -41,6 +42,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
import org.apache.beam.sdk.coders.CustomCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarLongCoder;
import org.apache.beam.sdk.io.CountingSource.CounterMark;
Expand All @@ -63,6 +65,7 @@
import org.apache.beam.sdk.transforms.windowing.GlobalWindows;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -149,6 +152,17 @@ public void populateDisplayData(DisplayData.Builder builder) {
assertThat(unboundedDisplayData, hasDisplayItem("maxReadTime", maxReadTime));
}

@Test
public void testReadBoundedPreservesTypeDescriptor() {
PCollection<String> input = pipeline.apply(Read.from(new SerializableBoundedSource()));
TypeDescriptor<String> typeDescriptor = input.getTypeDescriptor();
assertEquals(String.class, typeDescriptor.getType());

ListBoundedSource<Long> longs = new ListBoundedSource<>(VarLongCoder.of());
PCollection<List<Long>> numbers = pipeline.apply(Read.from(longs));
assertEquals(new TypeDescriptor<List<Long>>() {}, numbers.getTypeDescriptor());
}

@Test
@Category({
NeedsRunner.class,
Expand Down Expand Up @@ -261,6 +275,35 @@ public Coder<String> getOutputCoder() {
}
}

private static class ListBoundedSource<T> extends BoundedSource<List<T>> {
private Coder<T> coder;

ListBoundedSource(Coder<T> coder) {
this.coder = coder;
}

@Override
public List<? extends BoundedSource<List<T>>> split(
long desiredBundleSizeBytes, PipelineOptions options) throws Exception {
return null;
}

@Override
public long getEstimatedSizeBytes(PipelineOptions options) throws Exception {
return 0;
}

@Override
public BoundedReader<List<T>> createReader(PipelineOptions options) throws IOException {
return null;
}

@Override
public Coder<List<T>> getOutputCoder() {
return ListCoder.of(coder);
}
}

private static class NotSerializableBoundedSource extends CustomBoundedSource {
@SuppressWarnings("unused")
private final NotSerializableClass notSerializableClass = new NotSerializableClass();
Expand Down

0 comments on commit b03e429

Please sign in to comment.