Skip to content

Commit

Permalink
Fix bug, using basic instead of basicauth
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo committed Oct 23, 2023
1 parent b30d3c9 commit fbfa1be
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
public class SparkSubmitParameters {
public static final String SPACE = " ";
public static final String EQUALS = "=";
public static final String FLINT_BASIC_AUTH = "basic";

private final String className;
private final Map<String, String> config;
Expand Down Expand Up @@ -114,7 +115,7 @@ private void setFlintIndexStoreAuthProperties(
Supplier<String> password,
Supplier<String> region) {
if (AuthenticationType.get(authType).equals(AuthenticationType.BASICAUTH)) {
config.put(FLINT_INDEX_STORE_AUTH_KEY, authType);
config.put(FLINT_INDEX_STORE_AUTH_KEY, FLINT_BASIC_AUTH);
config.put(FLINT_INDEX_STORE_AUTH_USERNAME, userName.get());
config.put(FLINT_INDEX_STORE_AUTH_PASSWORD, password.get());
} else if (AuthenticationType.get(authType).equals(AuthenticationType.AWSSIGV4AUTH)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ protected JSONObject getResponseFromExecutor(AsyncQueryJobMetadata asyncQueryJob
Statement statement = getStatementByQueryId(asyncQueryJobMetadata.getSessionId(), queryId);
StatementState statementState = statement.getStatementState();
result.put(STATUS_FIELD, statementState.getState());
result.put(ERROR_FIELD, "");
result.put(ERROR_FIELD, Optional.of(statement.getStatementModel().getError()).orElse(""));
return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,31 @@ public class CreateSessionRequest {
private final String datasourceName;

public StartJobRequest getStartJobRequest() {
return new StartJobRequest(
return new InteractiveSessionStartJobRequest(
"select 1",
jobName,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
tags,
false,
resultIndex);
}

static class InteractiveSessionStartJobRequest extends StartJobRequest{
public InteractiveSessionStartJobRequest(String query, String jobName, String applicationId,
String executionRoleArn, String sparkSubmitParams,
Map<String, String> tags,
String resultIndex) {
super(query, jobName, applicationId, executionRoleArn, sparkSubmitParams, tags,
false, resultIndex);
}

/**
* Interactive query keep running.
*/
@Override
public Long executionTimeout() {
return 0L;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import static org.opensearch.sql.spark.execution.statement.StatementModel.STATEMENT_DOC_TYPE;
import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX;
import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement;
import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
import com.amazonaws.services.emrserverless.model.GetJobRunResult;
Expand Down Expand Up @@ -269,8 +270,123 @@ public void reuseSessionWhenCreateAsyncQuery() {
assertEquals(second.getQueryId(), secondModel.get().getQueryId());
}

@Test
public void batchQueryHasTimeout() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

enableSession(false);
CreateAsyncQueryResponse response =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null));

assertEquals(120L, (long) emrsClient.getJobRequest().executionTimeout());
}

@Test
public void interactiveQueryNoTimeout() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", DATASOURCE, LangType.SQL, null));
assertEquals(0L, (long) emrsClient.getJobRequest().executionTimeout());
}

@Test
public void datasourceWithBasicAuth() {
dataSourceService.createDataSource(
new DataSourceMetadata(
"mybasicauth",
DataSourceType.S3GLUE,
ImmutableList.of(),
ImmutableMap.of(
"glue.auth.type",
"iam_role",
"glue.auth.role_arn",
"arn:aws:iam::924196221507:role/FlintOpensearchServiceRole",
"glue.indexstore.opensearch.uri",
"http://ec2-18-237-133-156.us-west-2.compute.amazonaws" + ".com:9200",
"glue.indexstore.opensearch.auth",
"basicauth",
"glue.indexstore.opensearch.auth.username", "username",
"glue.indexstore.opensearch.auth.password","admin"),
null));
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("select 1", "mybasicauth", LangType.SQL, null));
String params = emrsClient.getJobRequest().getSparkSubmitParams();
assertTrue(
params.contains(
String.format("--conf spark.datasource.flint.auth=mybasicauth")));
assertTrue(
params.contains(
String.format("--conf spark.datasource.flint.auth.username=username")));
assertTrue(
params.contains(
String.format("--conf spark.datasource.flint.auth.password=password")));
}

@Test
public void withSessionCreateAsyncQueryFailed() {
LocalEMRSClient emrsClient = new LocalEMRSClient();
AsyncQueryExecutorService asyncQueryExecutorService =
createAsyncQueryExecutorService(emrsClient);

// enable session
enableSession(true);

// 1. create async query.
CreateAsyncQueryResponse response =
asyncQueryExecutorService.createAsyncQuery(
new CreateAsyncQueryRequest("myselect 1", DATASOURCE, LangType.SQL, null));
assertNotNull(response.getSessionId());
Optional<StatementModel> statementModel =
getStatement(stateStore, DATASOURCE).apply(response.getQueryId());
assertTrue(statementModel.isPresent());
assertEquals(StatementState.WAITING, statementModel.get().getStatementState());

// 2. fetch async query result. not result write to SPARK_RESPONSE_BUFFER_INDEX_NAME yet.
// mock failed statement.
StatementModel submitted = statementModel.get();
StatementModel mocked = StatementModel.builder()
.version("1.0")
.statementState(submitted.getStatementState())
.statementId(submitted.getStatementId())
.sessionId(submitted.getSessionId())
.applicationId(submitted.getApplicationId())
.jobId(submitted.getJobId())
.langType(submitted.getLangType())
.datasourceName(submitted.getDatasourceName())
.query(submitted.getQuery())
.queryId(submitted.getQueryId())
.submitTime(submitted.getSubmitTime())
.error("mock error")
.seqNo(submitted.getSeqNo())
.primaryTerm(submitted.getPrimaryTerm())
.build();
updateStatementState(stateStore, DATASOURCE).apply(mocked, StatementState.FAILED);


AsyncQueryExecutionResponse asyncQueryResults =
asyncQueryExecutorService.getAsyncQueryResults(response.getQueryId());
assertEquals(StatementState.FAILED.getState(), asyncQueryResults.getStatus());
assertEquals("mock error", asyncQueryResults.getError());
}

private DataSourceServiceImpl createDataSourceService() {
String masterKey = "1234567890";
String masterKey = "a57d991d9b573f75b9bba1df";
DataSourceMetadataStorage dataSourceMetadataStorage =
new OpenSearchDataSourceMetadataStorage(
client, clusterService, new EncryptorImpl(masterKey));
Expand Down

0 comments on commit fbfa1be

Please sign in to comment.