Skip to content

Commit

Permalink
Add search and singular APIs to conversation memory (#1504)
Browse files Browse the repository at this point in the history
* add searchConversation

Signed-off-by: HenryL27 <[email protected]>

* add searchinteractions

Signed-off-by: HenryL27 <[email protected]>

* add searchConversationsITTests

Signed-off-by: HenryL27 <[email protected]>

* add searchInteractionsITTests

Signed-off-by: HenryL27 <[email protected]>

* add unit tests for storage-layer search

Signed-off-by: HenryL27 <[email protected]>

* add Search transport actions and tests

Signed-off-by: HenryL27 <[email protected]>

* add rest search actions

Signed-off-by: HenryL27 <[email protected]>

* add search rest actions

Signed-off-by: HenryL27 <[email protected]>

* Add singular get actions at storage layer

Signed-off-by: HenryL27 <[email protected]>

* Add OpenSearhMemoryHandler unit tests for singular get

Signed-off-by: HenryL27 <[email protected]>

* Add singular get transport layer

Signed-off-by: HenryL27 <[email protected]>

* add singular get rest actions

Signed-off-by: HenryL27 <[email protected]>

* fix async return value problem

Signed-off-by: HenryL27 <[email protected]>

* address esay PR comments

Signed-off-by: HenryL27 <[email protected]>

---------

Signed-off-by: HenryL27 <[email protected]>
(cherry picked from commit adc7da0)
  • Loading branch information
HenryL27 authored and github-actions[bot] committed Nov 30, 2023
1 parent f46d258 commit 9bc0c0c
Show file tree
Hide file tree
Showing 48 changed files with 3,875 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,25 @@ public class ActionConstants {
/** name of success field in all requests */
public final static String SUCCESS_FIELD = "success";

private final static String BASE_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for create conversation */
public final static String CREATE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for list conversations */
public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation";
/** path for put interaction */
public final static String CREATE_INTERACTION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}";
public final static String CREATE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/_create";
/** path for get conversations */
public final static String GET_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_list";
/** path for create interaction */
public final static String CREATE_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_create";
/** path for get interactions */
public final static String GET_INTERACTIONS_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}";
public final static String GET_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_list";
/** path for delete conversation */
public final static String DELETE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}";
public final static String DELETE_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_delete";
/** path for search conversations */
public final static String SEARCH_CONVERSATIONS_REST_PATH = BASE_REST_PATH + "/_search";
/** path for search interactions */
public final static String SEARCH_INTERACTIONS_REST_PATH = BASE_REST_PATH + "/{conversation_id}/_search";
/** path for get conversation */
public final static String GET_CONVERSATION_REST_PATH = BASE_REST_PATH + "/{conversation_id}";
/** path for get interaction */
public final static String GET_INTERACTION_REST_PATH = BASE_REST_PATH + "/{conversation_id}/{interaction_id}";

/** default max results returned by get operations */
public final static int DEFAULT_MAX_RESULTS = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import java.util.List;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
Expand Down Expand Up @@ -171,4 +173,64 @@ public ActionFuture<String> createInteraction(
*/
public ActionFuture<Boolean> deleteConversation(String conversationId);

/**
* Search over conversations index
* @param request search request over the conversations index
* @param listener receives the search response
*/
public void searchConversations(SearchRequest request, ActionListener<SearchResponse> listener);

/**
* Search over conversations index
* @param request search request over the conversations index
* @return ActionFuture for the search response
*/
public ActionFuture<SearchResponse> searchConversations(SearchRequest request);

/**
* Search over interactions of a conversation
* @param conversationId id of the conversation to search through
* @param request search request over the interactions
* @param listener receives the search response
*/
public void searchInteractions(String conversationId, SearchRequest request, ActionListener<SearchResponse> listener);

/**
* Search over interactions of a conversation
* @param conversationId id of the conversation to search through
* @param request search request over the interactions
* @return ActionFuture for the search response
*/
public ActionFuture<SearchResponse> searchInteractions(String conversationId, SearchRequest request);

/**
* Get a single ConversationMeta object
* @param conversationId id of the conversation to get
* @param listener receives the conversationMeta object
*/
public void getConversation(String conversationId, ActionListener<ConversationMeta> listener);

/**
* Get a single ConversationMeta object
* @param conversationId id of the conversation to get
* @return ActionFuture for the conversationMeta object
*/
public ActionFuture<ConversationMeta> getConversation(String conversationId);

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @param listener receives the interaction
*/
public void getInteraction(String conversationId, String interactionId, ActionListener<Interaction> listener);

/**
* Get a single interaction
* @param conversationId id of the conversation this interaction belongs to
* @param interactionId id of this interaction
* @return ActionFuture for the interaction
*/
public ActionFuture<Interaction> getInteraction(String conversationId, String interactionId);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.action.ActionType;

/**
* Action for retrieving a top-level conversation object by id
*/
public class GetConversationAction extends ActionType<GetConversationResponse> {
/** Instance of this */
public static final GetConversationAction INSTANCE = new GetConversationAction();
/** Name of this action */
public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/get";

private GetConversationAction() {
super(NAME, GetConversationResponse::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import static org.opensearch.action.ValidateActions.addValidationError;

import java.io.IOException;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.ml.common.conversation.ActionConstants;
import org.opensearch.rest.RestRequest;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* Action Request object for GetConversation (singular)
*/
@AllArgsConstructor
public class GetConversationRequest extends ActionRequest {
@Getter
private String conversationId;

/**
* Stream Constructor
* @param in input stream to read this from
* @throws IOException if something goes wrong reading from stream
*/
public GetConversationRequest(StreamInput in) throws IOException {
super(in);
this.conversationId = in.readString();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(this.conversationId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;
if (this.conversationId == null) {
exception = addValidationError("GetConversation Request must have a conversation id", exception);
}
return exception;
}

/**
* Creates a GetConversationRequest from a rest request
* @param request Rest Request representing a GetConversationRequest
* @return the new GetConversationRequest
* @throws IOException if something goes wrong in translation
*/
public static GetConversationRequest fromRestRequest(RestRequest request) throws IOException {
String conversationId = request.param(ActionConstants.CONVERSATION_ID_FIELD);
return new GetConversationRequest(conversationId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import java.io.IOException;

import org.opensearch.core.action.ActionResponse;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.conversation.ConversationMeta;

import lombok.AllArgsConstructor;
import lombok.Getter;

/**
* ActionResponse object for GetConversation (singular)
*/
@AllArgsConstructor
public class GetConversationResponse extends ActionResponse implements ToXContentObject {

@Getter
private ConversationMeta conversation;

/**
* Stream Constructor
* @param in input stream to read this from
* @throws IOException if soething goes wrong in reading
*/
public GetConversationResponse(StreamInput in) throws IOException {
super(in);
this.conversation = ConversationMeta.fromStream(in);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
this.conversation.writeTo(out);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return this.conversation.toXContent(builder, params);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.memory.action.conversation;

import org.opensearch.OpenSearchException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.conversation.ConversationMeta;
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
import org.opensearch.ml.memory.ConversationalMemoryHandler;
import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import lombok.extern.log4j.Log4j2;

/**
* Transport Action for GetConversation
*/
@Log4j2
public class GetConversationTransportAction extends HandledTransportAction<GetConversationRequest, GetConversationResponse> {
private Client client;
private ConversationalMemoryHandler cmHandler;

private volatile boolean featureIsEnabled;

/**
* Constructor
* @param transportService for inter-node communications
* @param actionFilters for filtering actions
* @param cmHandler Handler for conversational memory operations
* @param client OS Client for dealing with OS
* @param clusterService for some cluster ops
*/
@Inject
public GetConversationTransportAction(
TransportService transportService,
ActionFilters actionFilters,
OpenSearchConversationalMemoryHandler cmHandler,
Client client,
ClusterService clusterService
) {
super(GetConversationAction.NAME, transportService, actionFilters, GetConversationRequest::new);
this.client = client;
this.cmHandler = cmHandler;
this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings());
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it);
}

@Override
public void doExecute(Task task, GetConversationRequest request, ActionListener<GetConversationResponse> actionListener) {
if (!featureIsEnabled) {
actionListener
.onFailure(
new OpenSearchException(
"The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
+ ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()
)
);
return;
} else {
String conversationId = request.getConversationId();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) {
ActionListener<GetConversationResponse> internalListener = ActionListener
.runBefore(actionListener, () -> context.restore());
ActionListener<ConversationMeta> al = ActionListener.wrap(conversationMeta -> {
internalListener.onResponse(new GetConversationResponse(conversationMeta));
}, e -> { internalListener.onFailure(e); });
cmHandler.getConversation(conversationId, al);
} catch (Exception e) {
log.error("Failed to get Conversation " + conversationId, e);
actionListener.onFailure(e);
}

}

}
}
Loading

0 comments on commit 9bc0c0c

Please sign in to comment.