Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose client-level context params on service client configuration #4834

Merged
merged 7 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/next-release/feature-AWSSDKforJavav2-7624538.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"type": "feature",
"category": "AWS SDK for Java v2",
"contributor": "anirudh9391",
"description": "Allowing SDK plugins to read and modify S3's crossRegionEnabled and SQS's checksumValidationEnabled"
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ private MethodSpec buildClientMethod() {
builder.addStatement("$1T client = new $2T(clientConfiguration)",
clientInterfaceName, clientClassName);
if (model.asyncClientDecoratorClassName().isPresent()) {
builder.addStatement("return new $T().decorate(client, clientConfiguration, clientContextParams.copy().build())",
builder.addStatement("return new $T().decorate(client, clientConfiguration)",
PoetUtils.classNameFromFqcn(model.asyncClientDecoratorClassName().get()));
} else {
builder.addStatement("return client");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private MethodSpec buildClientMethod() {
builder.addStatement("$1T client = new $2T(clientConfiguration)",
clientInterfaceName, clientClassName);
if (model.syncClientDecoratorClassName().isPresent()) {
builder.addStatement("return new $T().decorate(client, clientConfiguration, clientContextParams.copy().build())",
builder.addStatement("return new $T().decorate(client, clientConfiguration)",
PoetUtils.classNameFromFqcn(model.syncClientDecoratorClassName().get()));
} else {
builder.addStatement("return client");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;
import java.util.stream.Stream;
Expand All @@ -47,6 +49,7 @@
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.model.intermediate.OperationModel;
import software.amazon.awssdk.codegen.model.intermediate.Protocol;
import software.amazon.awssdk.codegen.model.service.ClientContextParam;
import software.amazon.awssdk.codegen.poet.PoetExtension;
import software.amazon.awssdk.codegen.poet.PoetUtils;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
Expand All @@ -56,6 +59,7 @@
import software.amazon.awssdk.codegen.poet.client.specs.QueryProtocolSpec;
import software.amazon.awssdk.codegen.poet.client.specs.XmlProtocolSpec;
import software.amazon.awssdk.codegen.poet.model.ServiceClientConfigurationUtils;
import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils;
import software.amazon.awssdk.core.RequestOverrideConfiguration;
import software.amazon.awssdk.core.SdkPlugin;
import software.amazon.awssdk.core.SdkRequest;
Expand All @@ -69,8 +73,11 @@
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.MetricPublisher;
import software.amazon.awssdk.metrics.NoOpMetricCollector;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.CompletableFutureUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

public class SyncClientClass extends SyncClientInterface {

Expand Down Expand Up @@ -418,7 +425,7 @@ protected MethodSpec.Builder waiterOperationBody(MethodSpec.Builder builder) {
poetExtensions.getSyncWaiterInterface());
}

protected static MethodSpec updateSdkClientConfigurationMethod(
protected MethodSpec updateSdkClientConfigurationMethod(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to enable the method to access the model instance. I am unsure as to why a code gen method is static in the first place.

TypeName serviceClientConfigurationBuilderClassName,
boolean shouldAddClientReference) {
MethodSpec.Builder builder = MethodSpec.methodBuilder("updateSdkClientConfiguration")
Expand All @@ -442,9 +449,34 @@ protected static MethodSpec updateSdkClientConfigurationMethod(
.addStatement("$1T serviceConfigBuilder = new $1T(configuration)", serviceClientConfigurationBuilderClassName)
.beginControlFlow("for ($T plugin : plugins)", SdkPlugin.class)
.addStatement("plugin.configureClient(serviceConfigBuilder)")
.endControlFlow()
.addStatement("return configuration.build()");
.endControlFlow();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only modified the Sync class and did not modify the Async class as it would suffice to incorporate the "request-level" modification check into the DefaultServiceClient Client

EndpointRulesSpecUtils endpointRulesSpecUtils = new EndpointRulesSpecUtils(this.model);

if (model.getCustomizationConfig() == null ||
CollectionUtils.isNullOrEmpty(model.getCustomizationConfig().getCustomClientContextParams())) {
builder.addStatement("return configuration.build()");
return builder.build();
}

Map<String, ClientContextParam> customClientConfigParams = model.getCustomizationConfig().getCustomClientContextParams();

builder.addCode("$1T newContextParams = configuration.option($2T.CLIENT_CONTEXT_PARAMS);\n"
+ "$1T originalContextParams = clientConfiguration.option($2T.CLIENT_CONTEXT_PARAMS);",
AttributeMap.class, SdkClientOption.class);

builder.addCode("newContextParams = (newContextParams != null) ? newContextParams : $1T.empty();\n"
+ "originalContextParams = originalContextParams != null ? originalContextParams : $1T.empty();",
AttributeMap.class);

customClientConfigParams.forEach((n, m) -> {
String keyName = model.getNamingStrategy().getEnumValueName(n);
builder.addStatement("$1T.validState($2T.equals(originalContextParams.get($3T.$4N), newContextParams.get($3T.$4N)),"
+ " $5S)",
Validate.class, Objects.class, endpointRulesSpecUtils.clientContextParamsName(), keyName,
keyName + " cannot be modified by request level plugins");
});

builder.addStatement("return configuration.build()");
return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.squareup.javapoet.TypeName;
import com.squareup.javapoet.WildcardTypeName;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -31,6 +32,7 @@
import software.amazon.awssdk.awscore.client.config.AwsClientOption;
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.poet.auth.scheme.AuthSchemeSpecUtils;
import software.amazon.awssdk.codegen.poet.rules.EndpointRulesSpecUtils;
import software.amazon.awssdk.core.client.config.ClientOption;
import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
Expand All @@ -41,12 +43,14 @@
import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity;
import software.amazon.awssdk.identity.spi.IdentityProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.utils.AttributeMap;
import software.amazon.awssdk.utils.Validate;

public class ServiceClientConfigurationUtils {
private final AuthSchemeSpecUtils authSchemeSpecUtils;
private final ClassName configurationClassName;
private final ClassName configurationBuilderClassName;
private final EndpointRulesSpecUtils endpointRulesSpecUtils;
private final List<Field> fields;

public ServiceClientConfigurationUtils(IntermediateModel model) {
Expand All @@ -56,7 +60,8 @@ public ServiceClientConfigurationUtils(IntermediateModel model) {
configurationBuilderClassName = ClassName.get(model.getMetadata().getFullClientInternalPackageName(),
serviceId + "ServiceClientConfigurationBuilder");
authSchemeSpecUtils = new AuthSchemeSpecUtils(model);
fields = fields();
endpointRulesSpecUtils = new EndpointRulesSpecUtils(model);
fields = fields(model);
}

/**
Expand All @@ -81,16 +86,44 @@ public List<Field> serviceClientConfigurationFields() {
return Collections.unmodifiableList(fields);
}

private List<Field> fields() {
return Arrays.asList(
private List<Field> fields(IntermediateModel model) {
List<Field> fields = new ArrayList<>();

fields.addAll(Arrays.asList(
overrideConfigurationField(),
endpointOverrideField(),
endpointProviderField(),
regionField(),
credentialsProviderField(),
authSchemesField(),
authSchemeProviderField()
);
));
fields.addAll(addCustomClientParams(model));
return fields;
}

private List<Field> addCustomClientParams(IntermediateModel model) {
List<Field> customClientParamFields = new ArrayList<>();

if (model.getCustomizationConfig() != null && model.getCustomizationConfig().getCustomClientContextParams() != null) {
anirudh9391 marked this conversation as resolved.
Show resolved Hide resolved
model.getCustomizationConfig().getCustomClientContextParams().forEach((n, m) -> {

String paramName = endpointRulesSpecUtils.paramMethodName(n);
String keyName = model.getNamingStrategy().getEnumValueName(n);
TypeName type = endpointRulesSpecUtils.toJavaType(m.getType());

customClientParamFields.add(fieldBuilder(paramName, type)
.doc(m.getDocumentation())
.isInherited(false)
.localSetter(basicLocalSetterCode(paramName))
.localGetter(basicLocalGetterCode(paramName))
.configSetter(customClientConfigParamSetter(paramName, keyName))
.configGetter(customClientConfigParamGetter(keyName))
.build());
});
}

return customClientParamFields;
}

private Field overrideConfigurationField() {
Expand Down Expand Up @@ -268,6 +301,27 @@ private CodeBlock authSchemeProviderConfigGetter() {
.build();
}

private CodeBlock customClientConfigParamSetter(String parameterName, String keyName) {
return CodeBlock.builder()
.addStatement("config.option($1T.CLIENT_CONTEXT_PARAMS, "
+ "config.computeOptionIfAbsent($1T.CLIENT_CONTEXT_PARAMS, $2T::empty)"
+ ".toBuilder().put($3T.$4N, $5N).build())",
SdkClientOption.class,
AttributeMap.class,
endpointRulesSpecUtils.clientContextParamsName(),
keyName, parameterName)
.addStatement("return this")
.build();
}

private CodeBlock customClientConfigParamGetter(String keyName) {
return CodeBlock.builder()
.addStatement("return config.computeOptionIfAbsent($T.CLIENT_CONTEXT_PARAMS, $T::empty)\n"
+ ".get($T.$N)", SdkClientOption.class, AttributeMap.class,
endpointRulesSpecUtils.clientContextParamsName(), keyName)
.build();
}

private CodeBlock basicLocalSetterCode(String fieldName) {
return CodeBlock.builder()
.addStatement("this.$1N = $1N", fieldName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ protected final JsonAsyncClient buildClient() {
SdkClientConfiguration clientConfiguration = super.asyncClientConfiguration();
this.validateClientOptions(clientConfiguration);
JsonAsyncClient client = new DefaultJsonAsyncClient(clientConfiguration);
return new AsyncClientDecorator().decorate(client, clientConfiguration, clientContextParams.copy().build());
return new AsyncClientDecorator().decorate(client, clientConfiguration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ protected final JsonClient buildClient() {
SdkClientConfiguration clientConfiguration = super.syncClientConfiguration();
this.validateClientOptions(clientConfiguration);
JsonClient client = new DefaultJsonClient(clientConfiguration);
return new SyncClientDecorator().decorate(client, clientConfiguration, clientContextParams.copy().build());
return new SyncClientDecorator().decorate(client, clientConfiguration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.function.Predicate;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams;
import software.amazon.awssdk.services.s3.internal.crossregion.S3CrossRegionAsyncClient;
Expand All @@ -39,8 +40,8 @@ public S3AsyncClientDecorator() {
}

public S3AsyncClient decorate(S3AsyncClient base,
SdkClientConfiguration clientConfiguration,
AttributeMap clientContextParams) {
SdkClientConfiguration clientConfiguration) {
AttributeMap clientContextParams = clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS);
List<ConditionalDecorator<S3AsyncClient>> decorators = new ArrayList<>();
decorators.add(ConditionalDecorator.create(
isCrossRegionEnabledAsync(clientContextParams),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.function.Predicate;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams;
import software.amazon.awssdk.services.s3.internal.crossregion.S3CrossRegionSyncClient;
Expand All @@ -33,8 +34,8 @@ public S3SyncClientDecorator() {
}

public S3Client decorate(S3Client base,
SdkClientConfiguration clientConfiguration,
AttributeMap clientContextParams) {
SdkClientConfiguration clientConfiguration) {
AttributeMap clientContextParams = clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS);
List<ConditionalDecorator<S3Client>> decorators = new ArrayList<>();
decorators.add(ConditionalDecorator.create(isCrossRegionEnabledSync(clientContextParams),
S3CrossRegionSyncClient::new));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import software.amazon.awssdk.core.client.config.ClientOption;
import software.amazon.awssdk.core.client.config.SdkClientConfiguration;
import software.amazon.awssdk.core.client.config.SdkClientOption;
import software.amazon.awssdk.services.s3.S3AsyncClient;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.endpoints.S3ClientContextParams;
Expand All @@ -33,14 +36,13 @@

public class ClientDecorationFactoryTest {

AttributeMap.Builder clientContextParams = AttributeMap.builder();
static SdkClientConfiguration.Builder clientConfiguration = SdkClientConfiguration.builder();

@ParameterizedTest
@MethodSource("syncTestCases")

void syncClientTest(AttributeMap clientContextParams, Class<Object> clazz, boolean isClass) {
void syncClientTest(SdkClientConfiguration clientConfiguration, Class<Object> clazz, boolean isClass) {
S3SyncClientDecorator decorator = new S3SyncClientDecorator();
S3Client decorateClient = decorator.decorate(S3Client.create(), null, clientContextParams);
S3Client decorateClient = decorator.decorate(S3Client.create(), clientConfiguration);
if (isClass) {
assertThat(decorateClient).isInstanceOf(clazz);
} else {
Expand All @@ -50,10 +52,10 @@ void syncClientTest(AttributeMap clientContextParams, Class<Object> clazz, boole

@ParameterizedTest
@MethodSource("asyncTestCases")
void asyncClientTest(AttributeMap clientContextParams, Class<Object> clazz, boolean isClass) {
void asyncClientTest(SdkClientConfiguration clientConfiguration, Class<Object> clazz, boolean isClass) {
S3AsyncClientDecorator decorator = new S3AsyncClientDecorator();
S3AsyncClient decoratedClient = decorator.decorate(S3AsyncClient.create(),
null ,clientContextParams);
clientConfiguration);
if (isClass) {
assertThat(decoratedClient).isInstanceOf(clazz);
} else {
Expand All @@ -64,24 +66,28 @@ void asyncClientTest(AttributeMap clientContextParams, Class<Object> clazz, bool

private static Stream<Arguments> syncTestCases() {
return Stream.of(
Arguments.of(AttributeMap.builder().build(), S3CrossRegionSyncClient.class, false),
Arguments.of(AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, false).build(),
Arguments.of(clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS, AttributeMap.builder().build()).build(), S3CrossRegionSyncClient.class, false),
Arguments.of(clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS,
AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, false).build()).build(),
S3CrossRegionSyncClient.class, false),
Arguments.of(AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, true).build(),
Arguments.of(clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS,
AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, true).build()).build(),
S3CrossRegionSyncClient.class, true)
);
}

private static Stream<Arguments> asyncTestCases() {
return Stream.of(
Arguments.of(AttributeMap.builder().build(),
Arguments.of(clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS, AttributeMap.builder().build()).build(),
S3CrossRegionAsyncClient.class,
false),
Arguments.of(AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, false).build(),
Arguments.of(clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS,
AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, false).build()).build(),
S3CrossRegionAsyncClient.class,
false),
Arguments.of(AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, true).build()
, S3CrossRegionAsyncClient.class,
Arguments.of(clientConfiguration.option(SdkClientOption.CLIENT_CONTEXT_PARAMS,
AttributeMap.builder().put(S3ClientContextParams.CROSS_REGION_ACCESS_ENABLED, true).build()).build(),
S3CrossRegionAsyncClient.class,
true)
);
}
Expand Down
Loading
Loading