diff --git a/.changes/next-release/feature-AmazonS3-04b3f91.json b/.changes/next-release/feature-AmazonS3-04b3f91.json new file mode 100644 index 000000000000..e72075a7f42a --- /dev/null +++ b/.changes/next-release/feature-AmazonS3-04b3f91.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "Amazon S3", + "contributor": "", + "description": "Propagating client apiCallTimeout values to S3Express createSession calls. If existing, this value overrides the default timeout value of 10s when making the nested S3Express session credentials call." +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressIdentityCache.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressIdentityCache.java index ba4cc335ec5b..bb521c62eba9 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressIdentityCache.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressIdentityCache.java @@ -16,12 +16,13 @@ package software.amazon.awssdk.services.s3.internal.s3express; import java.time.Duration; -import java.util.function.Consumer; +import java.util.Optional; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.auth.credentials.AwsCredentials; import software.amazon.awssdk.auth.credentials.CredentialUtils; import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.SdkClient; +import software.amazon.awssdk.core.SdkServiceClientConfiguration; import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; import software.amazon.awssdk.identity.spi.IdentityProvider; import software.amazon.awssdk.services.s3.S3AsyncClient; @@ -75,28 +76,42 @@ private CachedS3ExpressCredentials getCachedCredentials(S3ExpressIdentityKey key .build(); } - //TODO (s3express) user experience and error messaging when calls fail SessionCredentials getCredentials(S3ExpressIdentityKey key, IdentityProvider provider) { SdkClient client = key.client(); String bucket = key.bucket(); + SdkServiceClientConfiguration serviceClientConfiguration = client.serviceClientConfiguration(); if (client instanceof S3AsyncClient) { // TODO (s3express) don't join here - return ((S3AsyncClient) client).createSession(createSessionRequest(bucket, provider)).join().credentials(); + return ((S3AsyncClient) client).createSession(createSessionRequest(bucket, provider, serviceClientConfiguration)) + .join() + .credentials(); } if (client instanceof S3Client) { - return ((S3Client) client).createSession(createSessionRequest(bucket, provider)).credentials(); + return ((S3Client) client).createSession(createSessionRequest(bucket, provider, serviceClientConfiguration)) + .credentials(); } throw new UnsupportedOperationException("SdkClient must be either an S3Client or an S3AsyncClient, but was " + client.getClass()); } - private static Consumer + private static CreateSessionRequest createSessionRequest(String bucket, - IdentityProvider provider) { - return r -> r.bucket(bucket) + IdentityProvider provider, + SdkServiceClientConfiguration serviceClientConfiguration) { + + Duration requestApiCallTimeout = clientSetTimeoutIfExists(serviceClientConfiguration).orElse(DEFAULT_API_CALL_TIMEOUT); + + return CreateSessionRequest.builder().bucket(bucket) .sessionMode(SessionMode.READ_WRITE) .overrideConfiguration(o -> o.credentialsProvider(provider) - .apiCallTimeout(DEFAULT_API_CALL_TIMEOUT)); + .apiCallTimeout(requestApiCallTimeout)).build(); + } + + private static Optional clientSetTimeoutIfExists(SdkServiceClientConfiguration serviceClientConfiguration) { + if (serviceClientConfiguration != null && serviceClientConfiguration.overrideConfiguration() != null) { + return serviceClientConfiguration.overrideConfiguration().apiCallTimeout(); + } + return Optional.empty(); } } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/S3PresignerTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/S3PresignerTest.java index 6e6be1935528..16145290abe3 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/S3PresignerTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/S3PresignerTest.java @@ -950,7 +950,7 @@ private void verifyS3ExpressGetRequest(PresignedGetObjectRequest presigned, Stri private S3Presigner presignerWithS3ExpressWithMockS3Client(boolean disableS3ExpressSessionAuth) { S3Client mockS3SyncClient = mock(S3Client.class); - when(mockS3SyncClient.createSession((Consumer) any())).thenReturn( + when(mockS3SyncClient.createSession((CreateSessionRequest) any())).thenReturn( createS3ExpressSessionResponse()); return presignerForS3Express(disableS3ExpressSessionAuth, mockS3SyncClient); diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java new file mode 100644 index 000000000000..a2751cba65f0 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/functionaltests/S3ExpressCreateSessionTest.java @@ -0,0 +1,337 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.functionaltests; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.anyUrl; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlMatching; +import static java.lang.Boolean.TRUE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static software.amazon.awssdk.http.SdkHttpConfigurationOption.TRUST_ALL_CERTIFICATES; + +import com.github.tomakehurst.wiremock.client.ResponseDefinitionBuilder; +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.github.tomakehurst.wiremock.stubbing.Scenario; +import java.net.URI; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletionException; +import java.util.function.Function; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.core.exception.ApiCallTimeoutException; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.core.rules.testing.BaseRuleSetClientTest; +import software.amazon.awssdk.http.SdkHttpRequest; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3AsyncClientBuilder; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ClientBuilder; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.Protocol; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.utils.AttributeMap; +import software.amazon.awssdk.utils.http.SdkHttpUtils; + +@WireMockTest(httpsEnabled = true) +public class S3ExpressCreateSessionTest extends BaseRuleSetClientTest { + + private static final Function WM_HTTP_ENDPOINT = wm -> URI.create(wm.getHttpBaseUrl()); + private static final Function WM_HTTPS_ENDPOINT = wm -> URI.create(wm.getHttpsBaseUrl()); + private static final AwsCredentialsProvider CREDENTIALS_PROVIDER = + StaticCredentialsProvider.create(AwsBasicCredentials.create("akid", "skid")); + private static final PathStyleEnforcingInterceptor PATH_STYLE_INTERCEPTOR = new PathStyleEnforcingInterceptor(); + private static final CapturingInterceptor CAPTURING_INTERCEPTOR = new CapturingInterceptor(); + private static final String DEFAULT_BUCKET = "s3expressformat--use1-az1--x-s3"; + private static final String DEFAULT_KEY = "foo.txt"; + private static final String GET_BODY = "Hello world!"; + private static final int DEFAULT_API_CALL_TIMEOUT_VALUE_MILLIS = 10000; + + private static final String CREATE_SESSION_RESPONSE = String.format( + "\n" + + "\n" + + "\n" + + "%s\n" + + "%s\n" + + "%s" + + "\n" + + "", "TheToken", "TheSecret", "TheAccessKey"); + + @BeforeEach + public void commonSetup() { + stubFor(get(anyUrl()).willReturn(aResponse().withStatus(200) + .withBody(GET_BODY)) + .withName("OuterGetCall")); + } + + @Test + public void when_clientDefaultIsUsed_andOkResponse_callIsSuccessful(WireMockRuntimeInfo wm) { + stubFor(get(urlMatching("/.*session")).willReturn(aResponse().withStatus(200) + .withBody(CREATE_SESSION_RESPONSE))); + createClientAndCallGetObject(null, ClientType.SYNC, wm); + } + + @Test + public void when_clientDefaultIsUsed_andResponseIsDelayed_timeoutExceptionIsPropagated(WireMockRuntimeInfo wm) { + Integer delayResponseTimeInMillis = 10000; + stubFor(get(urlMatching("/.*session")).willReturn(aResponse().withStatus(200) + .withBody(CREATE_SESSION_RESPONSE) + .withFixedDelay(delayResponseTimeInMillis))); + assertThatThrownBy(() -> createClientAndCallGetObject(null, ClientType.SYNC, wm)) + .isInstanceOf(ApiCallTimeoutException.class) + .hasMessageContaining(String.valueOf(DEFAULT_API_CALL_TIMEOUT_VALUE_MILLIS)); + } + + @Test + public void when_asyncClientDefaultIsUsed_andResponseIsDelayed_timeoutExceptionIsPropagated(WireMockRuntimeInfo wm) { + Integer delayResponseTimeInMillis = 10000; + stubFor(get(urlMatching("/.*session")).willReturn(aResponse().withStatus(200) + .withBody(CREATE_SESSION_RESPONSE) + .withFixedDelay(delayResponseTimeInMillis))); + assertThatThrownBy(() -> createClientAndCallGetObject(null, ClientType.ASYNC, wm)) + .isInstanceOf(CompletionException.class) + .hasMessageContaining(String.valueOf(DEFAULT_API_CALL_TIMEOUT_VALUE_MILLIS)) + .hasCauseInstanceOf(ApiCallTimeoutException.class); + } + + @Test + public void when_clientDefaultIsUsed_andResponseHasRetryableError_exceptionIsPropagated(WireMockRuntimeInfo wm) { + stubFor(get(urlMatching("/.*session")).willReturn(aResponse().withStatus(500).withBody(""))); + try { + createClientAndCallGetObject(null, ClientType.SYNC, wm); + } catch (Exception e) { + assertThat(e).isInstanceOf(S3Exception.class); + assertThat(e.getSuppressed()).anySatisfy(throwable -> assertThat(throwable).isInstanceOf(SdkClientException.class)); + } + } + + @Test + public void when_asyncClientDefaultIsUsed_andResponseHasRetryableError_exceptionIsPropagated(WireMockRuntimeInfo wm) { + stubFor(get(urlMatching("/.*session")).willReturn(aResponse().withStatus(500).withBody(""))); + try { + createClientAndCallGetObject(null, ClientType.ASYNC, wm); + } catch (Exception e) { + assertThat(e).isInstanceOf(CompletionException.class); + Throwable cause = e.getCause(); + assertThat(cause).isInstanceOf(S3Exception.class); + assertThat(cause.getSuppressed()).anySatisfy(throwable -> assertThat(throwable).isInstanceOf(SdkClientException.class)); + } + } + + @Test + public void when_asyncClientDefaultIsUsed_andResponseHasRetryableErrorWithDelays_timeoutExceptionIsPropagated(WireMockRuntimeInfo wm) { + stubForaResponseWithDelayedRetryableException(); + assertThatThrownBy(() -> createClientAndCallGetObject(null, ClientType.ASYNC, wm)) + .isInstanceOf(CompletionException.class) + .hasMessageContaining(String.valueOf(DEFAULT_API_CALL_TIMEOUT_VALUE_MILLIS)) + .hasCauseInstanceOf(ApiCallTimeoutException.class); + } + + private static Stream apiCallTimeoutValues() { + return Stream.of( + Arguments.of(1000L), + Arguments.of(5000L) + ); + } + + @ParameterizedTest + @MethodSource("apiCallTimeoutValues") + public void when_clientApiCallTimeoutConfigured_andOkResponse_callIsSuccessful(Long apiCallTimeoutValue, + WireMockRuntimeInfo wm) { + stubFor(get(urlMatching("/.*session")).atPriority(1).willReturn(aResponse() + .withStatus(200) + .withBody(CREATE_SESSION_RESPONSE))); + createClientAndCallGetObject(apiCallTimeoutValue, ClientType.SYNC, wm); + } + + @ParameterizedTest + @MethodSource("apiCallTimeoutValues") + public void when_clientApiCallTimeoutConfigured_andResponseIsDelayed_timeoutExceptionIsPropagated(Long apiCallTimeoutValue, + WireMockRuntimeInfo wm) { + Integer delayResponseTimeInMillis = apiCallTimeoutValue.intValue() + 500; + stubFor(get(urlMatching("/.*session")).atPriority(1).willReturn(aResponse() + .withStatus(200) + .withBody(CREATE_SESSION_RESPONSE) + .withFixedDelay(delayResponseTimeInMillis))); + assertThatThrownBy(() -> createClientAndCallGetObject(apiCallTimeoutValue, ClientType.SYNC, wm)) + .isInstanceOf(ApiCallTimeoutException.class) + .hasMessageContaining(String.valueOf(apiCallTimeoutValue)); + } + + @ParameterizedTest + @MethodSource("apiCallTimeoutValues") + public void when_clientApiCallTimeoutConfigured_andResponseHasRetryableError_exceptionIsPropagated(Long apiCallTimeoutValue, + WireMockRuntimeInfo wm) { + stubFor(get(urlMatching("/.*session")).atPriority(1).willReturn(aResponse().withStatus(500).withBody(""))); + try { + createClientAndCallGetObject(apiCallTimeoutValue, ClientType.SYNC, wm); + } catch (Exception e) { + assertThat(e).isInstanceOf(S3Exception.class); + assertThat(e.getSuppressed()).anySatisfy(throwable -> assertThat(throwable).isInstanceOf(SdkClientException.class)); + } + } + + @ParameterizedTest + @MethodSource("apiCallTimeoutValues") + public void when_asyncClientApiCallTimeoutConfigured_andResponseHasRetryableError_exceptionIsPropagated(Long apiCallTimeoutValue, + WireMockRuntimeInfo wm) { + stubFor(get(urlMatching("/.*session")).atPriority(1).willReturn(aResponse().withStatus(500).withBody(""))); + try { + createClientAndCallGetObject(apiCallTimeoutValue, ClientType.ASYNC, wm); + } catch (Exception e) { + assertThat(e).isInstanceOf(CompletionException.class); + Throwable cause = e.getCause(); + assertThat(cause).isInstanceOf(S3Exception.class); + assertThat(cause.getSuppressed()).anySatisfy(throwable -> assertThat(throwable).isInstanceOf(SdkClientException.class)); + } + } + + private void createClientAndCallGetObject(Long apiCallTimeoutValue, ClientType clientType, + WireMockRuntimeInfo wm) { + GetObjectRequest getObjectRequest = GetObjectRequest.builder().bucket(DEFAULT_BUCKET).key(DEFAULT_KEY).build(); + ClientOverrideConfiguration.Builder overrideConfiguration = + ClientOverrideConfiguration.builder() + .addExecutionInterceptor(CAPTURING_INTERCEPTOR) + .addExecutionInterceptor(PATH_STYLE_INTERCEPTOR); + if (apiCallTimeoutValue != null) { + overrideConfiguration.apiCallTimeout(Duration.ofMillis(apiCallTimeoutValue)); + } + if (clientType == ClientType.SYNC) { + S3Client s3Client = s3Client(overrideConfiguration.build(), wm); + s3Client.getObject(getObjectRequest); + } else { + S3AsyncClient s3Client = s3AsyncClient(overrideConfiguration.build(), wm); + s3Client.getObject(getObjectRequest, AsyncResponseTransformer.toBytes()).join(); + } + } + + private enum ClientType { + SYNC, + ASYNC + } + + private void stubForaResponseWithDelayedRetryableException() { + ResponseDefinitionBuilder errorResponse = aResponse().withStatus(500).withBody(""); + stubFor(get(urlMatching("/.*session")).atPriority(1) + .inScenario("retriesWithDelay") + .willSetStateTo("second") + .whenScenarioStateIs(Scenario.STARTED) + .willReturn(errorResponse.withFixedDelay(5500))); + stubFor(get(urlMatching("/.*session")).atPriority(1) + .inScenario("retriesWithDelay") + .whenScenarioStateIs("second") + .willSetStateTo("third") + .willReturn(errorResponse.withFixedDelay(5000))); + stubFor(get(urlMatching("/.*session")).atPriority(1) + .inScenario("retriesWithDelay") + .whenScenarioStateIs("third") + .willSetStateTo("finish") + .willReturn(errorResponse.withFixedDelay(5000))); + } + + private S3Client s3Client(ClientOverrideConfiguration overrideConfiguration, WireMockRuntimeInfo wm) { + S3ClientBuilder syncClientBuilder = S3Client.builder() + .region(Region.US_EAST_1) + .overrideConfiguration(overrideConfiguration) + .credentialsProvider(CREDENTIALS_PROVIDER); + setEndpointParametersSync(syncClientBuilder, Protocol.HTTPS, wm); + return syncClientBuilder.build(); + } + + private S3AsyncClient s3AsyncClient(ClientOverrideConfiguration overrideConfiguration, WireMockRuntimeInfo wm) { + S3AsyncClientBuilder asyncClientBuilder = S3AsyncClient.builder() + .region(Region.US_EAST_1) + .overrideConfiguration(overrideConfiguration) + .credentialsProvider(CREDENTIALS_PROVIDER); + setEndpointParametersAsync(asyncClientBuilder, Protocol.HTTPS, wm); + return asyncClientBuilder.build(); + } + + private void setEndpointParametersAsync(S3AsyncClientBuilder clientBuilder, Protocol protocol, WireMockRuntimeInfo wm) { + if (protocol == Protocol.HTTP) { + clientBuilder.endpointOverride(WM_HTTP_ENDPOINT.apply(wm)); + } else { + clientBuilder.endpointOverride(WM_HTTPS_ENDPOINT.apply(wm)) + .httpClient(NettyNioAsyncHttpClient.builder() + .buildWithDefaults(AttributeMap.builder() + .put(TRUST_ALL_CERTIFICATES, true).build())); + } + } + + private void setEndpointParametersSync(S3ClientBuilder clientBuilder, Protocol protocol, WireMockRuntimeInfo wm) { + if (protocol == Protocol.HTTP) { + clientBuilder.endpointOverride(WM_HTTP_ENDPOINT.apply(wm)); + } else { + clientBuilder.endpointOverride(WM_HTTPS_ENDPOINT.apply(wm)) + .httpClient(ApacheHttpClient.builder() + .buildWithDefaults(AttributeMap.builder() + .put(TRUST_ALL_CERTIFICATES, TRUE) + .build())); + } + } + + /** + * S3Express does not support path style enforcement through client configuration and the endpoint will resolve + * to virtual style. However, path style is required for the HTTP client to be able to direct requests to localhost + * and the WireMock port. + */ + private static final class PathStyleEnforcingInterceptor implements ExecutionInterceptor { + + @Override + public SdkHttpRequest modifyHttpRequest(Context.ModifyHttpRequest context, ExecutionAttributes executionAttributes) { + SdkHttpRequest sdkHttpRequest = context.httpRequest(); + String host = sdkHttpRequest.host(); + String bucket = host.substring(0, host.indexOf(".localhost")); + + return sdkHttpRequest.toBuilder().host("localhost") + .encodedPath(SdkHttpUtils.appendUri(bucket, sdkHttpRequest.encodedPath())) + .build(); + } + } + + private static final class CapturingInterceptor implements ExecutionInterceptor { + private Map> headers; + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkHttpRequest sdkHttpRequest = context.httpRequest(); + this.headers = sdkHttpRequest.headers(); + System.out.printf("%s %s%n", sdkHttpRequest.method(), sdkHttpRequest.encodedPath()); + headers.forEach((k, strings) -> System.out.printf("%s, %s%n", k, strings)); + System.out.println(); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCreateSessionConfigurationTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCreateSessionConfigurationTest.java new file mode 100644 index 000000000000..47c167f99cc2 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/s3express/S3ExpressCreateSessionConfigurationTest.java @@ -0,0 +1,144 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.s3express; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.auth.credentials.AwsCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.SdkClient; +import software.amazon.awssdk.core.client.config.ClientOverrideConfiguration; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.S3ServiceClientConfiguration; +import software.amazon.awssdk.services.s3.model.CreateSessionRequest; +import software.amazon.awssdk.services.s3.model.CreateSessionResponse; +import software.amazon.awssdk.services.s3.model.SessionCredentials; + + +@ExtendWith(MockitoExtension.class) +class S3ExpressCreateSessionConfigurationTest { + + private static final Duration DEFAULT_API_CALL_TIMEOUT_VALUE = Duration.ofSeconds(10); + private static final StaticCredentialsProvider CREDENTIALS_PROVIDER = StaticCredentialsProvider.create(mock(AwsCredentials.class)); + + private static final CreateSessionResponse EMPTY_RESPONSE = CreateSessionResponse.builder() + .credentials(SessionCredentials.builder().build()) + .build(); + + @Mock + S3Client s3Client; + @Mock + S3AsyncClient s3AsyncClient; + @Captor + ArgumentCaptor requestCaptor; + + @Test + void when_noApiCallTimeoutIsSet_DefaultValueIsUsedByCreateSessionRequest() { + when(s3Client.createSession((CreateSessionRequest) any())).thenReturn(EMPTY_RESPONSE); + when(s3Client.serviceClientConfiguration()).thenReturn(serviceClientConfigurationWithApiCallTimeout(null)); + + S3ExpressIdentityCache s3ExpressIdentityCache = S3ExpressIdentityCache.create(); + s3ExpressIdentityCache.getCredentials(key(s3Client), CREDENTIALS_PROVIDER); + + verifyCreateSessionApiCallTimeoutOverride(DEFAULT_API_CALL_TIMEOUT_VALUE); + } + + @Test + void when_clientApiCallTimeoutIsSet_valueIsUsedByCreateSessionRequest() { + Duration clientApiCallTimeout = Duration.ofSeconds(3); + + when(s3Client.serviceClientConfiguration()).thenReturn(serviceClientConfigurationWithApiCallTimeout(clientApiCallTimeout)); + when(s3Client.createSession((CreateSessionRequest) any())).thenReturn(EMPTY_RESPONSE); + + S3ExpressIdentityCache s3ExpressIdentityCache = S3ExpressIdentityCache.create(); + s3ExpressIdentityCache.getCredentials(key(s3Client), CREDENTIALS_PROVIDER); + + verifyCreateSessionApiCallTimeoutOverride(clientApiCallTimeout); + } + + @Test + void async_when_noApiCallTimeoutIsSet_DefaultValueIsUsedByCreateSessionRequest() { + when(s3AsyncClient.createSession((CreateSessionRequest) any())).thenReturn(CompletableFuture.completedFuture(EMPTY_RESPONSE)); + when(s3AsyncClient.serviceClientConfiguration()).thenReturn(serviceClientConfigurationWithApiCallTimeout(null)); + + S3ExpressIdentityCache s3ExpressIdentityCache = S3ExpressIdentityCache.create(); + s3ExpressIdentityCache.getCredentials(key(s3AsyncClient), CREDENTIALS_PROVIDER); + + asyncVerifyCreateSessionApiCallTimeoutOverride(DEFAULT_API_CALL_TIMEOUT_VALUE); + } + + @Test + void async_when_clientpiCallTimeoutIsSet_valueIsUsedByCreateSessionRequest() { + Duration clientApiCallTimeout = Duration.ofSeconds(3); + + when(s3AsyncClient.serviceClientConfiguration()).thenReturn(serviceClientConfigurationWithApiCallTimeout(clientApiCallTimeout)); + when(s3AsyncClient.createSession((CreateSessionRequest) any())).thenReturn(CompletableFuture.completedFuture(EMPTY_RESPONSE)); + + S3ExpressIdentityCache s3ExpressIdentityCache = S3ExpressIdentityCache.create(); + s3ExpressIdentityCache.getCredentials(key(s3AsyncClient), CREDENTIALS_PROVIDER); + + asyncVerifyCreateSessionApiCallTimeoutOverride(clientApiCallTimeout); + } + + private S3ServiceClientConfiguration serviceClientConfigurationWithApiCallTimeout(Duration apiCallTimeout) { + return S3ServiceClientConfiguration.builder() + .overrideConfiguration(ClientOverrideConfiguration.builder() + .apiCallTimeout(apiCallTimeout) + .build()) + .build(); + } + + private S3ExpressIdentityKey key(SdkClient client) { + return S3ExpressIdentityKey.builder() + .bucket("Bucket-1") + .client(client) + .identity(mock(AwsCredentialsIdentity.class)) + .build(); + } + + private void asyncVerifyCreateSessionApiCallTimeoutOverride(Duration expectedTimeout) { + verify(s3AsyncClient, times(1)).createSession(requestCaptor.capture()); + verifyApiCallTimeoutOverride(expectedTimeout); + } + + private void verifyCreateSessionApiCallTimeoutOverride(Duration expectedTimeout) { + verify(s3Client, times(1)).createSession(requestCaptor.capture()); + verifyApiCallTimeoutOverride(expectedTimeout); + } + + private void verifyApiCallTimeoutOverride(Duration expectedTimeout) { + Optional awsRequestOverrideConfiguration = requestCaptor.getValue().overrideConfiguration(); + assertThat(awsRequestOverrideConfiguration.isPresent()); + assertThat(awsRequestOverrideConfiguration.get().apiCallTimeout()).isPresent().hasValue(expectedTimeout); + } +} \ No newline at end of file