diff --git a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java index 87fad13..b69f3a7 100644 --- a/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java +++ b/src/main/java/software/amazon/msk/auth/iam/internals/MSKCredentialProvider.java @@ -28,6 +28,9 @@ import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; import com.amazonaws.auth.SystemPropertiesCredentialsProvider; import com.amazonaws.auth.WebIdentityTokenCredentialsProvider; +import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration; +import com.amazonaws.regions.Region; +import com.amazonaws.regions.RegionUtils; import com.amazonaws.retry.PredefinedBackoffStrategies; import com.amazonaws.retry.v2.AndRetryCondition; import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition; @@ -199,19 +202,19 @@ private void logCallerIdentity(AWSCredentials credentials) { AWSSecurityTokenService getStsClientForDebuggingCreds(AWSCredentials credentials) { return AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) - .withCredentials(new AWSCredentialsProvider() { - @Override - public AWSCredentials getCredentials() { - return credentials; - } - - @Override - public void refresh() { - - } - }) - .build(); + .withRegion(stsRegion) + .withCredentials(new AWSCredentialsProvider() { + @Override + public AWSCredentials getCredentials() { + return credentials; + } + + @Override + public void refresh() { + + } + }) + .build(); } @Override @@ -267,6 +270,17 @@ public int getMaxBackOffTimeMs() { .orElse(DEFAULT_MAX_BACK_OFF_TIME_MS); } + public EndpointConfiguration buildEndpointConfiguration(String stsRegion){ + Region region = RegionUtils.getRegion(stsRegion); + String serviceEndpoint = region.getServiceEndpoint("sts"); + EndpointConfiguration endpointConfiguration = + new EndpointConfiguration( + String.format(serviceEndpoint, stsRegion), + stsRegion); + + return endpointConfiguration; + } + private Optional getProfileProvider() { return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> { if (log.isDebugEnabled()) { @@ -311,8 +325,9 @@ else if (externalId != null) { STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, String sessionName, String stsRegion) { + EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) + .withEndpointConfiguration(endpointConfiguration) .build(); return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName) .withStsClient(stsClient) @@ -322,8 +337,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn, String sessionName, String stsRegion, AWSCredentialsProvider credentials) { + EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) + .withEndpointConfiguration(endpointConfiguration) .withCredentials(credentials) .build(); @@ -336,8 +352,10 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r String externalId, String sessionName, String stsRegion) { + + EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard() - .withRegion(stsRegion) + .withEndpointConfiguration(endpointConfiguration) .build(); return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName) diff --git a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java index ddecffc..412dc0a 100644 --- a/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java +++ b/src/test/java/software/amazon/msk/auth/iam/internals/MSKCredentialProviderTest.java @@ -23,6 +23,8 @@ import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; + +import com.amazonaws.client.builder.AwsClientBuilder; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -312,6 +314,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r assertEquals(TEST_ROLE_ARN, roleArn); assertEquals(TEST_ROLE_SESSION_NAME, sessionName); assertEquals("eu-west-1", stsRegion); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint()); + return mockStsRoleProvider; } }; @@ -347,6 +352,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r assertEquals(TEST_ROLE_EXTERNAL_ID, externalId); assertEquals(TEST_ROLE_SESSION_NAME, sessionName); assertEquals("eu-west-1", stsRegion); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint()); + return mockStsRoleProvider; } }; @@ -381,6 +389,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r String sessionName, String stsRegion) { assertEquals(TEST_ROLE_ARN, roleArn); assertEquals("aws-msk-iam-auth", sessionName); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.amazonaws.com", endpointConfiguration.getServiceEndpoint()); return mockStsRoleProvider; } }; @@ -537,6 +547,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r String sessionName, String stsRegion) { assertEquals(TEST_ROLE_ARN, roleArn); assertEquals(s, sessionName); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.amazonaws.com", endpointConfiguration.getServiceEndpoint()); return mockStsRoleProvider; } }; @@ -550,6 +562,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r AWSCredentialsProvider credentials) { assertEquals(TEST_ROLE_ARN, roleArn); assertEquals(s, sessionName); + AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion); + assertEquals("sts.amazonaws.com", endpointConfiguration.getServiceEndpoint()); return mockStsRoleProvider; } };