From adc7da0cbb2baa500a8be553ddae34cb4ff2a4ff Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 30 Nov 2023 11:46:11 -0800 Subject: [PATCH] Add search and singular APIs to conversation memory (#1504) * add searchConversation Signed-off-by: HenryL27 * add searchinteractions Signed-off-by: HenryL27 * add searchConversationsITTests Signed-off-by: HenryL27 * add searchInteractionsITTests Signed-off-by: HenryL27 * add unit tests for storage-layer search Signed-off-by: HenryL27 * add Search transport actions and tests Signed-off-by: HenryL27 * add rest search actions Signed-off-by: HenryL27 * add search rest actions Signed-off-by: HenryL27 * Add singular get actions at storage layer Signed-off-by: HenryL27 * Add OpenSearhMemoryHandler unit tests for singular get Signed-off-by: HenryL27 * Add singular get transport layer Signed-off-by: HenryL27 * add singular get rest actions Signed-off-by: HenryL27 * fix async return value problem Signed-off-by: HenryL27 * address esay PR comments Signed-off-by: HenryL27 --------- Signed-off-by: HenryL27 --- .../common/conversation/ActionConstants.java | 23 +- .../memory/ConversationalMemoryHandler.java | 62 +++++ .../conversation/GetConversationAction.java | 34 +++ .../conversation/GetConversationRequest.java | 77 +++++++ .../conversation/GetConversationResponse.java | 60 +++++ .../GetConversationTransportAction.java | 100 ++++++++ .../conversation/GetInteractionAction.java | 34 +++ .../conversation/GetInteractionRequest.java | 85 +++++++ .../conversation/GetInteractionResponse.java | 61 +++++ .../GetInteractionTransportAction.java | 95 ++++++++ .../conversation/GetInteractionsAction.java | 2 +- .../SearchConversationsAction.java | 32 +++ .../SearchConversationsTransportAction.java | 92 ++++++++ .../SearchInteractionsAction.java | 32 +++ .../SearchInteractionsRequest.java | 54 +++++ .../SearchInteractionsTransportAction.java | 91 ++++++++ .../memory/index/ConversationMetaIndex.java | 155 ++++++++++--- .../ml/memory/index/InteractionsIndex.java | 133 +++++++++-- ...OpenSearchConversationalMemoryHandler.java | 86 +++++++ .../conversation/ConversationActionTests.java | 2 + .../GetConversationRequestTests.java | 68 ++++++ .../GetConversationResponseTests.java | 62 +++++ .../GetConversationTransportActionTests.java | 150 ++++++++++++ .../GetInteractionRequestTests.java | 85 +++++++ .../GetInteractionResponseTests.java | 64 ++++++ .../GetInteractionTransportActionTests.java | 158 +++++++++++++ .../conversation/InteractionActionTests.java | 6 +- ...archConversationsTransportActionTests.java | 121 ++++++++++ .../SearchInteractionsRequestTests.java | 64 ++++++ ...archInteractionsTransportActionsTests.java | 122 ++++++++++ .../index/ConversationMetaIndexITTests.java | 215 ++++++++++++++++++ .../index/ConversationMetaIndexTests.java | 158 ++++++++++++- .../index/InteractionsIndexITTests.java | 160 +++++++++++++ .../memory/index/InteractionsIndexTests.java | 147 ++++++++++++ ...earchConversationalMemoryHandlerTests.java | 52 ++++- .../ml/plugin/MachineLearningPlugin.java | 28 ++- .../rest/RestMemoryGetConversationAction.java | 51 +++++ .../rest/RestMemoryGetInteractionAction.java | 51 +++++ .../RestMemorySearchConversationsAction.java | 43 ++++ .../RestMemorySearchInteractionsAction.java | 83 +++++++ .../RestMemoryGetConversationActionIT.java | 81 +++++++ .../RestMemoryGetConversationActionTests.java | 64 ++++++ .../RestMemoryGetInteractionActionIT.java | 127 +++++++++++ .../RestMemoryGetInteractionActionTests.java | 65 ++++++ ...RestMemorySearchConversationsActionIT.java | 82 +++++++ ...tMemorySearchConversationsActionTests.java | 73 ++++++ .../RestMemorySearchInteractionsActionIT.java | 149 ++++++++++++ ...stMemorySearchInteractionsActionTests.java | 103 +++++++++ 48 files changed, 3875 insertions(+), 67 deletions(-) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 5bb8334bc1..f87da7c433 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -51,16 +51,25 @@ public class ActionConstants { /** name of success field in all requests */ public final static String SUCCESS_FIELD = "success"; + private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for create conversation */ - public final static String CREATE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation"; - /** path for list conversations */ - public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation"; - /** path for put interaction */ - public final static String CREATE_INTERACTION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; + /** path for get conversations */ + public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; + /** path for create interaction */ + public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; /** path for get interactions */ - public final static String GET_INTERACTIONS_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; /** path for delete conversation */ - public final static String DELETE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete"; + /** path for search conversations */ + public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; + /** path for search interactions */ + public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; + /** path for get conversation */ + public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + /** path for get interaction */ + public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}"; /** default max results returned by get operations */ public final static int DEFAULT_MAX_RESULTS = 10; diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index 18d23eff0d..42cece3f2e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -19,6 +19,8 @@ import java.util.List; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -171,4 +173,64 @@ public ActionFuture createInteraction( */ public ActionFuture deleteConversation(String conversationId); + /** + * Search over conversations index + * @param request search request over the conversations index + * @param listener receives the search response + */ + public void searchConversations(SearchRequest request, ActionListener listener); + + /** + * Search over conversations index + * @param request search request over the conversations index + * @return ActionFuture for the search response + */ + public ActionFuture searchConversations(SearchRequest request); + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @param listener receives the search response + */ + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener); + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @return ActionFuture for the search response + */ + public ActionFuture searchInteractions(String conversationId, SearchRequest request); + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @param listener receives the conversationMeta object + */ + public void getConversation(String conversationId, ActionListener listener); + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @return ActionFuture for the conversationMeta object + */ + public ActionFuture getConversation(String conversationId); + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @param listener receives the interaction + */ + public void getInteraction(String conversationId, String interactionId, ActionListener listener); + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @return ActionFuture for the interaction + */ + public ActionFuture getInteraction(String conversationId, String interactionId); + } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java new file mode 100644 index 0000000000..7839915201 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for retrieving a top-level conversation object by id + */ +public class GetConversationAction extends ActionType { + /** Instance of this */ + public static final GetConversationAction INSTANCE = new GetConversationAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get"; + + private GetConversationAction() { + super(NAME, GetConversationResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java new file mode 100644 index 0000000000..c5a6f6dd0e --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request object for GetConversation (singular) + */ +@AllArgsConstructor +public class GetConversationRequest extends ActionRequest { + @Getter + private String conversationId; + + /** + * Stream Constructor + * @param in input stream to read this from + * @throws IOException if something goes wrong reading from stream + */ + public GetConversationRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.conversationId == null) { + exception = addValidationError("GetConversation Request must have a conversation id", exception); + } + return exception; + } + + /** + * Creates a GetConversationRequest from a rest request + * @param request Rest Request representing a GetConversationRequest + * @return the new GetConversationRequest + * @throws IOException if something goes wrong in translation + */ + public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + return new GetConversationRequest(conversationId); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java new file mode 100644 index 0000000000..b757723e09 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ActionResponse object for GetConversation (singular) + */ +@AllArgsConstructor +public class GetConversationResponse extends ActionResponse implements ToXContentObject { + + @Getter + private ConversationMeta conversation; + + /** + * Stream Constructor + * @param in input stream to read this from + * @throws IOException if soething goes wrong in reading + */ + public GetConversationResponse(StreamInput in) throws IOException { + super(in); + this.conversation = ConversationMeta.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.conversation.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return this.conversation.toXContent(builder, params); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java new file mode 100644 index 0000000000..0f1c70ad51 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java @@ -0,0 +1,100 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * Transport Action for GetConversation + */ +@Log4j2 +public class GetConversationTransportAction extends HandledTransportAction { + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetConversationRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + String conversationId = request.getConversationId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener + .runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(conversationMeta -> { + internalListener.onResponse(new GetConversationResponse(conversationMeta)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getConversation(conversationId, al); + } catch (Exception e) { + log.error("Failed to get Conversation " + conversationId, e); + actionListener.onFailure(e); + } + + } + + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java new file mode 100644 index 0000000000..adaffcb9dc --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for Get Interaction (singular) + */ +public class GetInteractionAction extends ActionType { + /** Instance of this */ + public static final GetInteractionAction INSTANCE = new GetInteractionAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + + private GetInteractionAction() { + super(NAME, GetInteractionResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java new file mode 100644 index 0000000000..6808857c40 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request for GetInteraction + */ +@AllArgsConstructor +public class GetInteractionRequest extends ActionRequest { + @Getter + private String conversationId; + @Getter + private String interactionId; + + /** + * Stream Constructor + * @param in input stream to read this request from + * @throws IOException if somthing goes wrong reading + */ + public GetInteractionRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.interactionId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + out.writeString(this.interactionId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (conversationId == null) { + exception = addValidationError("Get Interaction Request must have a conversation id", exception); + } + if (interactionId == null) { + exception = addValidationError("Get Interaction Request must have an interaction id", exception); + } + return exception; + } + + /** + * Creates a GetInteractionRequest from a Rest Request + * @param request Rest Request representing a GetInteractionRequest + * @return new GetInteractionRequest built from the rest request + * @throws IOException if something goes wrong reading from the rest request + */ + public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + return new GetInteractionRequest(conversationId, interactionId); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java new file mode 100644 index 0000000000..7d3a1f3c73 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ActionResponse for Get Interaction (sg) + */ +@AllArgsConstructor +public class GetInteractionResponse extends ActionResponse implements ToXContentObject { + + @Getter + private Interaction interaction; + + /** + * Stream Constructor + * @param in Stream Input to read this response from + * @throws IOException if something goes wrong reading from stream + */ + public GetInteractionResponse(StreamInput in) throws IOException { + super(in); + this.interaction = Interaction.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + interaction.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return this.interaction.toXContent(builder, params); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java new file mode 100644 index 0000000000..16205ec8b9 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetInteractionTransportAction extends HandledTransportAction { + + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetInteractionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetInteractionAction.NAME, transportService, actionFilters, GetInteractionRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetInteractionRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + String conversationId = request.getConversationId(); + String interactionId = request.getInteractionId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(interaction -> { + internalListener.onResponse(new GetInteractionResponse(interaction)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getInteraction(conversationId, interactionId, al); + } catch (Exception e) { + log.error("Failed to get interaction " + interactionId + " in conversation " + conversationId, e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java index 024abe17ff..7a49d062d5 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java @@ -26,7 +26,7 @@ public class GetInteractionsAction extends ActionType { /** Instance of this */ public static final GetInteractionsAction INSTANCE = new GetInteractionsAction(); /** Name of this action */ - public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/list"; private GetInteractionsAction() { super(NAME, GetInteractionsResponse::new); diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java new file mode 100644 index 0000000000..38b19009ac --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +public class SearchConversationsAction extends ActionType { + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/search"; + /** Instance of this */ + public static final SearchConversationsAction INSTANCE = new SearchConversationsAction(); + + private SearchConversationsAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java new file mode 100644 index 0000000000..6aa8d79ca8 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -0,0 +1,92 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class SearchConversationsTransportAction extends HandledTransportAction { + + private ConversationalMemoryHandler cmHandler; + private Client client; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public SearchConversationsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new); + this.cmHandler = cmHandler; + this.client = client; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + cmHandler.searchConversations(request, internalListener); + } catch (Exception e) { + log.error("Failed to search conversations", e); + actionListener.onFailure(e); + } + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java new file mode 100644 index 0000000000..9386d6b674 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +public class SearchInteractionsAction extends ActionType { + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/interaction/search"; + /** Instance of this */ + public static final SearchInteractionsAction INSTANCE = new SearchInteractionsAction(); + + private SearchInteractionsAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java new file mode 100644 index 0000000000..ac55350d7f --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class SearchInteractionsRequest extends SearchRequest { + + @Setter + @Getter + private String conversationId; + + public SearchInteractionsRequest(String conversationId, SearchRequest request) { + super(request); + this.conversationId = conversationId; + } + + public SearchInteractionsRequest(StreamInput in) throws IOException { + super(in); + log.info("Got here"); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(conversationId); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java new file mode 100644 index 0000000000..5060e6111a --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class SearchInteractionsTransportAction extends HandledTransportAction { + + private ConversationalMemoryHandler cmHandler; + private Client client; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public SearchInteractionsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(SearchInteractionsAction.NAME, transportService, actionFilters, SearchInteractionsRequest::new); + this.cmHandler = cmHandler; + this.client = client; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, SearchInteractionsRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + cmHandler.searchInteractions(request.getConversationId(), request, internalListener); + } catch (Exception e) { + log.error("Failed to search conversations", e); + actionListener.onFailure(e); + } + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index e36c296066..47c55ac1e7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.index; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_INDEX_NAME; + import java.io.IOException; import java.time.Instant; import java.util.LinkedList; @@ -45,6 +47,8 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -66,21 +70,24 @@ public class ConversationMetaIndex { private Client client; private ClusterService clusterService; - private static final String indexName = ConversationalIndexConstants.META_INDEX_NAME; + + private String userstr() { + return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + } /** * Creates the conversational meta index if it doesn't already exist * @param listener listener to wait for this to finish */ public void initConversationMetaIndexIfAbsent(ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { log.debug("No conversational meta index found. Adding it"); - CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING); + CreateIndexRequest request = Requests.createIndexRequest(META_INDEX_NAME).mapping(ConversationalIndexConstants.META_MAPPING); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(createIndexResponse -> { - if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) { - log.info("created index [" + indexName + "]"); + if (createIndexResponse.equals(new CreateIndexResponse(true, true, META_INDEX_NAME))) { + log.info("created index [" + META_INDEX_NAME + "]"); internalListener.onResponse(true); } else { internalListener.onResponse(false); @@ -90,7 +97,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { internalListener.onResponse(true); } else { - log.error("failed to create index [" + indexName + "]", e); + log.error("failed to create index [" + META_INDEX_NAME + "]", e); internalListener.onFailure(e); } }); @@ -100,7 +107,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { listener.onResponse(true); } else { - log.error("failed to create index [" + indexName + "]", e); + log.error("failed to create index [" + META_INDEX_NAME + "]", e); listener.onFailure(e); } } @@ -117,12 +124,9 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) public void createConversation(String name, ActionListener listener) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); IndexRequest request = Requests - .indexRequest(indexName) + .indexRequest(META_INDEX_NAME) .source( ConversationalIndexConstants.META_CREATED_FIELD, Instant.now(), @@ -169,11 +173,12 @@ public void createConversation(ActionListener listener) { * @param listener gets the list of conversation metadata objects in the index */ public void getConversations(int from, int maxResults, ActionListener> listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(List.of()); + return; } - SearchRequest request = Requests.searchRequest(indexName); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + SearchRequest request = Requests.searchRequest(META_INDEX_NAME); + String userstr = userstr(); QueryBuilder queryBuilder; if (userstr == null) queryBuilder = new MatchAllQueryBuilder(); @@ -194,13 +199,12 @@ public void getConversations(int from, int maxResults, ActionListener { client.search(request, al); }, e -> { - log.error("Failed to retrieve conversations during refresh", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, al); + }, e -> { + log.error("Failed to retrieve conversations during refresh", e); + internalListener.onFailure(e); + })); } catch (Exception e) { log.error("Failed to retrieve conversations", e); listener.onFailure(e); @@ -222,11 +226,12 @@ public void getConversations(int maxResults, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); + return; } - DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { @@ -263,14 +268,14 @@ public void deleteConversation(String conversationId, ActionListener li */ public void checkAccess(String conversationId, ActionListener listener) { // If the index doesn't exist, you have permission. Just won't get you anywhere - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest getRequest = Requests.getRequest(indexName).id(conversationId); + GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { @@ -290,13 +295,93 @@ public void checkAccess(String conversationId, ActionListener listener) } internalListener.onResponse(true); }, e -> { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> { - log.error("Failed to refresh conversations index during check access ", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(getRequest, al); + }, e -> { + log.error("Failed to refresh conversations index during check access ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Search over the conversations in the index by wrapping the original search request + * If security is enabled, add a {"term": {"user": username}} to the wrapper must clause + * @param request original search request + * @param listener receives the search response for the wrapped query + */ + public void searchConversations(SearchRequest request, ActionListener listener) { + request.indices(META_INDEX_NAME); + QueryBuilder originalQuery = request.source().query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + newQuery.must(originalQuery); + String userstr = userstr(); + if (userstr != null) { + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user)); + } + request.source().query(newQuery); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh conversations index during search conversations ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @param listener receives the conversationMeta object + */ + public void getConversation(String conversationId, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException("cannot get conversation since the conversation index does not exist", META_INDEX_NAME) + ); + return; + } + String userstr = userstr(); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the conversation doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { + throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found"); + } + ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); + // If no security, return conversation + if (userstr == null || User.parse(userstr) == null) { + internalListener.onResponse(conversation); + return; + } + // If security and correct user, return conversation + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + if (user.equals(conversation.getUser())) { + internalListener.onResponse(conversation); + return; + } + // Otherwise you don't have permission + internalListener + .onFailure( + new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId) + ); + }, e -> { internalListener.onFailure(e); }); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh conversations index during get conversation ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 54857b274c..bd4eb1e39a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.index; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; + import java.io.IOException; import java.time.Instant; import java.util.LinkedList; @@ -25,10 +27,13 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.OpenSearchWrapperException; import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; @@ -41,6 +46,9 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; @@ -63,23 +71,28 @@ public class InteractionsIndex { private Client client; private ClusterService clusterService; private ConversationMetaIndex conversationMetaIndex; - private final String indexName = ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; // How big the steps should be when gathering *ALL* interactions in a conversation private final int resultsAtATime = 300; + private String userstr() { + return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + } + /** * 'PUT's the index in opensearch if it's not there already * @param listener gets whether the index needed to be initialized. Throws error if it fails to init */ public void initInteractionsIndexIfAbsent(ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { log.debug("No interactions index found. Adding it"); - CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); + CreateIndexRequest request = Requests + .createIndexRequest(INTERACTIONS_INDEX_NAME) + .mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(r -> { - if (r.equals(new CreateIndexResponse(true, true, indexName))) { - log.info("created index [" + indexName + "]"); + if (r.equals(new CreateIndexResponse(true, true, INTERACTIONS_INDEX_NAME))) { + log.info("created index [" + INTERACTIONS_INDEX_NAME + "]"); internalListener.onResponse(true); } else { internalListener.onResponse(false); @@ -89,7 +102,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { internalListener.onResponse(true); } else { - log.error("Failed to create index [" + indexName + "]", e); + log.error("Failed to create index [" + INTERACTIONS_INDEX_NAME + "]", e); internalListener.onFailure(e); } }); @@ -99,7 +112,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { listener.onResponse(true); } else { - log.error("Failed to create index [" + indexName + "]", e); + log.error("Failed to create index [" + INTERACTIONS_INDEX_NAME + "]", e); listener.onFailure(e); } } @@ -130,16 +143,13 @@ public void createInteraction( ActionListener listener ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { IndexRequest request = Requests - .indexRequest(indexName) + .indexRequest(INTERACTIONS_INDEX_NAME) .source( ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin, @@ -209,7 +219,7 @@ public void createInteraction( * @param listener gets the list, sorted by recency, of interactions */ public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener.onResponse(List.of()); return; } @@ -230,7 +240,7 @@ public void getInteractions(String conversationId, int from, int maxResults, Act @VisibleForTesting void innerGetInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { - SearchRequest request = Requests.searchRequest(indexName); + SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); TermQueryBuilder builder = new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); request.source().query(builder); request.source().from(from).size(maxResults); @@ -247,7 +257,7 @@ void innerGetInteractions(String conversationId, int from, int maxResults, Actio client .admin() .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(r -> { client.search(request, al); }, e -> { + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(r -> { client.search(request, al); }, e -> { internalListener.onFailure(e); })); } catch (Exception e) { @@ -307,18 +317,18 @@ ActionListener> nextGetListener( * @param listener gets whether the deletion was successful */ public void deleteConversation(String conversationId, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { BulkRequest request = Requests.bulkRequest(); for (Interaction interaction : interactions) { - DeleteRequest delRequest = Requests.deleteRequest(indexName).id(interaction.getId()); + DeleteRequest delRequest = Requests.deleteRequest(INTERACTIONS_INDEX_NAME).id(interaction.getId()); request.add(delRequest); } client @@ -340,4 +350,91 @@ public void deleteConversation(String conversationId, ActionListener li } } + /** + * Execute a search query over the interactions of a conversation by constructing a wrapper + * boolean query around the original query, AND a term query over conversation id + * @param conversationId the id of the conversation to query over + * @param request the original search request + * @param listener receives the search response from this query + */ + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener) { + conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { + if (access) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + request.indices(INTERACTIONS_INDEX_NAME); + QueryBuilder originalQuery = request.source().query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + newQuery.must(originalQuery); + newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId)); + request.source().query(newQuery); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh interactions index during search interactions ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } else { + String userstr = userstr(); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); + } + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @param listener receives the interaction + */ + public void getInteraction(String conversationId, String interactionId, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException( + "cannot get interaction since the interactions index does not exist", + INTERACTIONS_INDEX_NAME + ) + ); + return; + } + conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { + if (access) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the conversation doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); + } + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + internalListener.onResponse(interaction); + }, e -> { internalListener.onFailure(e); }); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } else { + String userstr = userstr(); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 6b33533ee2..c1997be829 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -21,6 +21,8 @@ import java.util.List; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -286,4 +288,88 @@ public ActionFuture deleteConversation(String conversationId) { return fut; } + /** + * Search over conversations index + * @param request search request over the conversations index + * @param listener receives the search response + */ + public void searchConversations(SearchRequest request, ActionListener listener) { + conversationMetaIndex.searchConversations(request, listener); + } + + /** + * Search over conversations index + * @param request search request over the conversations index + * @return ActionFuture for the search response + */ + public ActionFuture searchConversations(SearchRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + searchConversations(request, fut); + return fut; + } + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @param listener receives the search response + */ + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener) { + interactionsIndex.searchInteractions(conversationId, request, listener); + } + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @return ActionFuture for the search response + */ + public ActionFuture searchInteractions(String conversationId, SearchRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + searchInteractions(conversationId, request, fut); + return fut; + } + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @param listener receives the conversationMeta object + */ + public void getConversation(String conversationId, ActionListener listener) { + conversationMetaIndex.getConversation(conversationId, listener); + } + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @return ActionFuture for the conversationMeta object + */ + public ActionFuture getConversation(String conversationId) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + getConversation(conversationId, fut); + return fut; + } + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @param listener receives the interaction + */ + public void getInteraction(String conversationId, String interactionId, ActionListener listener) { + interactionsIndex.getInteraction(conversationId, interactionId, listener); + } + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @return ActionFuture for the interaction + */ + public ActionFuture getInteraction(String conversationId, String interactionId) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + getInteraction(conversationId, interactionId, fut); + return fut; + } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java index 2975cd4c1d..541ff5ed2e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java @@ -24,5 +24,7 @@ public void testActions() { assert (CreateConversationAction.INSTANCE instanceof CreateConversationAction); assert (DeleteConversationAction.INSTANCE instanceof DeleteConversationAction); assert (GetConversationsAction.INSTANCE instanceof GetConversationsAction); + assert (SearchConversationsAction.INSTANCE instanceof SearchConversationsAction); + assert (GetConversationAction.INSTANCE instanceof GetConversationAction); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java new file mode 100644 index 0000000000..cb8b67b44b --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetConversationRequestTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + GetConversationRequest request = new GetConversationRequest("Test-id"); + assert (request.validate() == null); + assert (request.getConversationId().equals("Test-id")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationRequest newRequest = new GetConversationRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getConversationId().equals("Test-id")); + } + + public void testNullConvoId_ThenFail() { + String id = null; + GetConversationRequest request = new GetConversationRequest(id); + ActionRequestValidationException exc = request.validate(); + assert (exc != null); + assert (exc.validationErrors().size() == 1); + assert (exc.validationErrors().get(0).equals("GetConversation Request must have a conversation id")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "testcid"); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + GetConversationRequest request = GetConversationRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("testcid")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java new file mode 100644 index 0000000000..4b8f3a8fed --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.test.OpenSearchTestCase; + +public class GetConversationResponseTests extends OpenSearchTestCase { + + public void testGetConversationResponseStreaming() throws IOException { + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + GetConversationResponse response = new GetConversationResponse(convo); + assert (response.getConversation().equals(convo)); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationResponse newResponse = new GetConversationResponse(in); + assert (newResponse.getConversation().equals(convo)); + } + + public void testToXContent() throws IOException { + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + GetConversationResponse response = new GetConversationResponse(convo); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + convo.getCreatedTime() + "\",\"name\":\"name\"}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java new file mode 100644 index 0000000000..3afcc1dd21 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetConversationTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetConversationRequest request; + GetConversationTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetConversationRequest("test-cid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetConversation() { + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), "name", null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(result); + return null; + }).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversation().getId().equals("test-cid")); + } + + public void testGetConversationFails_ThenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("CMHandler Failure")); + return null; + }).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Failure")); + } + + public void testHandlerThrows_ThenFail() { + doThrow(new RuntimeException("CMHandler Throws")).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Throws")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java new file mode 100644 index 0000000000..678004ae09 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetInteractionRequestTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + GetInteractionRequest request = new GetInteractionRequest("cid", "iid"); + assert (request.validate() == null); + assert (request.getConversationId().equals("cid")); + assert (request.getInteractionId().equals("iid")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionRequest newRequest = new GetInteractionRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getConversationId().equals("cid")); + assert (newRequest.getInteractionId().equals("iid")); + } + + public void testMalformedRequest_ThenInvalid() { + GetInteractionRequest bad1 = new GetInteractionRequest(null, "iid"); + GetInteractionRequest bad2 = new GetInteractionRequest("cid", null); + GetInteractionRequest bad3 = new GetInteractionRequest(null, null); + ActionRequestValidationException exc1 = bad1.validate(); + ActionRequestValidationException exc2 = bad2.validate(); + ActionRequestValidationException exc3 = bad3.validate(); + + assert (exc1 != null); + assert (exc1.validationErrors().size() == 1); + assert (exc1.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); + + assert (exc2 != null); + assert (exc2.validationErrors().size() == 1); + assert (exc2.validationErrors().get(0).equals("Get Interaction Request must have an interaction id")); + + assert (exc3 != null); + assert (exc3.validationErrors().size() == 2); + assert (exc3.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); + assert (exc3.validationErrors().get(1).equals("Get Interaction Request must have an interaction id")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map + .of(ActionConstants.CONVERSATION_ID_FIELD, "testcid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "testiid"); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + GetInteractionRequest request = GetInteractionRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("testcid")); + assert (request.getInteractionId().equals("testiid")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java new file mode 100644 index 0000000000..b7cbc1c471 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +public class GetInteractionResponseTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + GetInteractionResponse response = new GetInteractionResponse(interaction); + assert (response.getInteraction().equals(interaction)); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionResponse newResponse = new GetInteractionResponse(in); + assert (newResponse.getInteraction().equals(interaction)); + } + + public void testToXContent() throws IOException { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + GetInteractionResponse response = new GetInteractionResponse(interaction); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversation_id\":\"cid\",\"interaction_id\":\"iid\",\"create_time\":\"" + + interaction.getCreateTime() + + "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":\"extra\"}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java new file mode 100644 index 0000000000..6ca8197b54 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetInteractionTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetInteractionRequest request; + GetInteractionTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetInteractionRequest("cid", "iid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetInteraction() { + Interaction testInteraction = new Interaction( + "iid", + Instant.now(), + "cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(testInteraction); + return null; + }).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getInteraction().getId().equals("iid")); + } + + public void testGetInteractionFails_ThenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("Storage layer failure")); + return null; + }).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Storage layer failure")); + } + + public void testHandlerThrows_ThenFail() { + doThrow(new RuntimeException("CMHandler Failure")).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Failure")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java index 9002796bbc..187ae4bdf7 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java @@ -17,9 +17,13 @@ */ package org.opensearch.ml.memory.action.conversation; -public class InteractionActionTests { +import org.opensearch.test.OpenSearchTestCase; + +public class InteractionActionTests extends OpenSearchTestCase { public void testActions() { assert (CreateInteractionAction.INSTANCE instanceof CreateInteractionAction); assert (GetInteractionsAction.INSTANCE instanceof GetInteractionsAction); + assert (SearchInteractionsAction.INSTANCE instanceof SearchInteractionsAction); + assert (GetInteractionAction.INSTANCE instanceof GetInteractionAction); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java new file mode 100644 index 0000000000..7ec9d8c042 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class SearchConversationsTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + @Mock + SearchRequest request; + + SearchConversationsTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testEnabled_ThenSucceed() { + SearchResponse response = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(cmHandler).searchConversations(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().equals(response)); + } + + public void testDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java new file mode 100644 index 0000000000..af7dc33c9c --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.indices.IndicesModule; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class SearchInteractionsRequestTests extends OpenSearchTestCase { + + protected NamedWriteableRegistry namedWriteableRegistry; + + public void setUp() throws Exception { + super.setUp(); + IndicesModule indicesModule = new IndicesModule(Collections.emptyList()); + SearchModule searchModule = new SearchModule(Settings.EMPTY, List.of()); + List entries = new ArrayList<>(); + entries.addAll(indicesModule.getNamedWriteables()); + entries.addAll(searchModule.getNamedWriteables()); + namedWriteableRegistry = new NamedWriteableRegistry(entries); + } + + public void testConstructorsAndStreaming() throws IOException { + SearchRequest original = new SearchRequest(); + original.source(new SearchSourceBuilder()); + original.source().query(new MatchAllQueryBuilder()); + + SearchInteractionsRequest request = new SearchInteractionsRequest("test_cid", original); + assert (request instanceof SearchRequest); + assert (request.getConversationId().equals("test_cid")); + assert (request.validate() == null); + + SearchInteractionsRequest newRequest = copyWriteable(request, namedWriteableRegistry, SearchInteractionsRequest::new); + assert (newRequest.getConversationId().equals("test_cid")); + assert (newRequest.validate() == null); + assert (newRequest.equals(request)); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java new file mode 100644 index 0000000000..abe5204c65 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class SearchInteractionsTransportActionsTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + @Mock + SearchInteractionsRequest request; + + SearchInteractionsTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + when(this.request.getConversationId()).thenReturn("test_cid"); + + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testFeatureEnabled_ThenSucceed() { + SearchResponse response = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(cmHandler).searchInteractions(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().equals(response)); + } + + public void testDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index e1a0318758..fc605e3fb0 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -29,6 +29,8 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.cluster.service.ClusterService; @@ -36,8 +38,10 @@ import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; @@ -415,4 +419,215 @@ public void testDifferentUsersCannotTouchOthersConversations() { } } + public void testCanQueryOverConversations() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener convo1 = new StepListener<>(); + index.createConversation("Henry Conversation", convo1); + + StepListener convo2 = new StepListener<>(); + convo1.whenComplete(cid -> { index.createConversation("Mehul Conversation", convo2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener search = new StepListener<>(); + convo2.whenComplete(cid -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Henry Conversation")); + index.searchConversations(request, search); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + search.whenComplete(response -> { + log.info("SEARCH RESPONSE"); + log.info(response.toString()); + cdl.countDown(); + assert (response.getHits().getAt(0).getId().equals(convo1.result())); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testCanQueryOverConversationsSecurely() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + + final String user1 = "Dhrubo"; + final String user2 = "Jing"; + contextStack.push(setUser(user1)); + + StepListener convo1 = new StepListener<>(); + index.createConversation("Dhrubo Conversation", convo1); + + StepListener convo2 = new StepListener<>(); + convo1.whenComplete(cid -> { + contextStack.push(setUser(user2)); + index.createConversation("Jing Conversation", convo2); + }, onFail); + + StepListener search1 = new StepListener<>(); + convo2.whenComplete(cid -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Dhrubo Conversation")); + index.searchConversations(request, search1); + }, onFail); + + StepListener search2 = new StepListener<>(); + search1.whenComplete(response -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Jing Conversation")); + index.searchConversations(request, search2); + }, onFail); + + search2.whenComplete(response -> { + cdl.countDown(); + assert (response.getHits().getAt(0).getId().equals(convo2.result())); + assert (search1.result().getHits().getHits().length == 0); + while (!contextStack.isEmpty()) { + contextStack.pop().close(); + } + }, onFail); + + try { + cdl.await(); + threadContext.restore(); + } catch (InterruptedException e) { + log.error(e); + threadContext.restore(); + } + + } catch (Exception e) { + log.error(e); + } + } + + public void testCanGetAConversationById() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener cid1 = new StepListener<>(); + index.createConversation("convo1", cid1); + + StepListener cid2 = new StepListener<>(); + cid1.whenComplete(cid -> { index.createConversation("convo2", cid2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener get1 = new StepListener<>(); + cid2.whenComplete(cid -> { index.getConversation(cid1.result(), get1); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener get2 = new StepListener<>(); + get1.whenComplete(convo1 -> { index.getConversation(cid2.result(), get2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + get2.whenComplete(convo2 -> { + assert (cid1.result().equals(get1.result().getId())); + assert (cid2.result().equals(get2.result().getId())); + assert (get1.result().getName().equals("convo1")); + assert (get2.result().getName().equals("convo2")); + cdl.countDown(); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testCanGetAConversationByIdSecurely() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + + final String user1 = "Austin"; + final String user2 = "Yaliang"; + contextStack.push(setUser(user1)); + + StepListener cid1 = new StepListener<>(); + index.createConversation("Austin Convo", cid1); + + StepListener cid2 = new StepListener<>(); + cid1.whenComplete(cid -> { + contextStack.push(setUser(user2)); + index.createConversation("Yaliang Convo", cid2); + }, onFail); + + StepListener get2 = new StepListener<>(); + cid2.whenComplete(cid -> { index.getConversation(cid2.result(), get2); }, onFail); + + StepListener get1 = new StepListener<>(); + get2.whenComplete(convo -> { index.getConversation(cid1.result(), get1); }, onFail); + + get1.whenComplete(convo -> { + while (!contextStack.isEmpty()) { + contextStack.pop().close(); + } + cdl.countDown(); + assert (false); + }, e -> { + cdl.countDown(); + assert (e.getMessage().startsWith("User [Yaliang] does not have access to conversation")); + assert (get2.result().getName().equals("Yaliang Convo")); + assert (get2.result().getId().equals(cid2.result())); + }); + + try { + cdl.await(); + threadContext.restore(); + } catch (InterruptedException e) { + log.error(e); + threadContext.restore(); + } + + } catch (Exception e) { + log.error(e); + } + } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 821d801cdf..5445fd6213 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -40,6 +40,7 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; @@ -52,7 +53,9 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SendRequestTransportException; @@ -134,6 +137,13 @@ private void setupUser(String user) { }).when(threadPool).getThreadContext(); } + private SearchRequest dummyRequest() { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchAllQueryBuilder()); + return request; + } + public void testInit_DoesNotCreateIndex() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") @@ -402,6 +412,18 @@ public void testDelete_DeleteFails_ThenFail() { assert (argCaptor.getValue().getMessage().equals("Test Fail in Delete")); } + public void testDelete_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + blanketGrantAccess(); + doThrow(new RuntimeException("Client Fail in Delete")).when(client).delete(any(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Fail in Delete")); + } + public void testCheckAccess_DoesNotExist_ThenFail() { setupUser("user"); setupRefreshSuccess(); @@ -464,7 +486,7 @@ public void testCheckAccess_ClientFails_ThenFail() { setupUser("user"); setupRefreshSuccess(); doReturn(true).when(metadata).hasIndex(anyString()); - doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any()); + doThrow(new RuntimeException("Client Test Fail")).when(client).admin(); @SuppressWarnings("unchecked") ActionListener accessListener = mock(ActionListener.class); conversationMetaIndex.checkAccess("test id", accessListener); @@ -475,11 +497,143 @@ public void testCheckAccess_ClientFails_ThenFail() { public void testCheckAccess_EmptyStringUser_ThenReturnTrue() { setupUser(null); + setupRefreshSuccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + final String id = "test_id"; + GetResponse dummyGetResponse = mock(GetResponse.class); + doReturn(true).when(dummyGetResponse).isExists(); + doReturn(id).when(dummyGetResponse).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(dummyGetResponse); + return null; + }).when(client).get(any(), any()); @SuppressWarnings("unchecked") ActionListener accessListener = mock(ActionListener.class); - conversationMetaIndex.checkAccess("test id", accessListener); + conversationMetaIndex.checkAccess(id, accessListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); verify(accessListener, times(1)).onResponse(argCaptor.capture()); assert (argCaptor.getValue()); } + + public void testCheckAccess_RefreshFails_ThenFail() { + setupUser("user"); + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testSearchConversations_RefreshFails_ThenFail() { + SearchRequest request = dummyRequest(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener searchConversationsListener = mock(ActionListener.class); + conversationMetaIndex.searchConversations(request, searchConversationsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchConversationsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testSearchConversations_ClientFails_ThenFail() { + SearchRequest request = dummyRequest(); + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Client Test Fail")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.searchConversations(request, accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Test Fail")); + } + + public void testGetConversation_NoIndex_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals( + "no such index [.plugins-ml-conversation-meta] and cannot get conversation since the conversation index does not exist" + )); + } + + public void testGetConversation_ResponseNotExist_ThenFail() { + setupRefreshSuccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + GetResponse response = mock(GetResponse.class); + doReturn(false).when(response).isExists(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Conversation [tester_id] not found")); + } + + public void testGetConversation_WrongId_ThenFail() { + setupRefreshSuccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn("wrong id").when(response).getId(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Conversation [tester_id] not found")); + } + + public void testGetConversation_RefreshFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testGetConversation_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Clietn Failure")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Clietn Failure")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index c23177bc2f..0c0791fb23 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -19,16 +19,23 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import org.junit.Before; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; @@ -348,4 +355,157 @@ public void testDeleteConversation() { log.error(e); } } + + public void testSearchInteractions() { + final String conversation1 = "conversation1"; + final String conversation2 = "conversation2"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener iid1 = new StepListener<>(); + index + .createInteraction( + conversation1, + "input about fish", + "pt", + "response about fish", + "origin1", + "lots of information about fish", + iid1 + ); + + StepListener iid2 = new StepListener<>(); + iid1.whenComplete(r -> { + index + .createInteraction( + conversation1, + "input about squash", + "pt", + "response about squash", + "origin1", + "lots of information about squash", + iid2 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener iid3 = new StepListener<>(); + iid2.whenComplete(r -> { + index + .createInteraction( + conversation2, + "input about fish", + "pt2", + "response about fish", + "origin1", + "lots of information about fish", + iid3 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener iid4 = new StepListener<>(); + iid3.whenComplete(r -> { + index + .createInteraction( + conversation1, + "input about france", + "pt", + "response about france", + "origin1", + "lots of information about france", + iid4 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener searchListener = new StepListener<>(); + iid4.whenComplete(r -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchQueryBuilder(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "fish input")); + index.searchInteractions(conversation1, request, searchListener); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + searchListener.whenComplete(response -> { + cdl.countDown(); + assert (response.getHits().getHits().length == 3); + // BM25 was being a little unpredictable here so I don't assert ordering + List ids = new ArrayList<>(3); + for (SearchHit hit : response.getHits()) { + ids.add(hit.getId()); + } + assert (ids.contains(iid1.result())); + assert (ids.contains(iid2.result())); + assert (ids.contains(iid4.result())); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testGetInteractionById() { + final String conversation = "test-conversation"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener iid1 = new StepListener<>(); + index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", iid1); + + StepListener iid2 = new StepListener<>(); + iid1 + .whenComplete( + iid -> { index.createInteraction(conversation, "test input2", "pt", "test response", "test origin", "metadata", iid2); }, + e -> { + cdl.countDown(); + log.error(e); + assert false; + } + ); + + StepListener get1 = new StepListener<>(); + iid2.whenComplete(iid -> { index.getInteraction(conversation, iid1.result(), get1); }, e -> { + cdl.countDown(); + log.error(e); + }); + + StepListener get2 = new StepListener<>(); + get1.whenComplete(interaction1 -> { index.getInteraction(conversation, iid2.result(), get2); }, e -> { + cdl.countDown(); + log.error(e); + }); + + get2.whenComplete(interaction2 -> { + assert (get1.result().getId().equals(iid1.result())); + assert (get1.result().getInput().equals("test input")); + assert (get2.result().getId().equals(iid2.result())); + assert (get2.result().getInput().equals("test input2")); + cdl.countDown(); + }, e -> { + cdl.countDown(); + log.error(e); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 0e97c7e9f6..2d4184eec3 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -40,7 +40,9 @@ import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.refresh.RefreshResponse; import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; @@ -53,8 +55,10 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SendRequestTransportException; @@ -145,6 +149,13 @@ private void setupRefreshSuccess() { }).when(indicesAdminClient).refresh(any(), any()); } + private SearchRequest dummyRequest() { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchAllQueryBuilder()); + return request; + } + public void testInit_DoesNotCreateIndex_ThenReturnFalse() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") @@ -582,4 +593,140 @@ public void testDelete_MainFailure_ThenFail() { verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Failure")); } + + public void testSearch_RefreshFails_ThenFail() { + setupGrantAccess(); + SearchRequest request = dummyRequest(); + final String cid = "test_id"; + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failed during Search Refresh")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener searchInteractionsListener = mock(ActionListener.class); + interactionsIndex.searchInteractions(cid, request, searchInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed during Search Refresh")); + } + + public void testSearch_ClientFails_ThenFail() { + setupGrantAccess(); + SearchRequest request = dummyRequest(); + final String cid = "test_cid"; + doThrow(new RuntimeException("Client Failure in Search Interactions")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener searchInteractionsListener = mock(ActionListener.class); + interactionsIndex.searchInteractions(cid, request, searchInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure in Search Interactions")); + } + + public void testSearch_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("user"); + SearchRequest request = dummyRequest(); + final String cid = "test_cid"; + @SuppressWarnings("unchecked") + ActionListener searchInteractionsListener = mock(ActionListener.class); + interactionsIndex.searchInteractions(cid, request, searchInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation test_cid")); + } + + public void testGetSg_NoIndex_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals( + "no such index [.plugins-ml-conversation-interactions] and cannot get interaction since the interactions index does not exist" + )); + } + + public void testGetSg_InteractionNotExist_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + setupRefreshSuccess(); + GetResponse response = mock(GetResponse.class); + doReturn(false).when(response).isExists(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Interaction [iid] not found")); + } + + public void testGetSg_WrongId_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + setupRefreshSuccess(); + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn("wrong id").when(response).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Interaction [iid] not found")); + } + + public void testGetSg_RefreshFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failed during Sg Get Refresh")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed during Sg Get Refresh")); + } + + public void testGetSg_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doThrow(new RuntimeException("Client Failure in Sg Get")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get")); + } + + public void testGetSg_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("Henry"); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to conversation cid")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index e39513d2d8..c8df948bcb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -25,11 +25,14 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.time.Instant; import java.util.List; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -59,7 +62,7 @@ public void testCreateConversation_NoName_FutureSuccess() { ActionListener al = invocation.getArgument(0); al.onResponse("cid"); return null; - }).when(conversationMetaIndex).createConversation(any(ActionListener.class)); + }).when(conversationMetaIndex).createConversation(any()); ActionFuture result = cmHandler.createConversation(); assert (result.actionGet(200).equals("cid")); } @@ -241,4 +244,51 @@ public void testDelete_AsFuture() { ActionFuture result = cmHandler.deleteConversation("cid"); assert (result.actionGet(200)); } + + public void testSearchConversations_Future() { + SearchRequest request = mock(SearchRequest.class); + SearchResponse response = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(conversationMetaIndex).searchConversations(any(), any()); + ActionFuture result = cmHandler.searchConversations(request); + assert (result.actionGet().equals(response)); + } + + public void testSearchInteractions_Future() { + SearchRequest request = mock(SearchRequest.class); + SearchResponse response = mock(SearchResponse.class); + String cid = "cid"; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(interactionsIndex).searchInteractions(any(), any(), any()); + ActionFuture result = cmHandler.searchInteractions(cid, request); + assert (result.actionGet().equals(response)); + } + + public void testGetAConversation_Future() { + ConversationMeta response = new ConversationMeta("cid", Instant.now(), "boring name", null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(conversationMetaIndex).getConversation(any(), any()); + ActionFuture result = cmHandler.getConversation("cid"); + assert (result.actionGet().equals(response)); + } + + public void testGetAnInteraction_Future() { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(interaction); + return null; + }).when(interactionsIndex).getInteraction(any(), any(), any()); + ActionFuture result = cmHandler.getInteraction("cid", "iid"); + assert (result.actionGet().equals(interaction)); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index f5ba454c4d..13b4834d59 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -139,10 +139,18 @@ import org.opensearch.ml.memory.action.conversation.CreateInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; import org.opensearch.ml.memory.action.conversation.DeleteConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationTransportAction; import org.opensearch.ml.memory.action.conversation.GetConversationsAction; import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.memory.action.conversation.SearchConversationsTransportAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsTransportAction; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -177,8 +185,12 @@ import org.opensearch.ml.rest.RestMemoryCreateConversationAction; import org.opensearch.ml.rest.RestMemoryCreateInteractionAction; import org.opensearch.ml.rest.RestMemoryDeleteConversationAction; +import org.opensearch.ml.rest.RestMemoryGetConversationAction; import org.opensearch.ml.rest.RestMemoryGetConversationsAction; +import org.opensearch.ml.rest.RestMemoryGetInteractionAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; +import org.opensearch.ml.rest.RestMemorySearchConversationsAction; +import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -305,7 +317,11 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(CreateInteractionAction.INSTANCE, CreateInteractionTransportAction.class), new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class), new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class), - new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class) + new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class), + new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), + new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), + new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), + new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class) ); } @@ -557,6 +573,10 @@ public List getRestHandlers( RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(); + RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); + RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); + RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); return ImmutableList .of( restMLStatsAction, @@ -591,7 +611,11 @@ public List getRestHandlers( restCreateInteractionAction, restListInteractionsAction, restDeleteConversationAction, - restMLUpdateConnectorAction + restMLUpdateConnectorAction, + restSearchConversationsAction, + restSearchInteractionsAction, + restGetConversationAction, + restGetInteractionAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java new file mode 100644 index 0000000000..dbabd40953 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationAction extends BaseRestHandler { + private final static String GET_CONVERSATION_NAME = "conversational_get_conversation"; + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH)); + } + + @Override + public String getName() { + return GET_CONVERSATION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetConversationRequest gcRequest = GetConversationRequest.fromRestRequest(request); + return channel -> client.execute(GetConversationAction.INSTANCE, gcRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java new file mode 100644 index 0000000000..ad2b35dbf6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionAction extends BaseRestHandler { + private final static String GET_INTERACTION_NAME = "conversational_get_interaction"; + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH)); + } + + @Override + public String getName() { + return GET_INTERACTION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetInteractionRequest giRequest = GetInteractionRequest.fromRestRequest(request); + return channel -> client.execute(GetInteractionAction.INSTANCE, giRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java new file mode 100644 index 0000000000..5beee29c42 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchConversationsAction extends AbstractMLSearchAction { + private static final String SEARCH_CONVERSATIONS_NAME = "conversation_memory_search_conversations"; + + public RestMemorySearchConversationsAction() { + super( + ImmutableList.of(ActionConstants.SEARCH_CONVERSATIONS_REST_PATH), + ConversationalIndexConstants.META_INDEX_NAME, + ConversationMeta.class, + SearchConversationsAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_CONVERSATIONS_NAME; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java new file mode 100644 index 0000000000..063e2502ce --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java @@ -0,0 +1,83 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.utils.RestActionUtils.getSourceContext; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.search.builder.SearchSourceBuilder; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchInteractionsAction extends BaseRestHandler { + private static final String SEARCH_INTERACTIONS_NAME = "conversation_memory_search_interactions"; + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.POST, ActionConstants.SEARCH_INTERACTIONS_REST_PATH), + new Route(RestRequest.Method.GET, ActionConstants.SEARCH_INTERACTIONS_REST_PATH) + ); + } + + @Override + public String getName() { + return SEARCH_INTERACTIONS_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder); + SearchInteractionsRequest siRequest = new SearchInteractionsRequest(conversationId, searchRequest); + return channel -> client.execute(SearchInteractionsAction.INSTANCE, siRequest, search(channel)); + } + + protected RestResponseListener search(RestChannel channel) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(SearchResponse response) throws Exception { + if (response.isTimedOut()) { + return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString()); + } + return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), EMPTY_PARAMS)); + } + }; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java new file mode 100644 index 0000000000..5a55b1c301 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetConversation() throws IOException { + Response ccresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + @SuppressWarnings("unchecked") + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); + + Response gcresponse = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{conversation_id}", id), null, "", null); + assert (gcresponse != null); + assert (TestHelper.restStatus(gcresponse) == RestStatus.OK); + HttpEntity gchttpEntity = gcresponse.getEntity(); + String gcentitiyString = TestHelper.httpEntityToString(gchttpEntity); + @SuppressWarnings("unchecked") + Map gcmap = gson.fromJson(gcentitiyString, Map.class); + assert (gcmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gcmap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(id)); + assert (gcmap.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) + && gcmap.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD).equals("name")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java new file mode 100644 index 0000000000..0e81f2dacb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetConversationActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetConversationAction action = new RestMemoryGetConversationAction(); + assert (action.getName().equals("conversational_get_conversation")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetConversationAction action = new RestMemoryGetConversationAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationRequest.class); + verify(client, times(1)).execute(eq(GetConversationAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getConversationId().equals("cid")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java new file mode 100644 index 0000000000..691195a99b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetInteraction() throws IOException { + Response ccresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + @SuppressWarnings("unchecked") + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response ciresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (ciresponse != null); + assert (TestHelper.restStatus(ciresponse) == RestStatus.OK); + HttpEntity cihttpEntity = ciresponse.getEntity(); + String cientityString = TestHelper.httpEntityToString(cihttpEntity); + @SuppressWarnings("unchecked") + Map cimap = gson.fromJson(cientityString, Map.class); + assert (cimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD)); + String iid = cimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + + Response giresponse = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTION_REST_PATH.replace("{conversation_id}", cid).replace("{interaction_id}", iid), + null, + "", + null + ); + assert (giresponse != null); + assert (TestHelper.restStatus(giresponse) == RestStatus.OK); + HttpEntity gihttpEntity = giresponse.getEntity(); + String gientityString = TestHelper.httpEntityToString(gihttpEntity); + @SuppressWarnings("unchecked") + Map gimap = gson.fromJson(gientityString, Map.class); + assert (gimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD) + && gimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD).equals(iid)); + assert (gimap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gimap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(cid)); + assert (gimap.containsKey(ActionConstants.INPUT_FIELD) && gimap.get(ActionConstants.INPUT_FIELD).equals("input")); + assert (gimap.containsKey(ActionConstants.PROMPT_TEMPLATE_FIELD) + && gimap.get(ActionConstants.PROMPT_TEMPLATE_FIELD).equals("promtp template")); + assert (gimap.containsKey(ActionConstants.AI_RESPONSE_FIELD) && gimap.get(ActionConstants.AI_RESPONSE_FIELD).equals("response")); + assert (gimap.containsKey(ActionConstants.RESPONSE_ORIGIN_FIELD) + && gimap.get(ActionConstants.RESPONSE_ORIGIN_FIELD).equals("origin")); + assert (gimap.containsKey(ActionConstants.ADDITIONAL_INFO_FIELD) + && gimap.get(ActionConstants.ADDITIONAL_INFO_FIELD).equals("some metadata")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java new file mode 100644 index 0000000000..9d0cc6515b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetInteractionActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); + assert (action.getName().equals("conversational_get_interaction")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionRequest.class); + verify(client, times(1)).execute(eq(GetInteractionAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getConversationId().equals("cid")); + assert (argCaptor.getValue().getInteractionId().equals("iid")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java new file mode 100644 index 0000000000..264ef5ea24 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java @@ -0,0 +1,82 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.opensearch.ml.utils.TestData.matchAllSearchQuery; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchConversationsActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testSearchConversations_Successful() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String id = (String) ccmap.get("conversation_id"); + + Response scresponse = TestHelper + .makeRequest(client(), "POST", ActionConstants.SEARCH_CONVERSATIONS_REST_PATH, null, matchAllSearchQuery(), null); + assert (scresponse != null); + assert (TestHelper.restStatus(scresponse) == RestStatus.OK); + HttpEntity schttpEntity = scresponse.getEntity(); + String scentityString = TestHelper.httpEntityToString(schttpEntity); + Map scmap = gson.fromJson(scentityString, Map.class); + assert (scmap.containsKey("hits")); + Map hitsmap = (Map) scmap.get("hits"); + assert (hitsmap.containsKey("hits")); + ArrayList hitsarray = (ArrayList) hitsmap.get("hits"); + assert (hitsarray.size() == 1); + for (Map hit : hitsarray) { + assert (hit.containsKey("_id")); + assert (hit.get("_id").equals(id)); + } + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java new file mode 100644 index 0000000000..294e3deac4 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.gson.Gson; + +public class RestMemorySearchConversationsActionTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testBasics() { + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + assert (action.getName().equals("conversation_memory_search_conversations")); + List routes = action.routes(); + assert (routes.size() == 2); + assert (routes.get(0).equals(new Route(RestRequest.Method.POST, ActionConstants.SEARCH_CONVERSATIONS_REST_PATH))); + assert (routes.get(1).equals(new Route(RestRequest.Method.GET, ActionConstants.SEARCH_CONVERSATIONS_REST_PATH))); + } + + public void testPreprareRequest() throws Exception { + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + RestRequest request = TestHelper.getSearchAllRestRequest(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + verify(client, times(1)).execute(eq(SearchConversationsAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().source().query() instanceof MatchAllQueryBuilder); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java new file mode 100644 index 0000000000..9de93ac103 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java @@ -0,0 +1,149 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.opensearch.ml.utils.TestData.matchAllSearchQuery; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchInteractionsActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testSearchInteractions_Successfull() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Map params1 = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "fish metadata" + ); + Response ciresponse1 = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params1), + null + ); + assert (ciresponse1 != null); + assert (TestHelper.restStatus(ciresponse1) == RestStatus.OK); + HttpEntity cihttpEntity1 = ciresponse1.getEntity(); + String cientityString1 = TestHelper.httpEntityToString(cihttpEntity1); + Map cimap1 = gson.fromJson(cientityString1, Map.class); + assert (cimap1.containsKey("interaction_id")); + String iid1 = (String) cimap1.get("interaction_id"); + + Map params2 = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "france metadata" + ); + Response ciresponse2 = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params2), + null + ); + assert (ciresponse2 != null); + assert (TestHelper.restStatus(ciresponse2) == RestStatus.OK); + HttpEntity cihttpEntity2 = ciresponse2.getEntity(); + String cientityString2 = TestHelper.httpEntityToString(cihttpEntity2); + Map cimap2 = gson.fromJson(cientityString2, Map.class); + assert (cimap2.containsKey("interaction_id")); + String iid2 = (String) cimap2.get("interaction_id"); + + Response siresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.SEARCH_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + null, + matchAllSearchQuery(), + null + ); + assert (siresponse != null); + assert (TestHelper.restStatus(siresponse) == RestStatus.OK); + HttpEntity sihttpEntity = siresponse.getEntity(); + String sientityString = TestHelper.httpEntityToString(sihttpEntity); + Map simap = gson.fromJson(sientityString, Map.class); + assert (simap.containsKey("hits")); + Map hitsmap = (Map) simap.get("hits"); + assert (hitsmap.containsKey("hits")); + ArrayList hitsarray = (ArrayList) hitsmap.get("hits"); + assert (hitsarray.size() == 2); + for (Map hit : hitsarray) { + assert (hit.containsKey("_id")); + assert (hit.get("_id").equals(iid1) || hit.get("_id").equals(iid2)); + } + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java new file mode 100644 index 0000000000..dea27b4c42 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsRequest; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.gson.Gson; + +public class RestMemorySearchInteractionsActionTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testBasics() { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + assert (action.getName().equals("conversation_memory_search_interactions")); + List routes = action.routes(); + assert (routes.size() == 2); + assert (routes.get(0).equals(new Route(RestRequest.Method.POST, ActionConstants.SEARCH_INTERACTIONS_REST_PATH))); + assert (routes.get(1).equals(new Route(RestRequest.Method.GET, ActionConstants.SEARCH_INTERACTIONS_REST_PATH))); + } + + public void testPreprareRequest() throws Exception { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + RestRequest request = TestHelper.getSearchAllRestRequest(); + request.params().put(ActionConstants.CONVERSATION_ID_FIELD, "test_cid"); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchInteractionsRequest.class); + verify(client, times(1)).execute(eq(SearchInteractionsAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().source().query() instanceof MatchAllQueryBuilder); + assert (argumentCaptor.getValue().getConversationId().equals("test_cid")); + } + + public void testSearchListener_TimeOut() throws Exception { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + RestChannel channel = mock(RestChannel.class); + SearchResponse response = mock(SearchResponse.class); + doReturn(true).when(response).isTimedOut(); + doReturn("timed out").when(response).toString(); + RestResponse brr = action.search(channel).buildResponse(response); + assert (brr.status() == RestStatus.REQUEST_TIMEOUT); + } + + public void testSearchListener_Success() throws Exception { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + RestChannel channel = mock(RestChannel.class); + SearchResponse response = mock(SearchResponse.class); + doReturn(false).when(response).isTimedOut(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + doReturn(builder).when(channel).newBuilder(); + doReturn(builder).when(response).toXContent(any(), any()); + RestResponse brr = action.search(channel).buildResponse(response); + assert (brr.status() == RestStatus.OK); + } +}