Skip to content

Commit

Permalink
Add AOSS remote cluster connection configuration support (#125)
Browse files Browse the repository at this point in the history
Signed-off-by: Manasvini B S <[email protected]>
  • Loading branch information
manasvinibs authored Nov 6, 2024
1 parent 83a9bc6 commit 6c02a69
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 33 deletions.
26 changes: 14 additions & 12 deletions src/main/java/org/opensearch/jdbc/ConnectionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

package org.opensearch.jdbc;

import org.opensearch.jdbc.auth.AuthenticationType;
import org.opensearch.jdbc.config.ConnectionConfig;
import org.opensearch.jdbc.internal.JdbcWrapper;
import org.opensearch.jdbc.internal.Version;
Expand Down Expand Up @@ -83,21 +84,22 @@ log, new SQLNonTransientException("Could not initialize transport for the connec

log.debug(() -> logMessage("Initialized Transport: %s, Protocol: %s", transport, protocol));

try {
ConnectionResponse connectionResponse = this.protocol.connect(connectionConfig.getLoginTimeout() * 1000);
this.clusterMetadata = connectionResponse.getClusterMetadata();
this.open = true;
} catch (HttpException ex) {
if (ex.getStatusCode() == 401) {
logAndThrowSQLException(log, new SQLException("Connection error " + ex.getMessage(),
INCORRECT_CREDENTIALS_SQLSTATE, ex));
} else {
if (connectionConfig.getAuthenticationType() != AuthenticationType.AWS_SIGV4_SERVERLESS) {
try {
ConnectionResponse connectionResponse = this.protocol.connect(connectionConfig.getLoginTimeout() * 1000);
this.clusterMetadata = connectionResponse.getClusterMetadata();
this.open = true;
} catch (HttpException ex) {
if (ex.getStatusCode() == 401) {
logAndThrowSQLException(log, new SQLException("Connection error " + ex.getMessage(),
INCORRECT_CREDENTIALS_SQLSTATE, ex));
} else {
logAndThrowSQLException(log, new SQLException("Connection error " + ex.getMessage(), ex));
}
} catch (ResponseException | IOException ex) {
logAndThrowSQLException(log, new SQLException("Connection error " + ex.getMessage(), ex));
}
} catch (ResponseException | IOException ex) {
logAndThrowSQLException(log, new SQLException("Connection error " + ex.getMessage(), ex));
}

}

public String getUser() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,10 @@ public enum AuthenticationType {
/**
* AWS Signature V4
*/
AWS_SIGV4;
AWS_SIGV4,

/**
* AWS Signature V4 for AOSS Serverless collection
*/
AWS_SIGV4_SERVERLESS;
}
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ private void validateConfig() throws ConnectionPropertyException {
throw new ConnectionPropertyException(authConnectionProperty.getKey(),
"Basic authentication requires a valid username but none was provided.");

} else if (authenticationType == AuthenticationType.AWS_SIGV4 &&
} else if ((authenticationType == AuthenticationType.AWS_SIGV4 || authenticationType == AuthenticationType.AWS_SIGV4_SERVERLESS) &&
regionConnectionProperty.getValue() == null) {

// aws sdk auto-detection does not work for AWS ES endpoints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ public JsonHttpResponseHandler getJsonHttpResponseHandler() {
@Override
public ConnectionResponse connect(int timeout) throws ResponseException, IOException {
try (CloseableHttpResponse response = transport.doGet(
"/",
defaultEmptyRequestBodyJsonHeaders,
null, timeout)) {
"/",
defaultEmptyRequestBodyJsonHeaders,
null, timeout)) {

return jsonHttpResponseHandler.handleResponse(response, this::processConnectionResponse);

Expand All @@ -79,10 +79,10 @@ public ConnectionResponse connect(int timeout) throws ResponseException, IOExcep
@Override
public QueryResponse execute(QueryRequest request) throws ResponseException, IOException {
try (CloseableHttpResponse response = transport.doPost(
sqlContextPath,
defaultJsonHeaders,
defaultJdbcParams,
buildQueryRequestBody(request), 0)) {
sqlContextPath,
defaultJsonHeaders,
defaultJdbcParams,
buildQueryRequestBody(request), 0)) {

return jsonHttpResponseHandler.handleResponse(response, this::processQueryResponse);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

package org.opensearch.jdbc.transport.http;

import com.amazonaws.auth.AWS4Signer;
import org.opensearch.jdbc.auth.AuthenticationType;
import org.opensearch.jdbc.config.ConnectionConfig;
import org.opensearch.jdbc.logging.Logger;
import org.opensearch.jdbc.logging.LoggingSource;
import org.opensearch.jdbc.transport.TransportException;
import org.opensearch.jdbc.transport.http.auth.aws.AWSRequestSigningApacheInterceptor;
import com.amazonaws.auth.AWS4Signer;
import com.amazonaws.auth.AWS4UnsignedPayloadSigner;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import org.apache.http.Header;
Expand Down Expand Up @@ -121,8 +122,21 @@ public ApacheHttpTransport(ConnectionConfig connectionConfig, Logger log, String
signer,
provider,
connectionConfig.tunnelHost()));
}
} else if (connectionConfig.getAuthenticationType() == AuthenticationType.AWS_SIGV4_SERVERLESS) {
AWS4UnsignedPayloadSigner signer = new AWS4UnsignedPayloadSigner();
signer.setServiceName("aoss");
signer.setRegionName(connectionConfig.getRegion());

AWSCredentialsProvider provider = connectionConfig.getAwsCredentialsProvider() != null ?
connectionConfig.getAwsCredentialsProvider() : new DefaultAWSCredentialsProviderChain();

httpClientBuilder.addInterceptorLast(
new AWSRequestSigningApacheInterceptor(
"aoss",
signer,
provider,
null));
}
// TODO - can apply settings retry & backoff
this.httpClient = httpClientBuilder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
package org.opensearch.jdbc.transport.http.auth.aws;

import com.amazonaws.DefaultRequest;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.Signer;
import com.amazonaws.http.HttpMethodName;
Expand All @@ -23,13 +22,10 @@
import org.apache.http.message.BasicHeader;
import org.apache.http.protocol.HttpContext;

import java.io.IOException;
import java.io.*;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.*;

import static org.apache.http.protocol.HttpCoreContext.HTTP_TARGET_HOST;

Expand Down Expand Up @@ -113,18 +109,31 @@ public void process(final HttpRequest request, final HttpContext context)
if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
if (httpEntityEnclosingRequest.getEntity() != null) {

if (httpEntityEnclosingRequest.getEntity() == null) {
signableRequest.setContent(new ByteArrayInputStream(new byte[0]));
} else {
signableRequest.setContent(httpEntityEnclosingRequest.getEntity().getContent());
}
}

signableRequest.setParameters(nvpToMapParams(uriBuilder.getQueryParams()));
signableRequest.setHeaders(headerArrayToMap(request.getAllHeaders()));

Map<String, String> cleanedHeadersBeforeSign = headerArrayToMap(request.getAllHeaders());
signableRequest.setHeaders(cleanedHeadersBeforeSign);

// Sign it
signer.sign(signableRequest, awsCredentialsProvider.getCredentials());

// Now copy everything back
Header[] headers = request.getHeaders("content-length");
request.setHeaders(mapToHeaderArray(signableRequest.getHeaders()));
if (headers != null) {
Arrays.stream(headers)
.filter(h -> !"0".equals(h.getValue()))
.forEach(h -> request.addHeader(h));
}

if (request instanceof HttpEntityEnclosingRequest) {
HttpEntityEnclosingRequest httpEntityEnclosingRequest =
(HttpEntityEnclosingRequest) request;
Expand Down Expand Up @@ -172,8 +181,7 @@ private static Map<String, String> headerArrayToMap(final Header[] headers) {
*/
private static boolean skipHeader(final Header header) {
return ("content-length".equalsIgnoreCase(header.getName())
&& "0".equals(header.getValue())) // Strip Content-Length: 0
|| "host".equalsIgnoreCase(header.getName()); // Host comes from endpoint
|| "host".equalsIgnoreCase(header.getName())); // Host comes from endpoint
}

/**
Expand Down

0 comments on commit 6c02a69

Please sign in to comment.