Skip to content

Commit

Permalink
[BEAM-9464] Fix WithKeys to respect parameterized types
Browse files Browse the repository at this point in the history
  • Loading branch information
lukecwik committed Mar 6, 2020
1 parent 28d05d3 commit b5301d9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeDescriptors;

/**
* {@code WithKeys<K, V>} takes a {@code PCollection<V>}, and either a constant key of type {@code
Expand Down Expand Up @@ -74,17 +75,20 @@ public static <K, V> WithKeys<K, V> of(SerializableFunction<V, K> fn) {
*/
@SuppressWarnings("unchecked")
public static <K, V> WithKeys<K, V> of(@Nullable final K key) {
return new WithKeys<>(value -> key, (Class<K>) (key == null ? Void.class : key.getClass()));
return new WithKeys<>(
value -> key,
(TypeDescriptor<K>)
(key == null ? TypeDescriptors.voids() : TypeDescriptor.of(key.getClass())));
}

/////////////////////////////////////////////////////////////////////////////

private SerializableFunction<V, K> fn;
@CheckForNull private transient Class<K> keyClass;
@CheckForNull private transient TypeDescriptor<K> keyType;

private WithKeys(SerializableFunction<V, K> fn, Class<K> keyClass) {
private WithKeys(SerializableFunction<V, K> fn, TypeDescriptor<K> keyType) {
this.fn = fn;
this.keyClass = keyClass;
this.keyType = keyType;
}

/**
Expand All @@ -95,10 +99,7 @@ private WithKeys(SerializableFunction<V, K> fn, Class<K> keyClass) {
* PCollection}.
*/
public WithKeys<K, V> withKeyType(TypeDescriptor<K> keyType) {
// Safe cast
@SuppressWarnings("unchecked")
Class<K> rawType = (Class<K>) keyType.getRawType();
return new WithKeys<>(fn, rawType);
return new WithKeys<>(fn, keyType);
}

@Override
Expand All @@ -117,10 +118,10 @@ public KV<K, V> apply(V element) {
try {
Coder<K> keyCoder;
CoderRegistry coderRegistry = in.getPipeline().getCoderRegistry();
if (keyClass == null) {
if (keyType == null) {
keyCoder = coderRegistry.getOutputCoder(fn, in.getCoder());
} else {
keyCoder = coderRegistry.getCoder(TypeDescriptor.of(keyClass));
keyCoder = coderRegistry.getCoder(keyType);
}
// TODO: Remove when we can set the coder inference context.
result.setCoder(KvCoder.of(keyCoder, in.getCoder()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.junit.Assert.assertEquals;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.NeedsRunner;
Expand All @@ -28,6 +29,7 @@
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
Expand Down Expand Up @@ -142,6 +144,24 @@ public void withLambdaAndTypeDescriptorShouldSucceed() {
p.run();
}

@Test
@Category(NeedsRunner.class)
public void withLambdaAndParameterizedTypeDescriptorShouldSucceed() {

PCollection<String> values = p.apply(Create.of("1234", "3210"));
PCollection<KV<List<String>, String>> kvs =
values.apply(
WithKeys.of((SerializableFunction<String, List<String>>) Collections::singletonList)
.withKeyType(TypeDescriptors.lists(TypeDescriptors.strings())));

PAssert.that(kvs)
.containsInAnyOrder(
KV.of(Collections.singletonList("1234"), "1234"),
KV.of(Collections.singletonList("3210"), "3210"));

p.run();
}

@Test
@Category(NeedsRunner.class)
public void withLambdaAndNoTypeDescriptorShouldThrow() {
Expand Down

0 comments on commit b5301d9

Please sign in to comment.