Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AWSSigV4 auth to support different AWSCredentialProviders #1389

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ configurations.all {
resolutionStrategy.force "com.squareup.okhttp3:okhttp:4.9.3"
resolutionStrategy.force "joda-time:joda-time:2.10.12"
resolutionStrategy.force "org.slf4j:slf4j-api:1.7.36"
resolutionStrategy.force "org.apache.httpcomponents:httpcore:4.4.15"
resolutionStrategy.force "org.apache.httpcomponents:httpclient:4.5.13"
}
compileJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
Expand Down
2 changes: 2 additions & 0 deletions prometheus/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ dependencies {
implementation group: 'com.fasterxml.jackson.dataformat', name: 'jackson-dataformat-cbor', version: "${versions.jackson}"
implementation group: 'com.squareup.okhttp3', name: 'okhttp', version: '4.9.3'
implementation 'com.github.babbel:okhttp-aws-signer:1.0.2'
implementation group: 'com.amazonaws', name: 'aws-java-sdk-core', version: '1.12.1'
implementation group: 'com.amazonaws', name: 'aws-java-sdk-sts', version: '1.12.1'
implementation group: 'org.json', name: 'json', version: '20180813'

testImplementation('org.junit.jupiter:junit-jupiter:5.6.2')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

package org.opensearch.sql.prometheus.authinterceptors;

import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.babbel.mobile.android.commons.okhttpawssigner.OkHttpAwsV4Signer;
import java.io.IOException;
import java.time.ZoneId;
Expand All @@ -16,29 +19,29 @@
import okhttp3.Interceptor;
import okhttp3.Request;
import okhttp3.Response;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class AwsSigningInterceptor implements Interceptor {

private OkHttpAwsV4Signer okHttpAwsV4Signer;

private String accessKey;
private AWSCredentialsProvider awsCredentialsProvider;

private String secretKey;
private static final Logger LOG = LogManager.getLogger();

/**
* AwsSigningInterceptor which intercepts http requests
* and adds required headers for sigv4 authentication.
*
* @param accessKey accessKey.
* @param secretKey secretKey.
* @param awsCredentialsProvider awsCredentialsProvider.
* @param region region.
* @param serviceName serviceName.
*/
public AwsSigningInterceptor(@NonNull String accessKey, @NonNull String secretKey,
public AwsSigningInterceptor(@NonNull AWSCredentialsProvider awsCredentialsProvider,
@NonNull String region, @NonNull String serviceName) {
this.okHttpAwsV4Signer = new OkHttpAwsV4Signer(region, serviceName);
this.accessKey = accessKey;
this.secretKey = secretKey;
this.awsCredentialsProvider = awsCredentialsProvider;
}

@Override
Expand All @@ -48,11 +51,21 @@ public Response intercept(Interceptor.Chain chain) throws IOException {
DateTimeFormatter timestampFormat = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'")
.withZone(ZoneId.of("GMT"));

Request newRequest = request.newBuilder()

Request.Builder newRequestBuilder = request.newBuilder()
.addHeader("x-amz-date", timestampFormat.format(ZonedDateTime.now()))
.addHeader("host", request.url().host())
.build();
Request signed = okHttpAwsV4Signer.sign(newRequest, accessKey, secretKey);
.addHeader("host", request.url().host());

AWSCredentials awsCredentials = awsCredentialsProvider.getCredentials();
if (awsCredentialsProvider instanceof STSAssumeRoleSessionCredentialsProvider) {
newRequestBuilder.addHeader("x-amz-security-token",
((STSAssumeRoleSessionCredentialsProvider) awsCredentialsProvider)
.getCredentials()
.getSessionToken());
}
Request newRequest = newRequestBuilder.build();
Request signed = okHttpAwsV4Signer.sign(newRequest,
awsCredentials.getAWSAccessKeyId(), awsCredentials.getAWSSecretKey());
return chain.proceed(signed);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

package org.opensearch.sql.prometheus.storage;

import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.HashSet;
Expand Down Expand Up @@ -75,7 +77,8 @@ private OkHttpClient getHttpClient(Map<String, String> config) {
} else if (AuthenticationType.AWSSIGV4AUTH.equals(authenticationType)) {
validateFieldsInConfig(config, Set.of(REGION, ACCESS_KEY, SECRET_KEY));
okHttpClient.addInterceptor(new AwsSigningInterceptor(
config.get(ACCESS_KEY), config.get(SECRET_KEY),
new AWSStaticCredentialsProvider(
new BasicAWSCredentials(config.get(ACCESS_KEY), config.get(SECRET_KEY))),
config.get(REGION), "aps"));
} else {
throw new IllegalArgumentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSSessionCredentials;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.auth.STSSessionCredentialsProvider;
import lombok.SneakyThrows;
import okhttp3.Interceptor;
import okhttp3.Request;
Expand All @@ -30,16 +36,19 @@ public class AwsSigningInterceptorTest {
@Captor
ArgumentCaptor<Request> requestArgumentCaptor;

@Mock
private STSAssumeRoleSessionCredentialsProvider stsAssumeRoleSessionCredentialsProvider;

@Test
void testConstructors() {
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor(null, "secretKey", "us-east-1", "aps"));
new AwsSigningInterceptor(null, "us-east-1", "aps"));
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor("accessKey", null, "us-east-1", "aps"));
new AwsSigningInterceptor(getStaticAWSCredentialsProvider("accessKey", "secretKey"), null,
"aps"));
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor("accessKey", "secretKey", null, "aps"));
Assertions.assertThrows(NullPointerException.class, () ->
new AwsSigningInterceptor("accessKey", "secretKey", "us-east-1", null));
new AwsSigningInterceptor(getStaticAWSCredentialsProvider("accessKey", "secretKey"),
"us-east-1", null));
}

@Test
Expand All @@ -49,7 +58,9 @@ void testIntercept() {
.url("http://localhost:9090")
.build());
AwsSigningInterceptor awsSigningInterceptor
= new AwsSigningInterceptor("testAccessKey", "testSecretKey", "us-east-1", "aps");
= new AwsSigningInterceptor(
getStaticAWSCredentialsProvider("testAccessKey", "testSecretKey"),
"us-east-1", "aps");
awsSigningInterceptor.intercept(chain);
verify(chain).proceed(requestArgumentCaptor.capture());
Request request = requestArgumentCaptor.getValue();
Expand All @@ -58,4 +69,51 @@ void testIntercept() {
Assertions.assertNotNull(request.headers("host"));
}


@Test
@SneakyThrows
void testSTSCredentialsProviderInterceptor() {
when(chain.request()).thenReturn(new Request.Builder()
.url("http://localhost:9090")
.build());
when(stsAssumeRoleSessionCredentialsProvider.getCredentials())
.thenReturn(getAWSSessionCredentials());
AwsSigningInterceptor awsSigningInterceptor
= new AwsSigningInterceptor(stsAssumeRoleSessionCredentialsProvider,
"us-east-1", "aps");
awsSigningInterceptor.intercept(chain);
verify(chain).proceed(requestArgumentCaptor.capture());
Request request = requestArgumentCaptor.getValue();
Assertions.assertNotNull(request.headers("Authorization"));
Assertions.assertNotNull(request.headers("x-amz-date"));
Assertions.assertNotNull(request.headers("host"));
Assertions.assertEquals("session_token",
request.headers("x-amz-security-token").get(0));
}


private AWSCredentialsProvider getStaticAWSCredentialsProvider(String accessKey,
String secretKey) {
return new AWSStaticCredentialsProvider(new BasicAWSCredentials(accessKey, secretKey));
}

private AWSSessionCredentials getAWSSessionCredentials() {
return new AWSSessionCredentials() {
@Override
public String getSessionToken() {
return "session_token";
}

@Override
public String getAWSAccessKeyId() {
return "access_key";
}

@Override
public String getAWSSecretKey() {
return "secret_key";
}
};
}

}