diff --git a/plugin/build.gradle b/plugin/build.gradle index f0bad12c2d..5239dd81b4 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -108,6 +108,8 @@ configurations.all { resolutionStrategy.force "com.squareup.okhttp3:okhttp:4.9.3" resolutionStrategy.force "joda-time:joda-time:2.10.12" resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36" + resolutionStrategy.force "org.apache.httpcomponents:httpcore:4.4.15" + resolutionStrategy.force "org.apache.httpcomponents:httpclient:4.5.13" } compileJava { options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor']) diff --git a/prometheus/build.gradle b/prometheus/build.gradle index 7cf1e56085..ca70813e58 100644 --- a/prometheus/build.gradle +++ b/prometheus/build.gradle @@ -22,6 +22,8 @@ dependencies { implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}" implementation group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3' implementation 'com.github.babbel:okhttp-aws-signer:1.0.2' + implementation group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.1' + implementation group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.1' implementation group: 'org.json', name: 'json', version: '20180813' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptor.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptor.java index f3d91c55a2..56e66431fd 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptor.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptor.java @@ -7,6 +7,9 @@ package org.opensearch.sql.prometheus.authinterceptors; +import com.amazonaws.auth.AWSCredentials; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; import com.babbel.mobile.android.commons.okhttpawssigner.OkHttpAwsV4Signer; import java.io.IOException; import java.time.ZoneId; @@ -16,29 +19,29 @@ import okhttp3.Interceptor; import okhttp3.Request; import okhttp3.Response; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; public class AwsSigningInterceptor implements Interceptor { private OkHttpAwsV4Signer okHttpAwsV4Signer; - private String accessKey; + private AWSCredentialsProvider awsCredentialsProvider; - private String secretKey; + private static final Logger LOG = LogManager.getLogger(); /** * AwsSigningInterceptor which intercepts http requests * and adds required headers for sigv4 authentication. * - * @param accessKey accessKey. - * @param secretKey secretKey. + * @param awsCredentialsProvider awsCredentialsProvider. * @param region region. * @param serviceName serviceName. */ - public AwsSigningInterceptor(@NonNull String accessKey, @NonNull String secretKey, + public AwsSigningInterceptor(@NonNull AWSCredentialsProvider awsCredentialsProvider, @NonNull String region, @NonNull String serviceName) { this.okHttpAwsV4Signer = new OkHttpAwsV4Signer(region, serviceName); - this.accessKey = accessKey; - this.secretKey = secretKey; + this.awsCredentialsProvider = awsCredentialsProvider; } @Override @@ -48,11 +51,21 @@ public Response intercept(Interceptor.Chain chain) throws IOException { DateTimeFormatter timestampFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'") .withZone(ZoneId.of("GMT")); - Request newRequest = request.newBuilder() + + Request.Builder newRequestBuilder = request.newBuilder() .addHeader("x-amz-date", timestampFormat.format(ZonedDateTime.now())) - .addHeader("host", request.url().host()) - .build(); - Request signed = okHttpAwsV4Signer.sign(newRequest, accessKey, secretKey); + .addHeader("host", request.url().host()); + + AWSCredentials awsCredentials = awsCredentialsProvider.getCredentials(); + if (awsCredentialsProvider instanceof STSAssumeRoleSessionCredentialsProvider) { + newRequestBuilder.addHeader("x-amz-security-token", + ((STSAssumeRoleSessionCredentialsProvider) awsCredentialsProvider) + .getCredentials() + .getSessionToken()); + } + Request newRequest = newRequestBuilder.build(); + Request signed = okHttpAwsV4Signer.sign(newRequest, + awsCredentials.getAWSAccessKeyId(), awsCredentials.getAWSSecretKey()); return chain.proceed(signed); } diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactory.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactory.java index 4e8b30af2f..d65f315c8a 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactory.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/storage/PrometheusStorageFactory.java @@ -7,6 +7,8 @@ package org.opensearch.sql.prometheus.storage; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; import java.net.URI; import java.net.URISyntaxException; import java.util.HashSet; @@ -75,7 +77,8 @@ private OkHttpClient getHttpClient(Map config) { } else if (AuthenticationType.AWSSIGV4AUTH.equals(authenticationType)) { validateFieldsInConfig(config, Set.of(REGION, ACCESS_KEY, SECRET_KEY)); okHttpClient.addInterceptor(new AwsSigningInterceptor( - config.get(ACCESS_KEY), config.get(SECRET_KEY), + new AWSStaticCredentialsProvider( + new BasicAWSCredentials(config.get(ACCESS_KEY), config.get(SECRET_KEY))), config.get(REGION), "aps")); } else { throw new IllegalArgumentException( diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptorTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptorTest.java index a9224bf80f..5d5471edc0 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptorTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/authinterceptors/AwsSigningInterceptorTest.java @@ -10,6 +10,12 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSSessionCredentials; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; +import com.amazonaws.auth.STSSessionCredentialsProvider; import lombok.SneakyThrows; import okhttp3.Interceptor; import okhttp3.Request; @@ -30,16 +36,19 @@ public class AwsSigningInterceptorTest { @Captor ArgumentCaptor requestArgumentCaptor; + @Mock + private STSAssumeRoleSessionCredentialsProvider stsAssumeRoleSessionCredentialsProvider; + @Test void testConstructors() { Assertions.assertThrows(NullPointerException.class, () -> - new AwsSigningInterceptor(null, "secretKey", "us-east-1", "aps")); + new AwsSigningInterceptor(null, "us-east-1", "aps")); Assertions.assertThrows(NullPointerException.class, () -> - new AwsSigningInterceptor("accessKey", null, "us-east-1", "aps")); + new AwsSigningInterceptor(getStaticAWSCredentialsProvider("accessKey", "secretKey"), null, + "aps")); Assertions.assertThrows(NullPointerException.class, () -> - new AwsSigningInterceptor("accessKey", "secretKey", null, "aps")); - Assertions.assertThrows(NullPointerException.class, () -> - new AwsSigningInterceptor("accessKey", "secretKey", "us-east-1", null)); + new AwsSigningInterceptor(getStaticAWSCredentialsProvider("accessKey", "secretKey"), + "us-east-1", null)); } @Test @@ -49,7 +58,9 @@ void testIntercept() { .url("http://localhost:9090") .build()); AwsSigningInterceptor awsSigningInterceptor - = new AwsSigningInterceptor("testAccessKey", "testSecretKey", "us-east-1", "aps"); + = new AwsSigningInterceptor( + getStaticAWSCredentialsProvider("testAccessKey", "testSecretKey"), + "us-east-1", "aps"); awsSigningInterceptor.intercept(chain); verify(chain).proceed(requestArgumentCaptor.capture()); Request request = requestArgumentCaptor.getValue(); @@ -58,4 +69,51 @@ void testIntercept() { Assertions.assertNotNull(request.headers("host")); } + + @Test + @SneakyThrows + void testSTSCredentialsProviderInterceptor() { + when(chain.request()).thenReturn(new Request.Builder() + .url("http://localhost:9090") + .build()); + when(stsAssumeRoleSessionCredentialsProvider.getCredentials()) + .thenReturn(getAWSSessionCredentials()); + AwsSigningInterceptor awsSigningInterceptor + = new AwsSigningInterceptor(stsAssumeRoleSessionCredentialsProvider, + "us-east-1", "aps"); + awsSigningInterceptor.intercept(chain); + verify(chain).proceed(requestArgumentCaptor.capture()); + Request request = requestArgumentCaptor.getValue(); + Assertions.assertNotNull(request.headers("Authorization")); + Assertions.assertNotNull(request.headers("x-amz-date")); + Assertions.assertNotNull(request.headers("host")); + Assertions.assertEquals("session_token", + request.headers("x-amz-security-token").get(0)); + } + + + private AWSCredentialsProvider getStaticAWSCredentialsProvider(String accessKey, + String secretKey) { + return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey)); + } + + private AWSSessionCredentials getAWSSessionCredentials() { + return new AWSSessionCredentials() { + @Override + public String getSessionToken() { + return "session_token"; + } + + @Override + public String getAWSAccessKeyId() { + return "access_key"; + } + + @Override + public String getAWSSecretKey() { + return "secret_key"; + } + }; + } + }