Skip to content

Commit

Permalink
Integration with REPL Spark job (opensearch-project#2327)
Browse files Browse the repository at this point in the history
* add InteractiveSession and SessionManager

Signed-off-by: Peng Huo <[email protected]>

* add statement

Signed-off-by: Peng Huo <[email protected]>

* add statement

Signed-off-by: Peng Huo <[email protected]>

* fix format

Signed-off-by: Peng Huo <[email protected]>

* snapshot

Signed-off-by: Peng Huo <[email protected]>

* address comments

Signed-off-by: Peng Huo <[email protected]>

* update

Signed-off-by: Peng Huo <[email protected]>

* Update REST and Transport interface

Signed-off-by: Peng Huo <[email protected]>

* Revert on transport layer

Signed-off-by: Peng Huo <[email protected]>

* format code

Signed-off-by: Peng Huo <[email protected]>

* add API doc

Signed-off-by: Peng Huo <[email protected]>

* modify api

Signed-off-by: Peng Huo <[email protected]>

* create query_execution_request index on demand

Signed-off-by: Peng Huo <[email protected]>

* add REPL spark parameters

Signed-off-by: Peng Huo <[email protected]>

* Add IT

Signed-off-by: Peng Huo <[email protected]>

* format code

Signed-off-by: Peng Huo <[email protected]>

* bind request index to datasource

Signed-off-by: Peng Huo <[email protected]>

* fix bug when fetch query result

Signed-off-by: Peng Huo <[email protected]>

* revert entrypoint class

Signed-off-by: Peng Huo <[email protected]>

* update mapping

Signed-off-by: Peng Huo <[email protected]>

---------

Signed-off-by: Peng Huo <[email protected]>
  • Loading branch information
penghuo authored Oct 20, 2023
1 parent f835112 commit 7b4156e
Showing 22 changed files with 810 additions and 157 deletions.
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@

import static org.opensearch.sql.common.setting.Settings.Key.SPARK_EXECUTION_ENGINE_CONFIG;
import static org.opensearch.sql.datasource.model.DataSourceMetadata.defaultOpenSearchDataSourceMetadata;
import static org.opensearch.sql.spark.data.constants.SparkConstants.SPARK_REQUEST_BUFFER_INDEX_NAME;

import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.services.emrserverless.AWSEMRServerless;
@@ -321,9 +320,7 @@ private AsyncQueryExecutorService createAsyncQueryExecutorService(
new FlintIndexMetadataReaderImpl(client),
client,
new SessionManager(
new StateStore(SPARK_REQUEST_BUFFER_INDEX_NAME, client),
emrServerlessClient,
pluginSettings));
new StateStore(client, clusterService), emrServerlessClient, pluginSettings));
return new AsyncQueryExecutorServiceImpl(
asyncQueryJobMetadataStorageService,
sparkQueryDispatcher,
1 change: 1 addition & 0 deletions spark/build.gradle
Original file line number Diff line number Diff line change
@@ -68,6 +68,7 @@ dependencies {
because 'allows tests to run from IDEs that bundle older version of launcher'
}
testImplementation("org.opensearch.test:framework:${opensearch_version}")
testImplementation project(':opensearch')
}

test {
Original file line number Diff line number Diff line change
@@ -12,6 +12,7 @@
import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_INDEX_STORE_OPENSEARCH_URI;
import static org.opensearch.sql.datasources.glue.GlueDataSourceFactory.GLUE_ROLE_ARN;
import static org.opensearch.sql.spark.data.constants.SparkConstants.*;
import static org.opensearch.sql.spark.execution.statestore.StateStore.DATASOURCE_TO_REQUEST_INDEX;

import java.net.URI;
import java.net.URISyntaxException;
@@ -39,7 +40,7 @@ public class SparkSubmitParameters {

public static class Builder {

private final String className;
private String className;
private final Map<String, String> config;
private String extraParameters;

@@ -70,6 +71,11 @@ public static Builder builder() {
return new Builder();
}

public Builder className(String className) {
this.className = className;
return this;
}

public Builder dataSource(DataSourceMetadata metadata) {
if (DataSourceType.S3GLUE.equals(metadata.getConnector())) {
String roleArn = metadata.getProperties().get(GLUE_ROLE_ARN);
@@ -141,6 +147,12 @@ public Builder extraParameters(String params) {
return this;
}

public Builder sessionExecution(String sessionId, String datasourceName) {
config.put(FLINT_JOB_REQUEST_INDEX, DATASOURCE_TO_REQUEST_INDEX.apply(datasourceName));
config.put(FLINT_JOB_SESSION_ID, sessionId);
return this;
}

public SparkSubmitParameters build() {
return new SparkSubmitParameters(className, config, extraParameters);
}
Original file line number Diff line number Diff line change
@@ -87,4 +87,8 @@ public class SparkConstants {
public static final String EMR_ASSUME_ROLE_CREDENTIALS_PROVIDER =
"com.amazonaws.emr.AssumeRoleAWSCredentialsProvider";
public static final String JAVA_HOME_LOCATION = "/usr/lib/jvm/java-17-amazon-corretto.x86_64/";

public static final String FLINT_JOB_REQUEST_INDEX = "spark.flint.job.requestIndex";
public static final String FLINT_JOB_SESSION_ID = "spark.flint.job.sessionId";
public static final String FLINT_SESSION_CLASS_NAME = "org.apache.spark.sql.FlintREPL";
}
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@

import static org.opensearch.sql.spark.data.constants.SparkConstants.DATA_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.ERROR_FIELD;
import static org.opensearch.sql.spark.data.constants.SparkConstants.FLINT_SESSION_CLASS_NAME;
import static org.opensearch.sql.spark.data.constants.SparkConstants.STATUS_FIELD;

import com.amazonaws.services.emrserverless.model.CancelJobRunResult;
@@ -96,12 +97,19 @@ public JSONObject getQueryResponse(AsyncQueryJobMetadata asyncQueryJobMetadata)
return DropIndexResult.fromJobId(asyncQueryJobMetadata.getJobId()).result();
}

// either empty json when the result is not available or data with status
// Fetch from Result Index
JSONObject result =
jobExecutionResponseReader.getResultFromOpensearchIndex(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());

JSONObject result;
if (asyncQueryJobMetadata.getSessionId() == null) {
// either empty json when the result is not available or data with status
// Fetch from Result Index
result =
jobExecutionResponseReader.getResultFromOpensearchIndex(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
} else {
// when session enabled, jobId in asyncQueryJobMetadata is actually queryId.
result =
jobExecutionResponseReader.getResultWithQueryId(
asyncQueryJobMetadata.getJobId(), asyncQueryJobMetadata.getResultIndex());
}
// if result index document has a status, we are gonna use the status directly; otherwise, we
// will use emr-s job status.
// That a job is successful does not mean there is no error in execution. For example, even if
@@ -230,22 +238,7 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
dataSourceUserAuthorizationHelper.authorizeDataSource(dataSourceMetadata);
String jobName = dispatchQueryRequest.getClusterName() + ":" + "non-index-query";
Map<String, String> tags = getDefaultTagsForJobSubmission(dispatchQueryRequest);
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
false,
dataSourceMetadata.getResultIndex());

if (sessionManager.isEnabled()) {
Session session;
if (dispatchQueryRequest.getSessionId() != null) {
@@ -260,7 +253,19 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
// create session if not exist
session =
sessionManager.createSession(
new CreateSessionRequest(startJobRequest, dataSourceMetadata.getName()));
new CreateSessionRequest(
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.className(FLINT_SESSION_CLASS_NAME)
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams()),
tags,
dataSourceMetadata.getResultIndex(),
dataSourceMetadata.getName()));
}
StatementId statementId =
session.submit(
@@ -272,6 +277,22 @@ private DispatchQueryResponse handleNonIndexQuery(DispatchQueryRequest dispatchQ
dataSourceMetadata.getResultIndex(),
session.getSessionId().getSessionId());
} else {
StartJobRequest startJobRequest =
new StartJobRequest(
dispatchQueryRequest.getQuery(),
jobName,
dispatchQueryRequest.getApplicationId(),
dispatchQueryRequest.getExecutionRoleARN(),
SparkSubmitParameters.Builder.builder()
.dataSource(
dataSourceService.getRawDataSourceMetadata(
dispatchQueryRequest.getDatasource()))
.extraParameters(dispatchQueryRequest.getExtraSparkSubmitParams())
.build()
.toString(),
tags,
false,
dataSourceMetadata.getResultIndex());
String jobId = emrServerlessClient.startJobRun(startJobRequest);
return new DispatchQueryResponse(jobId, false, dataSourceMetadata.getResultIndex(), null);
}
Original file line number Diff line number Diff line change
@@ -5,11 +5,30 @@

package org.opensearch.sql.spark.execution.session;

import java.util.Map;
import lombok.Data;
import org.opensearch.sql.spark.asyncquery.model.SparkSubmitParameters;
import org.opensearch.sql.spark.client.StartJobRequest;

@Data
public class CreateSessionRequest {
private final StartJobRequest startJobRequest;
private final String jobName;
private final String applicationId;
private final String executionRoleArn;
private final SparkSubmitParameters.Builder sparkSubmitParametersBuilder;
private final Map<String, String> tags;
private final String resultIndex;
private final String datasourceName;

public StartJobRequest getStartJobRequest() {
return new StartJobRequest(
"select 1",
jobName,
applicationId,
executionRoleArn,
sparkSubmitParametersBuilder.build().toString(),
tags,
false,
resultIndex);
}
}
Original file line number Diff line number Diff line change
@@ -42,13 +42,17 @@ public class InteractiveSession implements Session {
@Override
public void open(CreateSessionRequest createSessionRequest) {
try {
// append session id;
createSessionRequest
.getSparkSubmitParametersBuilder()
.sessionExecution(sessionId.getSessionId(), createSessionRequest.getDatasourceName());
String jobID = serverlessClient.startJobRun(createSessionRequest.getStartJobRequest());
String applicationId = createSessionRequest.getStartJobRequest().getApplicationId();

sessionModel =
initInteractiveSession(
applicationId, jobID, sessionId, createSessionRequest.getDatasourceName());
createSession(stateStore).apply(sessionModel);
createSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel);
} catch (VersionConflictEngineException e) {
String errorMsg = "session already exist. " + sessionId;
LOG.error(errorMsg);
@@ -59,7 +63,8 @@ public void open(CreateSessionRequest createSessionRequest) {
/** todo. StatementSweeper will delete doc. */
@Override
public void close() {
Optional<SessionModel> model = getSession(stateStore).apply(sessionModel.getId());
Optional<SessionModel> model =
getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId());
if (model.isEmpty()) {
throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId());
} else {
@@ -69,7 +74,8 @@ public void close() {

/** Submit statement. If submit successfully, Statement in waiting state. */
public StatementId submit(QueryRequest request) {
Optional<SessionModel> model = getSession(stateStore).apply(sessionModel.getId());
Optional<SessionModel> model =
getSession(stateStore, sessionModel.getDatasourceName()).apply(sessionModel.getId());
if (model.isEmpty()) {
throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId());
} else {
@@ -84,6 +90,7 @@ public StatementId submit(QueryRequest request) {
.stateStore(stateStore)
.statementId(statementId)
.langType(LangType.SQL)
.datasourceName(sessionModel.getDatasourceName())
.query(request.getQuery())
.queryId(statementId.getId())
.build();
@@ -103,7 +110,7 @@ public StatementId submit(QueryRequest request) {

@Override
public Optional<Statement> get(StatementId stID) {
return StateStore.getStatement(stateStore)
return StateStore.getStatement(stateStore, sessionModel.getDatasourceName())
.apply(stID.getId())
.map(
model ->
Original file line number Diff line number Diff line change
@@ -5,15 +5,32 @@

package org.opensearch.sql.spark.execution.session;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import lombok.Data;
import org.apache.commons.lang3.RandomStringUtils;

@Data
public class SessionId {
public static final int PREFIX_LEN = 10;

private final String sessionId;

public static SessionId newSessionId() {
return new SessionId(RandomStringUtils.randomAlphanumeric(16));
public static SessionId newSessionId(String datasourceName) {
return new SessionId(encode(datasourceName));
}

public String getDataSourceName() {
return decode(sessionId);
}

private static String decode(String sessionId) {
return new String(Base64.getDecoder().decode(sessionId)).substring(PREFIX_LEN);
}

private static String encode(String datasourceName) {
String randomId = RandomStringUtils.randomAlphanumeric(PREFIX_LEN) + datasourceName;
return Base64.getEncoder().encodeToString(randomId.getBytes(StandardCharsets.UTF_8));
}

@Override
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ public class SessionManager {
public Session createSession(CreateSessionRequest request) {
InteractiveSession session =
InteractiveSession.builder()
.sessionId(newSessionId())
.sessionId(newSessionId(request.getDatasourceName()))
.stateStore(stateStore)
.serverlessClient(emrServerlessClient)
.build();
@@ -37,7 +37,8 @@ public Session createSession(CreateSessionRequest request) {
}

public Optional<Session> getSession(SessionId sid) {
Optional<SessionModel> model = StateStore.getSession(stateStore).apply(sid.getSessionId());
Optional<SessionModel> model =
StateStore.getSession(stateStore, sid.getDataSourceName()).apply(sid.getSessionId());
if (model.isPresent()) {
InteractiveSession session =
InteractiveSession.builder()
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Getter;
@@ -32,8 +33,10 @@ public enum SessionState {
.collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));

public static SessionState fromString(String key) {
if (STATES.containsKey(key)) {
return STATES.get(key);
for (SessionState ss : SessionState.values()) {
if (ss.getSessionState().toLowerCase(Locale.ROOT).equals(key)) {
return ss;
}
}
throw new IllegalArgumentException("Invalid session state: " + key);
}
Original file line number Diff line number Diff line change
@@ -5,9 +5,7 @@

package org.opensearch.sql.spark.execution.session;

import java.util.Arrays;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.Locale;
import lombok.Getter;

@Getter
@@ -20,13 +18,11 @@ public enum SessionType {
this.sessionType = sessionType;
}

private static Map<String, SessionType> TYPES =
Arrays.stream(SessionType.values())
.collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t));

public static SessionType fromString(String key) {
if (TYPES.containsKey(key)) {
return TYPES.get(key);
for (SessionType sType : SessionType.values()) {
if (sType.getSessionType().toLowerCase(Locale.ROOT).equals(key)) {
return sType;
}
}
throw new IllegalArgumentException("Invalid session type: " + key);
}
Loading

0 comments on commit 7b4156e

Please sign in to comment.