Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling STS regional endpoint to be used by specifying STS region. #144

Merged
merged 2 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.auth.SystemPropertiesCredentialsProvider;
import com.amazonaws.auth.WebIdentityTokenCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.RegionUtils;
import com.amazonaws.retry.PredefinedBackoffStrategies;
import com.amazonaws.retry.v2.AndRetryCondition;
import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition;
Expand Down Expand Up @@ -84,6 +87,7 @@ public class MSKCredentialProvider implements AWSCredentialsProvider, AutoClosea
private static final String AWS_DEBUG_CREDS_KEY = "awsDebugCreds";
private static final String AWS_MAX_RETRIES = "awsMaxRetries";
private static final String AWS_MAX_BACK_OFF_TIME_MS = "awsMaxBackOffTimeMs";
private static final String GLOBAL_REGION = "aws-global";
private static final int DEFAULT_MAX_RETRIES = 3;
private static final int DEFAULT_MAX_BACK_OFF_TIME_MS = 5000;
private static final int BASE_DELAY = 500;
Expand All @@ -105,10 +109,10 @@ public MSKCredentialProvider(Map<String, ?> options) {
}

MSKCredentialProvider(List<AWSCredentialsProvider> providers,
Boolean shouldDebugCreds,
String stsRegion,
int maxRetries,
int maxBackOffTimeMs) {
Boolean shouldDebugCreds,
String stsRegion,
int maxRetries,
int maxBackOffTimeMs) {
List<AWSCredentialsProvider> delegateList = new ArrayList<>(providers);
delegateList.add(getDefaultProvider());
compositeDelegate = new AWSCredentialsProviderChain(delegateList);
Expand Down Expand Up @@ -199,19 +203,19 @@ private void logCallerIdentity(AWSCredentials credentials) {

AWSSecurityTokenService getStsClientForDebuggingCreds(AWSCredentials credentials) {
return AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withCredentials(new AWSCredentialsProvider() {
@Override
public AWSCredentials getCredentials() {
return credentials;
}

@Override
public void refresh() {

}
})
.build();
.withRegion(stsRegion)
.withCredentials(new AWSCredentialsProvider() {
@Override
public AWSCredentials getCredentials() {
return credentials;
}

@Override
public void refresh() {

}
})
.build();
}

@Override
Expand Down Expand Up @@ -253,7 +257,7 @@ public Boolean shouldDebugCreds() {

public String getStsRegion() {
return Optional.ofNullable((String) optionsMap.get(AWS_STS_REGION))
.orElse("aws-global");
.orElse(GLOBAL_REGION);
}

public int getMaxRetries() {
Expand All @@ -267,6 +271,27 @@ public int getMaxBackOffTimeMs() {
.orElse(DEFAULT_MAX_BACK_OFF_TIME_MS);
}

public EndpointConfiguration buildEndpointConfiguration(String stsRegion){
Region region = RegionUtils.getRegion(stsRegion);
String serviceEndpoint = region.getServiceEndpoint("sts");
EndpointConfiguration endpointConfiguration =
new EndpointConfiguration(
String.format(serviceEndpoint, stsRegion),
stsRegion);

return endpointConfiguration;
}

private AWSSecurityTokenServiceClientBuilder getStsClientBuilder(String stsRegion) {
if (GLOBAL_REGION.equals(stsRegion)) {
return AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion);
} else {
return AWSSecurityTokenServiceClientBuilder.standard()
.withEndpointConfiguration(buildEndpointConfiguration(stsRegion));
}
}

private Optional<EnhancedProfileCredentialsProvider> getProfileProvider() {
return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> {
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -298,7 +323,6 @@ private Optional<STSAssumeRoleSessionCredentialsProvider> getStsRoleProvider() {
sessionToken != null
? new BasicSessionCredentials(accessKey, secretKey, sessionToken)
: new BasicAWSCredentials(accessKey, secretKey));

return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials);
}
else if (externalId != null) {
Expand All @@ -311,37 +335,25 @@ else if (externalId != null) {

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.build();
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.withStsClient(getStsClientBuilder(stsRegion).build())
.build();
}

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withCredentials(credentials)
.build();

return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.withStsClient(getStsClientBuilder(stsRegion).withCredentials(credentials).build())
.build();
}

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String externalId,
String sessionName,
String stsRegion) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.build();

return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.withStsClient(getStsClientBuilder(stsRegion).build())
.withExternalId(externalId)
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.amazonaws.client.builder.AwsClientBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

Expand Down Expand Up @@ -312,6 +314,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -347,6 +351,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_EXTERNAL_ID, externalId);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -531,10 +537,10 @@ protected AWSCredentialsProviderChain getDefaultProvider() {
}

private MSKCredentialProvider.ProviderBuilder getProviderBuilder(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider,
Map<String, String> optionsMap, String s) {
Map<String, String> optionsMap, String s) {
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
return mockStsRoleProvider;
Expand All @@ -543,11 +549,11 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
}

private MSKCredentialProvider.ProviderBuilder getProviderBuilderWithCredentials(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider,
Map<String, String> optionsMap, String s) {
Map<String, String> optionsMap, String s) {
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
return mockStsRoleProvider;
Expand Down
Loading