Skip to content

Commit

Permalink
address esay PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Oct 17, 2023
1 parent 0f80691 commit 09221dd
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,24 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class SearchConversationsTransportAction extends HandledTransportAction<SearchRequest, SearchResponse> {

private ConversationalMemoryHandler cmHandler;
private Client client;

private volatile boolean featureIsEnabled;

Expand All @@ -50,10 +56,12 @@ public SearchConversationsTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
) {
super(SearchConversationsAction.NAME, transportService, actionFilters, SearchRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
Expand All @@ -72,7 +80,13 @@ public void doExecute(Task task, SearchRequest request, ActionListener<SearchRes
);
return;
} else {
cmHandler.searchConversations(request, actionListener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
cmHandler.searchConversations(request, internalListener);
} catch (Exception e) {
log.error("Failed to search conversations", e);
actionListener.onFailure(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,24 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

@Log4j2
public class SearchInteractionsTransportAction extends HandledTransportAction<SearchInteractionsRequest, SearchResponse> {

private ConversationalMemoryHandler cmHandler;
private Client client;

private volatile boolean featureIsEnabled;

Expand All @@ -49,10 +55,12 @@ public SearchInteractionsTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
) {
super(SearchInteractionsAction.NAME, transportService, actionFilters, SearchInteractionsRequest::new);
this.cmHandler = cmHandler;
this.client = client;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
Expand All @@ -71,7 +79,13 @@ public void doExecute(Task task, SearchInteractionsRequest request, ActionListen
);
return;
} else {
cmHandler.searchInteractions(request.getConversationId(), request, actionListener);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(actionListener, () -> context.restore());
cmHandler.searchInteractions(request.getConversationId(), request, internalListener);
} catch (Exception e) {
log.error("Failed to search conversations", e);
actionListener.onFailure(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.opensearch.ml.memory.index;

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.META_INDEX_NAME;

import java.io.IOException;
import java.time.Instant;
import java.util.LinkedList;
Expand Down Expand Up @@ -68,21 +70,24 @@ public class ConversationMetaIndex {

private Client client;
private ClusterService clusterService;
private static final String indexName = ConversationalIndexConstants.META_INDEX_NAME;

private String userstr() {
return client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
}

/**
* Creates the conversational meta index if it doesn't already exist
* @param listener listener to wait for this to finish
*/
public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(indexName)) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
log.debug("No conversational meta index found. Adding it");
CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING);
CreateIndexRequest request = Requests.createIndexRequest(META_INDEX_NAME).mapping(ConversationalIndexConstants.META_MAPPING);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
ActionListener<CreateIndexResponse> al = ActionListener.wrap(createIndexResponse -> {
if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) {
log.info("created index [" + indexName + "]");
if (createIndexResponse.equals(new CreateIndexResponse(true, true, META_INDEX_NAME))) {
log.info("created index [" + META_INDEX_NAME + "]");
internalListener.onResponse(true);
} else {
internalListener.onResponse(false);
Expand All @@ -92,7 +97,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener)
|| (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) {
internalListener.onResponse(true);
} else {
log.error("failed to create index [" + indexName + "]", e);
log.error("failed to create index [" + META_INDEX_NAME + "]", e);
internalListener.onFailure(e);
}
});
Expand All @@ -102,7 +107,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener)
|| (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) {
listener.onResponse(true);
} else {
log.error("failed to create index [" + indexName + "]", e);
log.error("failed to create index [" + META_INDEX_NAME + "]", e);
listener.onFailure(e);
}
}
Expand All @@ -119,12 +124,9 @@ public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener)
public void createConversation(String name, ActionListener<String> listener) {
initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> {
if (indexExists) {
String userstr = client
.threadPool()
.getThreadContext()
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String userstr = userstr();
IndexRequest request = Requests
.indexRequest(indexName)
.indexRequest(META_INDEX_NAME)
.source(
ConversationalIndexConstants.META_CREATED_FIELD,
Instant.now(),
Expand Down Expand Up @@ -171,12 +173,12 @@ public void createConversation(ActionListener<String> listener) {
* @param listener gets the list of conversation metadata objects in the index
*/
public void getConversations(int from, int maxResults, ActionListener<List<ConversationMeta>> listener) {
if (!clusterService.state().metadata().hasIndex(indexName)) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(List.of());
return;
}
SearchRequest request = Requests.searchRequest(indexName);
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
String userstr = userstr();
QueryBuilder queryBuilder;
if (userstr == null)
queryBuilder = new MatchAllQueryBuilder();
Expand All @@ -197,13 +199,12 @@ public void getConversations(int from, int maxResults, ActionListener<List<Conve
log.error("Failed to retrieve conversations", e);
internalListener.onFailure(e);
});
client
.admin()
.indices()
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.search(request, al); }, e -> {
log.error("Failed to retrieve conversations during refresh", e);
internalListener.onFailure(e);
}));
client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
client.search(request, al);
}, e -> {
log.error("Failed to retrieve conversations during refresh", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to retrieve conversations", e);
listener.onFailure(e);
Expand All @@ -225,12 +226,12 @@ public void getConversations(int maxResults, ActionListener<List<ConversationMet
* @param listener gets whether the deletion was successful
*/
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(indexName)) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(true);
return;
}
DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId);
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
String userstr = userstr();
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
this.checkAccess(conversationId, ActionListener.wrap(access -> {
if (access) {
Expand Down Expand Up @@ -267,14 +268,14 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
*/
public void checkAccess(String conversationId, ActionListener<Boolean> listener) {
// If the index doesn't exist, you have permission. Just won't get you anywhere
if (!clusterService.state().metadata().hasIndex(indexName)) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(true);
return;
}
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String userstr = userstr();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest getRequest = Requests.getRequest(indexName).id(conversationId);
GetRequest getRequest = Requests.getRequest(META_INDEX_NAME).id(conversationId);
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the conversation doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) {
Expand All @@ -294,13 +295,12 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
}
internalListener.onResponse(true);
}, e -> { internalListener.onFailure(e); });
client
.admin()
.indices()
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(getRequest, al); }, e -> {
log.error("Failed to refresh conversations index during check access ", e);
internalListener.onFailure(e);
}));
client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
client.get(getRequest, al);
}, e -> {
log.error("Failed to refresh conversations index during check access ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
Expand All @@ -313,19 +313,19 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
* @param listener receives the search response for the wrapped query
*/
public void searchConversations(SearchRequest request, ActionListener<SearchResponse> listener) {
request.indices(indexName);
request.indices(META_INDEX_NAME);
QueryBuilder originalQuery = request.source().query();
BoolQueryBuilder newQuery = new BoolQueryBuilder();
newQuery.must(originalQuery);
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String userstr = userstr();
if (userstr != null) {
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
newQuery.must(new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, user));
}
request.source().query(newQuery);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
client.admin().indices().refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> {
client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
client.search(request, internalListener);
}, e -> {
log.error("Failed to refresh conversations index during search conversations ", e);
Expand All @@ -342,15 +342,17 @@ public void searchConversations(SearchRequest request, ActionListener<SearchResp
* @param listener receives the conversationMeta object
*/
public void getConversation(String conversationId, ActionListener<ConversationMeta> listener) {
if (!clusterService.state().metadata().hasIndex(indexName)) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener
.onFailure(new IndexNotFoundException("cannot get conversation since the conversation index does not exist", indexName));
.onFailure(
new IndexNotFoundException("cannot get conversation since the conversation index does not exist", META_INDEX_NAME)
);
return;
}
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
String userstr = userstr();
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<ConversationMeta> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
GetRequest request = Requests.getRequest(indexName).id(conversationId);
GetRequest request = Requests.getRequest(META_INDEX_NAME).id(conversationId);
ActionListener<GetResponse> al = ActionListener.wrap(getResponse -> {
// If the conversation doesn't exist, fail
if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) {
Expand All @@ -374,13 +376,12 @@ public void getConversation(String conversationId, ActionListener<ConversationMe
new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId)
);
}, e -> { internalListener.onFailure(e); });
client
.admin()
.indices()
.refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.get(request, al); }, e -> {
log.error("Failed to refresh conversations index during get conversation ", e);
internalListener.onFailure(e);
}));
client.admin().indices().refresh(Requests.refreshRequest(META_INDEX_NAME), ActionListener.wrap(refreshResponse -> {
client.get(request, al);
}, e -> {
log.error("Failed to refresh conversations index during get conversation ", e);
internalListener.onFailure(e);
}));
} catch (Exception e) {
listener.onFailure(e);
}
Expand Down
Loading

0 comments on commit 09221dd

Please sign in to comment.