Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RequestContext parameter to verifyDataSourceAccessAndGetRawMetada method #2866

Merged
merged 3 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.sql.spark.asyncquery.model;

import org.opensearch.sql.datasource.RequestContext;

/** Context interface to provide additional request related information */
public interface AsyncQueryRequestContext {
Object getAttribute(String name);
}
public interface AsyncQueryRequestContext extends RequestContext {}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public DispatchQueryResponse dispatch(
AsyncQueryRequestContext asyncQueryRequestContext) {
DataSourceMetadata dataSourceMetadata =
this.dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
dispatchQueryRequest.getDatasource());
dispatchQueryRequest.getDatasource(), asyncQueryRequestContext);

if (LangType.SQL.equals(dispatchQueryRequest.getLangType())) {
String query = dispatchQueryRequest.getQuery();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,8 @@ private void givenFlintIndexMetadataExists(String indexName) {
}

private void givenValidDataSourceMetadataExist() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(DATASOURCE_NAME))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
DATASOURCE_NAME, asyncQueryRequestContext))
.thenReturn(
new DataSourceMetadata.Builder()
.setName(DATASOURCE_NAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ void testDispatchSelectQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -223,7 +224,8 @@ void testDispatchSelectQueryWithLakeFormation() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithLakeFormation();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -255,7 +257,8 @@ void testDispatchSelectQueryWithBasicAuthIndexStoreDatasource() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadataWithBasicAuth();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -278,7 +281,8 @@ void testDispatchSelectQueryCreateNewSession() {
doReturn(new StatementId(MOCK_STATEMENT_ID)).when(session).submit(any(), any());
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -304,7 +308,8 @@ void testDispatchSelectQueryReuseSession() {
when(session.getSessionModel().getJobId()).thenReturn(EMR_JOB_ID);
when(session.isOperationalForDataSource(any())).thenReturn(true);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -324,7 +329,8 @@ void testDispatchSelectQueryFailedCreateSession() {
doReturn(true).when(sessionManager).isEnabled();
doThrow(RuntimeException.class).when(sessionManager).createSession(any(), any());
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

Assertions.assertThrows(
Expand Down Expand Up @@ -358,7 +364,8 @@ void testDispatchCreateAutoRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -393,7 +400,8 @@ void testDispatchCreateManualRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -426,7 +434,8 @@ void testDispatchWithPPLQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -450,7 +459,8 @@ void testDispatchWithSparkUDFQuery() {
"CREATE TEMPORARY FUNCTION square AS 'org.apache.spark.sql.functions.expr(\"num * num\")'");
for (String query : udfQueries) {
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

IllegalArgumentException illegalArgumentException =
Expand Down Expand Up @@ -489,7 +499,8 @@ void testInvalidSQLQueryDispatchToSpark() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -532,7 +543,8 @@ void testDispatchQueryWithoutATableAndDataSourceName() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -568,7 +580,8 @@ void testDispatchIndexQueryWithoutADatasourceName() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -589,8 +602,7 @@ void testDispatchMaterializedViewQuery() {
tags.put(CLUSTER_NAME_TAG_KEY, TEST_CLUSTER_NAME);
tags.put(JOB_TYPE_TAG_KEY, JobType.STREAMING.getText());
String query =
"CREATE MATERIALIZED VIEW mv_1 AS query=select * from my_glue.default.logs WITH"
+ " (auto_refresh = true)";
"CREATE MATERIALIZED VIEW mv_1 AS select * from logs WITH" + " (auto_refresh = true)";
String sparkSubmitParameters = constructExpectedSparkSubmitParameterString(query, "streaming");
StartJobRequest expected =
new StartJobRequest(
Expand All @@ -604,7 +616,8 @@ void testDispatchMaterializedViewQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -637,7 +650,8 @@ void testDispatchShowMVQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -670,7 +684,8 @@ void testRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -703,7 +718,8 @@ void testDispatchDescribeIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand Down Expand Up @@ -739,7 +755,8 @@ void testDispatchAlterToAutoRefreshIndexQuery() {
"query_execution_result_my_glue");
when(emrServerlessClient.startJobRun(expected)).thenReturn(EMR_JOB_ID);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

DispatchQueryResponse dispatchQueryResponse =
Expand All @@ -762,7 +779,8 @@ void testDispatchAlterToManualRefreshIndexQuery() {
"ALTER INDEX elb_and_requestUri ON my_glue.default.http_logs WITH"
+ " (auto_refresh = false)";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -785,7 +803,8 @@ void testDispatchDropIndexQuery() {

String query = "DROP INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -808,7 +827,8 @@ void testDispatchVacuumIndexQuery() {

String query = "VACUUM INDEX elb_and_requestUri ON my_glue.default.http_logs";
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_glue"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_glue", asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);
when(queryHandlerFactory.getIndexDMLHandler())
.thenReturn(
Expand All @@ -824,7 +844,8 @@ void testDispatchVacuumIndexQuery() {

@Test
void testDispatchWithUnSupportedDataSourceType() {
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata("my_prometheus"))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"my_prometheus", asyncQueryRequestContext))
.thenReturn(constructPrometheusDataSourceType());
String query = "select * from my_prometheus.default.http_logs";

Expand Down Expand Up @@ -1018,7 +1039,8 @@ void testGetQueryResponseWithSuccess() {
void testDispatchQueryWithExtraSparkSubmitParameters() {
when(emrServerlessClientFactory.getClient(any())).thenReturn(emrServerlessClient);
DataSourceMetadata dataSourceMetadata = constructMyGlueDataSourceMetadata();
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(MY_GLUE))
when(dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
MY_GLUE, asyncQueryRequestContext))
.thenReturn(dataSourceMetadata);

String extraParameters = "--conf spark.dynamicAllocation.enabled=false";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ public interface DataSourceService {
* Specifically for addressing use cases in SparkQueryDispatcher.
*
* @param dataSourceName of the {@link DataSource}
* @param context request context used by the implementation. It is passed by async-query-core.
* refer {@link RequestContext}
*/
DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName);
Copy link
Member

@vamsimanohar vamsimanohar Jul 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more information to the java documentation, otherwise someone reading the code would be clueless of why we introduced the field without using. Just mention few details regarding method change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments

DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext context);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.datasource;

/**
* Context interface to provide additional request related information. It is introduced to allow
* async-query-core library user to pass request context information to implementations of data
* accessors.
*/
public interface RequestContext {
Object getAttribute(String name);
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.sql.config.TestConfig;
import org.opensearch.sql.data.type.ExprType;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceType;
Expand Down Expand Up @@ -236,7 +237,8 @@ public Boolean dataSourceExists(String dataSourceName) {
}

@Override
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) {
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext requestContext) {
return null;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.*;
import java.util.stream.Collectors;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceStatus;
Expand Down Expand Up @@ -122,7 +123,8 @@ public Boolean dataSourceExists(String dataSourceName) {
}

@Override
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(String dataSourceName) {
public DataSourceMetadata verifyDataSourceAccessAndGetRawMetadata(
String dataSourceName, RequestContext requestContext) {
DataSourceMetadata dataSourceMetadata = getRawDataSourceMetadata(dataSourceName);
verifyDataSourceAccess(dataSourceMetadata);
return dataSourceMetadata;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.datasource.DataSourceService;
import org.opensearch.sql.datasource.RequestContext;
import org.opensearch.sql.datasource.model.DataSource;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasource.model.DataSourceStatus;
Expand All @@ -52,6 +53,7 @@ class DataSourceServiceImplTest {
@Mock private DataSourceFactory dataSourceFactory;
@Mock private StorageEngine storageEngine;
@Mock private DataSourceMetadataStorage dataSourceMetadataStorage;
@Mock private RequestContext requestContext;

@Mock private DataSourceUserAuthorizationHelper dataSourceUserAuthorizationHelper;

Expand Down Expand Up @@ -461,7 +463,9 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadataWithDisabledData() {
DatasourceDisabledException datasourceDisabledException =
Assertions.assertThrows(
DatasourceDisabledException.class,
() -> dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS"));
() ->
dataSourceService.verifyDataSourceAccessAndGetRawMetadata(
"testDS", requestContext));
Assertions.assertEquals(
"Datasource testDS is disabled.", datasourceDisabledException.getMessage());
}
Expand All @@ -484,7 +488,7 @@ void testVerifyDataSourceAccessAndGetRawDataSourceMetadata() {
when(dataSourceMetadataStorage.getDataSourceMetadata("testDS"))
.thenReturn(Optional.of(dataSourceMetadata));
DataSourceMetadata dataSourceMetadata1 =
dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS");
dataSourceService.verifyDataSourceAccessAndGetRawMetadata("testDS", requestContext);
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.uri"));
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.type"));
assertTrue(dataSourceMetadata1.getProperties().containsKey("prometheus.auth.username"));
Expand Down
Loading