Skip to content

Commit

Permalink
Merge pull request #144 from sidyag/main
Browse files Browse the repository at this point in the history
Enabling STS regional endpoint to be used by specifying STS region.
  • Loading branch information
sidyag authored Dec 4, 2023
2 parents 5ffbb58 + c2b72fc commit 29b3372
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 42 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The recommended way to use this library is to consume it from maven central whil
<dependency>
<groupId>software.amazon.msk</groupId>
<artifactId>aws-msk-iam-auth</artifactId>
<version>2.0.0</version>
<version>2.0.1</version>
</dependency>
```
If you want to use it with a pre-existing Kafka client, you could build the uber jar and place it in the Kafka client's
Expand Down Expand Up @@ -519,6 +519,9 @@ public static String UriEncode(CharSequence input, boolean encodeSlash) {

## Release Notes

### Release 2.0.1
- Enable STS region support to set regional endpoints

### Release 2.0.0
- Add SASL/OAUTHBEARER mechanism with IAM

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.auth.SystemPropertiesCredentialsProvider;
import com.amazonaws.auth.WebIdentityTokenCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration;
import com.amazonaws.regions.Region;
import com.amazonaws.regions.RegionUtils;
import com.amazonaws.retry.PredefinedBackoffStrategies;
import com.amazonaws.retry.v2.AndRetryCondition;
import com.amazonaws.retry.v2.MaxNumberOfRetriesCondition;
Expand Down Expand Up @@ -84,6 +87,7 @@ public class MSKCredentialProvider implements AWSCredentialsProvider, AutoClosea
private static final String AWS_DEBUG_CREDS_KEY = "awsDebugCreds";
private static final String AWS_MAX_RETRIES = "awsMaxRetries";
private static final String AWS_MAX_BACK_OFF_TIME_MS = "awsMaxBackOffTimeMs";
private static final String GLOBAL_REGION = "aws-global";
private static final int DEFAULT_MAX_RETRIES = 3;
private static final int DEFAULT_MAX_BACK_OFF_TIME_MS = 5000;
private static final int BASE_DELAY = 500;
Expand All @@ -105,10 +109,10 @@ public MSKCredentialProvider(Map<String, ?> options) {
}

MSKCredentialProvider(List<AWSCredentialsProvider> providers,
Boolean shouldDebugCreds,
String stsRegion,
int maxRetries,
int maxBackOffTimeMs) {
Boolean shouldDebugCreds,
String stsRegion,
int maxRetries,
int maxBackOffTimeMs) {
List<AWSCredentialsProvider> delegateList = new ArrayList<>(providers);
delegateList.add(getDefaultProvider());
compositeDelegate = new AWSCredentialsProviderChain(delegateList);
Expand Down Expand Up @@ -199,19 +203,19 @@ private void logCallerIdentity(AWSCredentials credentials) {

AWSSecurityTokenService getStsClientForDebuggingCreds(AWSCredentials credentials) {
return AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withCredentials(new AWSCredentialsProvider() {
@Override
public AWSCredentials getCredentials() {
return credentials;
}

@Override
public void refresh() {

}
})
.build();
.withRegion(stsRegion)
.withCredentials(new AWSCredentialsProvider() {
@Override
public AWSCredentials getCredentials() {
return credentials;
}

@Override
public void refresh() {

}
})
.build();
}

@Override
Expand Down Expand Up @@ -253,7 +257,7 @@ public Boolean shouldDebugCreds() {

public String getStsRegion() {
return Optional.ofNullable((String) optionsMap.get(AWS_STS_REGION))
.orElse("aws-global");
.orElse(GLOBAL_REGION);
}

public int getMaxRetries() {
Expand All @@ -267,6 +271,27 @@ public int getMaxBackOffTimeMs() {
.orElse(DEFAULT_MAX_BACK_OFF_TIME_MS);
}

public EndpointConfiguration buildEndpointConfiguration(String stsRegion){
Region region = RegionUtils.getRegion(stsRegion);
String serviceEndpoint = region.getServiceEndpoint("sts");
EndpointConfiguration endpointConfiguration =
new EndpointConfiguration(
String.format(serviceEndpoint, stsRegion),
stsRegion);

return endpointConfiguration;
}

private AWSSecurityTokenServiceClientBuilder getStsClientBuilder(String stsRegion) {
if (GLOBAL_REGION.equals(stsRegion)) {
return AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion);
} else {
return AWSSecurityTokenServiceClientBuilder.standard()
.withEndpointConfiguration(buildEndpointConfiguration(stsRegion));
}
}

private Optional<EnhancedProfileCredentialsProvider> getProfileProvider() {
return Optional.ofNullable(optionsMap.get(AWS_PROFILE_NAME_KEY)).map(p -> {
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -298,7 +323,6 @@ private Optional<STSAssumeRoleSessionCredentialsProvider> getStsRoleProvider() {
sessionToken != null
? new BasicSessionCredentials(accessKey, secretKey, sessionToken)
: new BasicAWSCredentials(accessKey, secretKey));

return createSTSRoleCredentialProvider((String) p, sessionName, stsRegion, credentials);
}
else if (externalId != null) {
Expand All @@ -311,37 +335,25 @@ else if (externalId != null) {

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.build();
return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.withStsClient(getStsClientBuilder(stsRegion).build())
.build();
}

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.withCredentials(credentials)
.build();

return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.withStsClient(getStsClientBuilder(stsRegion).withCredentials(credentials).build())
.build();
}

STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String externalId,
String sessionName,
String stsRegion) {
AWSSecurityTokenService stsClient = AWSSecurityTokenServiceClientBuilder.standard()
.withRegion(stsRegion)
.build();

return new STSAssumeRoleSessionCredentialsProvider.Builder(roleArn, sessionName)
.withStsClient(stsClient)
.withStsClient(getStsClientBuilder(stsRegion).build())
.withExternalId(externalId)
.build();
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/resources/version.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#Updated on 2023-11-08T16:12:00Z
#Updated on 2023-12-04T11:23:00Z
platform=java
version=2.0.0
version=2.0.1
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import com.amazonaws.client.builder.AwsClientBuilder;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

Expand Down Expand Up @@ -312,6 +314,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -347,6 +351,8 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
assertEquals(TEST_ROLE_EXTERNAL_ID, externalId);
assertEquals(TEST_ROLE_SESSION_NAME, sessionName);
assertEquals("eu-west-1", stsRegion);
AwsClientBuilder.EndpointConfiguration endpointConfiguration = buildEndpointConfiguration(stsRegion);
assertEquals("sts.eu-west-1.amazonaws.com", endpointConfiguration.getServiceEndpoint());
return mockStsRoleProvider;
}
};
Expand Down Expand Up @@ -531,10 +537,10 @@ protected AWSCredentialsProviderChain getDefaultProvider() {
}

private MSKCredentialProvider.ProviderBuilder getProviderBuilder(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider,
Map<String, String> optionsMap, String s) {
Map<String, String> optionsMap, String s) {
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion) {
String sessionName, String stsRegion) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
return mockStsRoleProvider;
Expand All @@ -543,11 +549,11 @@ STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String r
}

private MSKCredentialProvider.ProviderBuilder getProviderBuilderWithCredentials(STSAssumeRoleSessionCredentialsProvider mockStsRoleProvider,
Map<String, String> optionsMap, String s) {
Map<String, String> optionsMap, String s) {
return new MSKCredentialProvider.ProviderBuilder(optionsMap) {
STSAssumeRoleSessionCredentialsProvider createSTSRoleCredentialProvider(String roleArn,
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
String sessionName, String stsRegion,
AWSCredentialsProvider credentials) {
assertEquals(TEST_ROLE_ARN, roleArn);
assertEquals(s, sessionName);
return mockStsRoleProvider;
Expand Down

0 comments on commit 29b3372

Please sign in to comment.