From 04e66b8ac1413678ce7e2841510d8a9e9a0f3ae7 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 4 Oct 2023 12:45:49 -0700 Subject: [PATCH 01/14] add searchConversation Signed-off-by: HenryL27 --- .../memory/ConversationalMemoryHandler.java | 32 ++++++++++++++ .../memory/index/ConversationMetaIndex.java | 26 ++++++++++++ .../ml/memory/index/InteractionsIndex.java | 3 ++ ...OpenSearchConversationalMemoryHandler.java | 42 +++++++++++++++++++ 4 files changed, 103 insertions(+) diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index 18d23eff0d..6d512418d7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -19,6 +19,8 @@ import java.util.List; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -171,4 +173,34 @@ public ActionFuture createInteraction( */ public ActionFuture deleteConversation(String conversationId); + /** + * Search over conversations index + * @param request search request over the conversations index + * @param listener receives the search response + */ + public void searchConversations(SearchRequest request, ActionListener listener); + + /** + * Search over conversations index + * @param request search request over the conversations index + * @return ActionFuture for the search response + */ + public ActionFuture searchConversations(SearchRequest request); + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @param listener receives the search response + */ + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener); + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @return ActionFuture for the search response + */ + public ActionFuture searchInteractions(String conversationId, SearchRequest request); + } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index e36c296066..e6c18cf632 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -45,6 +45,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -301,4 +302,29 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onFailure(e); } } + + /** + * Search over the conversations in the index by wrapping the original search request + * If security is enabled, add a {"term": {"user": username}} to the wrapper must clause + * @param request original search request + * @param listener receives the search response for the wrapped query + */ + public void searchConversations(SearchRequest request, ActionListener listener) { + request.indices(indexName); + QueryBuilder originalQuery = request.source().query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + newQuery.must(originalQuery); + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + if (user != null) { + newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user)); + } + request.source().query(newQuery); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + client.search(request, internalListener); + } catch (Exception e) { + listener.onFailure(e); + } + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 54857b274c..2034ba1623 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -340,4 +340,7 @@ public void deleteConversation(String conversationId, ActionListener li } } + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener) { + + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index 6b33533ee2..fcfcc0c783 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -21,6 +21,8 @@ import java.util.List; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; @@ -286,4 +288,44 @@ public ActionFuture deleteConversation(String conversationId) { return fut; } + /** + * Search over conversations index + * @param request search request over the conversations index + * @param listener receives the search response + */ + public void searchConversations(SearchRequest request, ActionListener listener) { + conversationMetaIndex.searchConversations(request, listener); + } + + /** + * Search over conversations index + * @param request search request over the conversations index + * @return ActionFuture for the search response + */ + public ActionFuture searchConversations(SearchRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + searchConversations(request, fut); + return fut; + } + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @param listener receives the search response + */ + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener); + + /** + * Search over interactions of a conversation + * @param conversationId id of the conversation to search through + * @param request search request over the interactions + * @return ActionFuture for the search response + */ + public ActionFuture searchInteractions(String conversationId, SearchRequest request) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + searchInteractions(conversationId, request, fut); + return fut; + } + } From 6303d23e871a39d2a0e8a5ebce5883bc3a6878fe Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 4 Oct 2023 13:42:46 -0700 Subject: [PATCH 02/14] add searchinteractions Signed-off-by: HenryL27 --- .../ml/memory/index/InteractionsIndex.java | 25 ++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 2034ba1623..2a7797bb47 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -41,6 +41,8 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; @@ -341,6 +343,27 @@ public void deleteConversation(String conversationId, ActionListener li } public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener) { - + conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { + if(access) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + request.indices(indexName); + QueryBuilder originalQuery = request.source().query(); + BoolQueryBuilder newQuery = new BoolQueryBuilder(); + newQuery.must(originalQuery); + newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId)); + request.source().query(newQuery); + client.search(request, internalListener); + } catch (Exception e) { + listener.onFailure(e); + } + } else { + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { + listener.onFailure(e); + })); } } From 95bb9f4b6d5bd0b5b876e6175d6e70dac105b139 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 4 Oct 2023 16:37:47 -0700 Subject: [PATCH 03/14] add searchConversationsITTests Signed-off-by: HenryL27 --- .../memory/index/ConversationMetaIndex.java | 11 +- .../ml/memory/index/InteractionsIndex.java | 18 ++- ...OpenSearchConversationalMemoryHandler.java | 4 +- .../index/ConversationMetaIndexITTests.java | 111 ++++++++++++++++++ 4 files changed, 135 insertions(+), 9 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index e6c18cf632..8ba93d09e8 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -315,14 +315,19 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - client.search(request, internalListener); + client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh conversations index during search conversations ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 2a7797bb47..8d492c1207 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -342,9 +342,16 @@ public void deleteConversation(String conversationId, ActionListener li } } + /** + * Execute a search query over the interactions of a conversation by constructing a wrapper + * boolean query around the original query, AND a term query over conversation id + * @param conversationId the id of the conversation to query over + * @param request the original search request + * @param listener receives the search response from this query + */ public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener) { conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { - if(access) { + if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); request.indices(indexName); @@ -358,12 +365,13 @@ public void searchInteractions(String conversationId, SearchRequest request, Act listener.onFailure(e); } } else { - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } - }, e -> { - listener.onFailure(e); - })); + }, e -> { listener.onFailure(e); })); } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index fcfcc0c783..ba96fd00a9 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -314,7 +314,9 @@ public ActionFuture searchConversations(SearchRequest request) { * @param request search request over the interactions * @param listener receives the search response */ - public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener); + public void searchInteractions(String conversationId, SearchRequest request, ActionListener listener) { + interactionsIndex.searchInteractions(conversationId, request, listener); + } /** * Search over interactions of a conversation diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index e1a0318758..1868e819d0 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -29,6 +29,8 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.client.Requests; import org.opensearch.cluster.service.ClusterService; @@ -36,8 +38,10 @@ import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; @@ -415,4 +419,111 @@ public void testDifferentUsersCannotTouchOthersConversations() { } } + public void testCanQueryOverConversations() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener convo1 = new StepListener<>(); + index.createConversation("Henry Conversation", convo1); + + StepListener convo2 = new StepListener<>(); + convo1.whenComplete(cid -> { index.createConversation("Mehul Conversation", convo2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener search = new StepListener<>(); + convo2.whenComplete(cid -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Henry Conversation")); + index.searchConversations(request, search); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + search.whenComplete(response -> { + log.info("SEARCH RESPONSE"); + log.info(response.toString()); + cdl.countDown(); + assert (response.getHits().getAt(0).getId().equals(convo1.result())); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testCanQueryOverConversationsSecurely() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + + final String user1 = "Dhrubo"; + final String user2 = "Jing"; + contextStack.push(setUser(user1)); + + StepListener convo1 = new StepListener<>(); + index.createConversation("Dhrubo Conversation", convo1); + + StepListener convo2 = new StepListener<>(); + convo1.whenComplete(cid -> { + contextStack.push(setUser(user2)); + index.createConversation("Jing Conversation", convo2); + }, onFail); + + StepListener search1 = new StepListener<>(); + convo2.whenComplete(cid -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Dhrubo Conversation")); + index.searchConversations(request, search1); + }, onFail); + + StepListener search2 = new StepListener<>(); + search1.whenComplete(response -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new TermQueryBuilder(ConversationalIndexConstants.META_NAME_FIELD, "Jing Conversation")); + index.searchConversations(request, search2); + }, onFail); + + search2.whenComplete(response -> { + cdl.countDown(); + assert (response.getHits().getAt(0).getId().equals(convo2.result())); + assert (search1.result().getHits().getHits().length == 0); + while(!contextStack.isEmpty()) { + contextStack.pop().close(); + } + }, onFail); + + try { + cdl.await(); + threadContext.restore(); + } catch (InterruptedException e) { + log.error(e); + threadContext.restore(); + } + + } catch (Exception e) { + log.error(e); + } + } + } From 16bd75b14c5f81c4032da5aab197624f6dffcbfc Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 5 Oct 2023 09:39:26 -0700 Subject: [PATCH 04/14] add searchInteractionsITTests Signed-off-by: HenryL27 --- .../ml/memory/index/InteractionsIndex.java | 7 +- .../index/InteractionsIndexITTests.java | 85 +++++++++++++++++++ 2 files changed, 91 insertions(+), 1 deletion(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 8d492c1207..1f273678f7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -360,7 +360,12 @@ public void searchInteractions(String conversationId, SearchRequest request, Act newQuery.must(originalQuery); newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId)); request.source().query(newQuery); - client.search(request, internalListener); + client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh interactions index during search interactions ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index c23177bc2f..0bbeda03d1 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -19,16 +19,23 @@ import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CountDownLatch; import org.junit.Before; import org.opensearch.action.LatchedActionListener; import org.opensearch.action.StepListener; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.SearchHit; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchIntegTestCase; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; @@ -348,4 +355,82 @@ public void testDeleteConversation() { log.error(e); } } + + public void testSearchInteractions() { + final String conversation1 = "conversation1"; + final String conversation2 = "conversation2"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener iid1 = new StepListener<>(); + index.createInteraction(conversation1, "input about fish", "pt", "response about fish", "origin1", "lots of information about fish", iid1); + + StepListener iid2 = new StepListener<>(); + iid1 + .whenComplete( + r -> { index.createInteraction(conversation1, "input about squash", "pt", "response about squash", "origin1", "lots of information about squash", iid2); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener iid3 = new StepListener<>(); + iid2 + .whenComplete( + r -> { index.createInteraction(conversation2, "input about fish", "pt2", "response about fish", "origin1", "lots of information about fish", iid3); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener iid4 = new StepListener<>(); + iid3 + .whenComplete( + r -> { index.createInteraction(conversation1, "input about france", "pt", "response about france", "origin1", "lots of information about france", iid4); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener searchListener = new StepListener<>(); + iid4.whenComplete( + r -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchQueryBuilder(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "fish input")); + index.searchInteractions(conversation1, request, searchListener); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + searchListener.whenComplete( + response -> { + cdl.countDown(); + assert (response.getHits().getHits().length == 3); + // BM25 was being a little unpredictable here so I don't assert ordering + List ids = new ArrayList<>(3); + for (SearchHit hit : response.getHits()) { + ids.add(hit.getId()); + } + assert (ids.contains(iid1.result())); + assert (ids.contains(iid2.result())); + assert (ids.contains(iid4.result())); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } } From f6d1c0b099f1b4dde8e78a6b9c52fadb521e3684 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 5 Oct 2023 11:39:29 -0700 Subject: [PATCH 05/14] add unit tests for storage-layer search Signed-off-by: HenryL27 --- .../memory/index/ConversationMetaIndex.java | 2 + .../index/ConversationMetaIndexTests.java | 80 ++++++++++++++++++- .../memory/index/InteractionsIndexTests.java | 53 ++++++++++++ ...earchConversationalMemoryHandlerTests.java | 29 ++++++- 4 files changed, 161 insertions(+), 3 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 8ba93d09e8..ffd0d3225c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -277,6 +277,8 @@ public void checkAccess(String conversationId, ActionListener listener) if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found"); } + log.info(userstr); + log.info(User.parse(userstr)); // If security is off - User doesn't exist - you have permission if (userstr == null || User.parse(userstr) == null) { internalListener.onResponse(true); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 821d801cdf..799ff3b922 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -40,6 +40,7 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; @@ -52,7 +53,9 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SendRequestTransportException; @@ -134,6 +137,13 @@ private void setupUser(String user) { }).when(threadPool).getThreadContext(); } + private SearchRequest dummyRequest() { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchAllQueryBuilder()); + return request; + } + public void testInit_DoesNotCreateIndex() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") @@ -402,6 +412,18 @@ public void testDelete_DeleteFails_ThenFail() { assert (argCaptor.getValue().getMessage().equals("Test Fail in Delete")); } + public void testDelete_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + blanketGrantAccess(); + doThrow(new RuntimeException("Client Fail in Delete")).when(client).delete(any(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Fail in Delete")); + } + public void testCheckAccess_DoesNotExist_ThenFail() { setupUser("user"); setupRefreshSuccess(); @@ -464,7 +486,7 @@ public void testCheckAccess_ClientFails_ThenFail() { setupUser("user"); setupRefreshSuccess(); doReturn(true).when(metadata).hasIndex(anyString()); - doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any()); + doThrow(new RuntimeException("Client Test Fail")).when(client).admin(); @SuppressWarnings("unchecked") ActionListener accessListener = mock(ActionListener.class); conversationMetaIndex.checkAccess("test id", accessListener); @@ -475,11 +497,65 @@ public void testCheckAccess_ClientFails_ThenFail() { public void testCheckAccess_EmptyStringUser_ThenReturnTrue() { setupUser(null); + setupRefreshSuccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + final String id = "test_id"; + GetResponse dummyGetResponse = mock(GetResponse.class); + doReturn(true).when(dummyGetResponse).isExists(); + doReturn(id).when(dummyGetResponse).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(dummyGetResponse); + return null; + }).when(client).get(any(), any()); @SuppressWarnings("unchecked") ActionListener accessListener = mock(ActionListener.class); - conversationMetaIndex.checkAccess("test id", accessListener); + conversationMetaIndex.checkAccess(id, accessListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); verify(accessListener, times(1)).onResponse(argCaptor.capture()); assert (argCaptor.getValue()); } + + public void testCheckAccess_RefreshFails_ThenFail() { + setupUser("user"); + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testSearchConversations_RefreshFails_ThenFail() { + SearchRequest request = dummyRequest(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener searchConversationsListener = mock(ActionListener.class); + conversationMetaIndex.searchConversations(request, searchConversationsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchConversationsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testSearchConversations_ClientFails_ThenFail() { + SearchRequest request = dummyRequest(); + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Client Test Fail")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.searchConversations(request, accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Test Fail")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 0e97c7e9f6..69104c70fd 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -41,6 +41,7 @@ import org.opensearch.action.admin.indices.refresh.RefreshResponse; import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; @@ -53,8 +54,10 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SendRequestTransportException; @@ -145,6 +148,13 @@ private void setupRefreshSuccess() { }).when(indicesAdminClient).refresh(any(), any()); } + private SearchRequest dummyRequest() { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchAllQueryBuilder()); + return request; + } + public void testInit_DoesNotCreateIndex_ThenReturnFalse() { setupDoesNotMakeIndex(); @SuppressWarnings("unchecked") @@ -582,4 +592,47 @@ public void testDelete_MainFailure_ThenFail() { verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Test Failure")); } + + public void testSearch_RefreshFails_ThenFail() { + setupGrantAccess(); + SearchRequest request = dummyRequest(); + final String cid = "test_id"; + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failed during Search Refresh")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener searchInteractionsListener = mock(ActionListener.class); + interactionsIndex.searchInteractions(cid, request, searchInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed during Search Refresh")); + } + + public void testSearch_ClientFails_ThenFail() { + setupGrantAccess(); + SearchRequest request = dummyRequest(); + final String cid = "test_cid"; + doThrow(new RuntimeException("Client Failure in Search Interactions")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener searchInteractionsListener = mock(ActionListener.class); + interactionsIndex.searchInteractions(cid, request, searchInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure in Search Interactions")); + } + + public void testSearch_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("user"); + SearchRequest request = dummyRequest(); + final String cid = "test_cid"; + @SuppressWarnings("unchecked") + ActionListener searchInteractionsListener = mock(ActionListener.class); + interactionsIndex.searchInteractions(cid, request, searchInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation test_cid")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index e39513d2d8..175949b0ef 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -30,6 +30,8 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.Mock; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; import org.opensearch.common.action.ActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationMeta; @@ -59,7 +61,7 @@ public void testCreateConversation_NoName_FutureSuccess() { ActionListener al = invocation.getArgument(0); al.onResponse("cid"); return null; - }).when(conversationMetaIndex).createConversation(any(ActionListener.class)); + }).when(conversationMetaIndex).createConversation(any()); ActionFuture result = cmHandler.createConversation(); assert (result.actionGet(200).equals("cid")); } @@ -241,4 +243,29 @@ public void testDelete_AsFuture() { ActionFuture result = cmHandler.deleteConversation("cid"); assert (result.actionGet(200)); } + + public void testSearchConversations_Future() { + SearchRequest request = mock(SearchRequest.class); + SearchResponse response = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(conversationMetaIndex).searchConversations(any(), any()); + ActionFuture result = cmHandler.searchConversations(request); + assert (result.actionGet().equals(response)); + } + + public void testSearchInteractions_Future() { + SearchRequest request = mock(SearchRequest.class); + SearchResponse response = mock(SearchResponse.class); + String cid = "cid"; + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(interactionsIndex).searchInteractions(any(), any(), any()); + ActionFuture result = cmHandler.searchInteractions(cid, request); + assert (result.actionGet().equals(response)); + } } From 33d0c3b1d86d388d1dce20a03dfc945a266c11e3 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 5 Oct 2023 15:17:59 -0700 Subject: [PATCH 06/14] add Search transport actions and tests Signed-off-by: HenryL27 --- .../SearchConversationsAction.java | 32 +++++ .../SearchConversationsTransportAction.java | 78 +++++++++++ .../SearchInteractionsAction.java | 32 +++++ .../SearchInteractionsRequest.java | 54 ++++++++ .../SearchInteractionsTransportAction.java | 77 +++++++++++ .../conversation/ConversationActionTests.java | 1 + .../conversation/InteractionActionTests.java | 5 +- ...archConversationsTransportActionTests.java | 121 +++++++++++++++++ .../SearchInteractionsRequestTests.java | 65 ++++++++++ ...archInteractionsTransportActionsTests.java | 122 ++++++++++++++++++ 10 files changed, 586 insertions(+), 1 deletion(-) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java new file mode 100644 index 0000000000..38b19009ac --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsAction.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +public class SearchConversationsAction extends ActionType { + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/search"; + /** Instance of this */ + public static final SearchConversationsAction INSTANCE = new SearchConversationsAction(); + + private SearchConversationsAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java new file mode 100644 index 0000000000..02e1bb78ed --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchConversationsTransportAction extends HandledTransportAction { + + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public SearchConversationsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + ClusterService clusterService + ) { + super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new); + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, SearchRequest request, ActionListener actionListener) { + if(!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + cmHandler.searchConversations(request, actionListener); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java new file mode 100644 index 0000000000..9386d6b674 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsAction.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; +import org.opensearch.action.search.SearchResponse; + +public class SearchInteractionsAction extends ActionType { + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/interaction/search"; + /** Instance of this */ + public static final SearchInteractionsAction INSTANCE = new SearchInteractionsAction(); + + private SearchInteractionsAction() { + super(NAME, SearchResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java new file mode 100644 index 0000000000..7ecb498516 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class SearchInteractionsRequest extends SearchRequest { + + @Setter + @Getter + private String conversationId; + + public SearchInteractionsRequest(String conversationId, SearchRequest request) { + super(request); + this.conversationId = conversationId; + } + + public SearchInteractionsRequest(StreamInput in) throws IOException { + super(in); + log.info("Got here"); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(conversationId); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java new file mode 100644 index 0000000000..adfc208e2b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +public class SearchInteractionsTransportAction extends HandledTransportAction { + + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public SearchInteractionsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + ClusterService clusterService + ) { + super(SearchInteractionsAction.NAME, transportService, actionFilters, SearchInteractionsRequest::new); + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, SearchInteractionsRequest request, ActionListener actionListener) { + if(!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + cmHandler.searchInteractions(request.getConversationId(), request, actionListener); + } + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java index 2975cd4c1d..65cbb7dfea 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java @@ -24,5 +24,6 @@ public void testActions() { assert (CreateConversationAction.INSTANCE instanceof CreateConversationAction); assert (DeleteConversationAction.INSTANCE instanceof DeleteConversationAction); assert (GetConversationsAction.INSTANCE instanceof GetConversationsAction); + assert (SearchConversationsAction.INSTANCE instanceof SearchConversationsAction); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java index 9002796bbc..89a6fae6a3 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java @@ -17,9 +17,12 @@ */ package org.opensearch.ml.memory.action.conversation; -public class InteractionActionTests { +import org.opensearch.test.OpenSearchTestCase; + +public class InteractionActionTests extends OpenSearchTestCase { public void testActions() { assert (CreateInteractionAction.INSTANCE instanceof CreateInteractionAction); assert (GetInteractionsAction.INSTANCE instanceof GetInteractionsAction); + assert (SearchInteractionsAction.INSTANCE instanceof SearchInteractionsAction); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java new file mode 100644 index 0000000000..e086b9d269 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -0,0 +1,121 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class SearchConversationsTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + @Mock + SearchRequest request; + + SearchConversationsTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + } + + public void testEnabled_ThenSucceed() { + SearchResponse response = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(cmHandler).searchConversations(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().equals(response)); + } + + public void testDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java new file mode 100644 index 0000000000..b25af1beaa --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.indices.IndicesModule; +import org.opensearch.search.SearchModule; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + + +public class SearchInteractionsRequestTests extends OpenSearchTestCase { + + protected NamedWriteableRegistry namedWriteableRegistry; + + public void setUp() throws Exception { + super.setUp(); + IndicesModule indicesModule = new IndicesModule(Collections.emptyList()); + SearchModule searchModule = new SearchModule(Settings.EMPTY, List.of()); + List entries = new ArrayList<>(); + entries.addAll(indicesModule.getNamedWriteables()); + entries.addAll(searchModule.getNamedWriteables()); + namedWriteableRegistry = new NamedWriteableRegistry(entries); + } + + public void testConstructorsAndStreaming() throws IOException { + SearchRequest original = new SearchRequest(); + original.source(new SearchSourceBuilder()); + original.source().query(new MatchAllQueryBuilder()); + + SearchInteractionsRequest request = new SearchInteractionsRequest("test_cid", original); + assert (request instanceof SearchRequest); + assert (request.getConversationId().equals("test_cid")); + assert (request.validate() == null); + + SearchInteractionsRequest newRequest = copyWriteable(request, namedWriteableRegistry, SearchInteractionsRequest::new); + assert (newRequest.getConversationId().equals("test_cid")); + assert (newRequest.validate() == null); + assert (newRequest.equals(request)); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java new file mode 100644 index 0000000000..71cf791c6c --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class SearchInteractionsTransportActionsTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + @Mock + SearchInteractionsRequest request; + + SearchInteractionsTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + when(this.request.getConversationId()).thenReturn("test_cid"); + + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + } + + public void testFeatureEnabled_ThenSucceed() { + SearchResponse response = mock(SearchResponse.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(response); + return null; + }).when(cmHandler).searchInteractions(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().equals(response)); + } + + public void testDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } + +} From e8b160a0fe46b8bcb9f68d812260ba6621023a36 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 5 Oct 2023 17:10:27 -0700 Subject: [PATCH 07/14] add rest search actions Signed-off-by: HenryL27 --- .../common/conversation/ActionConstants.java | 15 +- .../SearchConversationsTransportAction.java | 4 +- .../SearchInteractionsRequest.java | 2 +- .../SearchInteractionsTransportAction.java | 4 +- ...archConversationsTransportActionTests.java | 2 +- .../SearchInteractionsRequestTests.java | 3 +- ...archInteractionsTransportActionsTests.java | 2 +- .../index/ConversationMetaIndexITTests.java | 2 +- .../index/InteractionsIndexITTests.java | 142 +++++++++++------- .../RestMemorySearchConversationsAction.java | 43 ++++++ .../RestMemorySearchInteractionsAction.java | 83 ++++++++++ ...RestMemorySearchConversationsActionIT.java | 22 +++ ...tMemorySearchConversationsActionTests.java | 79 ++++++++++ .../RestMemorySearchInteractionsAction.java | 22 +++ .../RestMemorySearchInteractionsActionIT.java | 22 +++ 15 files changed, 375 insertions(+), 72 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 5bb8334bc1..17a4bbe0b3 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -51,16 +51,21 @@ public class ActionConstants { /** name of success field in all requests */ public final static String SUCCESS_FIELD = "success"; + private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for create conversation */ - public final static String CREATE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation"; + public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH; /** path for list conversations */ - public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation"; + public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH; /** path for put interaction */ - public final static String CREATE_INTERACTION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; /** path for get interactions */ - public final static String GET_INTERACTIONS_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; /** path for delete conversation */ - public final static String DELETE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + /** path for search conversations */ + public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; + /** path for search interactions */ + public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; /** default max results returned by get operations */ public final static int DEFAULT_MAX_RESULTS = 10; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java index 02e1bb78ed..c528ac4391 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -32,7 +32,7 @@ import org.opensearch.transport.TransportService; public class SearchConversationsTransportAction extends HandledTransportAction { - + private ConversationalMemoryHandler cmHandler; private volatile boolean featureIsEnabled; @@ -62,7 +62,7 @@ public SearchConversationsTransportAction( @Override public void doExecute(Task task, SearchRequest request, ActionListener actionListener) { - if(!featureIsEnabled) { + if (!featureIsEnabled) { actionListener .onFailure( new OpenSearchException( diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java index 7ecb498516..ac55350d7f 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequest.java @@ -29,7 +29,7 @@ @Log4j2 public class SearchInteractionsRequest extends SearchRequest { - + @Setter @Getter private String conversationId; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java index adfc208e2b..80f56d28ca 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java @@ -31,7 +31,7 @@ import org.opensearch.transport.TransportService; public class SearchInteractionsTransportAction extends HandledTransportAction { - + private ConversationalMemoryHandler cmHandler; private volatile boolean featureIsEnabled; @@ -61,7 +61,7 @@ public SearchInteractionsTransportAction( @Override public void doExecute(Task task, SearchInteractionsRequest request, ActionListener actionListener) { - if(!featureIsEnabled) { + if (!featureIsEnabled) { actionListener .onFailure( new OpenSearchException( diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java index e086b9d269..a42188ce8b 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -49,7 +49,7 @@ import org.opensearch.transport.TransportService; public class SearchConversationsTransportActionTests extends OpenSearchTestCase { - + @Mock ThreadPool threadPool; diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java index b25af1beaa..af7dc33c9c 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsRequestTests.java @@ -31,7 +31,6 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; - public class SearchInteractionsRequestTests extends OpenSearchTestCase { protected NamedWriteableRegistry namedWriteableRegistry; @@ -45,7 +44,7 @@ public void setUp() throws Exception { entries.addAll(searchModule.getNamedWriteables()); namedWriteableRegistry = new NamedWriteableRegistry(entries); } - + public void testConstructorsAndStreaming() throws IOException { SearchRequest original = new SearchRequest(); original.source(new SearchSourceBuilder()); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java index 71cf791c6c..85e0c1bf0d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java @@ -48,7 +48,7 @@ import org.opensearch.transport.TransportService; public class SearchInteractionsTransportActionsTests extends OpenSearchTestCase { - + @Mock ThreadPool threadPool; diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index 1868e819d0..dc75e1245d 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -508,7 +508,7 @@ public void testCanQueryOverConversationsSecurely() { cdl.countDown(); assert (response.getHits().getAt(0).getId().equals(convo2.result())); assert (search1.result().getHits().getHits().length == 0); - while(!contextStack.isEmpty()) { + while (!contextStack.isEmpty()) { contextStack.pop().close(); } }, onFail); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index 0bbeda03d1..f089a48a75 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -361,71 +361,99 @@ public void testSearchInteractions() { final String conversation2 = "conversation2"; CountDownLatch cdl = new CountDownLatch(1); StepListener iid1 = new StepListener<>(); - index.createInteraction(conversation1, "input about fish", "pt", "response about fish", "origin1", "lots of information about fish", iid1); + index + .createInteraction( + conversation1, + "input about fish", + "pt", + "response about fish", + "origin1", + "lots of information about fish", + iid1 + ); StepListener iid2 = new StepListener<>(); - iid1 - .whenComplete( - r -> { index.createInteraction(conversation1, "input about squash", "pt", "response about squash", "origin1", "lots of information about squash", iid2); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid1.whenComplete(r -> { + index + .createInteraction( + conversation1, + "input about squash", + "pt", + "response about squash", + "origin1", + "lots of information about squash", + iid2 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener iid3 = new StepListener<>(); - iid2 - .whenComplete( - r -> { index.createInteraction(conversation2, "input about fish", "pt2", "response about fish", "origin1", "lots of information about fish", iid3); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); + iid2.whenComplete(r -> { + index + .createInteraction( + conversation2, + "input about fish", + "pt2", + "response about fish", + "origin1", + "lots of information about fish", + iid3 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); StepListener iid4 = new StepListener<>(); - iid3 - .whenComplete( - r -> { index.createInteraction(conversation1, "input about france", "pt", "response about france", "origin1", "lots of information about france", iid4); }, - e -> { - cdl.countDown(); - log.error(e); - assert (false); - } - ); - + iid3.whenComplete(r -> { + index + .createInteraction( + conversation1, + "input about france", + "pt", + "response about france", + "origin1", + "lots of information about france", + iid4 + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + StepListener searchListener = new StepListener<>(); - iid4.whenComplete( - r -> { - SearchRequest request = new SearchRequest(); - request.source(new SearchSourceBuilder()); - request.source().query(new MatchQueryBuilder(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "fish input")); - index.searchInteractions(conversation1, request, searchListener); - }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }); + iid4.whenComplete(r -> { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.source().query(new MatchQueryBuilder(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "fish input")); + index.searchInteractions(conversation1, request, searchListener); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); - searchListener.whenComplete( - response -> { - cdl.countDown(); - assert (response.getHits().getHits().length == 3); - // BM25 was being a little unpredictable here so I don't assert ordering - List ids = new ArrayList<>(3); - for (SearchHit hit : response.getHits()) { - ids.add(hit.getId()); - } - assert (ids.contains(iid1.result())); - assert (ids.contains(iid2.result())); - assert (ids.contains(iid4.result())); - }, e -> { - cdl.countDown(); - log.error(e); - assert (false); - }); + searchListener.whenComplete(response -> { + cdl.countDown(); + assert (response.getHits().getHits().length == 3); + // BM25 was being a little unpredictable here so I don't assert ordering + List ids = new ArrayList<>(3); + for (SearchHit hit : response.getHits()) { + ids.add(hit.getId()); + } + assert (ids.contains(iid1.result())); + assert (ids.contains(iid2.result())); + assert (ids.contains(iid4.result())); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); try { cdl.await(); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java new file mode 100644 index 0000000000..5beee29c42 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchConversationsAction.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchConversationsAction extends AbstractMLSearchAction { + private static final String SEARCH_CONVERSATIONS_NAME = "conversation_memory_search_conversations"; + + public RestMemorySearchConversationsAction() { + super( + ImmutableList.of(ActionConstants.SEARCH_CONVERSATIONS_REST_PATH), + ConversationalIndexConstants.META_INDEX_NAME, + ConversationMeta.class, + SearchConversationsAction.INSTANCE + ); + } + + @Override + public String getName() { + return SEARCH_CONVERSATIONS_NAME; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java new file mode 100644 index 0000000000..063e2502ce --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java @@ -0,0 +1,83 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; +import static org.opensearch.ml.utils.RestActionUtils.getSourceContext; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.rest.action.RestResponseListener; +import org.opensearch.search.builder.SearchSourceBuilder; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchInteractionsAction extends BaseRestHandler { + private static final String SEARCH_INTERACTIONS_NAME = "conversation_memory_search_interactions"; + + @Override + public List routes() { + return ImmutableList + .of( + new Route(RestRequest.Method.POST, ActionConstants.SEARCH_INTERACTIONS_REST_PATH), + new Route(RestRequest.Method.GET, ActionConstants.SEARCH_INTERACTIONS_REST_PATH) + ); + } + + @Override + public String getName() { + return SEARCH_INTERACTIONS_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.parseXContent(request.contentOrSourceParamParser()); + searchSourceBuilder.fetchSource(getSourceContext(request, searchSourceBuilder)); + searchSourceBuilder.seqNoAndPrimaryTerm(true).version(true); + SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder); + SearchInteractionsRequest siRequest = new SearchInteractionsRequest(conversationId, searchRequest); + return channel -> client.execute(SearchInteractionsAction.INSTANCE, siRequest, search(channel)); + } + + protected RestResponseListener search(RestChannel channel) { + return new RestResponseListener(channel) { + @Override + public RestResponse buildResponse(SearchResponse response) throws Exception { + if (response.isTimedOut()) { + return new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, response.toString()); + } + return new BytesRestResponse(RestStatus.OK, response.toXContent(channel.newBuilder(), EMPTY_PARAMS)); + } + }; + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java new file mode 100644 index 0000000000..5a0c0aca12 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +public class RestMemorySearchConversationsActionIT { + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java new file mode 100644 index 0000000000..fb1fcbc1b1 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java @@ -0,0 +1,79 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class RestMemorySearchConversationsActionTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testBasics() { + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + assert (action.getName().equals("conversation_memory_search_conversations")); + List routes = action.routes(); + assert (routes.size() == 2); + assert (routes.get(0).equals(new Route(RestRequest.Method.POST, ActionConstants.SEARCH_CONVERSATIONS_REST_PATH))); + assert (routes.get(1).equals(new Route(RestRequest.Method.GET, ActionConstants.SEARCH_CONVERSATIONS_REST_PATH))); + } + + public void testPreprareRequest() throws Exception { + RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent(new BytesArray(gson.toJson(Map.of("query", Map.of("match_all", Map.of())))), MediaTypeRegistry.JSON) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchRequest.class); + verify(client, times(1)).execute(eq(SearchConversationsAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().source().query() instanceof MatchAllQueryBuilder); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java new file mode 100644 index 0000000000..811cb40062 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +public class RestMemorySearchInteractionsAction { + +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java new file mode 100644 index 0000000000..a345e08fac --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +public class RestMemorySearchInteractionsActionIT { + +} From baf6619cf86642d1648e3c140c76bf276793d6de Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Fri, 6 Oct 2023 15:31:13 -0700 Subject: [PATCH 08/14] add search rest actions Signed-off-by: HenryL27 --- .../common/conversation/ActionConstants.java | 10 +- plugin/build.gradle | 1 + .../ml/plugin/MachineLearningPlugin.java | 16 ++- ...RestMemorySearchConversationsActionIT.java | 62 ++++++++- ...tMemorySearchConversationsActionTests.java | 10 +- .../RestMemorySearchInteractionsAction.java | 22 --- .../RestMemorySearchInteractionsActionIT.java | 129 +++++++++++++++++- ...stMemorySearchInteractionsActionTests.java | 103 ++++++++++++++ 8 files changed, 314 insertions(+), 39 deletions(-) delete mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 17a4bbe0b3..96431e7679 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -53,15 +53,15 @@ public class ActionConstants { private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for create conversation */ - public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH; + public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; /** path for list conversations */ - public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH; + public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; /** path for put interaction */ - public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; /** path for get interactions */ - public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; /** path for delete conversation */ - public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete"; /** path for search conversations */ public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; /** path for search interactions */ diff --git a/plugin/build.gradle b/plugin/build.gradle index 042ac13423..567f8bd0d7 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -257,6 +257,7 @@ jacocoTestReport { xml.getRequired().set(true) csv.getRequired().set(false) html.getRequired().set(true) + html.outputLocation = layout.buildDirectory.dir('jacocoHtml') } dependsOn test diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index f5ba454c4d..8f23c84f4e 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -143,6 +143,10 @@ import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; +import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.memory.action.conversation.SearchConversationsTransportAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsTransportAction; import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -179,6 +183,8 @@ import org.opensearch.ml.rest.RestMemoryDeleteConversationAction; import org.opensearch.ml.rest.RestMemoryGetConversationsAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; +import org.opensearch.ml.rest.RestMemorySearchConversationsAction; +import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -305,7 +311,9 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(CreateInteractionAction.INSTANCE, CreateInteractionTransportAction.class), new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class), new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class), - new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class) + new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class), + new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), + new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class) ); } @@ -557,6 +565,8 @@ public List getRestHandlers( RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); + RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(); + RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); return ImmutableList .of( restMLStatsAction, @@ -591,7 +601,9 @@ public List getRestHandlers( restCreateInteractionAction, restListInteractionsAction, restDeleteConversationAction, - restMLUpdateConnectorAction + restMLUpdateConnectorAction, + restSearchConversationsAction, + restSearchInteractionsAction ); } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java index 5a0c0aca12..264ef5ea24 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionIT.java @@ -17,6 +17,66 @@ */ package org.opensearch.ml.rest; -public class RestMemorySearchConversationsActionIT { +import static org.opensearch.ml.utils.TestData.matchAllSearchQuery; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchConversationsActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testSearchConversations_Successful() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String id = (String) ccmap.get("conversation_id"); + + Response scresponse = TestHelper + .makeRequest(client(), "POST", ActionConstants.SEARCH_CONVERSATIONS_REST_PATH, null, matchAllSearchQuery(), null); + assert (scresponse != null); + assert (TestHelper.restStatus(scresponse) == RestStatus.OK); + HttpEntity schttpEntity = scresponse.getEntity(); + String scentityString = TestHelper.httpEntityToString(schttpEntity); + Map scmap = gson.fromJson(scentityString, Map.class); + assert (scmap.containsKey("hits")); + Map hitsmap = (Map) scmap.get("hits"); + assert (hitsmap.containsKey("hits")); + ArrayList hitsarray = (ArrayList) hitsmap.get("hits"); + assert (hitsarray.size() == 1); + for (Map hit : hitsarray) { + assert (hit.containsKey("_id")); + assert (hit.get("_id").equals(id)); + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java index fb1fcbc1b1..294e3deac4 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchConversationsActionTests.java @@ -24,23 +24,19 @@ import static org.mockito.Mockito.verify; import java.util.List; -import java.util.Map; import org.junit.Before; import org.mockito.ArgumentCaptor; import org.opensearch.action.search.SearchRequest; import org.opensearch.client.node.NodeClient; -import org.opensearch.core.common.bytes.BytesArray; -import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; +import org.opensearch.ml.utils.TestHelper; import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.rest.FakeRestRequest; import com.google.gson.Gson; @@ -64,9 +60,7 @@ public void testBasics() { public void testPreprareRequest() throws Exception { RestMemorySearchConversationsAction action = new RestMemorySearchConversationsAction(); - RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withContent(new BytesArray(gson.toJson(Map.of("query", Map.of("match_all", Map.of())))), MediaTypeRegistry.JSON) - .build(); + RestRequest request = TestHelper.getSearchAllRestRequest(); NodeClient client = mock(NodeClient.class); RestChannel channel = mock(RestChannel.class); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java deleted file mode 100644 index 811cb40062..0000000000 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsAction.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2023 Aryn - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.opensearch.ml.rest; - -public class RestMemorySearchInteractionsAction { - -} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java index a345e08fac..9de93ac103 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionIT.java @@ -17,6 +17,133 @@ */ package org.opensearch.ml.rest; -public class RestMemorySearchInteractionsActionIT { +import static org.opensearch.ml.utils.TestData.matchAllSearchQuery; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemorySearchInteractionsActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testSearchInteractions_Successfull() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Map params1 = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "fish metadata" + ); + Response ciresponse1 = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params1), + null + ); + assert (ciresponse1 != null); + assert (TestHelper.restStatus(ciresponse1) == RestStatus.OK); + HttpEntity cihttpEntity1 = ciresponse1.getEntity(); + String cientityString1 = TestHelper.httpEntityToString(cihttpEntity1); + Map cimap1 = gson.fromJson(cientityString1, Map.class); + assert (cimap1.containsKey("interaction_id")); + String iid1 = (String) cimap1.get("interaction_id"); + + Map params2 = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "france metadata" + ); + Response ciresponse2 = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params2), + null + ); + assert (ciresponse2 != null); + assert (TestHelper.restStatus(ciresponse2) == RestStatus.OK); + HttpEntity cihttpEntity2 = ciresponse2.getEntity(); + String cientityString2 = TestHelper.httpEntityToString(cihttpEntity2); + Map cimap2 = gson.fromJson(cientityString2, Map.class); + assert (cimap2.containsKey("interaction_id")); + String iid2 = (String) cimap2.get("interaction_id"); + + Response siresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.SEARCH_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + null, + matchAllSearchQuery(), + null + ); + assert (siresponse != null); + assert (TestHelper.restStatus(siresponse) == RestStatus.OK); + HttpEntity sihttpEntity = siresponse.getEntity(); + String sientityString = TestHelper.httpEntityToString(sihttpEntity); + Map simap = gson.fromJson(sientityString, Map.class); + assert (simap.containsKey("hits")); + Map hitsmap = (Map) simap.get("hits"); + assert (hitsmap.containsKey("hits")); + ArrayList hitsarray = (ArrayList) hitsmap.get("hits"); + assert (hitsarray.size() == 2); + for (Map hit : hitsarray) { + assert (hit.containsKey("_id")); + assert (hit.get("_id").equals(iid1) || hit.get("_id").equals(iid2)); + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java new file mode 100644 index 0000000000..dea27b4c42 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemorySearchInteractionsActionTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.node.NodeClient; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsAction; +import org.opensearch.ml.memory.action.conversation.SearchInteractionsRequest; +import org.opensearch.ml.utils.TestHelper; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.RestResponse; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.gson.Gson; + +public class RestMemorySearchInteractionsActionTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testBasics() { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + assert (action.getName().equals("conversation_memory_search_interactions")); + List routes = action.routes(); + assert (routes.size() == 2); + assert (routes.get(0).equals(new Route(RestRequest.Method.POST, ActionConstants.SEARCH_INTERACTIONS_REST_PATH))); + assert (routes.get(1).equals(new Route(RestRequest.Method.GET, ActionConstants.SEARCH_INTERACTIONS_REST_PATH))); + } + + public void testPreprareRequest() throws Exception { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + RestRequest request = TestHelper.getSearchAllRestRequest(); + request.params().put(ActionConstants.CONVERSATION_ID_FIELD, "test_cid"); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SearchInteractionsRequest.class); + verify(client, times(1)).execute(eq(SearchInteractionsAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().source().query() instanceof MatchAllQueryBuilder); + assert (argumentCaptor.getValue().getConversationId().equals("test_cid")); + } + + public void testSearchListener_TimeOut() throws Exception { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + RestChannel channel = mock(RestChannel.class); + SearchResponse response = mock(SearchResponse.class); + doReturn(true).when(response).isTimedOut(); + doReturn("timed out").when(response).toString(); + RestResponse brr = action.search(channel).buildResponse(response); + assert (brr.status() == RestStatus.REQUEST_TIMEOUT); + } + + public void testSearchListener_Success() throws Exception { + RestMemorySearchInteractionsAction action = new RestMemorySearchInteractionsAction(); + RestChannel channel = mock(RestChannel.class); + SearchResponse response = mock(SearchResponse.class); + doReturn(false).when(response).isTimedOut(); + XContentBuilder builder = XContentFactory.jsonBuilder(); + doReturn(builder).when(channel).newBuilder(); + doReturn(builder).when(response).toXContent(any(), any()); + RestResponse brr = action.search(channel).buildResponse(response); + assert (brr.status() == RestStatus.OK); + } +} From 9c3a693583c95951b2c4a25cd239db6977370fdf Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 9 Oct 2023 16:22:07 -0700 Subject: [PATCH 09/14] Add singular get actions at storage layer Signed-off-by: HenryL27 --- .../memory/ConversationalMemoryHandler.java | 30 +++++ .../memory/index/ConversationMetaIndex.java | 53 ++++++++- .../ml/memory/index/InteractionsIndex.java | 52 +++++++++ ...OpenSearchConversationalMemoryHandler.java | 42 +++++++ .../index/ConversationMetaIndexITTests.java | 104 ++++++++++++++++++ .../index/ConversationMetaIndexTests.java | 78 +++++++++++++ .../index/InteractionsIndexITTests.java | 47 ++++++++ .../memory/index/InteractionsIndexTests.java | 97 ++++++++++++++++ 8 files changed, 501 insertions(+), 2 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java index 6d512418d7..42cece3f2e 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -203,4 +203,34 @@ public ActionFuture createInteraction( */ public ActionFuture searchInteractions(String conversationId, SearchRequest request); + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @param listener receives the conversationMeta object + */ + public void getConversation(String conversationId, ActionListener listener); + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @return ActionFuture for the conversationMeta object + */ + public ActionFuture getConversation(String conversationId); + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @param listener receives the interaction + */ + public void getInteraction(String conversationId, String interactionId, ActionListener listener); + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @return ActionFuture for the interaction + */ + public ActionFuture getInteraction(String conversationId, String interactionId); + } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index ffd0d3225c..5480b8ad96 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -45,6 +45,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; @@ -277,8 +278,6 @@ public void checkAccess(String conversationId, ActionListener listener) if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found"); } - log.info(userstr); - log.info(User.parse(userstr)); // If security is off - User doesn't exist - you have permission if (userstr == null || User.parse(userstr) == null) { internalListener.onResponse(true); @@ -334,4 +333,54 @@ public void searchConversations(SearchRequest request, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener + .onFailure(new IndexNotFoundException("cannot get conversation since the conversation index does not exist", indexName)); + return; + } + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(indexName).id(conversationId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the conversation doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { + throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found"); + } + ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); + // If no security, return conversation + if (userstr == null || User.parse(userstr) == null) { + internalListener.onResponse(conversation); + return; + } + // If security and correct user, return conversation + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + if (user.equals(conversation.getUser())) { + internalListener.onResponse(conversation); + return; + } + // Otherwise you don't have permission + internalListener + .onFailure( + new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId) + ); + }, e -> { internalListener.onFailure(e); }); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> { + log.error("Failed to refresh conversations index during get conversation ", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 1f273678f7..30e47019fd 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -25,10 +25,13 @@ import org.opensearch.OpenSearchSecurityException; import org.opensearch.OpenSearchWrapperException; import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; @@ -41,6 +44,7 @@ import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; @@ -379,4 +383,52 @@ public void searchInteractions(String conversationId, SearchRequest request, Act } }, e -> { listener.onFailure(e); })); } + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @param listener receives the interaction + */ + public void getInteraction(String conversationId, String interactionId, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener.onFailure(new IndexNotFoundException("cannot get interaction since the interactions index does not exist", indexName)); + return; + } + conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { + if (access) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + GetRequest request = Requests.getRequest(indexName).id(interactionId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the conversation doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { + throw new ResourceNotFoundException("Interaction [" + interactionId + "] not found"); + } + Interaction interaction = Interaction.fromMap(interactionId, getResponse.getSourceAsMap()); + internalListener.onResponse(interaction); + }, e -> { internalListener.onFailure(e); }); + client + .admin() + .indices() + .refresh( + Requests.refreshRequest(indexName), + ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + }) + ); + } catch (Exception e) { + listener.onFailure(e); + } + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); + } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java index ba96fd00a9..c1997be829 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -330,4 +330,46 @@ public ActionFuture searchInteractions(String conversationId, Se return fut; } + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @param listener receives the conversationMeta object + */ + public void getConversation(String conversationId, ActionListener listener) { + conversationMetaIndex.getConversation(conversationId, listener); + } + + /** + * Get a single ConversationMeta object + * @param conversationId id of the conversation to get + * @return ActionFuture for the conversationMeta object + */ + public ActionFuture getConversation(String conversationId) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + getConversation(conversationId, fut); + return fut; + } + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @param listener receives the interaction + */ + public void getInteraction(String conversationId, String interactionId, ActionListener listener) { + interactionsIndex.getInteraction(conversationId, interactionId, listener); + } + + /** + * Get a single interaction + * @param conversationId id of the conversation this interaction belongs to + * @param interactionId id of this interaction + * @return ActionFuture for the interaction + */ + public ActionFuture getInteraction(String conversationId, String interactionId) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + getInteraction(conversationId, interactionId, fut); + return fut; + } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java index dc75e1245d..fc605e3fb0 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -526,4 +526,108 @@ public void testCanQueryOverConversationsSecurely() { } } + public void testCanGetAConversationById() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener cid1 = new StepListener<>(); + index.createConversation("convo1", cid1); + + StepListener cid2 = new StepListener<>(); + cid1.whenComplete(cid -> { index.createConversation("convo2", cid2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener get1 = new StepListener<>(); + cid2.whenComplete(cid -> { index.getConversation(cid1.result(), get1); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener get2 = new StepListener<>(); + get1.whenComplete(convo1 -> { index.getConversation(cid2.result(), get2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + get2.whenComplete(convo2 -> { + assert (cid1.result().equals(get1.result().getId())); + assert (cid2.result().equals(get2.result().getId())); + assert (get1.result().getName().equals("convo1")); + assert (get2.result().getName().equals("convo2")); + cdl.countDown(); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testCanGetAConversationByIdSecurely() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + + final String user1 = "Austin"; + final String user2 = "Yaliang"; + contextStack.push(setUser(user1)); + + StepListener cid1 = new StepListener<>(); + index.createConversation("Austin Convo", cid1); + + StepListener cid2 = new StepListener<>(); + cid1.whenComplete(cid -> { + contextStack.push(setUser(user2)); + index.createConversation("Yaliang Convo", cid2); + }, onFail); + + StepListener get2 = new StepListener<>(); + cid2.whenComplete(cid -> { index.getConversation(cid2.result(), get2); }, onFail); + + StepListener get1 = new StepListener<>(); + get2.whenComplete(convo -> { index.getConversation(cid1.result(), get1); }, onFail); + + get1.whenComplete(convo -> { + while (!contextStack.isEmpty()) { + contextStack.pop().close(); + } + cdl.countDown(); + assert (false); + }, e -> { + cdl.countDown(); + assert (e.getMessage().startsWith("User [Yaliang] does not have access to conversation")); + assert (get2.result().getName().equals("Yaliang Convo")); + assert (get2.result().getId().equals(cid2.result())); + }); + + try { + cdl.await(); + threadContext.restore(); + } catch (InterruptedException e) { + log.error(e); + threadContext.restore(); + } + + } catch (Exception e) { + log.error(e); + } + } + } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 799ff3b922..5445fd6213 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -558,4 +558,82 @@ public void testSearchConversations_ClientFails_ThenFail() { verify(accessListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("Client Test Fail")); } + + public void testGetConversation_NoIndex_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals( + "no such index [.plugins-ml-conversation-meta] and cannot get conversation since the conversation index does not exist" + )); + } + + public void testGetConversation_ResponseNotExist_ThenFail() { + setupRefreshSuccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + GetResponse response = mock(GetResponse.class); + doReturn(false).when(response).isExists(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Conversation [tester_id] not found")); + } + + public void testGetConversation_WrongId_ThenFail() { + setupRefreshSuccess(); + doReturn(true).when(metadata).hasIndex(anyString()); + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn("wrong id").when(response).getId(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Conversation [tester_id] not found")); + } + + public void testGetConversation_RefreshFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testGetConversation_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Clietn Failure")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + conversationMetaIndex.getConversation("tester_id", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Clietn Failure")); + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java index f089a48a75..0c0791fb23 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -461,4 +461,51 @@ public void testSearchInteractions() { log.error(e); } } + + public void testGetInteractionById() { + final String conversation = "test-conversation"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener iid1 = new StepListener<>(); + index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", iid1); + + StepListener iid2 = new StepListener<>(); + iid1 + .whenComplete( + iid -> { index.createInteraction(conversation, "test input2", "pt", "test response", "test origin", "metadata", iid2); }, + e -> { + cdl.countDown(); + log.error(e); + assert false; + } + ); + + StepListener get1 = new StepListener<>(); + iid2.whenComplete(iid -> { index.getInteraction(conversation, iid1.result(), get1); }, e -> { + cdl.countDown(); + log.error(e); + }); + + StepListener get2 = new StepListener<>(); + get1.whenComplete(interaction1 -> { index.getInteraction(conversation, iid2.result(), get2); }, e -> { + cdl.countDown(); + log.error(e); + }); + + get2.whenComplete(interaction2 -> { + assert (get1.result().getId().equals(iid1.result())); + assert (get1.result().getInput().equals("test input")); + assert (get2.result().getId().equals(iid2.result())); + assert (get2.result().getInput().equals("test input2")); + cdl.countDown(); + }, e -> { + cdl.countDown(); + log.error(e); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 69104c70fd..54ec44df77 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -40,6 +40,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.refresh.RefreshResponse; import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.get.GetResponse; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -62,6 +63,9 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SendRequestTransportException; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class InteractionsIndexTests extends OpenSearchTestCase { @Mock Client client; @@ -635,4 +639,97 @@ public void testSearch_NoAccess_ThenFail() { verify(searchInteractionsListener, times(1)).onFailure(argCaptor.capture()); assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation test_cid")); } + + public void testGetSg_NoIndex_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals( + "no such index [.plugins-ml-conversation-interactions] and cannot get interaction since the interactions index does not exist" + )); + } + + public void testGetSg_InteractionNotExist_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + setupRefreshSuccess(); + GetResponse response = mock(GetResponse.class); + doReturn(false).when(response).isExists(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Interaction [iid] not found")); + } + + public void testGetSg_WrongId_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + setupRefreshSuccess(); + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn("wrong id").when(response).getId(); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Interaction [iid] not found")); + } + + public void testGetSg_RefreshFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failed during Sg Get Refresh")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed during Sg Get Refresh")); + } + + public void testGetSg_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doThrow(new RuntimeException("Client Failure in Sg Get")).when(client).admin(); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure in Sg Get")); + } + + public void testGetSg_NoAccess_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("Henry"); + @SuppressWarnings("unchecked") + ActionListener getListener = mock(ActionListener.class); + interactionsIndex.getInteraction("cid", "iid", getListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [Henry] does not have access to conversation cid")); + } } From a6305503e9e10d1846edbbb72a590b54625ed75c Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Mon, 9 Oct 2023 16:47:36 -0700 Subject: [PATCH 10/14] Add OpenSearhMemoryHandler unit tests for singular get Signed-off-by: HenryL27 --- .../memory/index/InteractionsIndexTests.java | 3 --- ...earchConversationalMemoryHandlerTests.java | 23 +++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java index 54ec44df77..2d4184eec3 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -63,9 +63,6 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.SendRequestTransportException; -import lombok.extern.log4j.Log4j2; - -@Log4j2 public class InteractionsIndexTests extends OpenSearchTestCase { @Mock Client client; diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java index 175949b0ef..c8df948bcb 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.time.Instant; import java.util.List; import org.junit.Before; @@ -268,4 +269,26 @@ public void testSearchInteractions_Future() { ActionFuture result = cmHandler.searchInteractions(cid, request); assert (result.actionGet().equals(response)); } + + public void testGetAConversation_Future() { + ConversationMeta response = new ConversationMeta("cid", Instant.now(), "boring name", null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(response); + return null; + }).when(conversationMetaIndex).getConversation(any(), any()); + ActionFuture result = cmHandler.getConversation("cid"); + assert (result.actionGet().equals(response)); + } + + public void testGetAnInteraction_Future() { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(interaction); + return null; + }).when(interactionsIndex).getInteraction(any(), any(), any()); + ActionFuture result = cmHandler.getInteraction("cid", "iid"); + assert (result.actionGet().equals(interaction)); + } } From 3f7404cbfd7d87346a026dae0e36e95feae6baba Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 10 Oct 2023 14:26:40 -0700 Subject: [PATCH 11/14] Add singular get transport layer Signed-off-by: HenryL27 --- .../conversation/GetConversationAction.java | 34 ++++ .../conversation/GetConversationRequest.java | 77 +++++++++ .../conversation/GetConversationResponse.java | 60 +++++++ .../GetConversationTransportAction.java | 100 +++++++++++ .../conversation/GetInteractionAction.java | 34 ++++ .../conversation/GetInteractionRequest.java | 85 ++++++++++ .../conversation/GetInteractionResponse.java | 61 +++++++ .../GetInteractionTransportAction.java | 95 +++++++++++ .../conversation/ConversationActionTests.java | 1 + .../GetConversationRequestTests.java | 68 ++++++++ .../GetConversationResponseTests.java | 62 +++++++ .../GetConversationTransportActionTests.java | 150 +++++++++++++++++ .../GetInteractionRequestTests.java | 85 ++++++++++ .../GetInteractionResponseTests.java | 64 +++++++ .../GetInteractionTransportActionTests.java | 158 ++++++++++++++++++ .../conversation/InteractionActionTests.java | 1 + 16 files changed, 1135 insertions(+) create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java create mode 100644 memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java create mode 100644 memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java new file mode 100644 index 0000000000..7839915201 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for retrieving a top-level conversation object by id + */ +public class GetConversationAction extends ActionType { + /** Instance of this */ + public static final GetConversationAction INSTANCE = new GetConversationAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get"; + + private GetConversationAction() { + super(NAME, GetConversationResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java new file mode 100644 index 0000000000..c5a6f6dd0e --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request object for GetConversation (singular) + */ +@AllArgsConstructor +public class GetConversationRequest extends ActionRequest { + @Getter + private String conversationId; + + /** + * Stream Constructor + * @param in input stream to read this from + * @throws IOException if something goes wrong reading from stream + */ + public GetConversationRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.conversationId == null) { + exception = addValidationError("GetConversation Request must have a conversation id", exception); + } + return exception; + } + + /** + * Creates a GetConversationRequest from a rest request + * @param request Rest Request representing a GetConversationRequest + * @return the new GetConversationRequest + * @throws IOException if something goes wrong in translation + */ + public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + return new GetConversationRequest(conversationId); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java new file mode 100644 index 0000000000..b757723e09 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationResponse.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ActionResponse object for GetConversation (singular) + */ +@AllArgsConstructor +public class GetConversationResponse extends ActionResponse implements ToXContentObject { + + @Getter + private ConversationMeta conversation; + + /** + * Stream Constructor + * @param in input stream to read this from + * @throws IOException if soething goes wrong in reading + */ + public GetConversationResponse(StreamInput in) throws IOException { + super(in); + this.conversation = ConversationMeta.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.conversation.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return this.conversation.toXContent(builder, params); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java new file mode 100644 index 0000000000..0f1c70ad51 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportAction.java @@ -0,0 +1,100 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * Transport Action for GetConversation + */ +@Log4j2 +public class GetConversationTransportAction extends HandledTransportAction { + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetConversationRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } else { + String conversationId = request.getConversationId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener + .runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(conversationMeta -> { + internalListener.onResponse(new GetConversationResponse(conversationMeta)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getConversation(conversationId, al); + } catch (Exception e) { + log.error("Failed to get Conversation " + conversationId, e); + actionListener.onFailure(e); + } + + } + + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java new file mode 100644 index 0000000000..adaffcb9dc --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for Get Interaction (singular) + */ +public class GetInteractionAction extends ActionType { + /** Instance of this */ + public static final GetInteractionAction INSTANCE = new GetInteractionAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + + private GetInteractionAction() { + super(NAME, GetInteractionResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java new file mode 100644 index 0000000000..6808857c40 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request for GetInteraction + */ +@AllArgsConstructor +public class GetInteractionRequest extends ActionRequest { + @Getter + private String conversationId; + @Getter + private String interactionId; + + /** + * Stream Constructor + * @param in input stream to read this request from + * @throws IOException if somthing goes wrong reading + */ + public GetInteractionRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.interactionId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(this.conversationId); + out.writeString(this.interactionId); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (conversationId == null) { + exception = addValidationError("Get Interaction Request must have a conversation id", exception); + } + if (interactionId == null) { + exception = addValidationError("Get Interaction Request must have an interaction id", exception); + } + return exception; + } + + /** + * Creates a GetInteractionRequest from a Rest Request + * @param request Rest Request representing a GetInteractionRequest + * @return new GetInteractionRequest built from the rest request + * @throws IOException if something goes wrong reading from the rest request + */ + public static GetInteractionRequest fromRestRequest(RestRequest request) throws IOException { + String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String interactionId = request.param(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + return new GetInteractionRequest(conversationId, interactionId); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java new file mode 100644 index 0000000000..7d3a1f3c73 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponse.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * ActionResponse for Get Interaction (sg) + */ +@AllArgsConstructor +public class GetInteractionResponse extends ActionResponse implements ToXContentObject { + + @Getter + private Interaction interaction; + + /** + * Stream Constructor + * @param in Stream Input to read this response from + * @throws IOException if something goes wrong reading from stream + */ + public GetInteractionResponse(StreamInput in) throws IOException { + super(in); + this.interaction = Interaction.fromStream(in); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + interaction.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return this.interaction.toXContent(builder, params); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java new file mode 100644 index 0000000000..16205ec8b9 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportAction.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetInteractionTransportAction extends HandledTransportAction { + + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetInteractionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetInteractionAction.NAME, transportService, actionFilters, GetInteractionRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetInteractionRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + String conversationId = request.getConversationId(); + String interactionId = request.getInteractionId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(interaction -> { + internalListener.onResponse(new GetInteractionResponse(interaction)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getInteraction(conversationId, interactionId, al); + } catch (Exception e) { + log.error("Failed to get interaction " + interactionId + " in conversation " + conversationId, e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java index 65cbb7dfea..541ff5ed2e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java @@ -25,5 +25,6 @@ public void testActions() { assert (DeleteConversationAction.INSTANCE instanceof DeleteConversationAction); assert (GetConversationsAction.INSTANCE instanceof GetConversationsAction); assert (SearchConversationsAction.INSTANCE instanceof SearchConversationsAction); + assert (GetConversationAction.INSTANCE instanceof GetConversationAction); } } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java new file mode 100644 index 0000000000..cb8b67b44b --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationRequestTests.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetConversationRequestTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + GetConversationRequest request = new GetConversationRequest("Test-id"); + assert (request.validate() == null); + assert (request.getConversationId().equals("Test-id")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationRequest newRequest = new GetConversationRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getConversationId().equals("Test-id")); + } + + public void testNullConvoId_ThenFail() { + String id = null; + GetConversationRequest request = new GetConversationRequest(id); + ActionRequestValidationException exc = request.validate(); + assert (exc != null); + assert (exc.validationErrors().size() == 1); + assert (exc.validationErrors().get(0).equals("GetConversation Request must have a conversation id")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "testcid"); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + GetConversationRequest request = GetConversationRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("testcid")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java new file mode 100644 index 0000000000..4b8f3a8fed --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationResponseTests.java @@ -0,0 +1,62 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.test.OpenSearchTestCase; + +public class GetConversationResponseTests extends OpenSearchTestCase { + + public void testGetConversationResponseStreaming() throws IOException { + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + GetConversationResponse response = new GetConversationResponse(convo); + assert (response.getConversation().equals(convo)); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationResponse newResponse = new GetConversationResponse(in); + assert (newResponse.getConversation().equals(convo)); + } + + public void testToXContent() throws IOException { + ConversationMeta convo = new ConversationMeta("cid", Instant.now(), "name", null); + GetConversationResponse response = new GetConversationResponse(convo); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversation_id\":\"cid\",\"create_time\":\"" + convo.getCreatedTime() + "\",\"name\":\"name\"}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java new file mode 100644 index 0000000000..3afcc1dd21 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationTransportActionTests.java @@ -0,0 +1,150 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetConversationTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetConversationRequest request; + GetConversationTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetConversationRequest("test-cid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetConversation() { + ConversationMeta result = new ConversationMeta("test-cid", Instant.now(), "name", null); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(result); + return null; + }).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversation().getId().equals("test-cid")); + } + + public void testGetConversationFails_ThenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("CMHandler Failure")); + return null; + }).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Failure")); + } + + public void testHandlerThrows_ThenFail() { + doThrow(new RuntimeException("CMHandler Throws")).when(cmHandler).getConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Throws")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java new file mode 100644 index 0000000000..678004ae09 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionRequestTests.java @@ -0,0 +1,85 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetInteractionRequestTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + GetInteractionRequest request = new GetInteractionRequest("cid", "iid"); + assert (request.validate() == null); + assert (request.getConversationId().equals("cid")); + assert (request.getInteractionId().equals("iid")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionRequest newRequest = new GetInteractionRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getConversationId().equals("cid")); + assert (newRequest.getInteractionId().equals("iid")); + } + + public void testMalformedRequest_ThenInvalid() { + GetInteractionRequest bad1 = new GetInteractionRequest(null, "iid"); + GetInteractionRequest bad2 = new GetInteractionRequest("cid", null); + GetInteractionRequest bad3 = new GetInteractionRequest(null, null); + ActionRequestValidationException exc1 = bad1.validate(); + ActionRequestValidationException exc2 = bad2.validate(); + ActionRequestValidationException exc3 = bad3.validate(); + + assert (exc1 != null); + assert (exc1.validationErrors().size() == 1); + assert (exc1.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); + + assert (exc2 != null); + assert (exc2.validationErrors().size() == 1); + assert (exc2.validationErrors().get(0).equals("Get Interaction Request must have an interaction id")); + + assert (exc3 != null); + assert (exc3.validationErrors().size() == 2); + assert (exc3.validationErrors().get(0).equals("Get Interaction Request must have a conversation id")); + assert (exc3.validationErrors().get(1).equals("Get Interaction Request must have an interaction id")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map + .of(ActionConstants.CONVERSATION_ID_FIELD, "testcid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "testiid"); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + GetInteractionRequest request = GetInteractionRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("testcid")); + assert (request.getInteractionId().equals("testiid")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java new file mode 100644 index 0000000000..b7cbc1c471 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionResponseTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +public class GetInteractionResponseTests extends OpenSearchTestCase { + + public void testConstructorAndStreaming() throws IOException { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + GetInteractionResponse response = new GetInteractionResponse(interaction); + assert (response.getInteraction().equals(interaction)); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionResponse newResponse = new GetInteractionResponse(in); + assert (newResponse.getInteraction().equals(interaction)); + } + + public void testToXContent() throws IOException { + Interaction interaction = new Interaction("iid", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "extra"); + GetInteractionResponse response = new GetInteractionResponse(interaction); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversation_id\":\"cid\",\"interaction_id\":\"iid\",\"create_time\":\"" + + interaction.getCreateTime() + + "\",\"input\":\"inp\",\"prompt_template\":\"pt\",\"response\":\"rsp\",\"origin\":\"ogn\",\"additional_info\":\"extra\"}"; + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + assert (ld.getDistance(result, expected) > 0.95); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java new file mode 100644 index 0000000000..6ca8197b54 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionTransportActionTests.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class GetInteractionTransportActionTests extends OpenSearchTestCase { + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetInteractionRequest request; + GetInteractionTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetInteractionRequest("cid", "iid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetInteraction() { + Interaction testInteraction = new Interaction( + "iid", + Instant.now(), + "cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(testInteraction); + return null; + }).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionResponse.class); + verify(actionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getInteraction().getId().equals("iid")); + } + + public void testGetInteractionFails_ThenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new Exception("Storage layer failure")); + return null; + }).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Storage layer failure")); + } + + public void testHandlerThrows_ThenFail() { + doThrow(new RuntimeException("CMHandler Failure")).when(cmHandler).getInteraction(any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("CMHandler Failure")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java index 89a6fae6a3..187ae4bdf7 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java @@ -24,5 +24,6 @@ public void testActions() { assert (CreateInteractionAction.INSTANCE instanceof CreateInteractionAction); assert (GetInteractionsAction.INSTANCE instanceof GetInteractionsAction); assert (SearchInteractionsAction.INSTANCE instanceof SearchInteractionsAction); + assert (GetInteractionAction.INSTANCE instanceof GetInteractionAction); } } From a0c49c28ea6a82f815360c7b4e09ced16b7af507 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Wed, 11 Oct 2023 09:41:59 -0700 Subject: [PATCH 12/14] add singular get rest actions Signed-off-by: HenryL27 --- .../common/conversation/ActionConstants.java | 8 +- .../conversation/GetInteractionsAction.java | 2 +- .../ml/plugin/MachineLearningPlugin.java | 16 ++- .../rest/RestMemoryGetConversationAction.java | 51 +++++++ .../rest/RestMemoryGetInteractionAction.java | 51 +++++++ .../RestMemoryGetConversationActionIT.java | 81 +++++++++++ .../RestMemoryGetConversationActionTests.java | 64 +++++++++ .../RestMemoryGetInteractionActionIT.java | 127 ++++++++++++++++++ .../RestMemoryGetInteractionActionTests.java | 65 +++++++++ 9 files changed, 460 insertions(+), 5 deletions(-) create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java create mode 100644 plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java create mode 100644 plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java index 96431e7679..f87da7c433 100644 --- a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -54,9 +54,9 @@ public class ActionConstants { private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation"; /** path for create conversation */ public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create"; - /** path for list conversations */ + /** path for get conversations */ public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list"; - /** path for put interaction */ + /** path for create interaction */ public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create"; /** path for get interactions */ public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list"; @@ -66,6 +66,10 @@ public class ActionConstants { public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search"; /** path for search interactions */ public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search"; + /** path for get conversation */ + public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}"; + /** path for get interaction */ + public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}"; /** default max results returned by get operations */ public final static int DEFAULT_MAX_RESULTS = 10; diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java index 024abe17ff..7a49d062d5 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java @@ -26,7 +26,7 @@ public class GetInteractionsAction extends ActionType { /** Instance of this */ public static final GetInteractionsAction INSTANCE = new GetInteractionsAction(); /** Name of this action */ - public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/list"; private GetInteractionsAction() { super(NAME, GetInteractionsResponse::new); diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 8f23c84f4e..13b4834d59 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -139,8 +139,12 @@ import org.opensearch.ml.memory.action.conversation.CreateInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; import org.opensearch.ml.memory.action.conversation.DeleteConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationTransportAction; import org.opensearch.ml.memory.action.conversation.GetConversationsAction; import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionTransportAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; import org.opensearch.ml.memory.action.conversation.SearchConversationsAction; @@ -181,7 +185,9 @@ import org.opensearch.ml.rest.RestMemoryCreateConversationAction; import org.opensearch.ml.rest.RestMemoryCreateInteractionAction; import org.opensearch.ml.rest.RestMemoryDeleteConversationAction; +import org.opensearch.ml.rest.RestMemoryGetConversationAction; import org.opensearch.ml.rest.RestMemoryGetConversationsAction; +import org.opensearch.ml.rest.RestMemoryGetInteractionAction; import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; import org.opensearch.ml.rest.RestMemorySearchConversationsAction; import org.opensearch.ml.rest.RestMemorySearchInteractionsAction; @@ -313,7 +319,9 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class), new ActionHandler<>(MLUpdateConnectorAction.INSTANCE, UpdateConnectorTransportAction.class), new ActionHandler<>(SearchInteractionsAction.INSTANCE, SearchInteractionsTransportAction.class), - new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class) + new ActionHandler<>(SearchConversationsAction.INSTANCE, SearchConversationsTransportAction.class), + new ActionHandler<>(GetConversationAction.INSTANCE, GetConversationTransportAction.class), + new ActionHandler<>(GetInteractionAction.INSTANCE, GetInteractionTransportAction.class) ); } @@ -567,6 +575,8 @@ public List getRestHandlers( RestMLUpdateConnectorAction restMLUpdateConnectorAction = new RestMLUpdateConnectorAction(mlFeatureEnabledSetting); RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(); RestMemorySearchInteractionsAction restSearchInteractionsAction = new RestMemorySearchInteractionsAction(); + RestMemoryGetConversationAction restGetConversationAction = new RestMemoryGetConversationAction(); + RestMemoryGetInteractionAction restGetInteractionAction = new RestMemoryGetInteractionAction(); return ImmutableList .of( restMLStatsAction, @@ -603,7 +613,9 @@ public List getRestHandlers( restDeleteConversationAction, restMLUpdateConnectorAction, restSearchConversationsAction, - restSearchInteractionsAction + restSearchInteractionsAction, + restGetConversationAction, + restGetInteractionAction ); } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java new file mode 100644 index 0000000000..dbabd40953 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationAction.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationAction extends BaseRestHandler { + private final static String GET_CONVERSATION_NAME = "conversational_get_conversation"; + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH)); + } + + @Override + public String getName() { + return GET_CONVERSATION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetConversationRequest gcRequest = GetConversationRequest.fromRestRequest(request); + return channel -> client.execute(GetConversationAction.INSTANCE, gcRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java new file mode 100644 index 0000000000..ad2b35dbf6 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionAction.java @@ -0,0 +1,51 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionAction extends BaseRestHandler { + private final static String GET_INTERACTION_NAME = "conversational_get_interaction"; + + @Override + public List routes() { + return ImmutableList.of(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH)); + } + + @Override + public String getName() { + return GET_INTERACTION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetInteractionRequest giRequest = GetInteractionRequest.fromRestRequest(request); + return channel -> client.execute(GetInteractionAction.INSTANCE, giRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java new file mode 100644 index 0000000000..5a55b1c301 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionIT.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetConversation() throws IOException { + Response ccresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + @SuppressWarnings("unchecked") + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); + String id = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); + + Response gcresponse = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_CONVERSATION_REST_PATH.replace("{conversation_id}", id), null, "", null); + assert (gcresponse != null); + assert (TestHelper.restStatus(gcresponse) == RestStatus.OK); + HttpEntity gchttpEntity = gcresponse.getEntity(); + String gcentitiyString = TestHelper.httpEntityToString(gchttpEntity); + @SuppressWarnings("unchecked") + Map gcmap = gson.fromJson(gcentitiyString, Map.class); + assert (gcmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gcmap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(id)); + assert (gcmap.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD) + && gcmap.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD).equals("name")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java new file mode 100644 index 0000000000..0e81f2dacb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationAction; +import org.opensearch.ml.memory.action.conversation.GetConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetConversationActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetConversationAction action = new RestMemoryGetConversationAction(); + assert (action.getName().equals("conversational_get_conversation")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetConversationAction action = new RestMemoryGetConversationAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationRequest.class); + verify(client, times(1)).execute(eq(GetConversationAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getConversationId().equals("cid")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java new file mode 100644 index 0000000000..691195a99b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionIT.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.hc.core5.http.HttpEntity; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionActionIT extends MLCommonsRestTestCase { + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetInteraction() throws IOException { + Response ccresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + @SuppressWarnings("unchecked") + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey(ActionConstants.CONVERSATION_ID_FIELD)); + String cid = (String) ccmap.get(ActionConstants.CONVERSATION_ID_FIELD); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response ciresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (ciresponse != null); + assert (TestHelper.restStatus(ciresponse) == RestStatus.OK); + HttpEntity cihttpEntity = ciresponse.getEntity(); + String cientityString = TestHelper.httpEntityToString(cihttpEntity); + @SuppressWarnings("unchecked") + Map cimap = gson.fromJson(cientityString, Map.class); + assert (cimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD)); + String iid = cimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD); + + Response giresponse = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTION_REST_PATH.replace("{conversation_id}", cid).replace("{interaction_id}", iid), + null, + "", + null + ); + assert (giresponse != null); + assert (TestHelper.restStatus(giresponse) == RestStatus.OK); + HttpEntity gihttpEntity = giresponse.getEntity(); + String gientityString = TestHelper.httpEntityToString(gihttpEntity); + @SuppressWarnings("unchecked") + Map gimap = gson.fromJson(gientityString, Map.class); + assert (gimap.containsKey(ActionConstants.RESPONSE_INTERACTION_ID_FIELD) + && gimap.get(ActionConstants.RESPONSE_INTERACTION_ID_FIELD).equals(iid)); + assert (gimap.containsKey(ActionConstants.CONVERSATION_ID_FIELD) && gimap.get(ActionConstants.CONVERSATION_ID_FIELD).equals(cid)); + assert (gimap.containsKey(ActionConstants.INPUT_FIELD) && gimap.get(ActionConstants.INPUT_FIELD).equals("input")); + assert (gimap.containsKey(ActionConstants.PROMPT_TEMPLATE_FIELD) + && gimap.get(ActionConstants.PROMPT_TEMPLATE_FIELD).equals("promtp template")); + assert (gimap.containsKey(ActionConstants.AI_RESPONSE_FIELD) && gimap.get(ActionConstants.AI_RESPONSE_FIELD).equals("response")); + assert (gimap.containsKey(ActionConstants.RESPONSE_ORIGIN_FIELD) + && gimap.get(ActionConstants.RESPONSE_ORIGIN_FIELD).equals("origin")); + assert (gimap.containsKey(ActionConstants.ADDITIONAL_INFO_FIELD) + && gimap.get(ActionConstants.ADDITIONAL_INFO_FIELD).equals("some metadata")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java new file mode 100644 index 0000000000..9d0cc6515b --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionActionTests.java @@ -0,0 +1,65 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetInteractionActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); + assert (action.getName().equals("conversational_get_interaction")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetInteractionAction action = new RestMemoryGetInteractionAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid", ActionConstants.RESPONSE_INTERACTION_ID_FIELD, "iid")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionRequest.class); + verify(client, times(1)).execute(eq(GetInteractionAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getConversationId().equals("cid")); + assert (argCaptor.getValue().getInteractionId().equals("iid")); + } +} From 9b4c461824c00acf6dd2510829b7cd88e68a8e48 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 12 Oct 2023 09:41:41 -0700 Subject: [PATCH 13/14] fix async return value problem Signed-off-by: HenryL27 --- .../org/opensearch/ml/memory/index/ConversationMetaIndex.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index 5480b8ad96..a75e2aa0c0 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -173,6 +173,7 @@ public void createConversation(ActionListener listener) { public void getConversations(int from, int maxResults, ActionListener> listener) { if (!clusterService.state().metadata().hasIndex(indexName)) { listener.onResponse(List.of()); + return; } SearchRequest request = Requests.searchRequest(indexName); String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); @@ -226,6 +227,7 @@ public void getConversations(int maxResults, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(indexName)) { listener.onResponse(true); + return; } DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId); String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); From 1f4bfdb6a012fac0ae55ba8d979025fe74ebc881 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 17 Oct 2023 15:24:08 -0700 Subject: [PATCH 14/14] address esay PR comments Signed-off-by: HenryL27 --- .../SearchConversationsTransportAction.java | 16 ++- .../SearchInteractionsTransportAction.java | 16 ++- .../memory/index/ConversationMetaIndex.java | 99 ++++++++++--------- .../ml/memory/index/InteractionsIndex.java | 92 +++++++++-------- ...archConversationsTransportActionTests.java | 4 +- ...archInteractionsTransportActionsTests.java | 4 +- plugin/build.gradle | 1 - 7 files changed, 133 insertions(+), 99 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java index c528ac4391..6aa8d79ca8 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -22,8 +22,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -31,9 +33,13 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class SearchConversationsTransportAction extends HandledTransportAction { private ConversationalMemoryHandler cmHandler; + private Client client; private volatile boolean featureIsEnabled; @@ -50,10 +56,12 @@ public SearchConversationsTransportAction( TransportService transportService, ActionFilters actionFilters, OpenSearchConversationalMemoryHandler cmHandler, + Client client, ClusterService clusterService ) { super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new); this.cmHandler = cmHandler; + this.client = client; this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); clusterService .getClusterSettings() @@ -72,7 +80,13 @@ public void doExecute(Task task, SearchRequest request, ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + cmHandler.searchConversations(request, internalListener); + } catch (Exception e) { + log.error("Failed to search conversations", e); + actionListener.onFailure(e); + } } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java index 80f56d28ca..5060e6111a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java @@ -21,8 +21,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -30,9 +32,13 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class SearchInteractionsTransportAction extends HandledTransportAction { private ConversationalMemoryHandler cmHandler; + private Client client; private volatile boolean featureIsEnabled; @@ -49,10 +55,12 @@ public SearchInteractionsTransportAction( TransportService transportService, ActionFilters actionFilters, OpenSearchConversationalMemoryHandler cmHandler, + Client client, ClusterService clusterService ) { super(SearchInteractionsAction.NAME, transportService, actionFilters, SearchInteractionsRequest::new); this.cmHandler = cmHandler; + this.client = client; this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); clusterService .getClusterSettings() @@ -71,7 +79,13 @@ public void doExecute(Task task, SearchInteractionsRequest request, ActionListen ); return; } else { - cmHandler.searchInteractions(request.getConversationId(), request, actionListener); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + cmHandler.searchInteractions(request.getConversationId(), request, internalListener); + } catch (Exception e) { + log.error("Failed to search conversations", e); + actionListener.onFailure(e); + } } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index a75e2aa0c0..47c55ac1e7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.index; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_INDEX_NAME; + import java.io.IOException; import java.time.Instant; import java.util.LinkedList; @@ -68,21 +70,24 @@ public class ConversationMetaIndex { private Client client; private ClusterService clusterService; - private static final String indexName = ConversationalIndexConstants.META_INDEX_NAME; + + private String userstr() { + return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + } /** * Creates the conversational meta index if it doesn't already exist * @param listener listener to wait for this to finish */ public void initConversationMetaIndexIfAbsent(ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { log.debug("No conversational meta index found. Adding it"); - CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING); + CreateIndexRequest request = Requests.createIndexRequest(META_INDEX_NAME).mapping(ConversationalIndexConstants.META_MAPPING); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(createIndexResponse -> { - if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) { - log.info("created index [" + indexName + "]"); + if (createIndexResponse.equals(new CreateIndexResponse(true, true, META_INDEX_NAME))) { + log.info("created index [" + META_INDEX_NAME + "]"); internalListener.onResponse(true); } else { internalListener.onResponse(false); @@ -92,7 +97,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { internalListener.onResponse(true); } else { - log.error("failed to create index [" + indexName + "]", e); + log.error("failed to create index [" + META_INDEX_NAME + "]", e); internalListener.onFailure(e); } }); @@ -102,7 +107,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { listener.onResponse(true); } else { - log.error("failed to create index [" + indexName + "]", e); + log.error("failed to create index [" + META_INDEX_NAME + "]", e); listener.onFailure(e); } } @@ -119,12 +124,9 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) public void createConversation(String name, ActionListener listener) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); IndexRequest request = Requests - .indexRequest(indexName) + .indexRequest(META_INDEX_NAME) .source( ConversationalIndexConstants.META_CREATED_FIELD, Instant.now(), @@ -171,12 +173,12 @@ public void createConversation(ActionListener listener) { * @param listener gets the list of conversation metadata objects in the index */ public void getConversations(int from, int maxResults, ActionListener> listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(List.of()); return; } - SearchRequest request = Requests.searchRequest(indexName); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + SearchRequest request = Requests.searchRequest(META_INDEX_NAME); + String userstr = userstr(); QueryBuilder queryBuilder; if (userstr == null) queryBuilder = new MatchAllQueryBuilder(); @@ -197,13 +199,12 @@ public void getConversations(int from, int maxResults, ActionListener { client.search(request, al); }, e -> { - log.error("Failed to retrieve conversations during refresh", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, al); + }, e -> { + log.error("Failed to retrieve conversations during refresh", e); + internalListener.onFailure(e); + })); } catch (Exception e) { log.error("Failed to retrieve conversations", e); listener.onFailure(e); @@ -225,12 +226,12 @@ public void getConversations(int maxResults, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); return; } - DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { @@ -267,14 +268,14 @@ public void deleteConversation(String conversationId, ActionListener li */ public void checkAccess(String conversationId, ActionListener listener) { // If the index doesn't exist, you have permission. Just won't get you anywhere - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest getRequest = Requests.getRequest(indexName).id(conversationId); + GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { @@ -294,13 +295,12 @@ public void checkAccess(String conversationId, ActionListener listener) } internalListener.onResponse(true); }, e -> { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> { - log.error("Failed to refresh conversations index during check access ", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(getRequest, al); + }, e -> { + log.error("Failed to refresh conversations index during check access ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } @@ -313,11 +313,11 @@ public void checkAccess(String conversationId, ActionListener listener) * @param listener receives the search response for the wrapped query */ public void searchConversations(SearchRequest request, ActionListener listener) { - request.indices(indexName); + request.indices(META_INDEX_NAME); QueryBuilder originalQuery = request.source().query(); BoolQueryBuilder newQuery = new BoolQueryBuilder(); newQuery.must(originalQuery); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); if (userstr != null) { String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user)); @@ -325,7 +325,7 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { client.search(request, internalListener); }, e -> { log.error("Failed to refresh conversations index during search conversations ", e); @@ -342,15 +342,17 @@ public void searchConversations(SearchRequest request, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener - .onFailure(new IndexNotFoundException("cannot get conversation since the conversation index does not exist", indexName)); + .onFailure( + new IndexNotFoundException("cannot get conversation since the conversation index does not exist", META_INDEX_NAME) + ); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest request = Requests.getRequest(indexName).id(conversationId); + GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { @@ -374,13 +376,12 @@ public void getConversation(String conversationId, ActionListener { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> { - log.error("Failed to refresh conversations index during get conversation ", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh conversations index during get conversation ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 30e47019fd..bd4eb1e39a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.index; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; + import java.io.IOException; import java.time.Instant; import java.util.LinkedList; @@ -69,23 +71,28 @@ public class InteractionsIndex { private Client client; private ClusterService clusterService; private ConversationMetaIndex conversationMetaIndex; - private final String indexName = ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; // How big the steps should be when gathering *ALL* interactions in a conversation private final int resultsAtATime = 300; + private String userstr() { + return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + } + /** * 'PUT's the index in opensearch if it's not there already * @param listener gets whether the index needed to be initialized. Throws error if it fails to init */ public void initInteractionsIndexIfAbsent(ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { log.debug("No interactions index found. Adding it"); - CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); + CreateIndexRequest request = Requests + .createIndexRequest(INTERACTIONS_INDEX_NAME) + .mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(r -> { - if (r.equals(new CreateIndexResponse(true, true, indexName))) { - log.info("created index [" + indexName + "]"); + if (r.equals(new CreateIndexResponse(true, true, INTERACTIONS_INDEX_NAME))) { + log.info("created index [" + INTERACTIONS_INDEX_NAME + "]"); internalListener.onResponse(true); } else { internalListener.onResponse(false); @@ -95,7 +102,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { internalListener.onResponse(true); } else { - log.error("Failed to create index [" + indexName + "]", e); + log.error("Failed to create index [" + INTERACTIONS_INDEX_NAME + "]", e); internalListener.onFailure(e); } }); @@ -105,7 +112,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { listener.onResponse(true); } else { - log.error("Failed to create index [" + indexName + "]", e); + log.error("Failed to create index [" + INTERACTIONS_INDEX_NAME + "]", e); listener.onFailure(e); } } @@ -136,16 +143,13 @@ public void createInteraction( ActionListener listener ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { IndexRequest request = Requests - .indexRequest(indexName) + .indexRequest(INTERACTIONS_INDEX_NAME) .source( ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin, @@ -215,7 +219,7 @@ public void createInteraction( * @param listener gets the list, sorted by recency, of interactions */ public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener.onResponse(List.of()); return; } @@ -236,7 +240,7 @@ public void getInteractions(String conversationId, int from, int maxResults, Act @VisibleForTesting void innerGetInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { - SearchRequest request = Requests.searchRequest(indexName); + SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); TermQueryBuilder builder = new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); request.source().query(builder); request.source().from(from).size(maxResults); @@ -253,7 +257,7 @@ void innerGetInteractions(String conversationId, int from, int maxResults, Actio client .admin() .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(r -> { client.search(request, al); }, e -> { + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(r -> { client.search(request, al); }, e -> { internalListener.onFailure(e); })); } catch (Exception e) { @@ -313,18 +317,18 @@ ActionListener> nextGetListener( * @param listener gets whether the deletion was successful */ public void deleteConversation(String conversationId, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { BulkRequest request = Requests.bulkRequest(); for (Interaction interaction : interactions) { - DeleteRequest delRequest = Requests.deleteRequest(indexName).id(interaction.getId()); + DeleteRequest delRequest = Requests.deleteRequest(INTERACTIONS_INDEX_NAME).id(interaction.getId()); request.add(delRequest); } client @@ -358,26 +362,26 @@ public void searchInteractions(String conversationId, SearchRequest request, Act if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - request.indices(indexName); + request.indices(INTERACTIONS_INDEX_NAME); QueryBuilder originalQuery = request.source().query(); BoolQueryBuilder newQuery = new BoolQueryBuilder(); newQuery.must(originalQuery); newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId)); request.source().query(newQuery); - client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { - client.search(request, internalListener); - }, e -> { - log.error("Failed to refresh interactions index during search interactions ", e); - internalListener.onFailure(e); - })); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh interactions index during search interactions ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } @@ -391,15 +395,21 @@ public void searchInteractions(String conversationId, SearchRequest request, Act * @param listener receives the interaction */ public void getInteraction(String conversationId, String interactionId, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { - listener.onFailure(new IndexNotFoundException("cannot get interaction since the interactions index does not exist", indexName)); + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException( + "cannot get interaction since the interactions index does not exist", + INTERACTIONS_INDEX_NAME + ) + ); return; } conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest request = Requests.getRequest(indexName).id(interactionId); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { @@ -411,21 +421,17 @@ public void getInteraction(String conversationId, String interactionId, ActionLi client .admin() .indices() - .refresh( - Requests.refreshRequest(indexName), - ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> { - log.error("Failed to refresh interactions index during get interaction ", e); - internalListener.onFailure(e); - }) - ); + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java index a42188ce8b..7ec9d8c042 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -92,7 +92,7 @@ public void setup() throws IOException { when(this.clusterService.getClusterSettings()) .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); - this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testEnabled_ThenSucceed() { @@ -111,7 +111,7 @@ public void testEnabled_ThenSucceed() { public void testDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); - this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java index 85e0c1bf0d..abe5204c65 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java @@ -92,7 +92,7 @@ public void setup() throws IOException { .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); when(this.request.getConversationId()).thenReturn("test_cid"); - this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testFeatureEnabled_ThenSucceed() { @@ -111,7 +111,7 @@ public void testFeatureEnabled_ThenSucceed() { public void testDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); - this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/build.gradle b/plugin/build.gradle index 567f8bd0d7..042ac13423 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -257,7 +257,6 @@ jacocoTestReport { xml.getRequired().set(true) csv.getRequired().set(false) html.getRequired().set(true) - html.outputLocation = layout.buildDirectory.dir('jacocoHtml') } dependsOn test