diff --git a/spark/build.gradle b/spark/build.gradle index c2c925ecaf..d8bb08657d 100644 --- a/spark/build.gradle +++ b/spark/build.gradle @@ -119,8 +119,9 @@ jacocoTestCoverageVerification { 'org.opensearch.sql.spark.dispatcher.model.*', 'org.opensearch.sql.spark.flint.FlintIndexType', // ignore because XContext IOException - 'org.opensearch.sql.spark.execution.statestore.SessionStateStore', - 'org.opensearch.sql.spark.execution.session.SessionModel' + 'org.opensearch.sql.spark.execution.statestore.StateStore', + 'org.opensearch.sql.spark.execution.session.SessionModel', + 'org.opensearch.sql.spark.execution.statement.StatementModel' ] limit { counter = 'LINE' diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java index 620e46b9be..e33ef4245a 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/InteractiveSession.java @@ -6,6 +6,10 @@ package org.opensearch.sql.spark.execution.session; import static org.opensearch.sql.spark.execution.session.SessionModel.initInteractiveSession; +import static org.opensearch.sql.spark.execution.session.SessionState.END_STATE; +import static org.opensearch.sql.spark.execution.statement.StatementId.newStatementId; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createSession; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import java.util.Optional; import lombok.Builder; @@ -14,7 +18,11 @@ import org.apache.logging.log4j.Logger; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; /** * Interactive session. @@ -27,9 +35,8 @@ public class InteractiveSession implements Session { private static final Logger LOG = LogManager.getLogger(); private final SessionId sessionId; - private final SessionStateStore sessionStateStore; + private final StateStore stateStore; private final EMRServerlessClient serverlessClient; - private SessionModel sessionModel; @Override @@ -41,7 +48,7 @@ public void open(CreateSessionRequest createSessionRequest) { sessionModel = initInteractiveSession( applicationId, jobID, sessionId, createSessionRequest.getDatasourceName()); - sessionStateStore.create(sessionModel); + createSession(stateStore).apply(sessionModel); } catch (VersionConflictEngineException e) { String errorMsg = "session already exist. " + sessionId; LOG.error(errorMsg); @@ -49,13 +56,67 @@ public void open(CreateSessionRequest createSessionRequest) { } } + /** todo. StatementSweeper will delete doc. */ @Override public void close() { - Optional model = sessionStateStore.get(sessionModel.getSessionId()); + Optional model = getSession(stateStore).apply(sessionModel.getId()); if (model.isEmpty()) { - throw new IllegalStateException("session not exist. " + sessionModel.getSessionId()); + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); } else { serverlessClient.cancelJobRun(sessionModel.getApplicationId(), sessionModel.getJobId()); } } + + /** Submit statement. If submit successfully, Statement in waiting state. */ + public StatementId submit(QueryRequest request) { + Optional model = getSession(stateStore).apply(sessionModel.getId()); + if (model.isEmpty()) { + throw new IllegalStateException("session does not exist. " + sessionModel.getSessionId()); + } else { + sessionModel = model.get(); + if (!END_STATE.contains(sessionModel.getSessionState())) { + StatementId statementId = newStatementId(); + Statement st = + Statement.builder() + .sessionId(sessionId) + .applicationId(sessionModel.getApplicationId()) + .jobId(sessionModel.getJobId()) + .stateStore(stateStore) + .statementId(statementId) + .langType(LangType.SQL) + .query(request.getQuery()) + .queryId(statementId.getId()) + .build(); + st.open(); + return statementId; + } else { + String errMsg = + String.format( + "can't submit statement, session should not be in end state, " + + "current session state is: %s", + sessionModel.getSessionState().getSessionState()); + LOG.debug(errMsg); + throw new IllegalStateException(errMsg); + } + } + } + + @Override + public Optional get(StatementId stID) { + return StateStore.getStatement(stateStore) + .apply(stID.getId()) + .map( + model -> + Statement.builder() + .sessionId(sessionId) + .applicationId(model.getApplicationId()) + .jobId(model.getJobId()) + .statementId(model.getStatementId()) + .langType(model.getLangType()) + .query(model.getQuery()) + .queryId(model.getQueryId()) + .stateStore(stateStore) + .statementModel(model) + .build()); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java index ec9775e60a..4d919d5e2e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/Session.java @@ -5,6 +5,11 @@ package org.opensearch.sql.spark.execution.session; +import java.util.Optional; +import org.opensearch.sql.spark.execution.statement.QueryRequest; +import org.opensearch.sql.spark.execution.statement.Statement; +import org.opensearch.sql.spark.execution.statement.StatementId; + /** Session define the statement execution context. Each session is binding to one Spark Job. */ public interface Session { /** open session. */ @@ -13,6 +18,22 @@ public interface Session { /** close session. */ void close(); + /** + * submit {@link QueryRequest}. + * + * @param request {@link QueryRequest} + * @return {@link StatementId} + */ + StatementId submit(QueryRequest request); + + /** + * get {@link Statement}. + * + * @param stID {@link StatementId} + * @return {@link Statement} + */ + Optional get(StatementId stID); + SessionModel getSessionModel(); SessionId getSessionId(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java index 3d0916bac8..217af80caf 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionManager.java @@ -10,7 +10,7 @@ import java.util.Optional; import lombok.RequiredArgsConstructor; import org.opensearch.sql.spark.client.EMRServerlessClient; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; /** * Singleton Class @@ -19,14 +19,14 @@ */ @RequiredArgsConstructor public class SessionManager { - private final SessionStateStore stateStore; + private final StateStore stateStore; private final EMRServerlessClient emrServerlessClient; public Session createSession(CreateSessionRequest request) { InteractiveSession session = InteractiveSession.builder() .sessionId(newSessionId()) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrServerlessClient) .build(); session.open(request); @@ -34,12 +34,12 @@ public Session createSession(CreateSessionRequest request) { } public Optional getSession(SessionId sid) { - Optional model = stateStore.get(sid); + Optional model = StateStore.getSession(stateStore).apply(sid.getSessionId()); if (model.isPresent()) { InteractiveSession session = InteractiveSession.builder() .sessionId(sid) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrServerlessClient) .sessionModel(model.get()) .build(); diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java index 656f0ec8ce..806cdb083e 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionModel.java @@ -12,16 +12,16 @@ import lombok.Builder; import lombok.Data; import lombok.SneakyThrows; -import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.statestore.StateModel; /** Session data in flint.ql.sessions index. */ @Data @Builder -public class SessionModel implements ToXContentObject { +public class SessionModel extends StateModel { public static final String VERSION = "version"; public static final String TYPE = "type"; public static final String SESSION_TYPE = "sessionType"; @@ -73,6 +73,27 @@ public static SessionModel of(SessionModel copy, long seqNo, long primaryTerm) { .sessionId(new SessionId(copy.sessionId.getSessionId())) .sessionState(copy.sessionState) .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static SessionModel copyWithState( + SessionModel copy, SessionState state, long seqNo, long primaryTerm) { + return builder() + .version(copy.version) + .sessionType(copy.sessionType) + .sessionId(new SessionId(copy.sessionId.getSessionId())) + .sessionState(state) + .datasourceName(copy.datasourceName) + .applicationId(copy.getApplicationId()) + .jobId(copy.jobId) + .error(UNKNOWN) + .lastUpdateTime(copy.getLastUpdateTime()) .seqNo(seqNo) .primaryTerm(primaryTerm) .build(); @@ -140,4 +161,9 @@ public static SessionModel initInteractiveSession( .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) .build(); } + + @Override + public String getId() { + return sessionId.getSessionId(); + } } diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java index 509d5105e9..a4da957f12 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/session/SessionState.java @@ -5,7 +5,9 @@ package org.opensearch.sql.spark.execution.session; +import com.google.common.collect.ImmutableList; import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.stream.Collectors; import lombok.Getter; @@ -17,6 +19,8 @@ public enum SessionState { DEAD("dead"), FAIL("fail"); + public static List END_STATE = ImmutableList.of(DEAD, FAIL); + private final String sessionState; SessionState(String sessionState) { diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java new file mode 100644 index 0000000000..10061404ca --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/QueryRequest.java @@ -0,0 +1,15 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.opensearch.sql.spark.rest.model.LangType; + +@Data +public class QueryRequest { + private final LangType langType; + private final String query; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java new file mode 100644 index 0000000000..8fcedb5fca --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/Statement.java @@ -0,0 +1,85 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementModel.submitStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.createStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import lombok.Builder; +import lombok.Getter; +import lombok.Setter; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.index.engine.DocumentMissingException; +import org.opensearch.index.engine.VersionConflictEngineException; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement represent query to execute in session. One statement map to one session. */ +@Getter +@Builder +public class Statement { + private static final Logger LOG = LogManager.getLogger(); + + private final SessionId sessionId; + private final String applicationId; + private final String jobId; + private final StatementId statementId; + private final LangType langType; + private final String query; + private final String queryId; + private final StateStore stateStore; + + @Setter private StatementModel statementModel; + + /** Open a statement. */ + public void open() { + try { + statementModel = + submitStatement(sessionId, applicationId, jobId, statementId, langType, query, queryId); + statementModel = createStatement(stateStore).apply(statementModel); + } catch (VersionConflictEngineException e) { + String errorMsg = "statement already exist. " + statementId; + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + /** Cancel a statement. */ + public void cancel() { + if (statementModel.getStatementState().equals(StatementState.RUNNING)) { + String errorMsg = + String.format("can't cancel statement in waiting state. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + try { + this.statementModel = + updateStatementState(stateStore).apply(this.statementModel, StatementState.CANCELLED); + } catch (DocumentMissingException e) { + String errorMsg = + String.format("cancel statement failed. no statement found. statement: %s.", statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } catch (VersionConflictEngineException e) { + this.statementModel = + getStatement(stateStore).apply(statementModel.getId()).orElse(this.statementModel); + String errorMsg = + String.format( + "cancel statement failed. current statementState: %s " + "statement: %s.", + this.statementModel.getStatementState(), statementId); + LOG.error(errorMsg); + throw new IllegalStateException(errorMsg); + } + } + + public StatementState getStatementState() { + return statementModel.getStatementState(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java new file mode 100644 index 0000000000..4baff71493 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementId.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import lombok.Data; +import org.apache.commons.lang3.RandomStringUtils; + +@Data +public class StatementId { + private final String id; + + public static StatementId newStatementId() { + return new StatementId(RandomStringUtils.random(10, true, true)); + } + + @Override + public String toString() { + return "statementId=" + id; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java new file mode 100644 index 0000000000..c7f681c541 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementModel.java @@ -0,0 +1,194 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.session.SessionModel.APPLICATION_ID; +import static org.opensearch.sql.spark.execution.session.SessionModel.JOB_ID; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; + +import java.io.IOException; +import lombok.Builder; +import lombok.Data; +import lombok.SneakyThrows; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.statestore.StateModel; +import org.opensearch.sql.spark.rest.model.LangType; + +/** Statement data in flint.ql.sessions index. */ +@Data +@Builder +public class StatementModel extends StateModel { + public static final String VERSION = "version"; + public static final String TYPE = "type"; + public static final String STATEMENT_STATE = "state"; + public static final String STATEMENT_ID = "statementId"; + public static final String SESSION_ID = "sessionId"; + public static final String LANG = "lang"; + public static final String QUERY = "query"; + public static final String QUERY_ID = "queryId"; + public static final String SUBMIT_TIME = "submitTime"; + public static final String ERROR = "error"; + public static final String UNKNOWN = "unknown"; + public static final String STATEMENT_DOC_TYPE = "statement"; + + private final String version; + private final StatementState statementState; + private final StatementId statementId; + private final SessionId sessionId; + private final String applicationId; + private final String jobId; + private final LangType langType; + private final String query; + private final String queryId; + private final long submitTime; + private final String error; + + private final long seqNo; + private final long primaryTerm; + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder + .startObject() + .field(VERSION, version) + .field(TYPE, STATEMENT_DOC_TYPE) + .field(STATEMENT_STATE, statementState.getState()) + .field(STATEMENT_ID, statementId.getId()) + .field(SESSION_ID, sessionId.getSessionId()) + .field(APPLICATION_ID, applicationId) + .field(JOB_ID, jobId) + .field(LANG, langType.getText()) + .field(QUERY, query) + .field(QUERY_ID, queryId) + .field(SUBMIT_TIME, submitTime) + .field(ERROR, error) + .endObject(); + return builder; + } + + public static StatementModel copy(StatementModel copy, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(copy.statementState) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .langType(copy.langType) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + public static StatementModel copyWithState( + StatementModel copy, StatementState state, long seqNo, long primaryTerm) { + return builder() + .version("1.0") + .statementState(state) + .statementId(copy.statementId) + .sessionId(copy.sessionId) + .applicationId(copy.applicationId) + .jobId(copy.jobId) + .langType(copy.langType) + .query(copy.query) + .queryId(copy.queryId) + .submitTime(copy.submitTime) + .error(copy.error) + .seqNo(seqNo) + .primaryTerm(primaryTerm) + .build(); + } + + @SneakyThrows + public static StatementModel fromXContent(XContentParser parser, long seqNo, long primaryTerm) { + StatementModel.StatementModelBuilder builder = StatementModel.builder(); + XContentParserUtils.ensureExpectedToken( + XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + while (!XContentParser.Token.END_OBJECT.equals(parser.nextToken())) { + String fieldName = parser.currentName(); + parser.nextToken(); + switch (fieldName) { + case VERSION: + builder.version(parser.text()); + break; + case TYPE: + // do nothing + break; + case STATEMENT_STATE: + builder.statementState(StatementState.fromString(parser.text())); + break; + case STATEMENT_ID: + builder.statementId(new StatementId(parser.text())); + break; + case SESSION_ID: + builder.sessionId(new SessionId(parser.text())); + break; + case APPLICATION_ID: + builder.applicationId(parser.text()); + break; + case JOB_ID: + builder.jobId(parser.text()); + break; + case LANG: + builder.langType(LangType.fromString(parser.text())); + break; + case QUERY: + builder.query(parser.text()); + break; + case QUERY_ID: + builder.queryId(parser.text()); + break; + case SUBMIT_TIME: + builder.submitTime(parser.longValue()); + break; + case ERROR: + builder.error(parser.text()); + break; + } + } + builder.seqNo(seqNo); + builder.primaryTerm(primaryTerm); + return builder.build(); + } + + public static StatementModel submitStatement( + SessionId sid, + String applicationId, + String jobId, + StatementId statementId, + LangType langType, + String query, + String queryId) { + return builder() + .version("1.0") + .statementState(WAITING) + .statementId(statementId) + .sessionId(sid) + .applicationId(applicationId) + .jobId(jobId) + .langType(langType) + .query(query) + .queryId(queryId) + .submitTime(System.currentTimeMillis()) + .error(UNKNOWN) + .seqNo(SequenceNumbers.UNASSIGNED_SEQ_NO) + .primaryTerm(SequenceNumbers.UNASSIGNED_PRIMARY_TERM) + .build(); + } + + @Override + public String getId() { + return statementId.getId(); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java new file mode 100644 index 0000000000..33f7f5e831 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statement/StatementState.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import java.util.Arrays; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.Getter; + +/** {@link Statement} State. */ +@Getter +public enum StatementState { + WAITING("waiting"), + RUNNING("running"), + SUCCESS("success"), + FAILED("failed"), + CANCELLED("cancelled"); + + private final String state; + + StatementState(String state) { + this.state = state; + } + + private static Map STATES = + Arrays.stream(StatementState.values()) + .collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + public static StatementState fromString(String key) { + if (STATES.containsKey(key)) { + return STATES.get(key); + } + throw new IllegalArgumentException("Invalid statement state: " + key); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java deleted file mode 100644 index 6ddce55360..0000000000 --- a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/SessionStateStore.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import java.io.IOException; -import java.util.Locale; -import java.util.Optional; -import lombok.RequiredArgsConstructor; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.get.GetResponse; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.action.support.WriteRequest; -import org.opensearch.client.Client; -import org.opensearch.common.xcontent.LoggingDeprecationHandler; -import org.opensearch.common.xcontent.XContentFactory; -import org.opensearch.common.xcontent.XContentType; -import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.core.xcontent.ToXContent; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionModel; - -@RequiredArgsConstructor -public class SessionStateStore { - private static final Logger LOG = LogManager.getLogger(); - - private final String indexName; - private final Client client; - - public SessionModel create(SessionModel session) { - try { - IndexRequest indexRequest = - new IndexRequest(indexName) - .id(session.getSessionId().getSessionId()) - .source(session.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) - .setIfSeqNo(session.getSeqNo()) - .setIfPrimaryTerm(session.getPrimaryTerm()) - .create(true) - .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); - IndexResponse indexResponse = client.index(indexRequest).actionGet(); - if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { - LOG.debug("Successfully created doc. id: {}", session.getSessionId()); - return SessionModel.of(session, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); - } else { - throw new RuntimeException( - String.format( - Locale.ROOT, - "Failed create doc. id: %s, error: %s", - session.getSessionId(), - indexResponse.getResult().getLowercase())); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public Optional get(SessionId sid) { - try { - GetRequest getRequest = new GetRequest().index(indexName).id(sid.getSessionId()); - GetResponse getResponse = client.get(getRequest).actionGet(); - if (getResponse.isExists()) { - XContentParser parser = - XContentType.JSON - .xContent() - .createParser( - NamedXContentRegistry.EMPTY, - LoggingDeprecationHandler.INSTANCE, - getResponse.getSourceAsString()); - parser.nextToken(); - return Optional.of( - SessionModel.fromXContent( - parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); - } else { - return Optional.empty(); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java new file mode 100644 index 0000000000..b5bf31a6ba --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateModel.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentParser; + +public abstract class StateModel implements ToXContentObject { + + public abstract String getId(); + + public abstract long getSeqNo(); + + public abstract long getPrimaryTerm(); + + public interface CopyBuilder { + T of(T copy, long seqNo, long primaryTerm); + } + + public interface StateCopyBuilder { + T of(T copy, S state, long seqNo, long primaryTerm); + } + + public interface FromXContent { + T fromXContent(XContentParser parser, long seqNo, long primaryTerm); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java new file mode 100644 index 0000000000..bd72b17353 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/execution/statestore/StateStore.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statestore; + +import java.io.IOException; +import java.util.Locale; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.function.Function; +import lombok.RequiredArgsConstructor; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.update.UpdateRequest; +import org.opensearch.action.update.UpdateResponse; +import org.opensearch.client.Client; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.sql.spark.execution.session.SessionModel; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statement.StatementModel; +import org.opensearch.sql.spark.execution.statement.StatementState; + +@RequiredArgsConstructor +public class StateStore { + private static final Logger LOG = LogManager.getLogger(); + + private final String indexName; + private final Client client; + + protected T create(T st, StateModel.CopyBuilder builder) { + try { + IndexRequest indexRequest = + new IndexRequest(indexName) + .id(st.getId()) + .source(st.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .setIfSeqNo(st.getSeqNo()) + .setIfPrimaryTerm(st.getPrimaryTerm()) + .create(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + IndexResponse indexResponse = client.index(indexRequest).actionGet(); + if (indexResponse.getResult().equals(DocWriteResponse.Result.CREATED)) { + LOG.debug("Successfully created doc. id: {}", st.getId()); + return builder.of(st, indexResponse.getSeqNo(), indexResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed create doc. id: %s, error: %s", + st.getId(), + indexResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected Optional get(String sid, StateModel.FromXContent builder) { + try { + GetRequest getRequest = new GetRequest().index(indexName).id(sid); + GetResponse getResponse = client.get(getRequest).actionGet(); + if (getResponse.isExists()) { + XContentParser parser = + XContentType.JSON + .xContent() + .createParser( + NamedXContentRegistry.EMPTY, + LoggingDeprecationHandler.INSTANCE, + getResponse.getSourceAsString()); + parser.nextToken(); + return Optional.of( + builder.fromXContent(parser, getResponse.getSeqNo(), getResponse.getPrimaryTerm())); + } else { + return Optional.empty(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + protected T updateState( + T st, S state, StateModel.StateCopyBuilder builder) { + try { + T model = builder.of(st, state, st.getSeqNo(), st.getPrimaryTerm()); + UpdateRequest updateRequest = + new UpdateRequest() + .index(indexName) + .id(model.getId()) + .setIfSeqNo(model.getSeqNo()) + .setIfPrimaryTerm(model.getPrimaryTerm()) + .doc(model.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS)) + .fetchSource(true) + .setRefreshPolicy(WriteRequest.RefreshPolicy.WAIT_UNTIL); + UpdateResponse updateResponse = client.update(updateRequest).actionGet(); + if (updateResponse.getResult().equals(DocWriteResponse.Result.UPDATED)) { + LOG.debug("Successfully update doc. id: {}", st.getId()); + return builder.of(model, state, updateResponse.getSeqNo(), updateResponse.getPrimaryTerm()); + } else { + throw new RuntimeException( + String.format( + Locale.ROOT, + "Failed update doc. id: %s, error: %s", + st.getId(), + updateResponse.getResult().getLowercase())); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** Helper Functions */ + public static Function createStatement(StateStore stateStore) { + return (st) -> stateStore.create(st, StatementModel::copy); + } + + public static Function> getStatement(StateStore stateStore) { + return (docId) -> stateStore.get(docId, StatementModel::fromXContent); + } + + public static BiFunction updateStatementState( + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, StatementModel::copyWithState); + } + + public static Function createSession(StateStore stateStore) { + return (session) -> stateStore.create(session, SessionModel::of); + } + + public static Function> getSession(StateStore stateStore) { + return (docId) -> stateStore.get(docId, SessionModel::fromXContent); + } + + public static BiFunction updateSessionState( + StateStore stateStore) { + return (old, state) -> stateStore.updateState(old, state, SessionModel::copyWithState); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java index 53dc211ded..488252d05a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/InteractiveSessionTest.java @@ -7,6 +7,7 @@ import static org.opensearch.sql.spark.execution.session.InteractiveSessionTest.TestSession.testSession; import static org.opensearch.sql.spark.execution.session.SessionState.NOT_STARTED; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getSession; import com.amazonaws.services.emrserverless.model.CancelJobRunResult; import com.amazonaws.services.emrserverless.model.GetJobRunResult; @@ -20,7 +21,7 @@ import org.opensearch.action.delete.DeleteRequest; import org.opensearch.sql.spark.client.EMRServerlessClient; import org.opensearch.sql.spark.client.StartJobRequest; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; /** mock-maker-inline does not work with OpenSearchTestCase. */ @@ -30,13 +31,13 @@ public class InteractiveSessionTest extends OpenSearchSingleNodeTestCase { private TestEMRServerlessClient emrsClient; private StartJobRequest startJobRequest; - private SessionStateStore stateStore; + private StateStore stateStore; @Before public void setup() { emrsClient = new TestEMRServerlessClient(); startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); - stateStore = new SessionStateStore(indexName, client()); + stateStore = new StateStore(indexName, client()); createIndex(indexName); } @@ -50,7 +51,7 @@ public void openCloseSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(SessionId.newSessionId()) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); @@ -74,7 +75,7 @@ public void openSessionFailedConflict() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); session.open(new CreateSessionRequest(startJobRequest, "datasource")); @@ -82,7 +83,7 @@ public void openSessionFailedConflict() { InteractiveSession duplicateSession = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); IllegalStateException exception = @@ -98,15 +99,15 @@ public void closeNotExistSession() { InteractiveSession session = InteractiveSession.builder() .sessionId(sessionId) - .sessionStateStore(stateStore) + .stateStore(stateStore) .serverlessClient(emrsClient) .build(); session.open(new CreateSessionRequest(startJobRequest, "datasource")); - client().delete(new DeleteRequest(indexName, sessionId.getSessionId())); + client().delete(new DeleteRequest(indexName, sessionId.getSessionId())).actionGet(); IllegalStateException exception = assertThrows(IllegalStateException.class, session::close); - assertEquals("session not exist. " + sessionId, exception.getMessage()); + assertEquals("session does not exist. " + sessionId, exception.getMessage()); emrsClient.cancelJobRunCalled(0); } @@ -142,9 +143,9 @@ public void sessionManagerGetSessionNotExist() { @RequiredArgsConstructor static class TestSession { private final Session session; - private final SessionStateStore stateStore; + private final StateStore stateStore; - public static TestSession testSession(Session session, SessionStateStore stateStore) { + public static TestSession testSession(Session session, StateStore stateStore) { return new TestSession(session, stateStore); } @@ -152,7 +153,7 @@ public TestSession assertSessionState(SessionState expected) { assertEquals(expected, session.getSessionModel().getSessionState()); Optional sessionStoreState = - stateStore.get(session.getSessionModel().getSessionId()); + getSession(stateStore).apply(session.getSessionModel().getId()); assertTrue(sessionStoreState.isPresent()); assertEquals(expected, sessionStoreState.get().getSessionState()); @@ -180,7 +181,7 @@ public TestSession close() { } } - static class TestEMRServerlessClient implements EMRServerlessClient { + public static class TestEMRServerlessClient implements EMRServerlessClient { private int startJobRunCalled = 0; private int cancelJobRunCalled = 0; diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java index d35105f787..95b85613be 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/session/SessionManagerTest.java @@ -5,29 +5,20 @@ package org.opensearch.sql.spark.execution.session; -import static org.junit.jupiter.api.Assertions.*; - import org.junit.After; import org.junit.Before; -import org.mockito.MockMakers; -import org.mockito.MockSettings; -import org.mockito.Mockito; import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; -import org.opensearch.sql.spark.execution.statestore.SessionStateStore; +import org.opensearch.sql.spark.execution.statestore.StateStore; import org.opensearch.test.OpenSearchSingleNodeTestCase; class SessionManagerTest extends OpenSearchSingleNodeTestCase { private static final String indexName = "mockindex"; - // mock-maker-inline does not work with OpenSearchTestCase. make sure use mockSettings when mock. - private static final MockSettings mockSettings = - Mockito.withSettings().mockMaker(MockMakers.SUBCLASS); - - private SessionStateStore stateStore; + private StateStore stateStore; @Before public void setup() { - stateStore = new SessionStateStore(indexName, client()); + stateStore = new StateStore(indexName, client()); createIndex(indexName); } diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java new file mode 100644 index 0000000000..b7af1123ba --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementStateTest.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.junit.Assert.assertThrows; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class StatementStateTest { + @Test + public void invalidStatementState() { + IllegalArgumentException exception = + assertThrows(IllegalArgumentException.class, () -> StatementState.fromString("invalid")); + Assertions.assertEquals("Invalid statement state: invalid", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java new file mode 100644 index 0000000000..331955e14e --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/execution/statement/StatementTest.java @@ -0,0 +1,356 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.execution.statement; + +import static org.opensearch.sql.spark.execution.statement.StatementState.CANCELLED; +import static org.opensearch.sql.spark.execution.statement.StatementState.WAITING; +import static org.opensearch.sql.spark.execution.statement.StatementTest.TestStatement.testStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.getStatement; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateSessionState; +import static org.opensearch.sql.spark.execution.statestore.StateStore.updateStatementState; + +import java.util.HashMap; +import java.util.Optional; +import lombok.RequiredArgsConstructor; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.action.admin.indices.delete.DeleteIndexRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.sql.spark.client.StartJobRequest; +import org.opensearch.sql.spark.execution.session.CreateSessionRequest; +import org.opensearch.sql.spark.execution.session.InteractiveSessionTest; +import org.opensearch.sql.spark.execution.session.Session; +import org.opensearch.sql.spark.execution.session.SessionId; +import org.opensearch.sql.spark.execution.session.SessionManager; +import org.opensearch.sql.spark.execution.session.SessionState; +import org.opensearch.sql.spark.execution.statestore.StateStore; +import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.test.OpenSearchSingleNodeTestCase; + +public class StatementTest extends OpenSearchSingleNodeTestCase { + + private static final String indexName = "mockindex"; + + private StartJobRequest startJobRequest; + private StateStore stateStore; + private InteractiveSessionTest.TestEMRServerlessClient emrsClient = + new InteractiveSessionTest.TestEMRServerlessClient(); + + @Before + public void setup() { + startJobRequest = new StartJobRequest("", "", "appId", "", "", new HashMap<>(), false, ""); + stateStore = new StateStore(indexName, client()); + createIndex(indexName); + } + + @After + public void clean() { + client().admin().indices().delete(new DeleteIndexRequest(indexName)).actionGet(); + } + + @Test + public void openThenCancelStatement() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + + // submit statement + TestStatement testStatement = testStatement(st, stateStore); + testStatement + .open() + .assertSessionState(WAITING) + .assertStatementId(new StatementId("statementId")); + + // close statement + testStatement.cancel().assertSessionState(CANCELLED); + } + + @Test + public void openFailedBecauseConflict() { + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // open statement with same statement id + Statement dupSt = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(new StatementId("statementId")) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + IllegalStateException exception = assertThrows(IllegalStateException.class, dupSt::open); + assertEquals("statement already exist. statementId=statementId", exception.getMessage()); + } + + @Test + public void cancelNotExistStatement() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + client().delete(new DeleteRequest(indexName, stId.getId())); + + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("cancel statement failed. no statement found. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelFailedBecauseOfConflict() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + StatementModel running = + updateStatementState(stateStore).apply(st.getStatementModel(), CANCELLED); + + assertEquals(StatementState.CANCELLED, running.getStatementState()); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format( + "cancel statement failed. current statementState: CANCELLED " + "statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void cancelRunningStatementFailed() { + StatementId stId = new StatementId("statementId"); + Statement st = + Statement.builder() + .sessionId(new SessionId("sessionId")) + .applicationId("appId") + .jobId("jobId") + .statementId(stId) + .langType(LangType.SQL) + .query("query") + .queryId("statementId") + .stateStore(stateStore) + .build(); + st.open(); + + // update to running state + StatementModel model = st.getStatementModel(); + st.setStatementModel( + StatementModel.copyWithState( + st.getStatementModel(), + StatementState.RUNNING, + model.getSeqNo(), + model.getPrimaryTerm())); + + // cancel conflict + IllegalStateException exception = assertThrows(IllegalStateException.class, st::cancel); + assertEquals( + String.format("can't cancel statement in waiting state. statement: %s.", stId), + exception.getMessage()); + } + + @Test + public void submitStatementInRunningSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void submitStatementInNotStartedState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + assertFalse(statementId.getId().isEmpty()); + } + + @Test + public void failToSubmitStatementInDeadState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.DEAD); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " dead", + exception.getMessage()); + } + + @Test + public void failToSubmitStatementInFailState() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.FAIL); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals( + "can't submit statement, session should not be in end state, current session state is:" + + " fail", + exception.getMessage()); + } + + @Test + public void newStatementFieldAssert() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + Optional statement = session.get(statementId); + + assertTrue(statement.isPresent()); + assertEquals(session.getSessionId(), statement.get().getSessionId()); + assertEquals("appId", statement.get().getApplicationId()); + assertEquals("jobId", statement.get().getJobId()); + assertEquals(statementId, statement.get().getStatementId()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(LangType.SQL, statement.get().getLangType()); + assertEquals("select 1", statement.get().getQuery()); + } + + @Test + public void failToSubmitStatementInDeletedSession() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + + // other's delete session + client() + .delete(new DeleteRequest(indexName, session.getSessionId().getSessionId())) + .actionGet(); + + IllegalStateException exception = + assertThrows( + IllegalStateException.class, + () -> session.submit(new QueryRequest(LangType.SQL, "select 1"))); + assertEquals("session does not exist. " + session.getSessionId(), exception.getMessage()); + } + + @Test + public void getStatementSuccess() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + StatementId statementId = session.submit(new QueryRequest(LangType.SQL, "select 1")); + + Optional statement = session.get(statementId); + assertTrue(statement.isPresent()); + assertEquals(WAITING, statement.get().getStatementState()); + assertEquals(statementId, statement.get().getStatementId()); + } + + @Test + public void getStatementNotExist() { + Session session = + new SessionManager(stateStore, emrsClient) + .createSession(new CreateSessionRequest(startJobRequest, "datasource")); + // App change state to running + updateSessionState(stateStore).apply(session.getSessionModel(), SessionState.RUNNING); + + Optional statement = session.get(StatementId.newStatementId()); + assertFalse(statement.isPresent()); + } + + @RequiredArgsConstructor + static class TestStatement { + private final Statement st; + private final StateStore stateStore; + + public static TestStatement testStatement(Statement st, StateStore stateStore) { + return new TestStatement(st, stateStore); + } + + public TestStatement assertSessionState(StatementState expected) { + assertEquals(expected, st.getStatementModel().getStatementState()); + + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementState()); + + return this; + } + + public TestStatement assertStatementId(StatementId expected) { + assertEquals(expected, st.getStatementModel().getStatementId()); + + Optional model = getStatement(stateStore).apply(st.getStatementId().getId()); + assertTrue(model.isPresent()); + assertEquals(expected, model.get().getStatementId()); + return this; + } + + public TestStatement open() { + st.open(); + return this; + } + + public TestStatement cancel() { + st.cancel(); + return this; + } + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java b/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java deleted file mode 100644 index 9c779555d7..0000000000 --- a/spark/src/test/java/org/opensearch/sql/spark/execution/statestore/SessionStateStoreTest.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.spark.execution.statestore; - -import static org.junit.Assert.assertThrows; -import static org.mockito.Answers.RETURNS_DEEP_STUBS; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.when; - -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.action.DocWriteResponse; -import org.opensearch.action.index.IndexResponse; -import org.opensearch.client.Client; -import org.opensearch.sql.spark.execution.session.SessionId; -import org.opensearch.sql.spark.execution.session.SessionModel; - -@ExtendWith(MockitoExtension.class) -class SessionStateStoreTest { - @Mock(answer = RETURNS_DEEP_STUBS) - private Client client; - - @Mock private IndexResponse indexResponse; - - @Test - public void createWithException() { - when(client.index(any()).actionGet()).thenReturn(indexResponse); - doReturn(DocWriteResponse.Result.NOT_FOUND).when(indexResponse).getResult(); - SessionModel sessionModel = - SessionModel.initInteractiveSession( - "appId", "jobId", SessionId.newSessionId(), "datasource"); - SessionStateStore sessionStateStore = new SessionStateStore("indexName", client); - - assertThrows(RuntimeException.class, () -> sessionStateStore.create(sessionModel)); - } -}