diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java index 87fad13..5133314 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java @@ -28,6 +28,8 @@ import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; import com.amazonaws.auth.SystemPropertiesCredentialsProvider; import com.amazonaws.auth.WebIdentityTokenCredentialsProvider; +import com.amazonaws.client.builder.AwsClientBuilder; +import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; import com.amazonaws.retry.PredefinedBackoffStrategies; import com.amazonaws.retry.v2.AndRetryCondition; import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition; @@ -267,6 +269,22 @@ public int getMaxBackOffTimeMs() { .orElse(DEFAULT_MAX_BACK_OFF_TIME_MS); } + public EndpointConfiguration buildEndpointConfiguration(String stsRegion){ + //An AWSSecurityTokenService with a regional endpoint configuration + EndpointConfiguration endpointConfiguration = + new AwsClientBuilder.EndpointConfiguration( + String.format("sts.%s.amazonaws.com", stsRegion), + stsRegion); + //An AWSSecurityTokenService with a global endpoint configuration + if (stsRegion.equals("aws-global")) { + endpointConfiguration = + new EndpointConfiguration( + "sts.amazonaws.com", + stsRegion); + } + return endpointConfiguration; + } + private Optional getProfileProvider() { return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> { if (log.isDebugEnabled()) { @@ -311,8 +329,9 @@ else if (externalId != null) { STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, String sessionName, String stsRegion) { + EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) + .withEndpointConfiguration(endpointConfiguration) .build(); return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName) .withStsClient(stsClient) @@ -322,8 +341,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, String sessionName, String stsRegion, AWSCredentialsProvider credentials) { + EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) + .withEndpointConfiguration(endpointConfiguration) .withCredentials(credentials) .build(); @@ -336,8 +356,10 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r String externalId, String sessionName, String stsRegion) { + + EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) + .withEndpointConfiguration(endpointConfiguration) .build(); return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)