Skip to content

Commit

Permalink
fix system index access bug #1272 (#1320)
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 authored Sep 12, 2023
1 parent 65f47e9 commit 8cdac91
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.ml.common.conversation.ConversationMeta;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.search.SearchHit;
Expand Down Expand Up @@ -75,7 +76,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener)
if (!clusterService.state().metadata().hasIndex(indexName)) {
log.debug("No conversational meta index found. Adding it");
CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<CreateIndexResponse> al = ActionListener.wrap(createIndexResponse -> {
if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) {
Expand Down Expand Up @@ -130,7 +131,7 @@ public void createConversation(String name, ActionListener<String> listener) {
ConversationalIndexConstants.USER_FIELD,
userstr == null ? null : User.parse(userstr).getName()
);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
if (resp.status() == RestStatus.CREATED) {
Expand Down Expand Up @@ -181,7 +182,7 @@ public void getConversations(int from, int maxResults, ActionListener<List<Conve
request.source().query(queryBuilder);
request.source().from(from).size(maxResults);
request.source().sort(ConversationalIndexConstants.META_CREATED_FIELD, SortOrder.DESC);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<List<ConversationMeta>> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<SearchResponse> al = ActionListener.wrap(searchResponse -> {
List<ConversationMeta> result = new LinkedList<ConversationMeta>();
Expand Down Expand Up @@ -225,37 +226,34 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
listener.onResponse(true);
}
DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
// When we get the delete response, do this:
ActionListener<DeleteResponse> al = ActionListener.wrap(deleteResponse -> {
if (deleteResponse.getResult() == Result.DELETED) {
internalListener.onResponse(true);
} else if (deleteResponse.status() == RestStatus.NOT_FOUND) {
internalListener.onResponse(true);
} else {
internalListener.onResponse(false);
}
}, e -> {
log.error("Failure deleting conversation " + conversationId, e);
internalListener.onFailure(e);
});
this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
// When we get the delete response, do this:
ActionListener<DeleteResponse> al = ActionListener.wrap(deleteResponse -> {
if (deleteResponse.getResult() == Result.DELETED) {
internalListener.onResponse(true);
} else if (deleteResponse.status() == RestStatus.NOT_FOUND) {
internalListener.onResponse(true);
} else {
internalListener.onResponse(false);
}
}, e -> {
log.error("Failure deleting conversation " + conversationId, e);
internalListener.onFailure(e);
});
client.delete(delRequest, al);
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
} catch (Exception e) {
log.error("Failed deleting conversation with id=" + conversationId, e);
listener.onFailure(e);
}
}, e -> { internalListener.onFailure(e); }));
} catch (Exception e) {
log.error("Failed deleting conversation with id=" + conversationId, e);
listener.onFailure(e);
}
} else {
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); }));
}

/**
Expand All @@ -269,13 +267,9 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
listener.onResponse(true);
return;
}
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
log.info("USERSTR: " + userstr);
// If security is off - User doesn't exist - you have permission
if (userstr == null || User.parse(userstr) == null) {
internalListener.onResponse(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void initInteractionsIndexIfAbsent(ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(indexName)) {
log.debug("No interactions index found. Adding it");
CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<CreateIndexResponse> al = ActionListener.wrap(r -> {
if (r.equals(new CreateIndexResponse(true, true, indexName))) {
Expand Down Expand Up @@ -130,6 +130,11 @@ public void createInteraction(
ActionListener<String> listener
) {
initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
if (indexExists) {
this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
Expand All @@ -151,7 +156,7 @@ public void createInteraction(
ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD,
timestamp
);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
if (resp.status() == RestStatus.CREATED) {
Expand All @@ -165,13 +170,6 @@ public void createInteraction(
listener.onFailure(e);
}
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null
? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS
: User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); }));
Expand Down Expand Up @@ -313,7 +311,9 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
listener.onResponse(true);
return;
}
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<List<Interaction>> searchListener = ActionListener.wrap(interactions -> {
BulkRequest request = Requests.bulkRequest();
Expand All @@ -330,11 +330,6 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
if (access) {
getAllInteractions(conversationId, resultsAtATime, searchListener);
} else {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
}
}, e -> { listener.onFailure(e); });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,17 +402,6 @@ public void testDelete_DeleteFails_ThenFail() {
assert (argCaptor.getValue().getMessage().equals("Test Fail in Delete"));
}

public void testDelete_HighLevelFailure_ThenFail() {
doReturn(true).when(metadata).hasIndex(anyString());
doThrow(new RuntimeException("Check Fail")).when(conversationMetaIndex).checkAccess(any(), any());
@SuppressWarnings("unchecked")
ActionListener<Boolean> deleteConversationListener = mock(ActionListener.class);
conversationMetaIndex.deleteConversation("test-id", deleteConversationListener);
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture());
assert (argCaptor.getValue().getMessage().equals("Check Fail"));
}

public void testCheckAccess_DoesNotExist_ThenFail() {
setupUser("user");
doReturn(true).when(metadata).hasIndex(anyString());
Expand Down

0 comments on commit 8cdac91

Please sign in to comment.