From 1213b5430b1bdceee6e78ff325d0c747c6ffa133 Mon Sep 17 00:00:00 2001 From: Vamsi Manohar Date: Wed, 31 Jan 2024 14:00:16 -0800 Subject: [PATCH] Fix wrong 503 error response code Signed-off-by: Vamsi Manohar --- .../rest/RestDataSourceQueryAction.java | 4 +- .../opensearch/sql/ppl/ResourceMonitorIT.java | 2 +- .../sql/legacy/plugin/RestSqlAction.java | 29 +++++++------ .../sql/legacy/plugin/RestSqlStatsAction.java | 6 +-- .../matchtoterm/TermFieldRewriter.java | 10 ++++- .../rewriter/term/TermFieldRewriterTest.java | 10 +++++ .../org/opensearch/sql/plugin/SQLPlugin.java | 2 +- .../sql/plugin/rest/RestPPLQueryAction.java | 3 +- .../sql/plugin/rest/RestPPLStatsAction.java | 6 +-- ...chAsyncQueryJobMetadataStorageService.java | 9 ++++ .../rest/RestAsyncQueryManagementAction.java | 27 ++++++++---- .../rest/model/CreateAsyncQueryRequest.java | 37 +++++++++------- .../AsyncQueryExecutorServiceSpec.java | 2 +- ...yncQueryJobMetadataStorageServiceTest.java | 43 ++++++++++++++++++- 14 files changed, 137 insertions(+), 53 deletions(-) diff --git a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java index 02f87a69f2..779a8bf772 100644 --- a/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java +++ b/datasources/src/main/java/org/opensearch/sql/datasources/rest/RestDataSourceQueryAction.java @@ -8,8 +8,8 @@ package org.opensearch.sql.datasources.rest; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.NOT_FOUND; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import static org.opensearch.rest.RestRequest.Method.*; import com.google.common.collect.ImmutableList; @@ -293,7 +293,7 @@ private void handleException(Exception e, RestChannel restChannel) { reportError(restChannel, e, BAD_REQUEST); } else { MetricUtils.incrementNumericalMetric(MetricName.DATASOURCE_FAILED_REQ_COUNT_SYS); - reportError(restChannel, e, SERVICE_UNAVAILABLE); + reportError(restChannel, e, INTERNAL_SERVER_ERROR); } } } diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java index 56b54ba748..eed2369590 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ResourceMonitorIT.java @@ -31,7 +31,7 @@ public void queryExceedResourceLimitShouldFail() throws IOException { String query = String.format("search source=%s age=20", TEST_INDEX_DOG); ResponseException exception = expectThrows(ResponseException.class, () -> executeQuery(query)); - assertEquals(503, exception.getResponse().getStatusLine().getStatusCode()); + assertEquals(500, exception.getResponse().getStatusLine().getStatusCode()); assertThat( exception.getMessage(), Matchers.containsString("resource is not enough to run the" + " query, quit.")); diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java index fc8934dd73..d75f5abc76 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlAction.java @@ -6,8 +6,8 @@ package org.opensearch.sql.legacy.plugin; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.OK; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import com.alibaba.druid.sql.parser.ParserException; import com.google.common.collect.ImmutableList; @@ -23,6 +23,7 @@ import java.util.regex.Pattern; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchSecurityException; import org.opensearch.client.Client; import org.opensearch.client.node.NodeClient; import org.opensearch.common.inject.Injector; @@ -171,21 +172,23 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli QueryAction queryAction = explainRequest(client, sqlRequest, format); executeSqlRequest(request, queryAction, client, restChannel); } catch (Exception e) { - logAndPublishMetrics(e); - reportError(restChannel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + handleException(restChannel, e); } }, - (restChannel, exception) -> { - logAndPublishMetrics(exception); - reportError( - restChannel, - exception, - isClientError(exception) ? BAD_REQUEST : SERVICE_UNAVAILABLE); - }); + this::handleException); } catch (Exception e) { - logAndPublishMetrics(e); - return channel -> - reportError(channel, e, isClientError(e) ? BAD_REQUEST : SERVICE_UNAVAILABLE); + return channel -> handleException(channel, e); + } + } + + private void handleException(RestChannel restChannel, Exception exception) { + logAndPublishMetrics(exception); + if (exception instanceof OpenSearchSecurityException) { + OpenSearchSecurityException securityException = (OpenSearchSecurityException) exception; + reportError(restChannel, exception, securityException.status()); + } else { + reportError( + restChannel, exception, isClientError(exception) ? BAD_REQUEST : INTERNAL_SERVER_ERROR); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java index bc0f3c73b8..b95974424b 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/plugin/RestSqlStatsAction.java @@ -5,7 +5,7 @@ package org.opensearch.sql.legacy.plugin; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -77,8 +77,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> channel.sendResponse( new BytesRestResponse( - SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) .toString())); } } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java index 2c837a7b2b..f9744ab841 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/rewriter/matchtoterm/TermFieldRewriter.java @@ -151,7 +151,15 @@ public void collect( indexToType.put(tableName, null); } else if (sqlExprTableSource.getExpr() instanceof SQLBinaryOpExpr) { SQLBinaryOpExpr sqlBinaryOpExpr = (SQLBinaryOpExpr) sqlExprTableSource.getExpr(); - tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName(); + SQLExpr leftSideOfExpression = sqlBinaryOpExpr.getLeft(); + if (leftSideOfExpression instanceof SQLIdentifierExpr) { + tableName = ((SQLIdentifierExpr) sqlBinaryOpExpr.getLeft()).getName(); + } else { + throw new ParserException( + "Left side of the expression [" + + leftSideOfExpression.toString() + + "] is expected to be an identifier"); + } SQLExpr rightSideOfExpression = sqlBinaryOpExpr.getRight(); // This assumes that right side of the expression is different name in query diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java index 44d3e2cbc0..7922d60647 100644 --- a/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java +++ b/legacy/src/test/java/org/opensearch/sql/legacy/unittest/rewriter/term/TermFieldRewriterTest.java @@ -10,6 +10,7 @@ import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.expr.SQLQueryExpr; +import com.alibaba.druid.sql.parser.ParserException; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -100,6 +101,15 @@ public void testSelectTheFieldWithConflictMappingShouldThrowException() { rewriteTerm(sql); } + @Test + public void testIssue2391_WithWrongBinaryOperation() { + String sql = "SELECT * from I_THINK/IM/A_URL"; + exception.expect(ParserException.class); + exception.expectMessage( + "Left side of the expression [I_THINK / IM] is expected to be an identifier"); + rewriteTerm(sql); + } + private String rewriteTerm(String sql) { SQLQueryExpr sqlQueryExpr = SqlParserUtils.parse(sql); sqlQueryExpr.accept(new TermFieldRewriter()); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index f0689a0966..890b340176 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -325,7 +325,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService( StateStore stateStore = new StateStore(client, clusterService); registerStateStoreMetrics(stateStore); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore, this.dataSourceService); EMRServerlessClient emrServerlessClient = createEMRServerlessClient(sparkExecutionEngineConfig.getRegion()); JobExecutionResponseReader jobExecutionResponseReader = new JobExecutionResponseReader(client); diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java index d35962be91..4dda11718f 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLQueryAction.java @@ -8,7 +8,6 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.OK; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -130,7 +129,7 @@ public void onFailure(Exception e) { Metrics.getInstance() .getNumericalMetric(MetricName.PPL_FAILED_REQ_COUNT_SYS) .increment(); - reportError(channel, e, SERVICE_UNAVAILABLE); + reportError(channel, e, INTERNAL_SERVER_ERROR); } } } diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java index 7a51fc282b..e7edd66a54 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/rest/RestPPLStatsAction.java @@ -5,7 +5,7 @@ package org.opensearch.sql.plugin.rest; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import com.google.common.collect.ImmutableList; import java.util.Arrays; @@ -75,8 +75,8 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli return channel -> channel.sendResponse( new BytesRestResponse( - SERVICE_UNAVAILABLE, - ErrorMessageFactory.createErrorMessage(e, SERVICE_UNAVAILABLE.getStatus()) + INTERNAL_SERVER_ERROR, + ErrorMessageFactory.createErrorMessage(e, INTERNAL_SERVER_ERROR.getStatus()) .toString())); } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java index 6de8c35f03..af35e63543 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java +++ b/spark/src/main/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryJobMetadataStorageService.java @@ -11,6 +11,9 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.sql.datasource.DataSourceService; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -22,6 +25,8 @@ public class OpensearchAsyncQueryJobMetadataStorageService private final StateStore stateStore; + private final DataSourceService dataSourceService; + @Override public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { AsyncQueryId queryId = asyncQueryJobMetadata.getQueryId(); @@ -31,6 +36,10 @@ public void storeJobMetadata(AsyncQueryJobMetadata asyncQueryJobMetadata) { @Override public Optional getJobMetadata(String qid) { AsyncQueryId queryId = new AsyncQueryId(qid); + if (!dataSourceService.dataSourceExists(queryId.getDataSourceName()) + || StringUtils.isEmpty(queryId.getId())) { + throw new AsyncQueryNotFoundException(String.format("Invalid queryId: %s", qid)); + } return StateStore.getJobMetaData(stateStore, queryId.getDataSourceName()) .apply(queryId.docId()); } diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java index ae4adc6de9..90d5d73696 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/RestAsyncQueryManagementAction.java @@ -6,7 +6,7 @@ package org.opensearch.sql.spark.rest; import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; -import static org.opensearch.core.rest.RestStatus.SERVICE_UNAVAILABLE; +import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.rest.RestStatus.TOO_MANY_REQUESTS; import static org.opensearch.rest.RestRequest.Method.DELETE; import static org.opensearch.rest.RestRequest.Method.GET; @@ -26,10 +26,12 @@ import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestRequest; +import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException; import org.opensearch.sql.datasources.exceptions.ErrorMessage; import org.opensearch.sql.datasources.utils.Scheduler; import org.opensearch.sql.legacy.metrics.MetricName; import org.opensearch.sql.legacy.utils.MetricUtils; +import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException; import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.transport.TransportCancelAsyncQueryRequestAction; @@ -112,12 +114,12 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient } } - private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) - throws IOException { - MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); - CreateAsyncQueryRequest submitJobRequest = - CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); - return restChannel -> + private RestChannelConsumer executePostRequest(RestRequest restRequest, NodeClient nodeClient) { + return restChannel -> { + try { + MetricUtils.incrementNumericalMetric(MetricName.ASYNC_QUERY_CREATE_API_REQUEST_COUNT); + CreateAsyncQueryRequest submitJobRequest = + CreateAsyncQueryRequest.fromXContentParser(restRequest.contentParser()); Scheduler.schedule( nodeClient, () -> @@ -140,6 +142,10 @@ public void onFailure(Exception e) { handleException(e, restChannel, restRequest.method()); } })); + } catch (Exception e) { + handleException(e, restChannel, restRequest.method()); + } + }; } private RestChannelConsumer executeGetAsyncQueryResultRequest( @@ -187,7 +193,7 @@ private void handleException( reportError(restChannel, e, BAD_REQUEST); addCustomerErrorMetric(requestMethod); } else { - reportError(restChannel, e, SERVICE_UNAVAILABLE); + reportError(restChannel, e, INTERNAL_SERVER_ERROR); addSystemErrorMetric(requestMethod); } } @@ -227,7 +233,10 @@ private void reportError(final RestChannel channel, final Exception e, final Res } private static boolean isClientError(Exception e) { - return e instanceof IllegalArgumentException || e instanceof IllegalStateException; + return e instanceof IllegalArgumentException + || e instanceof IllegalStateException + || e instanceof DataSourceNotFoundException + || e instanceof AsyncQueryNotFoundException; } private void addSystemErrorMetric(RestRequest.Method requestMethod) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java index 6acf6bc9a8..98527b6241 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java +++ b/spark/src/main/java/org/opensearch/sql/spark/rest/model/CreateAsyncQueryRequest.java @@ -41,23 +41,28 @@ public static CreateAsyncQueryRequest fromXContentParser(XContentParser parser) LangType lang = null; String datasource = null; String sessionId = null; - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - if (fieldName.equals("query")) { - query = parser.textOrNull(); - } else if (fieldName.equals("lang")) { - String langString = parser.textOrNull(); - lang = LangType.fromString(langString); - } else if (fieldName.equals("datasource")) { - datasource = parser.textOrNull(); - } else if (fieldName.equals(SESSION_ID)) { - sessionId = parser.textOrNull(); - } else { - throw new IllegalArgumentException("Unknown field: " + fieldName); + try { + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = parser.currentName(); + parser.nextToken(); + if (fieldName.equals("query")) { + query = parser.textOrNull(); + } else if (fieldName.equals("lang")) { + String langString = parser.textOrNull(); + lang = LangType.fromString(langString); + } else if (fieldName.equals("datasource")) { + datasource = parser.textOrNull(); + } else if (fieldName.equals(SESSION_ID)) { + sessionId = parser.textOrNull(); + } else { + throw new IllegalArgumentException("Unknown field: " + fieldName); + } } + return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); + } catch (Exception e) { + throw new IllegalArgumentException( + String.format("Error while parsing the request body: %s", e.getMessage())); } - return new CreateAsyncQueryRequest(query, datasource, lang, sessionId); } } diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java index c7054dd200..9c422bf4a6 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryExecutorServiceSpec.java @@ -206,7 +206,7 @@ protected AsyncQueryExecutorService createAsyncQueryExecutorService( JobExecutionResponseReader jobExecutionResponseReader) { StateStore stateStore = new StateStore(client, clusterService); AsyncQueryJobMetadataStorageService asyncQueryJobMetadataStorageService = - new OpensearchAsyncQueryJobMetadataStorageService(stateStore); + new OpensearchAsyncQueryJobMetadataStorageService(stateStore, dataSourceService); SparkQueryDispatcher sparkQueryDispatcher = new SparkQueryDispatcher( emrServerlessClient, diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java index cf838db829..b60a843b50 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest.java @@ -5,15 +5,27 @@ package org.opensearch.sql.spark.asyncquery; +import static org.opensearch.sql.opensearch.setting.OpenSearchSettings.DATASOURCE_URI_HOSTS_DENY_LIST; import static org.opensearch.sql.spark.constants.TestConstants.EMRS_APPLICATION_ID; import static org.opensearch.sql.spark.constants.TestConstants.EMR_JOB_ID; +import com.google.common.collect.ImmutableSet; +import java.util.Collections; import java.util.Optional; import org.junit.Before; import org.junit.Test; +import org.opensearch.client.node.NodeClient; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.sql.datasources.encryptor.EncryptorImpl; +import org.opensearch.sql.datasources.glue.GlueDataSourceFactory; +import org.opensearch.sql.datasources.service.DataSourceMetadataStorage; +import org.opensearch.sql.datasources.service.DataSourceServiceImpl; +import org.opensearch.sql.datasources.storage.OpenSearchDataSourceMetadataStorage; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryId; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryJobMetadata; import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.test.OpenSearchIntegTestCase; public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest @@ -23,12 +35,41 @@ public class OpensearchAsyncQueryAsyncQueryJobMetadataStorageServiceTest private static final String MOCK_SESSION_ID = "sessionId"; private static final String MOCK_RESULT_INDEX = "resultIndex"; private OpensearchAsyncQueryJobMetadataStorageService opensearchJobMetadataStorageService; + protected ClusterService clusterService; + protected org.opensearch.sql.common.setting.Settings pluginSettings; + protected NodeClient client; + protected DataSourceServiceImpl dataSourceService; @Before public void setup() { + clusterService = clusterService(); + client = (NodeClient) cluster().client(); opensearchJobMetadataStorageService = new OpensearchAsyncQueryJobMetadataStorageService( - new StateStore(client(), clusterService())); + new StateStore(client, clusterService), dataSourceService); + client + .admin() + .cluster() + .prepareUpdateSettings() + .setTransientSettings( + Settings.builder() + .putList(DATASOURCE_URI_HOSTS_DENY_LIST.getKey(), Collections.emptyList()) + .build()) + .get(); + dataSourceService = createDataSourceService(); + } + + private DataSourceServiceImpl createDataSourceService() { + String masterKey = "a57d991d9b573f75b9bba1df"; + DataSourceMetadataStorage dataSourceMetadataStorage = + new OpenSearchDataSourceMetadataStorage( + client, clusterService, new EncryptorImpl(masterKey)); + return new DataSourceServiceImpl( + new ImmutableSet.Builder() + .add(new GlueDataSourceFactory(pluginSettings)) + .build(), + dataSourceMetadataStorage, + meta -> {}); } @Test