From 09221dd465f62f7b2df6019bdef6aa0f2790ae39 Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 17 Oct 2023 15:24:08 -0700 Subject: [PATCH] address esay PR comments Signed-off-by: HenryL27 --- .../SearchConversationsTransportAction.java | 16 ++- .../SearchInteractionsTransportAction.java | 16 ++- .../memory/index/ConversationMetaIndex.java | 99 ++++++++++--------- .../ml/memory/index/InteractionsIndex.java | 92 +++++++++-------- ...archConversationsTransportActionTests.java | 4 +- ...archInteractionsTransportActionsTests.java | 4 +- plugin/build.gradle | 1 - 7 files changed, 133 insertions(+), 99 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java index c528ac4391..6aa8d79ca8 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportAction.java @@ -22,8 +22,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -31,9 +33,13 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class SearchConversationsTransportAction extends HandledTransportAction { private ConversationalMemoryHandler cmHandler; + private Client client; private volatile boolean featureIsEnabled; @@ -50,10 +56,12 @@ public SearchConversationsTransportAction( TransportService transportService, ActionFilters actionFilters, OpenSearchConversationalMemoryHandler cmHandler, + Client client, ClusterService clusterService ) { super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new); this.cmHandler = cmHandler; + this.client = client; this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); clusterService .getClusterSettings() @@ -72,7 +80,13 @@ public void doExecute(Task task, SearchRequest request, ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + cmHandler.searchConversations(request, internalListener); + } catch (Exception e) { + log.error("Failed to search conversations", e); + actionListener.onFailure(e); + } } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java index 80f56d28ca..5060e6111a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportAction.java @@ -21,8 +21,10 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.ml.memory.ConversationalMemoryHandler; @@ -30,9 +32,13 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import lombok.extern.log4j.Log4j2; + +@Log4j2 public class SearchInteractionsTransportAction extends HandledTransportAction { private ConversationalMemoryHandler cmHandler; + private Client client; private volatile boolean featureIsEnabled; @@ -49,10 +55,12 @@ public SearchInteractionsTransportAction( TransportService transportService, ActionFilters actionFilters, OpenSearchConversationalMemoryHandler cmHandler, + Client client, ClusterService clusterService ) { super(SearchInteractionsAction.NAME, transportService, actionFilters, SearchInteractionsRequest::new); this.cmHandler = cmHandler; + this.client = client; this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); clusterService .getClusterSettings() @@ -71,7 +79,13 @@ public void doExecute(Task task, SearchInteractionsRequest request, ActionListen ); return; } else { - cmHandler.searchInteractions(request.getConversationId(), request, actionListener); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + cmHandler.searchInteractions(request.getConversationId(), request, internalListener); + } catch (Exception e) { + log.error("Failed to search conversations", e); + actionListener.onFailure(e); + } } } } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index a75e2aa0c0..47c55ac1e7 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.index; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_INDEX_NAME; + import java.io.IOException; import java.time.Instant; import java.util.LinkedList; @@ -68,21 +70,24 @@ public class ConversationMetaIndex { private Client client; private ClusterService clusterService; - private static final String indexName = ConversationalIndexConstants.META_INDEX_NAME; + + private String userstr() { + return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + } /** * Creates the conversational meta index if it doesn't already exist * @param listener listener to wait for this to finish */ public void initConversationMetaIndexIfAbsent(ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { log.debug("No conversational meta index found. Adding it"); - CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING); + CreateIndexRequest request = Requests.createIndexRequest(META_INDEX_NAME).mapping(ConversationalIndexConstants.META_MAPPING); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(createIndexResponse -> { - if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) { - log.info("created index [" + indexName + "]"); + if (createIndexResponse.equals(new CreateIndexResponse(true, true, META_INDEX_NAME))) { + log.info("created index [" + META_INDEX_NAME + "]"); internalListener.onResponse(true); } else { internalListener.onResponse(false); @@ -92,7 +97,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { internalListener.onResponse(true); } else { - log.error("failed to create index [" + indexName + "]", e); + log.error("failed to create index [" + META_INDEX_NAME + "]", e); internalListener.onFailure(e); } }); @@ -102,7 +107,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { listener.onResponse(true); } else { - log.error("failed to create index [" + indexName + "]", e); + log.error("failed to create index [" + META_INDEX_NAME + "]", e); listener.onFailure(e); } } @@ -119,12 +124,9 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) public void createConversation(String name, ActionListener listener) { initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { if (indexExists) { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); IndexRequest request = Requests - .indexRequest(indexName) + .indexRequest(META_INDEX_NAME) .source( ConversationalIndexConstants.META_CREATED_FIELD, Instant.now(), @@ -171,12 +173,12 @@ public void createConversation(ActionListener listener) { * @param listener gets the list of conversation metadata objects in the index */ public void getConversations(int from, int maxResults, ActionListener> listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(List.of()); return; } - SearchRequest request = Requests.searchRequest(indexName); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + SearchRequest request = Requests.searchRequest(META_INDEX_NAME); + String userstr = userstr(); QueryBuilder queryBuilder; if (userstr == null) queryBuilder = new MatchAllQueryBuilder(); @@ -197,13 +199,12 @@ public void getConversations(int from, int maxResults, ActionListener { client.search(request, al); }, e -> { - log.error("Failed to retrieve conversations during refresh", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, al); + }, e -> { + log.error("Failed to retrieve conversations during refresh", e); + internalListener.onFailure(e); + })); } catch (Exception e) { log.error("Failed to retrieve conversations", e); listener.onFailure(e); @@ -225,12 +226,12 @@ public void getConversations(int maxResults, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); return; } - DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); this.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { @@ -267,14 +268,14 @@ public void deleteConversation(String conversationId, ActionListener li */ public void checkAccess(String conversationId, ActionListener listener) { // If the index doesn't exist, you have permission. Just won't get you anywhere - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest getRequest = Requests.getRequest(indexName).id(conversationId); + GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { @@ -294,13 +295,12 @@ public void checkAccess(String conversationId, ActionListener listener) } internalListener.onResponse(true); }, e -> { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> { - log.error("Failed to refresh conversations index during check access ", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(getRequest, al); + }, e -> { + log.error("Failed to refresh conversations index during check access ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } @@ -313,11 +313,11 @@ public void checkAccess(String conversationId, ActionListener listener) * @param listener receives the search response for the wrapped query */ public void searchConversations(SearchRequest request, ActionListener listener) { - request.indices(indexName); + request.indices(META_INDEX_NAME); QueryBuilder originalQuery = request.source().query(); BoolQueryBuilder newQuery = new BoolQueryBuilder(); newQuery.must(originalQuery); - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); if (userstr != null) { String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user)); @@ -325,7 +325,7 @@ public void searchConversations(SearchRequest request, ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { client.search(request, internalListener); }, e -> { log.error("Failed to refresh conversations index during search conversations ", e); @@ -342,15 +342,17 @@ public void searchConversations(SearchRequest request, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) { listener - .onFailure(new IndexNotFoundException("cannot get conversation since the conversation index does not exist", indexName)); + .onFailure( + new IndexNotFoundException("cannot get conversation since the conversation index does not exist", META_INDEX_NAME) + ); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest request = Requests.getRequest(indexName).id(conversationId); + GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { @@ -374,13 +376,12 @@ public void getConversation(String conversationId, ActionListener { internalListener.onFailure(e); }); - client - .admin() - .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> { - log.error("Failed to refresh conversations index during get conversation ", e); - internalListener.onFailure(e); - })); + client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh conversations index during get conversation ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index 30e47019fd..bd4eb1e39a 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -17,6 +17,8 @@ */ package org.opensearch.ml.memory.index; +import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; + import java.io.IOException; import java.time.Instant; import java.util.LinkedList; @@ -69,23 +71,28 @@ public class InteractionsIndex { private Client client; private ClusterService clusterService; private ConversationMetaIndex conversationMetaIndex; - private final String indexName = ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; // How big the steps should be when gathering *ALL* interactions in a conversation private final int resultsAtATime = 300; + private String userstr() { + return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + } + /** * 'PUT's the index in opensearch if it's not there already * @param listener gets whether the index needed to be initialized. Throws error if it fails to init */ public void initInteractionsIndexIfAbsent(ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { log.debug("No interactions index found. Adding it"); - CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); + CreateIndexRequest request = Requests + .createIndexRequest(INTERACTIONS_INDEX_NAME) + .mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(r -> { - if (r.equals(new CreateIndexResponse(true, true, indexName))) { - log.info("created index [" + indexName + "]"); + if (r.equals(new CreateIndexResponse(true, true, INTERACTIONS_INDEX_NAME))) { + log.info("created index [" + INTERACTIONS_INDEX_NAME + "]"); internalListener.onResponse(true); } else { internalListener.onResponse(false); @@ -95,7 +102,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { internalListener.onResponse(true); } else { - log.error("Failed to create index [" + indexName + "]", e); + log.error("Failed to create index [" + INTERACTIONS_INDEX_NAME + "]", e); internalListener.onFailure(e); } }); @@ -105,7 +112,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { listener.onResponse(true); } else { - log.error("Failed to create index [" + indexName + "]", e); + log.error("Failed to create index [" + INTERACTIONS_INDEX_NAME + "]", e); listener.onFailure(e); } } @@ -136,16 +143,13 @@ public void createInteraction( ActionListener listener ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { IndexRequest request = Requests - .indexRequest(indexName) + .indexRequest(INTERACTIONS_INDEX_NAME) .source( ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin, @@ -215,7 +219,7 @@ public void createInteraction( * @param listener gets the list, sorted by recency, of interactions */ public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener.onResponse(List.of()); return; } @@ -236,7 +240,7 @@ public void getInteractions(String conversationId, int from, int maxResults, Act @VisibleForTesting void innerGetInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { - SearchRequest request = Requests.searchRequest(indexName); + SearchRequest request = Requests.searchRequest(INTERACTIONS_INDEX_NAME); TermQueryBuilder builder = new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); request.source().query(builder); request.source().from(from).size(maxResults); @@ -253,7 +257,7 @@ void innerGetInteractions(String conversationId, int from, int maxResults, Actio client .admin() .indices() - .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(r -> { client.search(request, al); }, e -> { + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(r -> { client.search(request, al); }, e -> { internalListener.onFailure(e); })); } catch (Exception e) { @@ -313,18 +317,18 @@ ActionListener> nextGetListener( * @param listener gets whether the deletion was successful */ public void deleteConversation(String conversationId, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { listener.onResponse(true); return; } - String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { BulkRequest request = Requests.bulkRequest(); for (Interaction interaction : interactions) { - DeleteRequest delRequest = Requests.deleteRequest(indexName).id(interaction.getId()); + DeleteRequest delRequest = Requests.deleteRequest(INTERACTIONS_INDEX_NAME).id(interaction.getId()); request.add(delRequest); } client @@ -358,26 +362,26 @@ public void searchInteractions(String conversationId, SearchRequest request, Act if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - request.indices(indexName); + request.indices(INTERACTIONS_INDEX_NAME); QueryBuilder originalQuery = request.source().query(); BoolQueryBuilder newQuery = new BoolQueryBuilder(); newQuery.must(originalQuery); newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId)); request.source().query(newQuery); - client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { - client.search(request, internalListener); - }, e -> { - log.error("Failed to refresh interactions index during search interactions ", e); - internalListener.onFailure(e); - })); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.search(request, internalListener); + }, e -> { + log.error("Failed to refresh interactions index during search interactions ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } @@ -391,15 +395,21 @@ public void searchInteractions(String conversationId, SearchRequest request, Act * @param listener receives the interaction */ public void getInteraction(String conversationId, String interactionId, ActionListener listener) { - if (!clusterService.state().metadata().hasIndex(indexName)) { - listener.onFailure(new IndexNotFoundException("cannot get interaction since the interactions index does not exist", indexName)); + if (!clusterService.state().metadata().hasIndex(INTERACTIONS_INDEX_NAME)) { + listener + .onFailure( + new IndexNotFoundException( + "cannot get interaction since the interactions index does not exist", + INTERACTIONS_INDEX_NAME + ) + ); return; } conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - GetRequest request = Requests.getRequest(indexName).id(interactionId); + GetRequest request = Requests.getRequest(INTERACTIONS_INDEX_NAME).id(interactionId); ActionListener al = ActionListener.wrap(getResponse -> { // If the conversation doesn't exist, fail if (!(getResponse.isExists() && getResponse.getId().equals(interactionId))) { @@ -411,21 +421,17 @@ public void getInteraction(String conversationId, String interactionId, ActionLi client .admin() .indices() - .refresh( - Requests.refreshRequest(indexName), - ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> { - log.error("Failed to refresh interactions index during get interaction ", e); - internalListener.onFailure(e); - }) - ); + .refresh(Requests.refreshRequest(INTERACTIONS_INDEX_NAME), ActionListener.wrap(refreshResponse -> { + client.get(request, al); + }, e -> { + log.error("Failed to refresh interactions index during get interaction ", e); + internalListener.onFailure(e); + })); } catch (Exception e) { listener.onFailure(e); } } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String userstr = userstr(); String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java index a42188ce8b..7ec9d8c042 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchConversationsTransportActionTests.java @@ -92,7 +92,7 @@ public void setup() throws IOException { when(this.clusterService.getClusterSettings()) .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); - this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testEnabled_ThenSucceed() { @@ -111,7 +111,7 @@ public void testEnabled_ThenSucceed() { public void testDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); - this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java index 85e0c1bf0d..abe5204c65 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/SearchInteractionsTransportActionsTests.java @@ -92,7 +92,7 @@ public void setup() throws IOException { .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); when(this.request.getConversationId()).thenReturn("test_cid"); - this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); } public void testFeatureEnabled_ThenSucceed() { @@ -111,7 +111,7 @@ public void testFeatureEnabled_ThenSucceed() { public void testDisabled_ThenFail() { when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); - this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, clusterService)); + this.action = spy(new SearchInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); action.doExecute(null, request, actionListener); ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); diff --git a/plugin/build.gradle b/plugin/build.gradle index 2a9607a5d9..b389360e0b 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -256,7 +256,6 @@ jacocoTestReport { xml.getRequired().set(true) csv.getRequired().set(false) html.getRequired().set(true) - html.outputLocation = layout.buildDirectory.dir('jacocoHtml') } dependsOn test