diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java index 60bd4a9dc2a2..c5fe782c36e3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/WithKeys.java @@ -28,6 +28,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; /** * {@code WithKeys} takes a {@code PCollection}, and either a constant key of type {@code @@ -74,17 +75,20 @@ public static WithKeys of(SerializableFunction fn) { */ @SuppressWarnings("unchecked") public static WithKeys of(@Nullable final K key) { - return new WithKeys<>(value -> key, (Class) (key == null ? Void.class : key.getClass())); + return new WithKeys<>( + value -> key, + (TypeDescriptor) + (key == null ? TypeDescriptors.voids() : TypeDescriptor.of(key.getClass()))); } ///////////////////////////////////////////////////////////////////////////// private SerializableFunction fn; - @CheckForNull private transient Class keyClass; + @CheckForNull private transient TypeDescriptor keyType; - private WithKeys(SerializableFunction fn, Class keyClass) { + private WithKeys(SerializableFunction fn, TypeDescriptor keyType) { this.fn = fn; - this.keyClass = keyClass; + this.keyType = keyType; } /** @@ -95,10 +99,7 @@ private WithKeys(SerializableFunction fn, Class keyClass) { * PCollection}. */ public WithKeys withKeyType(TypeDescriptor keyType) { - // Safe cast - @SuppressWarnings("unchecked") - Class rawType = (Class) keyType.getRawType(); - return new WithKeys<>(fn, rawType); + return new WithKeys<>(fn, keyType); } @Override @@ -117,10 +118,10 @@ public KV apply(V element) { try { Coder keyCoder; CoderRegistry coderRegistry = in.getPipeline().getCoderRegistry(); - if (keyClass == null) { + if (keyType == null) { keyCoder = coderRegistry.getOutputCoder(fn, in.getCoder()); } else { - keyCoder = coderRegistry.getCoder(TypeDescriptor.of(keyClass)); + keyCoder = coderRegistry.getCoder(keyType); } // TODO: Remove when we can set the coder inference context. result.setCoder(KvCoder.of(keyCoder, in.getCoder())); diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java index 1baa3d475c42..5a8da194f4e7 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/WithKeysTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertEquals; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.testing.NeedsRunner; @@ -28,6 +29,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.junit.Rule; import org.junit.Test; import org.junit.experimental.categories.Category; @@ -142,6 +144,24 @@ public void withLambdaAndTypeDescriptorShouldSucceed() { p.run(); } + @Test + @Category(NeedsRunner.class) + public void withLambdaAndParameterizedTypeDescriptorShouldSucceed() { + + PCollection values = p.apply(Create.of("1234", "3210")); + PCollection, String>> kvs = + values.apply( + WithKeys.of((SerializableFunction>) Collections::singletonList) + .withKeyType(TypeDescriptors.lists(TypeDescriptors.strings()))); + + PAssert.that(kvs) + .containsInAnyOrder( + KV.of(Collections.singletonList("1234"), "1234"), + KV.of(Collections.singletonList("3210"), "3210")); + + p.run(); + } + @Test @Category(NeedsRunner.class) public void withLambdaAndNoTypeDescriptorShouldThrow() {