Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for STS Regional Endpoint (#118) #119

Merged
merged 5 commits into from
Jul 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.auth.SystemPropertiesCredentialsProvider;
import com.amazonaws.auth.WebIdentityTokenCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.RegionUtils;
import com.amazonaws.retry.PredefinedBackoffStrategies;
import com.amazonaws.retry.v2.AndRetryCondition;
import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition;
Expand Down Expand Up @@ -199,19 +202,19 @@ private void logCallerIdentity(AWSCredentials credentials) {

AWSSecurityTokenService getStsClientForDebuggingCreds(AWSCredentials credentials) {
return AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withCredentials(new AWSCredentialsProvider() {
@Override
public AWSCredentials getCredentials() {
return credentials;
}

@Override
public void refresh() {

}
})
.build();
menuetb marked this conversation as resolved.
Show resolved Hide resolved
.withRegion(stsRegion)
.withCredentials(new AWSCredentialsProvider() {
@Override
public AWSCredentials getCredentials() {
return credentials;
}

@Override
public void refresh() {

}
})
.build();
}

@Override
Expand Down Expand Up @@ -267,6 +270,17 @@ public int getMaxBackOffTimeMs() {
.orElse(DEFAULT_MAX_BACK_OFF_TIME_MS);
}

public EndpointConfiguration buildEndpointConfiguration(String stsRegion){
Region region = RegionUtils.getRegion(stsRegion);
String serviceEndpoint = region.getServiceEndpoint("sts");
EndpointConfiguration endpointConfiguration =
new EndpointConfiguration(
String.format(serviceEndpoint, stsRegion),
stsRegion);

return endpointConfiguration;
}

private Optional<EnhancedProfileCredentialsProvider> getProfileProvider() {
return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> {
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -311,8 +325,9 @@ else if (externalId != null) {

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withEndpointConfiguration(endpointConfiguration)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to set the endpoint configuration? if the stsRegion is set to aws-global it should be using sts.amazonaws.com anyway.

regional – The SDK or tool always uses the AWS STS endpoint for the currently configured Region. For example, if the client is configured to use us-west-2, all calls to AWS STS are made to the Regional endpoint sts.us-west-2.amazonaws.com, instead of the global sts.amazonaws.com endpoint. To send a request to the global endpoint while this setting is enabled, you can set the Region to aws-global.

https://docs.aws.amazon.com/sdkref/latest/guide/feature-sts-regionalized-endpoints.html

Have you tested this out? In KDA you should be able to define an environment variable and override the REGION config.

My concern with manually setting endpointconfiguration is the following endpoint might not be valid for all aws regions and partitions in the future.

EndpointConfiguration(
                            String.format("sts.%s.amazonaws.com", stsRegion),
                            stsRegion);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I published a new commit to use region.getServiceEndpoint API instead of a string.

Copy link
Contributor

@plazma-prizma plazma-prizma Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also try setting region as aws-global in KDA?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aws-global is not reachable for customer who are running KDA or MSK Connect in a private subnet. They have to use a regional endpoint for STS.

.build();
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
Expand All @@ -322,8 +337,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withEndpointConfiguration(endpointConfiguration)
.withCredentials(credentials)
.build();

Expand All @@ -336,8 +352,10 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
String externalId,
String sessionName,
String stsRegion) {

EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withEndpointConfiguration(endpointConfiguration)
.build();

return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.amazonaws.client.builder.AwsClientBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

Expand Down Expand Up @@ -312,6 +314,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());

return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -347,6 +352,9 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_EXTERNAL_ID, externalId);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());

return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -381,6 +389,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals("aws-msk-iam-auth", sessionName);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -537,6 +547,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand All @@ -550,6 +562,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
AWSCredentialsProvider credentials) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand Down