From c1a2b66867bbc5e465ad1b07bf2c10f8dee7c8ab Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Tue, 12 Sep 2023 09:01:58 -0700 Subject: [PATCH] fix system index access bug #1272 (#1320) Signed-off-by: HenryL27 --- .../memory/index/ConversationMetaIndex.java | 70 +++++++++---------- .../ml/memory/index/InteractionsIndex.java | 25 +++---- .../index/ConversationMetaIndexTests.java | 11 --- 3 files changed, 42 insertions(+), 64 deletions(-) diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java index d7a4169fe7..64b4f5267f 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -48,6 +48,7 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; import org.opensearch.ml.common.conversation.ConversationMeta; import org.opensearch.ml.common.conversation.ConversationalIndexConstants; import org.opensearch.search.SearchHit; @@ -75,7 +76,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener listener) if (!clusterService.state().metadata().hasIndex(indexName)) { log.debug("No conversational meta index found. Adding it"); CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(createIndexResponse -> { if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) { @@ -130,7 +131,7 @@ public void createConversation(String name, ActionListener listener) { ConversationalIndexConstants.USER_FIELD, userstr == null ? null : User.parse(userstr).getName() ); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { if (resp.status() == RestStatus.CREATED) { @@ -181,7 +182,7 @@ public void getConversations(int from, int maxResults, ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(searchResponse -> { List result = new LinkedList(); @@ -225,37 +226,34 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); } DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { - ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - // When we get the delete response, do this: - ActionListener al = ActionListener.wrap(deleteResponse -> { - if (deleteResponse.getResult() == Result.DELETED) { - internalListener.onResponse(true); - } else if (deleteResponse.status() == RestStatus.NOT_FOUND) { - internalListener.onResponse(true); - } else { - internalListener.onResponse(false); - } - }, e -> { - log.error("Failure deleting conversation " + conversationId, e); - internalListener.onFailure(e); - }); - this.checkAccess(conversationId, ActionListener.wrap(access -> { - if (access) { + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + this.checkAccess(conversationId, ActionListener.wrap(access -> { + if (access) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + // When we get the delete response, do this: + ActionListener al = ActionListener.wrap(deleteResponse -> { + if (deleteResponse.getResult() == Result.DELETED) { + internalListener.onResponse(true); + } else if (deleteResponse.status() == RestStatus.NOT_FOUND) { + internalListener.onResponse(true); + } else { + internalListener.onResponse(false); + } + }, e -> { + log.error("Failure deleting conversation " + conversationId, e); + internalListener.onFailure(e); + }); client.delete(delRequest, al); - } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr).getName(); - throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } catch (Exception e) { + log.error("Failed deleting conversation with id=" + conversationId, e); + listener.onFailure(e); } - }, e -> { internalListener.onFailure(e); })); - } catch (Exception e) { - log.error("Failed deleting conversation with id=" + conversationId, e); - listener.onFailure(e); - } + } else { + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); } /** @@ -269,13 +267,9 @@ public void checkAccess(String conversationId, ActionListener listener) listener.onResponse(true); return; } - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - log.info("USERSTR: " + userstr); // If security is off - User doesn't exist - you have permission if (userstr == null || User.parse(userstr) == null) { internalListener.onResponse(true); diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java index a6714c63c3..54857b274c 100644 --- a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -75,7 +75,7 @@ public void initInteractionsIndexIfAbsent(ActionListener listener) { if (!clusterService.state().metadata().hasIndex(indexName)) { log.debug("No interactions index found. Adding it"); CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(r -> { if (r.equals(new CreateIndexResponse(true, true, indexName))) { @@ -130,6 +130,11 @@ public void createInteraction( ActionListener listener ) { initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); if (indexExists) { this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { if (access) { @@ -151,7 +156,7 @@ public void createInteraction( ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, timestamp ); - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener al = ActionListener.wrap(resp -> { if (resp.status() == RestStatus.CREATED) { @@ -165,13 +170,6 @@ public void createInteraction( listener.onFailure(e); } } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null - ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS - : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } }, e -> { listener.onFailure(e); })); @@ -313,7 +311,9 @@ public void deleteConversation(String conversationId, ActionListener li listener.onResponse(true); return; } - try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); ActionListener> searchListener = ActionListener.wrap(interactions -> { BulkRequest request = Requests.bulkRequest(); @@ -330,11 +330,6 @@ public void deleteConversation(String conversationId, ActionListener li if (access) { getAllInteractions(conversationId, resultsAtATime, searchListener); } else { - String userstr = client - .threadPool() - .getThreadContext() - .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); } }, e -> { listener.onFailure(e); }); diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java index 8d9667536b..c85384407e 100644 --- a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -402,17 +402,6 @@ public void testDelete_DeleteFails_ThenFail() { assert (argCaptor.getValue().getMessage().equals("Test Fail in Delete")); } - public void testDelete_HighLevelFailure_ThenFail() { - doReturn(true).when(metadata).hasIndex(anyString()); - doThrow(new RuntimeException("Check Fail")).when(conversationMetaIndex).checkAccess(any(), any()); - @SuppressWarnings("unchecked") - ActionListener deleteConversationListener = mock(ActionListener.class); - conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); - ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); - verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); - assert (argCaptor.getValue().getMessage().equals("Check Fail")); - } - public void testCheckAccess_DoesNotExist_ThenFail() { setupUser("user"); doReturn(true).when(metadata).hasIndex(anyString());