From 38d9641aaf81522e57481d1b2005e7a556277ef9 Mon Sep 17 00:00:00 2001 From: Moritz Mack Date: Fri, 4 Mar 2022 20:11:51 +0100 Subject: [PATCH] [BEAM-13416] Introduce Schema provider for AWS models and deprecate low level coders --- .../java/io/amazon-web-services2/build.gradle | 1 + .../beam/sdk/io/aws2/coders/AwsCoders.java | 9 +- .../io/aws2/schemas/AwsBuilderFactory.java | 35 ++ .../io/aws2/schemas/AwsSchemaProvider.java | 219 +++++++++++ .../AwsSchemaRegistrar.java} | 19 +- .../sdk/io/aws2/schemas/AwsSchemaUtils.java | 130 +++++++ .../beam/sdk/io/aws2/schemas/AwsTypes.java | 297 ++++++++++++++ .../sdk/io/aws2/schemas/package-info.java | 24 ++ .../io/aws2/sns/PublishResponseCoders.java | 7 +- .../aws2/sns/SnsCoderProviderRegistrar.java | 39 -- .../apache/beam/sdk/io/aws2/sns/SnsIO.java | 19 +- .../sdk/io/aws2/sns/SnsResponseCoder.java | 7 +- .../beam/sdk/io/aws2/sqs/MessageCoder.java | 50 --- .../io/aws2/sqs/SendMessageRequestCoder.java | 51 --- .../aws2/schemas/AwsSchemaProviderTest.java | 361 ++++++++++++++++++ .../io/aws2/schemas/AwsSchemaUtilsTest.java} | 30 +- .../beam/sdk/io/aws2/schemas/Sample.java | 341 +++++++++++++++++ .../beam/sdk/io/aws2/sns/SnsIOTest.java | 2 + 18 files changed, 1467 insertions(+), 174 deletions(-) create mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsBuilderFactory.java create mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java rename sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/{sqs/MessageCoderRegistrar.java => schemas/AwsSchemaRegistrar.java} (59%) create mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java create mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java create mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/package-info.java delete mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsCoderProviderRegistrar.java delete mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java delete mode 100644 sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java create mode 100644 sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProviderTest.java rename sdks/java/io/amazon-web-services2/src/{main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java => test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtilsTest.java} (51%) create mode 100644 sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/Sample.java diff --git a/sdks/java/io/amazon-web-services2/build.gradle b/sdks/java/io/amazon-web-services2/build.gradle index 84672c408484..1c5d3dc82683 100644 --- a/sdks/java/io/amazon-web-services2/build.gradle +++ b/sdks/java/io/amazon-web-services2/build.gradle @@ -54,6 +54,7 @@ dependencies { implementation "software.amazon.kinesis:amazon-kinesis-client:2.3.4", excludeNetty implementation library.java.netty_all // force version of netty used by Beam permitUnusedDeclared library.java.netty_all + implementation library.java.byte_buddy implementation library.java.jackson_core implementation library.java.jackson_annotations implementation library.java.jackson_databind diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/coders/AwsCoders.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/coders/AwsCoders.java index 9797618301ce..a15818875c2f 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/coders/AwsCoders.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/coders/AwsCoders.java @@ -37,7 +37,14 @@ import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.utils.ImmutableMap; -/** {@link Coder}s for common AWS SDK objects. */ +/** + * {@link Coder}s for common AWS SDK objects. + * + * @deprecated {@link org.apache.beam.sdk.schemas.SchemaCoder SchemaCoders} for {@link + * software.amazon.awssdk.core.SdkPojo AWS model classes} will be automatically inferred by + * means of {@link org.apache.beam.sdk.io.aws2.schemas.AwsSchemaProvider AwsSchemaProvider}. + */ +@Deprecated public final class AwsCoders { private AwsCoders() {} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsBuilderFactory.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsBuilderFactory.java new file mode 100644 index 000000000000..53aa10ad6093 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsBuilderFactory.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.schemas; + +import java.io.Serializable; +import java.util.List; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.utils.builder.SdkBuilder; + +/** Builder factory for AWS {@link SdkPojo} to avoid using reflection to instantiate a builder. */ +public abstract class AwsBuilderFactory< + PojoT extends SdkPojo, BuilderT extends SdkBuilder & SdkPojo> + implements Serializable { + protected List> sdkFields() { + return get().sdkFields(); + } + + protected abstract BuilderT get(); +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java new file mode 100644 index 000000000000..7f8182e6136a --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProvider.java @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.schemas; + +import static java.util.function.Function.identity; +import static java.util.stream.Collectors.toMap; +import static org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.getter; +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets.difference; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets.newHashSet; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiConsumer; +import org.apache.beam.sdk.io.aws2.schemas.AwsSchemaUtils.SdkBuilderSetter; +import org.apache.beam.sdk.io.aws2.schemas.AwsTypes.ConverterFactory; +import org.apache.beam.sdk.schemas.CachingFactory; +import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueTypeInformation; +import org.apache.beam.sdk.schemas.GetterBasedSchemaProvider; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.SchemaUserTypeCreator; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.RowWithGetters; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.checkerframework.checker.nullness.qual.Nullable; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.utils.builder.SdkBuilder; + +/** + * Schema provider for AWS {@link SdkPojo} models using the provided field metadata (@see {@link + * SdkPojo#sdkFields()}) rather than reflection. + * + *

Note: Beam doesn't support self-referential schemas. Some AWS models are not compatible with + * schemas for that reason and require a dedicated coder, such as {@link + * software.amazon.awssdk.services.dynamodb.model.AttributeValue DynamoDB AttributeValue} ({@link + * org.apache.beam.sdk.io.aws2.dynamodb.AttributeValueCoder coder}). + */ +public class AwsSchemaProvider extends GetterBasedSchemaProvider { + /** Byte-code generated {@link SdkBuilder} factories. */ + @SuppressWarnings("rawtypes") // Crashes checker otherwise + private static final Map FACTORIES = Maps.newConcurrentMap(); + + @Override + public @Nullable Schema schemaFor(TypeDescriptor type) { + if (!SdkPojo.class.isAssignableFrom(type.getRawType())) { + return null; + } + return AwsTypes.schemaFor(sdkFields((Class) type.getRawType())); + } + + @SuppressWarnings("rawtypes") + @Override + public List fieldValueGetters(Class clazz, Schema schema) { + ConverterFactory fromAws = ConverterFactory.fromAws(); + Map> sdkFields = sdkFieldsByName((Class) clazz); + List getters = new ArrayList<>(schema.getFieldCount()); + for (String field : schema.getFieldNames()) { + SdkField sdkField = checkStateNotNull(sdkFields.get(field), "Unknown field"); + getters.add(getter(field, fromAws.create(sdkField::getValueOrDefault, sdkField))); + } + return getters; + } + + // Overriding `fromRowFunction` to instead use the generated builder factories with SDK provided + // setters from `SdkField`s. + @Override + public SerializableFunction fromRowFunction(TypeDescriptor type) { + checkState(SdkPojo.class.isAssignableFrom(type.getRawType()), "Unsupported type %s", type); + return FromRowFactory.create(type.getRawType()); + } + + private static class FromRowWithBuilder + implements SerializableFunction { + private final Class cls; + private final Factory> factory; + + FromRowWithBuilder(Class cls, Factory> factory) { + this.cls = cls; + this.factory = factory; + } + + @Override + @SuppressWarnings("nullness") // checker doesn't recognize the builder type + public T apply(Row row) { + if (row instanceof RowWithGetters) { + Object target = ((RowWithGetters) row).getGetterTarget(); + if (target.getClass().equals(cls)) { + return (T) target; // simply extract the underlying object instead of creating a new one. + } + } + SdkBuilder builder = sdkBuilder(cls); + List setters = factory.create(cls, row.getSchema()); + for (SdkBuilderSetter set : setters) { + if (!row.getSchema().hasField(set.name())) { + continue; + } + set.set(builder, row.getValue(set.name())); + } + return builder.build(); + } + + @Override + public boolean equals(@Nullable Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FromRowWithBuilder that = (FromRowWithBuilder) o; + return cls.equals(that.cls); + } + + @Override + public int hashCode() { + return Objects.hash(cls); + } + } + + private static class FromRowFactory implements Factory> { + @SuppressWarnings("nullness") // circular initialization + private final Factory> cachingFactory = new CachingFactory<>(this); + + private final Factory> settersFactory = + new CachingFactory<>(new SettersFactory()); + + @SuppressWarnings("nullness") // schema nullable for this factory + static SerializableFunction create(Class clazz) { + checkState(SdkPojo.class.isAssignableFrom(clazz), "Unsupported clazz %s", clazz); + return (SerializableFunction) new FromRowFactory().cachingFactory.create(clazz, null); + } + + @Override + public SerializableFunction create(Class clazz, Schema ignored) { + return new FromRowWithBuilder<>((Class) clazz, settersFactory); + } + + private class SettersFactory implements Factory> { + private final ConverterFactory toAws; + + private SettersFactory() { + this.toAws = ConverterFactory.toAws(cachingFactory); + } + + @Override + public List create(Class clazz, Schema schema) { + Map> fields = sdkFieldsByName((Class) clazz); + checkForUnknownFields(schema, fields); + + List setters = new ArrayList<>(schema.getFieldCount()); + for (Entry> entry : fields.entrySet()) { + SdkField sdkField = entry.getValue(); + BiConsumer, Object> setter = + toAws.needsConversion(sdkField) + ? ConverterFactory.createSetter(sdkField::set, toAws.create(sdkField)) + : sdkField::set; + setters.add(AwsSchemaUtils.setter(entry.getKey(), setter)); + } + return setters; + } + } + + private void checkForUnknownFields(Schema schema, Map> fields) { + Set unknowns = difference(newHashSet(schema.getFieldNames()), fields.keySet()); + checkState(unknowns.isEmpty(), "Row schema contains unknown fields: %s", unknowns); + } + } + + @Override + public List fieldValueTypeInformations(Class cls, Schema schema) { + throw new UnsupportedOperationException("FieldValueTypeInformation not available"); + } + + @Override + public SchemaUserTypeCreator schemaTypeCreator(Class cls, Schema schema) { + throw new UnsupportedOperationException("SchemaUserTypeCreator not available"); + } + + private static AwsBuilderFactory builderFactory(Class cls) { + return FACTORIES.computeIfAbsent(cls, c -> AwsSchemaUtils.builderFactory(cls)); + } + + private static List> sdkFields(Class cls) { + return builderFactory(cls).sdkFields(); + } + + private static SdkBuilder sdkBuilder(Class cls) { + return builderFactory(cls).get(); + } + + private static Map> sdkFieldsByName(Class cls) { + return sdkFields(cls).stream().collect(toMap(AwsTypes::normalizedNameOf, identity())); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoderRegistrar.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaRegistrar.java similarity index 59% rename from sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoderRegistrar.java rename to sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaRegistrar.java index 0b72338dc141..131a62eac46c 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoderRegistrar.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaRegistrar.java @@ -15,23 +15,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.aws2.sqs; +package org.apache.beam.sdk.io.aws2.schemas; import com.google.auto.service.AutoService; import java.util.List; -import org.apache.beam.sdk.coders.CoderProvider; -import org.apache.beam.sdk.coders.CoderProviderRegistrar; -import org.apache.beam.sdk.coders.CoderProviders; -import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.schemas.SchemaProvider; +import org.apache.beam.sdk.schemas.SchemaProviderRegistrar; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; -import software.amazon.awssdk.services.sqs.model.Message; -/** A {@link CoderProviderRegistrar} for standard types used with {@link SqsIO}. */ -@AutoService(CoderProviderRegistrar.class) -public class MessageCoderRegistrar implements CoderProviderRegistrar { +@AutoService(SchemaProviderRegistrar.class) +public class AwsSchemaRegistrar implements SchemaProviderRegistrar { @Override - public List getCoderProviders() { - return ImmutableList.of( - CoderProviders.forCoder(TypeDescriptor.of(Message.class), MessageCoder.of())); + public List getSchemaProviders() { + return ImmutableList.of(new AwsSchemaProvider()); } } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java new file mode 100644 index 000000000000..7aa23335b583 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtils.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.schemas; + +import static net.bytebuddy.matcher.ElementMatchers.isStatic; +import static net.bytebuddy.matcher.ElementMatchers.named; + +import java.util.function.BiConsumer; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.description.type.TypeDescription.ForLoadedType; +import net.bytebuddy.description.type.TypeDescription.Generic; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.MethodCall; +import org.apache.beam.sdk.schemas.FieldValueGetter; +import org.apache.beam.sdk.schemas.FieldValueSetter; +import org.apache.beam.sdk.schemas.utils.ByteBuddyUtils; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.checkerframework.checker.nullness.qual.Nullable; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.utils.builder.SdkBuilder; + +class AwsSchemaUtils { + private static final ByteBuddy BYTE_BUDDY = new ByteBuddy(); + private static final TypeDescription FACTORY_TYPE = ForLoadedType.of(AwsBuilderFactory.class); + + private AwsSchemaUtils() {} + + /** + * Generate an efficient implementation of {@link AwsBuilderFactory} for the given {@code clazz} + * in byte code avoiding reflective access. + */ + static & SdkPojo> + AwsBuilderFactory builderFactory(Class clazz) { + + Generic pojoType = new ForLoadedType(clazz).asGenericType(); + MethodDescription builderMethod = + pojoType.getDeclaredMethods().filter(named("builder").and(isStatic())).getOnly(); + Generic providerType = + Generic.Builder.parameterizedType(FACTORY_TYPE, pojoType, builderMethod.getReturnType()) + .build(); + + try { + return (AwsBuilderFactory) + BYTE_BUDDY + .with(new ByteBuddyUtils.InjectPackageStrategy(clazz)) + .subclass(providerType) + .method(named("get")) + .intercept(MethodCall.invoke(builderMethod)) + .make() + .load(ReflectHelpers.findClassLoader(), ClassLoadingStrategy.Default.INJECTION) + .getLoaded() + .getDeclaredConstructor() + .newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Unable to generate builder factory for " + clazz, e); + } + } + + static SdkBuilderSetter setter(String name, BiConsumer, Object> setter) { + return new ValueSetter(name, setter); + } + + static FieldValueGetter getter( + String name, SerializableFunction getter) { + return new ValueGetter<>(name, getter); + } + + interface SdkBuilderSetter extends FieldValueSetter, Object> {} + + private static class ValueSetter implements SdkBuilderSetter { + private final BiConsumer, Object> setter; + private final String name; + + ValueSetter(String name, BiConsumer, Object> setter) { + this.name = name; + this.setter = setter; + } + + @Override + public void set(SdkBuilder builder, @Nullable Object value) { + if (value != null) { + setter.accept(builder, value); // don't call setter if value is absent + } + } + + @Override + public String name() { + return name; + } + } + + private static class ValueGetter implements FieldValueGetter { + private final SerializableFunction getter; + private final String name; + + ValueGetter(String name, SerializableFunction getter) { + this.name = name; + this.getter = getter; + } + + @Override + @Nullable + public ValT get(ObjT object) { + return getter.apply(object); + } + + @Override + public String name() { + return name; + } + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java new file mode 100644 index 000000000000..229ad43854d5 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/AwsTypes.java @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.schemas; + +import static java.util.Collections.singleton; +import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; +import static software.amazon.awssdk.core.protocol.MarshallingType.INSTANT; +import static software.amazon.awssdk.core.protocol.MarshallingType.LIST; +import static software.amazon.awssdk.core.protocol.MarshallingType.MAP; +import static software.amazon.awssdk.core.protocol.MarshallingType.SDK_BYTES; +import static software.amazon.awssdk.core.protocol.MarshallingType.SDK_POJO; + +import java.io.Serializable; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import org.apache.beam.sdk.schemas.Factory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Ascii; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.joda.time.Instant; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.ListTrait; +import software.amazon.awssdk.core.traits.MapTrait; +import software.amazon.awssdk.core.util.DefaultSdkAutoConstructList; +import software.amazon.awssdk.core.util.DefaultSdkAutoConstructMap; +import software.amazon.awssdk.utils.ImmutableMap; + +public class AwsTypes { + // Mapping of simple AWS types to schema field types + private static final Map, FieldType> typeMapping = + ImmutableMap., FieldType>builder() + .put(MarshallingType.STRING, FieldType.STRING) + .put(MarshallingType.SHORT, FieldType.INT16) + .put(MarshallingType.INTEGER, FieldType.INT32) + .put(MarshallingType.LONG, FieldType.INT64) + .put(MarshallingType.FLOAT, FieldType.FLOAT) + .put(MarshallingType.DOUBLE, FieldType.DOUBLE) + .put(MarshallingType.BIG_DECIMAL, FieldType.DECIMAL) + .put(MarshallingType.BOOLEAN, FieldType.BOOLEAN) + .put(INSTANT, FieldType.DATETIME) + .put(SDK_BYTES, FieldType.BYTES) + .build(); + + private static FieldType fieldType(SdkField field, Set> seen) { + MarshallingType type = field.marshallingType(); + if (type == LIST) { + return FieldType.array(fieldType(elementField(field), seen)); + } else if (type == MAP) { + return FieldType.map(FieldType.STRING, fieldType(valueField(field), seen)); + } else if (type == SDK_POJO) { + SdkPojo builder = field.constructor().get(); + Class clazz = targetClassOf(builder); + checkState(!seen.contains(clazz), "Self-recursive types are not supported: %s", clazz); + return FieldType.row(schemaFor(builder.sdkFields(), Sets.union(seen, singleton(clazz)))); + } + FieldType fieldType = typeMapping.get(type); + if (fieldType != null) { + return fieldType; + } + throw new RuntimeException( + String.format("Type %s of field %s is unknown.", type, normalizedNameOf(field))); + } + + private static Schema schemaFor(List> fields, Set> seen) { + Schema.Builder builder = Schema.builder(); + for (SdkField sdkField : fields) { + // AWS SDK fields are all optional and marked as nullable + builder.addField(Field.nullable(normalizedNameOf(sdkField), fieldType(sdkField, seen))); + } + return builder.build(); + } + + static Schema schemaFor(List> fields) { + return schemaFor(fields, ImmutableSet.of()); + } + + /** + * Converter factory to handle specific AWS types. + * + *

Any occurrences of {@link java.time.Instant} or {@link SdkBytes} are converted to & from the + * corresponding Beam types. When used with {@link org.apache.beam.sdk.schemas.FieldValueSetter}, + * any {@link Row} has to be converted back to the respective {@link SdkPojo}. + */ + @SuppressWarnings("rawtypes") + abstract static class ConverterFactory implements Serializable { + @SuppressWarnings("nullness") + private static final SerializableFunction IDENTITY = x -> x; + + private final SerializableFunction instantConverter; + private final SerializableFunction bytesConverter; + private final boolean convertPojoType; + + private ConverterFactory( + SerializableFunction instantConverter, + SerializableFunction bytesConverter, + boolean convertPojoType) { + this.instantConverter = instantConverter; + this.bytesConverter = bytesConverter; + this.convertPojoType = convertPojoType; + } + + static ConverterFactory toAws(Factory> fromRowFactory) { + return new ToAws(fromRowFactory); + } + + static ConverterFactory fromAws() { + return FromAws.INSTANCE; + } + + static BiConsumer createSetter( + BiConsumer set, SerializableFunction fn) { + return (obj, value) -> set.accept(obj, ((SerializableFunction) fn).apply(value)); + } + + SerializableFunction pojoTypeConverter(SdkField field) { + throw new UnsupportedOperationException(); + } + + SerializableFunction create(SdkField field) { + return create(IDENTITY, field); + } + + SerializableFunction create(SerializableFunction fn, SdkField field) { + MarshallingType awsType = field.marshallingType(); + SerializableFunction converter; + if (awsType == SDK_POJO) { + converter = pojoTypeConverter(field); + } else if (awsType == INSTANT) { + converter = instantConverter; + } else if (awsType == SDK_BYTES) { + converter = bytesConverter; + } else if (awsType == LIST) { + converter = transformList(create(elementField(field))); + } else if (awsType == MAP) { + converter = transformMap(create(valueField(field))); + } else { + throw new IllegalStateException("Unexpected marshalling type " + awsType); + } + return fn != IDENTITY ? andThen(fn, nullSafe(converter)) : nullSafe(converter); + } + + boolean needsConversion(SdkField field) { + MarshallingType type = field.marshallingType(); + return (convertPojoType && type.equals(MarshallingType.SDK_POJO)) + || type.equals(INSTANT) + || type.equals(SDK_BYTES) + || (type.equals(MAP) && needsConversion(valueField(field))) + || (type.equals(LIST) && needsConversion(elementField(field))); + } + + private static SerializableFunction andThen( + SerializableFunction fn1, SerializableFunction fn2) { + return v -> fn2.apply(fn1.apply(v)); + } + + @SuppressWarnings("nullness") + private static SerializableFunction nullSafe(SerializableFunction fn) { + return v -> v == null ? null : fn.apply(v); + } + + @SuppressWarnings("nullness") + private static SerializableFunction transformList(SerializableFunction fn) { + return list -> Lists.transform((List) list, fn::apply); + } + + @SuppressWarnings("nullness") + private static SerializableFunction transformMap(SerializableFunction fn) { + return map -> Maps.transformValues((Map) map, fn::apply); + } + + /** Converter factory from Beam row value types to AWS types. This is applicable for setters. */ + private static class ToAws extends ConverterFactory { + private final Factory> fromRowFactory; + + ToAws(Factory> fromRowFactory) { + super(AwsTypes::toJavaInstant, AwsTypes::toSdkBytes, true); + this.fromRowFactory = fromRowFactory; + } + + @Override + @SuppressWarnings("nullness") // schema nullable for this factory + protected SerializableFunction pojoTypeConverter(SdkField field) { + return fromRowFactory.create(targetClassOf(field.constructor().get()), null); + } + } + + /** + * Converter factory from AWS types to Beam raw unmodified row types. This is applicable for + * getters and also removes default values for lists & maps to avoid serializing those. + */ + private static class FromAws extends ConverterFactory { + private static final ConverterFactory INSTANCE = new FromAws(); + + FromAws() { + super(AwsTypes::toJodaInstant, AwsTypes::toBytes, false); + } + + @Override + SerializableFunction create(SerializableFunction fn, SdkField field) { + MarshallingType type = field.marshallingType(); + if (type.equals(MAP)) { + fn = skipDefaultMap(fn); + } else if (type.equals(LIST)) { + fn = skipDefaultList(fn); + } + return needsConversion(field) ? super.create(fn, field) : fn; + } + + @SuppressWarnings("nullness") + private static SerializableFunction skipDefaultList(SerializableFunction fn) { + return in -> { + Object list = fn.apply(in); + return list != DefaultSdkAutoConstructList.getInstance() ? list : null; + }; + } + + @SuppressWarnings("nullness") + private static SerializableFunction skipDefaultMap(SerializableFunction fn) { + return in -> { + Object map = fn.apply(in); + return map != DefaultSdkAutoConstructMap.getInstance() ? map : null; + }; + } + } + } + + // Convert upper camel SDK field names to lower camel + static String normalizedNameOf(SdkField field) { + String name = field.memberName(); + return name.length() > 1 && Ascii.isLowerCase(name.charAt(1)) + ? Ascii.toLowerCase(name.charAt(0)) + name.substring(1) + : name.toLowerCase(Locale.ROOT); + } + + static java.time.Instant toJavaInstant(Object instant) { + return java.time.Instant.ofEpochMilli(((Instant) instant).getMillis()); + } + + private static Instant toJodaInstant(Object instant) { + return Instant.ofEpochMilli(((java.time.Instant) instant).toEpochMilli()); + } + + private static SdkBytes toSdkBytes(Object sdkBytes) { + // Unsafe operation, wrapping bytes directly as done by core for byte arrays / buffers + return SdkBytes.fromByteArrayUnsafe((byte[]) sdkBytes); + } + + private static byte[] toBytes(Object sdkBytes) { + // Unsafe operation, exposing bytes directly as done by core for byte arrays / buffers + return ((SdkBytes) sdkBytes).asByteArrayUnsafe(); + } + + private static SdkField elementField(SdkField field) { + return field.getTrait(ListTrait.class).memberFieldInfo(); + } + + private static SdkField valueField(SdkField field) { + return field.getTrait(MapTrait.class).valueFieldInfo(); + } + + private static Class targetClassOf(SdkPojo builder) { + // the declaring class is the class this builder produces + return checkArgumentNotNull( + builder.getClass().getDeclaringClass(), + "Expected nested builder class, but got %s", + builder.getClass()); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/package-info.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/package-info.java new file mode 100644 index 000000000000..1d6465f5cb85 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/schemas/package-info.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Schemas for AWS model classes. */ +@Experimental(Kind.SCHEMAS) +package org.apache.beam.sdk.io.aws2.schemas; + +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/PublishResponseCoders.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/PublishResponseCoders.java index 016c5db3a4dc..57acd46f6151 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/PublishResponseCoders.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/PublishResponseCoders.java @@ -34,7 +34,12 @@ import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.services.sns.model.PublishResponse; -/** Coders for SNS {@link PublishResponse}. */ +/** + * Coders for SNS {@link PublishResponse}. + * + * @deprecated Schema based coder is inferred automatically. + */ +@Deprecated public class PublishResponseCoders { private static final Coder MESSAGE_ID_CODER = StringUtf8Coder.of(); private static final NullableCoder METADATA_CODER = diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsCoderProviderRegistrar.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsCoderProviderRegistrar.java deleted file mode 100644 index 6ce0d1e5660f..000000000000 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsCoderProviderRegistrar.java +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws2.sns; - -import com.google.auto.service.AutoService; -import java.util.List; -import org.apache.beam.sdk.coders.CoderProvider; -import org.apache.beam.sdk.coders.CoderProviderRegistrar; -import org.apache.beam.sdk.coders.CoderProviders; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; -import software.amazon.awssdk.services.sns.model.PublishResponse; - -/** A {@link CoderProviderRegistrar} for standard types used with {@link SnsIO}. */ -@AutoService(CoderProviderRegistrar.class) -public class SnsCoderProviderRegistrar implements CoderProviderRegistrar { - @Override - public List getCoderProviders() { - return ImmutableList.of( - CoderProviders.forCoder( - TypeDescriptor.of(PublishResponse.class), - PublishResponseCoders.defaultPublishResponse())); - } -} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsIO.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsIO.java index ee56200d9cd0..c3019dfaa297 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsIO.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsIO.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.io.aws2.common.ClientBuilderFactory; import org.apache.beam.sdk.io.aws2.common.ClientConfiguration; import org.apache.beam.sdk.io.aws2.options.AwsOptions; +import org.apache.beam.sdk.io.aws2.schemas.AwsSchemaProvider; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; @@ -281,22 +282,34 @@ public Write withRetryConfiguration(RetryConfiguration retry) { } /** - * Encode the full {@code PublishResult} object, including sdkResponseMetadata and + * Encode the full {@link PublishResponse} object, including sdkResponseMetadata and * sdkHttpMetadata with the HTTP response headers. + * + * @deprecated Writes fail exceptionally in case of errors, there is no need to check headers. */ + @Deprecated public Write withFullPublishResponse() { return withCoder(PublishResponseCoders.fullPublishResponse()); } /** - * Encode the full {@code PublishResult} object, including sdkResponseMetadata and + * Encode the full {@link PublishResponse} object, including sdkResponseMetadata and * sdkHttpMetadata but excluding the HTTP response headers. + * + * @deprecated Writes fail exceptionally in case of errors, there is no need to check headers. */ + @Deprecated public Write withFullPublishResponseWithoutHeaders() { return withCoder(PublishResponseCoders.fullPublishResponseWithoutHeaders()); } - /** Encode the {@code PublishResult} with the given coder. */ + /** + * Encode the {@link PublishResponse} with the given coder. + * + * @deprecated Explicit usage of coders is deprecated. Inferred schemas provided by {@link + * AwsSchemaProvider} will be used instead. + */ + @Deprecated public Write withCoder(Coder coder) { return builder().setCoder(coder).build(); } diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsResponseCoder.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsResponseCoder.java index 3fe979c6fe32..88d40f74edca 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsResponseCoder.java +++ b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sns/SnsResponseCoder.java @@ -30,7 +30,12 @@ import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; -/** Custom Coder for WrappedSnsResponse. */ +/** + * Custom Coder for WrappedSnsResponse. + * + * @deprecated Coder of deprecated {@link SnsResponse}. + */ +@Deprecated @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java deleted file mode 100644 index 4cf5da34a0fd..000000000000 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/MessageCoder.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws2.sqs; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.Serializable; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import software.amazon.awssdk.services.sqs.model.Message; - -/** Custom Coder for handling SendMessageRequest for using in Write. */ -public class MessageCoder extends AtomicCoder implements Serializable { - private static final MessageCoder INSTANCE = new MessageCoder(); - - private MessageCoder() {} - - static MessageCoder of() { - return INSTANCE; - } - - @Override - public void encode(Message value, OutputStream outStream) throws IOException { - StringUtf8Coder.of().encode(value.messageId(), outStream); - StringUtf8Coder.of().encode(value.body(), outStream); - } - - @Override - public Message decode(InputStream inStream) throws IOException { - final String messageId = StringUtf8Coder.of().decode(inStream); - final String body = StringUtf8Coder.of().decode(inStream); - return Message.builder().messageId(messageId).body(body).build(); - } -} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java b/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java deleted file mode 100644 index e8c0283317bc..000000000000 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoder.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.sdk.io.aws2.sqs; - -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.io.Serializable; -import org.apache.beam.sdk.coders.AtomicCoder; -import org.apache.beam.sdk.coders.StringUtf8Coder; -import software.amazon.awssdk.services.sqs.model.SendMessageRequest; - -/** Custom Coder for handling SendMessageRequest for using in Write. */ -public class SendMessageRequestCoder extends AtomicCoder - implements Serializable { - private static final SendMessageRequestCoder INSTANCE = new SendMessageRequestCoder(); - - private SendMessageRequestCoder() {} - - static SendMessageRequestCoder of() { - return INSTANCE; - } - - @Override - public void encode(SendMessageRequest value, OutputStream outStream) throws IOException { - StringUtf8Coder.of().encode(value.queueUrl(), outStream); - StringUtf8Coder.of().encode(value.messageBody(), outStream); - } - - @Override - public SendMessageRequest decode(InputStream inStream) throws IOException { - final String queueUrl = StringUtf8Coder.of().decode(inStream); - final String message = StringUtf8Coder.of().decode(inStream); - return SendMessageRequest.builder().queueUrl(queueUrl).messageBody(message).build(); - } -} diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProviderTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProviderTest.java new file mode 100644 index 000000000000..71767c915fe7 --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaProviderTest.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.schemas; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.beam.sdk.schemas.Schema.TypeName.ARRAY; +import static org.apache.beam.sdk.util.CoderUtils.decodeFromByteArray; +import static org.apache.beam.sdk.util.CoderUtils.encodeToByteArray; +import static org.apache.beam.sdk.util.SerializableUtils.ensureSerializableRoundTrip; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.math.BigDecimal; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.function.Function; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.utils.SchemaTestUtils; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.assertj.core.api.Condition; +import org.assertj.core.api.SoftAssertions; +import org.junit.Test; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.dynamodb.model.AttributeValue; +import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; + +public class AwsSchemaProviderTest { + private final SchemaRegistry registry = SchemaRegistry.createDefault(); + private final Instant now = Instant.now().truncatedTo(ChronoUnit.MINUTES); + + private interface Schemas { + Schema MESSAGE_ATTRIBUTES = + Schema.builder() + .addNullableField("stringValue", FieldType.STRING) + .addNullableField("binaryValue", FieldType.BYTES) + .addNullableField("stringListValues", FieldType.array(FieldType.STRING)) + .addNullableField("binaryListValues", FieldType.array(FieldType.BYTES)) + .addNullableField("dataType", FieldType.STRING) + .build(); + + Schema SEND_MESSAGE_REQUEST = + Schema.builder() + .addNullableField("queueUrl", FieldType.STRING) + .addNullableField("messageBody", FieldType.STRING) + .addNullableField("delaySeconds", FieldType.INT32) + .addNullableField( + "messageAttributes", + FieldType.map(FieldType.STRING, FieldType.row(MESSAGE_ATTRIBUTES))) + .addNullableField( + "messageSystemAttributes", + FieldType.map(FieldType.STRING, FieldType.row(MESSAGE_ATTRIBUTES))) + .addNullableField("messageDeduplicationId", FieldType.STRING) + .addNullableField("messageGroupId", FieldType.STRING) + .build(); + + Schema SAMPLE = + Schema.builder() + .addNullableField("string", FieldType.STRING) + .addNullableField("short", FieldType.INT16) + .addNullableField("integer", FieldType.INT32) + .addNullableField("long", FieldType.INT64) + .addNullableField("float", FieldType.FLOAT) + .addNullableField("double", FieldType.DOUBLE) + .addNullableField("decimal", FieldType.DECIMAL) + .addNullableField("boolean", FieldType.BOOLEAN) + .addNullableField("instant", FieldType.DATETIME) + .addNullableField("bytes", FieldType.BYTES) + .addNullableField("list", FieldType.array(FieldType.STRING)) + .addNullableField("map", FieldType.map(FieldType.STRING, FieldType.STRING)) + .build(); + } + + @Test + public void testSampleSchema() throws NoSuchSchemaException { + Schema schema = registry.getSchema(Sample.class); + SchemaTestUtils.assertSchemaEquivalent(Schemas.SAMPLE, schema); + } + + @Test + public void testAwsExampleSchema() throws NoSuchSchemaException { + Schema schema = registry.getSchema(SendMessageRequest.class); + SchemaTestUtils.assertSchemaEquivalent(Schemas.SEND_MESSAGE_REQUEST, schema); + } + + @Test + public void testRecursiveSchema() { + assertThatThrownBy(() -> registry.getSchema(AttributeValue.class)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Self-recursive types are not supported: " + AttributeValue.class); + } + + @Test + public void testToRowSerializable() throws NoSuchSchemaException { + ensureSerializableRoundTrip(registry.getToRowFunction(Sample.class)); + ensureSerializableRoundTrip(registry.getToRowFunction(SendMessageRequest.class)); + } + + @Test + public void testFromRowSerializable() throws NoSuchSchemaException { + ensureSerializableRoundTrip(registry.getFromRowFunction(Sample.class)); + ensureSerializableRoundTrip(registry.getFromRowFunction(SendMessageRequest.class)); + } + + @Test + public void testSampleToRow() throws NoSuchSchemaException { + Sample sample = + Sample.builder() + .stringField("0") + .shortField((short) 1) + .integerField(2) + .longField(3L) + .floatField((float) 4.0) + .doubleField(5.0) + .decimalField(BigDecimal.valueOf(6L)) + .instantField(now) + .bytesField(SdkBytes.fromUtf8String("7")) + .listField(ImmutableList.of("8")) + .mapField(ImmutableMap.of("9", "9")) + .build(); + + Row row = registry.getToRowFunction(Sample.class).apply(sample); + + assertThat(row) + .has(field("string", "0")) + .has(field("short", (short) 1)) + .has(field("integer", 2)) + .has(field("long", 3L)) + .has(field("float", (float) 4.0)) + .has(field("double", 5.0)) + .has(field("decimal", BigDecimal.valueOf(6L))) + .has(field("instant", org.joda.time.Instant.ofEpochMilli(now.toEpochMilli()))) + .has(field("bytes", "7".getBytes(UTF_8))) + .has(field("list", ImmutableList.of("8"))) + .has(field("map", ImmutableMap.of("9", "9"))); + } + + @Test + public void testAwsExampleToRow() throws NoSuchSchemaException { + SendMessageRequest request = + SendMessageRequest.builder() + .queueUrl("queue") + .messageBody("body") + .delaySeconds(100) + .messageDeduplicationId("dedupId") + .messageGroupId("groupId") + .messageAttributes( + ImmutableMap.of( + "string", + attribute(b -> b.stringValue("v").dataType("String")), + "binary", + attribute(b -> b.binaryValue(sdkBytes("v")).dataType("Binary")), + "stringList", + attribute(b -> b.stringListValues("v1", "v2")), + "binaryList", + attribute(b -> b.binaryListValues(sdkBytes("v1"), sdkBytes("v2"))))) + .build(); + + Row row = registry.getToRowFunction(SendMessageRequest.class).apply(request); + + assertThat(row) + .has(field("queueUrl", "queue")) + .has(field("messageBody", "body")) + .has(field("delaySeconds", 100)) + .has(field("messageDeduplicationId", "dedupId")) + .has(field("messageGroupId", "groupId")); + + assertThat((Row) row.getMap("messageAttributes").get("string")) + .has(field("dataType", "String")) + .has(field("stringValue", "v")) + .has(field("binaryValue", null)) + .has(field("stringListValues", null)) + .has(field("binaryListValues", null)); + + assertThat((Row) row.getMap("messageAttributes").get("binary")) + .has(field("dataType", "Binary")) + .has(field("stringValue", null)) + .has(field("binaryValue", bytes("v"))) + .has(field("stringListValues", null)) + .has(field("binaryListValues", null)); + + assertThat((Row) row.getMap("messageAttributes").get("stringList")) + .has(field("dataType", null)) + .has(field("stringValue", null)) + .has(field("binaryValue", null)) + .has(field("stringListValues", ImmutableList.of("v1", "v2"))) + .has(field("binaryListValues", null)); + + assertThat((Row) row.getMap("messageAttributes").get("binaryList")) + .has(field("dataType", null)) + .has(field("stringValue", null)) + .has(field("binaryValue", null)) + .has(field("stringListValues", null)) + .has(field("binaryListValues", ImmutableList.of(bytes("v1"), bytes("v2")))); + } + + @Test + public void testSampleFromRow() throws NoSuchSchemaException, CoderException { + Sample sample = + Sample.builder() + .stringField("0") + .shortField((short) 1) + .integerField(2) + .longField(3L) + .floatField((float) 4.0) + .doubleField(5.0) + .decimalField(BigDecimal.valueOf(6L)) + .instantField(now) + .bytesField(SdkBytes.fromUtf8String("7")) + .listField(ImmutableList.of("8")) + .mapField(ImmutableMap.of("9", "9")) + .build(); + + SchemaCoder coder = registry.getSchemaCoder(Sample.class); + + Row row = coder.getToRowFunction().apply(sample); + assertThat(coder.getFromRowFunction().apply(row)).isEqualTo(sample); + + byte[] sampleBytes = encodeToByteArray(coder, sample); + Sample sampleFromBytes = decodeFromByteArray(coder, sampleBytes); + assertThat(sampleFromBytes).isEqualTo(sample); + + // verify still serializable after use + ensureSerializableRoundTrip(coder.getToRowFunction()); + ensureSerializableRoundTrip(coder.getFromRowFunction()); + } + + @Test + public void testAwsExampleFromRow() throws NoSuchSchemaException, CoderException { + SendMessageRequest request = + SendMessageRequest.builder() + .queueUrl("queue") + .messageBody("body") + .delaySeconds(100) + .messageDeduplicationId("dedupId") + .messageGroupId("groupId") + .messageAttributes( + ImmutableMap.of( + "string", + attribute(b -> b.stringValue("v").dataType("String")), + "binary", + attribute(b -> b.binaryValue(sdkBytes("v")).dataType("Binary")), + "stringList", + attribute(b -> b.stringListValues("v1", "v2")), + "binaryList", + attribute(b -> b.binaryListValues(sdkBytes("v1"), sdkBytes("v2"))))) + .build(); + + SchemaCoder coder = registry.getSchemaCoder(SendMessageRequest.class); + + Row row = coder.getToRowFunction().apply(request); + assertThat(coder.getFromRowFunction().apply(row)).isEqualTo(request); + + byte[] requestBytes = encodeToByteArray(coder, request); + SendMessageRequest requestFromBytes = decodeFromByteArray(coder, requestBytes); + assertThat(requestFromBytes).isEqualTo(request); + + // verify still serializable after use + ensureSerializableRoundTrip(coder.getToRowFunction()); + ensureSerializableRoundTrip(coder.getFromRowFunction()); + } + + @Test + public void testFromRowWithPartialSchema() throws NoSuchSchemaException { + SerializableFunction fromRow = + registry.getFromRowFunction(SendMessageRequest.class); + + Schema partialSchema = + Schema.builder() + .addNullableField("queueUrl", FieldType.STRING) + .addNullableField("messageBody", FieldType.STRING) + .addNullableField("delaySeconds", FieldType.INT32) + .build(); + + SendMessageRequest request = + SendMessageRequest.builder() + .queueUrl("queue") + .messageBody("body") + .delaySeconds(100) + .build(); + + Row row = Row.withSchema(partialSchema).addValues("queue", "body", 100).build(); + + assertThat(fromRow.apply(row)).isEqualTo(request); + } + + @Test + public void testFailFromRowOnUnknownField() throws NoSuchSchemaException { + SerializableFunction fromRow = + registry.getFromRowFunction(SendMessageRequest.class); + + Schema partialSchema = + Schema.builder() + .addNullableField("queueUrl", FieldType.STRING) + .addNullableField("messageBody", FieldType.STRING) + .addNullableField("unknownField", FieldType.INT32) + .build(); + + Row row = Row.withSchema(partialSchema).addValues("queue", "body", 100).build(); + + assertThatThrownBy(() -> fromRow.apply(row)) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Row schema contains unknown fields: [unknownField]"); + } + + private static MessageAttributeValue attribute( + Function b) { + return b.apply(MessageAttributeValue.builder()).build(); + } + + private static SdkBytes sdkBytes(String str) { + return SdkBytes.fromByteArrayUnsafe(bytes(str)); + } + + private static byte[] bytes(String str) { + return str.getBytes(UTF_8); + } + + private Condition field(String name, T value) { + return new Condition<>( + row -> { + FieldType type = row.getSchema().getField(name).getType(); + SoftAssertions soft = new SoftAssertions(); + Object actual = row.getValue(name); + if (type.getTypeName() == ARRAY && value != null && value instanceof List) { + soft.assertThat((List) actual).containsExactlyElementsOf((List) value); + } else { + soft.assertThat((T) actual).isEqualTo(value); + } + soft.errorsCollected().forEach(System.out::println); + return soft.errorsCollected().isEmpty(); + }, + "field %s of: %s", + name, + value); + } +} diff --git a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtilsTest.java similarity index 51% rename from sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java rename to sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtilsTest.java index 814f1f34ad87..82ccecffcc48 100644 --- a/sdks/java/io/amazon-web-services2/src/main/java/org/apache/beam/sdk/io/aws2/sqs/SendMessageRequestCoderRegistrar.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/AwsSchemaUtilsTest.java @@ -15,24 +15,22 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.beam.sdk.io.aws2.sqs; +package org.apache.beam.sdk.io.aws2.schemas; -import com.google.auto.service.AutoService; -import java.util.List; -import org.apache.beam.sdk.coders.CoderProvider; -import org.apache.beam.sdk.coders.CoderProviderRegistrar; -import org.apache.beam.sdk.coders.CoderProviders; -import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.Test; import software.amazon.awssdk.services.sqs.model.SendMessageRequest; -/** A {@link CoderProviderRegistrar} for standard types used with {@link SqsIO}. */ -@AutoService(CoderProviderRegistrar.class) -public class SendMessageRequestCoderRegistrar implements CoderProviderRegistrar { - @Override - public List getCoderProviders() { - return ImmutableList.of( - CoderProviders.forCoder( - TypeDescriptor.of(SendMessageRequest.class), SendMessageRequestCoder.of())); +public class AwsSchemaUtilsTest { + + @Test + public void generateBuilderFactory() { + AwsBuilderFactory factory = + AwsSchemaUtils.builderFactory(SendMessageRequest.class); + + assertThat(factory.getClass().getPackage()).isEqualTo(SendMessageRequest.class.getPackage()); + assertThat(factory.get()).isInstanceOf(SendMessageRequest.Builder.class); + assertThat(factory.sdkFields()).isEqualTo(SendMessageRequest.builder().sdkFields()); } } diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/Sample.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/Sample.java new file mode 100644 index 000000000000..cb8cf16f63cd --- /dev/null +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/schemas/Sample.java @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.aws2.schemas; + +import static software.amazon.awssdk.core.protocol.MarshallLocation.PAYLOAD; +import static software.amazon.awssdk.core.protocol.MarshallingType.BIG_DECIMAL; +import static software.amazon.awssdk.core.protocol.MarshallingType.BOOLEAN; +import static software.amazon.awssdk.core.protocol.MarshallingType.DOUBLE; +import static software.amazon.awssdk.core.protocol.MarshallingType.FLOAT; +import static software.amazon.awssdk.core.protocol.MarshallingType.INSTANT; +import static software.amazon.awssdk.core.protocol.MarshallingType.INTEGER; +import static software.amazon.awssdk.core.protocol.MarshallingType.LIST; +import static software.amazon.awssdk.core.protocol.MarshallingType.LONG; +import static software.amazon.awssdk.core.protocol.MarshallingType.MAP; +import static software.amazon.awssdk.core.protocol.MarshallingType.SDK_BYTES; +import static software.amazon.awssdk.core.protocol.MarshallingType.SHORT; +import static software.amazon.awssdk.core.protocol.MarshallingType.STRING; + +import java.io.Serializable; +import java.math.BigDecimal; +import java.time.Instant; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Function; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.builder.EqualsBuilder; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.builder.HashCodeBuilder; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.commons.lang3.builder.ReflectionToStringBuilder; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.SdkField; +import software.amazon.awssdk.core.SdkPojo; +import software.amazon.awssdk.core.protocol.MarshallingType; +import software.amazon.awssdk.core.traits.ListTrait; +import software.amazon.awssdk.core.traits.LocationTrait; +import software.amazon.awssdk.core.traits.MapTrait; +import software.amazon.awssdk.utils.builder.SdkBuilder; + +/** + * Sample AWS SDK pojo with all supported types. Nested objects as well as collections of objects + * are tested using "real" AWS model classes. + */ +public class Sample implements SdkPojo, Serializable { + private static SdkField.Builder sdkField(MarshallingType type) { + return SdkField.builder(type).traits(LocationTrait.builder().location(PAYLOAD).build()); + } + + private static SdkField.Builder sdkField( + MarshallingType type, + String name, + Function g, + BiConsumer s) { + return sdkField((MarshallingType) type) + .memberName(name) + .getter(obj -> g.apply((Sample) obj)) + .setter((obj, val) -> s.accept((Builder) obj, val)); + } + + private static final SdkField STRING_F = + sdkField(STRING, "String", Sample::stringField, Builder::stringField).build(); + + private static final SdkField SHORT_F = + sdkField(SHORT, "Short", Sample::shortField, Builder::shortField).build(); + + private static final SdkField INTEGER_F = + sdkField(INTEGER, "Integer", Sample::integerField, Builder::integerField).build(); + + private static final SdkField LONG_F = + sdkField(LONG, "Long", Sample::longField, Builder::longField).build(); + + private static final SdkField FLOAT_F = + sdkField(FLOAT, "Float", Sample::floatField, Builder::floatField).build(); + + private static final SdkField DOUBLE_F = + sdkField(DOUBLE, "Double", Sample::doubleField, Builder::doubleField).build(); + + private static final SdkField DECIMAL_F = + sdkField(BIG_DECIMAL, "Decimal", Sample::decimalField, Builder::decimalField).build(); + + private static final SdkField BOOLEAN_F = + sdkField(BOOLEAN, "Boolean", Sample::booleanField, Builder::booleanField).build(); + + private static final SdkField INSTANT_F = + sdkField(INSTANT, "Instant", Sample::instantField, Builder::instantField).build(); + + private static final SdkField BYTES_F = + sdkField(SDK_BYTES, "Bytes", Sample::bytesField, Builder::bytesField).build(); + + private static final SdkField> LIST_F = + sdkField(LIST, "List", Sample::listField, Builder::listField) + .traits(ListTrait.builder().memberFieldInfo(sdkField(STRING).build()).build()) + .build(); + + private static final SdkField> MAP_F = + sdkField(MAP, "Map", Sample::mapField, Builder::mapField) + .traits(MapTrait.builder().valueFieldInfo(sdkField(STRING).build()).build()) + .build(); + + private static final List> SDK_FIELDS = + ImmutableList.of( + STRING_F, SHORT_F, INTEGER_F, LONG_F, FLOAT_F, DOUBLE_F, DECIMAL_F, BOOLEAN_F, INSTANT_F, + BYTES_F, LIST_F, MAP_F); + + private final String stringField; + private final Short shortField; + private final Integer integerField; + private final Long longField; + private final Float floatField; + private final Double doubleField; + private final BigDecimal decimalField; + private final Boolean booleanField; + private final Instant instantField; + private final SdkBytes bytesField; + private final List listField; + private final Map mapField; + + private Sample(BuilderImpl builder) { + this.stringField = builder.stringField; + this.shortField = builder.shortField; + this.integerField = builder.integerField; + this.longField = builder.longField; + this.floatField = builder.floatField; + this.doubleField = builder.doubleField; + this.decimalField = builder.decimalField; + this.booleanField = builder.booleanField; + this.instantField = builder.instantField; + this.bytesField = builder.bytesField; + this.listField = builder.listField; + this.mapField = builder.mapField; + } + + public final String stringField() { + return stringField; + } + + public final Short shortField() { + return shortField; + } + + public final Integer integerField() { + return integerField; + } + + public final Long longField() { + return longField; + } + + public final Float floatField() { + return floatField; + } + + public final Double doubleField() { + return doubleField; + } + + public final BigDecimal decimalField() { + return decimalField; + } + + public final Boolean booleanField() { + return booleanField; + } + + public final Instant instantField() { + return instantField; + } + + public final SdkBytes bytesField() { + return bytesField; + } + + public final List listField() { + return listField; + } + + public final Map mapField() { + return mapField; + } + + public static Builder builder() { + return new BuilderImpl(); + } + + @Override + public final int hashCode() { + return HashCodeBuilder.reflectionHashCode(this); + } + + @Override + public final boolean equals(Object obj) { + return EqualsBuilder.reflectionEquals(this, obj); + } + + @Override + public final String toString() { + return ReflectionToStringBuilder.toString(this); + } + + @Override + public final List> sdkFields() { + return SDK_FIELDS; + } + + public interface Builder extends SdkPojo, SdkBuilder { + Builder stringField(String stringField); + + Builder shortField(Short shortField); + + Builder integerField(Integer integerField); + + Builder longField(Long longField); + + Builder floatField(Float floatField); + + Builder doubleField(Double doubleField); + + Builder decimalField(BigDecimal decimalField); + + Builder booleanField(Boolean booleanField); + + Builder instantField(Instant instantField); + + Builder bytesField(SdkBytes bytesField); + + Builder listField(List listField); + + Builder mapField(Map mapField); + } + + static final class BuilderImpl implements Builder { + private String stringField; + private Short shortField; + private Integer integerField; + private Long longField; + private Float floatField; + private Double doubleField; + private BigDecimal decimalField; + private Boolean booleanField; + private Instant instantField; + private SdkBytes bytesField; + private List listField; + private Map mapField; + + @Override + public Builder stringField(String stringField) { + this.stringField = stringField; + return this; + } + + @Override + public Builder shortField(Short shortField) { + this.shortField = shortField; + return this; + } + + @Override + public Builder integerField(Integer integerField) { + this.integerField = integerField; + return this; + } + + @Override + public Builder longField(Long longField) { + this.longField = longField; + return this; + } + + @Override + public Builder floatField(Float floatField) { + this.floatField = floatField; + return this; + } + + @Override + public Builder doubleField(Double doubleField) { + this.doubleField = doubleField; + return this; + } + + @Override + public Builder decimalField(BigDecimal decimalField) { + this.decimalField = decimalField; + return this; + } + + @Override + public Builder booleanField(Boolean booleanField) { + this.booleanField = booleanField; + return this; + } + + @Override + public Builder instantField(Instant instantField) { + this.instantField = instantField; + return this; + } + + @Override + public Builder bytesField(SdkBytes bytesField) { + this.bytesField = bytesField; + return this; + } + + @Override + public Builder listField(List listField) { + this.listField = listField; + return this; + } + + @Override + public Builder mapField(Map mapField) { + this.mapField = mapField; + return this; + } + + @Override + public Sample build() { + return new Sample(this); + } + + @Override + public List> sdkFields() { + return SDK_FIELDS; + } + } +} diff --git a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOTest.java b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOTest.java index 236c025d6260..c61aa94c016c 100644 --- a/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOTest.java +++ b/sdks/java/io/amazon-web-services2/src/test/java/org/apache/beam/sdk/io/aws2/sns/SnsIOTest.java @@ -43,6 +43,7 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.junit.Before; @@ -104,6 +105,7 @@ private void failOnTopicValidation(Function, Write> fn) { public void testSkipTopicValidation() { PCollection input = mock(PCollection.class); when(input.getPipeline()).thenReturn(p); + when(input.apply(any(PTransform.class))).thenReturn(mock(PCollection.class)); Write snsWrite = SnsIO.write().withPublishRequestBuilder(msg -> requestBuilder(msg, topicArn));