diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java new file mode 100644 index 0000000000..5bb8334bc1 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ActionConstants.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.conversation; + +/** + * Constants for conversational actions + */ +public class ActionConstants { + + /** name of conversation Id field in all responses */ + public final static String CONVERSATION_ID_FIELD = "conversation_id"; + + /** name of list of conversations in all responses */ + public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations"; + /** name of list on interactions in all responses */ + public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions"; + /** name of interaction Id field in all responses */ + public final static String RESPONSE_INTERACTION_ID_FIELD = "interaction_id"; + + /** name of conversation name in all requests */ + public final static String REQUEST_CONVERSATION_NAME_FIELD = "name"; + /** name of maxResults field name in all requests */ + public final static String REQUEST_MAX_RESULTS_FIELD = "max_results"; + /** name of nextToken field name in all messages */ + public final static String NEXT_TOKEN_FIELD = "next_token"; + /** name of input field in all requests */ + public final static String INPUT_FIELD = "input"; + /** name of AI response field in all respopnses */ + public final static String AI_RESPONSE_FIELD = "response"; + /** name of origin field in all requests */ + public final static String RESPONSE_ORIGIN_FIELD = "origin"; + /** name of prompt template field in all requests */ + public final static String PROMPT_TEMPLATE_FIELD = "prompt_template"; + /** name of metadata field in all requests */ + public final static String ADDITIONAL_INFO_FIELD = "additional_info"; + /** name of success field in all requests */ + public final static String SUCCESS_FIELD = "success"; + + /** path for create conversation */ + public final static String CREATE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation"; + /** path for list conversations */ + public final static String GET_CONVERSATIONS_REST_PATH = "/_plugins/_ml/memory/conversation"; + /** path for put interaction */ + public final static String CREATE_INTERACTION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + /** path for get interactions */ + public final static String GET_INTERACTIONS_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + /** path for delete conversation */ + public final static String DELETE_CONVERSATION_REST_PATH = "/_plugins/_ml/memory/conversation/{conversation_id}"; + + /** default max results returned by get operations */ + public final static int DEFAULT_MAX_RESULTS = 10; + + /** default username for reporting security errors if no or malformed username */ + public final static String DEFAULT_USERNAME_FOR_ERRORS = "BAD_USER"; +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java new file mode 100644 index 0000000000..8ba518a065 --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationMeta.java @@ -0,0 +1,145 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; +import java.util.Objects; + +import org.opensearch.action.index.IndexRequest; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Class for holding conversational metadata + */ +@AllArgsConstructor +public class ConversationMeta implements Writeable, ToXContentObject { + + @Getter + private String id; + @Getter + private Instant createdTime; + @Getter + private String name; + @Getter + private String user; + + /** + * Creates a conversationMeta object from a SearchHit object + * @param hit the search hit to transform into a conversationMeta object + * @return a new conversationMeta object representing the search hit + */ + public static ConversationMeta fromSearchHit(SearchHit hit) { + String id = hit.getId(); + return ConversationMeta.fromMap(id, hit.getSourceAsMap()); + } + + /** + * Creates a conversationMeta object from a Map of fields in the OS index + * @param id the conversation's id + * @param docFields the map of source fields + * @return a new conversationMeta object representing the map + */ + public static ConversationMeta fromMap(String id, Map docFields) { + Instant created = Instant.parse((String) docFields.get(ConversationalIndexConstants.META_CREATED_FIELD)); + String name = (String) docFields.get(ConversationalIndexConstants.META_NAME_FIELD); + String user = (String) docFields.get(ConversationalIndexConstants.USER_FIELD); + return new ConversationMeta(id, created, name, user); + } + + /** + * Creates a conversationMeta from a stream, given the stream was written to by + * conversationMeta.writeTo + * @param in stream to read from + * @return new conversationMeta object + * @throws IOException if you're reading from a stream without a conversationMeta in it + */ + public static ConversationMeta fromStream(StreamInput in) throws IOException { + String id = in.readString(); + Instant created = in.readInstant(); + String name = in.readString(); + String user = in.readOptionalString(); + return new ConversationMeta(id, created, name, user); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeInstant(createdTime); + out.writeString(name); + out.writeOptionalString(user); + } + + + /** + * Convert this conversationMeta object into an IndexRequest so it can be indexed + * @param index the index to send this conversation to. Should usually be .conversational-meta + * @return the IndexRequest for the client to send + */ + public IndexRequest toIndexRequest(String index) { + IndexRequest request = new IndexRequest(index); + return request.id(this.id).source( + ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime, + ConversationalIndexConstants.META_NAME_FIELD, this.name + ); + } + + @Override + public String toString() { + return "{id=" + id + + ", name=" + name + + ", created=" + createdTime.toString() + + ", user=" + user + + "}"; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException { + builder.startObject(); + builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.id); + builder.field(ConversationalIndexConstants.META_CREATED_FIELD, this.createdTime); + builder.field(ConversationalIndexConstants.META_NAME_FIELD, this.name); + if(this.user != null) { + builder.field(ConversationalIndexConstants.USER_FIELD, this.user); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + if(!(other instanceof ConversationMeta)) { + return false; + } + ConversationMeta otherConversation = (ConversationMeta) other; + return Objects.equals(this.id, otherConversation.id) && + Objects.equals(this.user, otherConversation.user) && + Objects.equals(this.createdTime, otherConversation.createdTime) && + Objects.equals(this.name, otherConversation.name); + } + +} diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java new file mode 100644 index 0000000000..c8e652265b --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/conversation/ConversationalIndexConstants.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.conversation; + +import org.opensearch.common.settings.Setting; + +/** + * Class containing a bunch of constant defining how the conversational indices are formatted + */ +public class ConversationalIndexConstants { + /** Version of the meta index schema */ + public final static Integer META_INDEX_SCHEMA_VERSION = 1; + /** Name of the conversational metadata index */ + public final static String META_INDEX_NAME = ".plugins-ml-conversation-meta"; + /** Name of the metadata field for initial timestamp */ + public final static String META_CREATED_FIELD = "create_time"; + /** Name of the metadata field for name of the conversation */ + public final static String META_NAME_FIELD = "name"; + /** Name of the owning user field in all indices */ + public final static String USER_FIELD = "user"; + /** Mappings for the conversational metadata index */ + public final static String META_MAPPING = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + META_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + META_NAME_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + META_CREATED_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + USER_FIELD + + "\": {\"type\": \"keyword\"}\n" + + " }\n" + + "}"; + + /** Version of the interactions index schema */ + public final static Integer INTERACTIONS_INDEX_SCHEMA_VERSION = 1; + /** Name of the conversational interactions index */ + public final static String INTERACTIONS_INDEX_NAME = ".plugins-ml-conversation-interactions"; + /** Name of the interaction field for the conversation Id */ + public final static String INTERACTIONS_CONVERSATION_ID_FIELD = "conversation_id"; + /** Name of the interaction field for the human input */ + public final static String INTERACTIONS_INPUT_FIELD = "input"; + /** Name of the interaction field for the prompt template */ + public final static String INTERACTIONS_PROMPT_TEMPLATE_FIELD = "prompt_template"; + /** Name of the interaction field for the AI response */ + public final static String INTERACTIONS_RESPONSE_FIELD = "response"; + /** Name of the interaction field for the response's origin */ + public final static String INTERACTIONS_ORIGIN_FIELD = "origin"; + /** Name of the interaction field for additional metadata */ + public final static String INTERACTIONS_ADDITIONAL_INFO_FIELD = "additional_info"; + /** Name of the interaction field for the timestamp */ + public final static String INTERACTIONS_CREATE_TIME_FIELD = "create_time"; + /** Mappings for the interactions index */ + public final static String INTERACTIONS_MAPPINGS = "{\n" + + " \"_meta\": {\n" + + " \"schema_version\": " + INTERACTIONS_INDEX_SCHEMA_VERSION + "\n" + + " },\n" + + " \"properties\": {\n" + + " \"" + + INTERACTIONS_CONVERSATION_ID_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_CREATE_TIME_FIELD + + "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"},\n" + + " \"" + + INTERACTIONS_INPUT_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_PROMPT_TEMPLATE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_RESPONSE_FIELD + + "\": {\"type\": \"text\"},\n" + + " \"" + + INTERACTIONS_ORIGIN_FIELD + + "\": {\"type\": \"keyword\"},\n" + + " \"" + + INTERACTIONS_ADDITIONAL_INFO_FIELD + + "\": {\"type\": \"text\"}\n" + + " }\n" + + "}"; + + /** Feature Flag setting for conversational memory */ + public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting + .boolSetting("plugins.ml_commons.memory_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); +} \ No newline at end of file diff --git a/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java new file mode 100644 index 0000000000..9b6ec636bd --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/conversation/Interaction.java @@ -0,0 +1,165 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.common.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.Map; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +/** + * Class for dealing with Interactions + */ +@Builder +@AllArgsConstructor +public class Interaction implements Writeable, ToXContentObject { + + @Getter + private String id; + @Getter + private Instant createTime; + @Getter + private String conversationId; + @Getter + private String input; + @Getter + private String promptTemplate; + @Getter + private String response; + @Getter + private String origin; + @Getter + private String additionalInfo; + + /** + * Creates an Interaction object from a map of fields in the OS index + * @param id the Interaction id + * @param fields the field mapping from the OS document + * @return a new Interaction object representing the OS document + */ + public static Interaction fromMap(String id, Map fields) { + Instant createTime = Instant.parse((String) fields.get(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD)); + String conversationId = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD); + String input = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD); + String promptTemplate = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD); + String response = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD); + String origin = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD); + String additionalInfo = (String) fields.get(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD); + return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); + } + + /** + * Creates an Interaction object from a search hit + * @param hit the search hit from the interactions index + * @return a new Interaction object representing the search hit + */ + public static Interaction fromSearchHit(SearchHit hit) { + String id = hit.getId(); + return fromMap(id, hit.getSourceAsMap()); + } + + /** + * Creates a new Interaction object from a stream + * @param in stream to read from; assumes Interactions.writeTo was called on it + * @return a new Interaction + * @throws IOException if can't read or the stream isn't pointing to an intraction or something + */ + public static Interaction fromStream(StreamInput in) throws IOException { + String id = in.readString(); + Instant createTime = in.readInstant(); + String conversationId = in.readString(); + String input = in.readString(); + String promptTemplate = in.readString(); + String response = in.readString(); + String origin = in.readString(); + String additionalInfo = in.readOptionalString(); + return new Interaction(id, createTime, conversationId, input, promptTemplate, response, origin, additionalInfo); + } + + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + out.writeInstant(createTime); + out.writeString(conversationId); + out.writeString(input); + out.writeString(promptTemplate); + out.writeString(response); + out.writeString(origin); + out.writeOptionalString(additionalInfo); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException { + builder.startObject(); + builder.field(ActionConstants.CONVERSATION_ID_FIELD, conversationId); + builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, id); + builder.field(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, createTime); + builder.field(ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, input); + builder.field(ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, promptTemplate); + builder.field(ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, response); + builder.field(ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, origin); + if(additionalInfo != null) { + builder.field(ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, additionalInfo); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object other) { + return ( + other instanceof Interaction && + ((Interaction) other).id.equals(this.id) && + ((Interaction) other).conversationId.equals(this.conversationId) && + ((Interaction) other).createTime.equals(this.createTime) && + ((Interaction) other).input.equals(this.input) && + ((Interaction) other).promptTemplate.equals(this.promptTemplate) && + ((Interaction) other).response.equals(this.response) && + ((Interaction) other).origin.equals(this.origin) && + ( (((Interaction) other).additionalInfo == null && this.additionalInfo == null) || + ((Interaction) other).additionalInfo.equals(this.additionalInfo)) + ); + } + + @Override + public String toString() { + return "Interaction{" + + "id=" + id + + ",cid=" + conversationId + + ",create_time=" + createTime + + ",origin=" + origin + + ",input=" + input + + ",promt_template=" + promptTemplate + + ",response=" + response + + ",additional_info=" + additionalInfo + + "}"; + } + + +} \ No newline at end of file diff --git a/memory/build.gradle b/memory/build.gradle new file mode 100644 index 0000000000..1c67e02c7f --- /dev/null +++ b/memory/build.gradle @@ -0,0 +1,88 @@ +/* + * Copyright Aryn, Inc 2023 + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +plugins { + id 'java' + id 'jacoco' + id "io.freefair.lombok" + id 'com.diffplug.spotless' version '6.18.0' +} + +dependencies { + implementation project(":opensearch-ml-common") + implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1' + implementation "org.opensearch:common-utils:${common_utils_version}" + implementation group: 'com.google.guava', name: 'guava', version: '32.0.1-jre' + testImplementation (group: 'junit', name: 'junit', version: '4.13.2') { + exclude module : 'hamcrest' + exclude module : 'hamcrest-core' + } + testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.4.0' + testImplementation "org.opensearch.test:framework:${opensearch_version}" + testImplementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" + testImplementation group: 'com.google.code.gson', name: 'gson', version: '2.10.1' +} + +test { + include "**/*Tests.class" + jvmArgs '-Dtests.security.manager=false' +} + +jacocoTestReport { + reports { + html.required = true + html.outputLocation = layout.buildDirectory.dir('jacocoHtml') + } + + dependsOn test +} + +List jacocoExclusions = [] + +jacocoTestCoverageVerification { + violationRules { + rule { + element = 'CLASS' + excludes = jacocoExclusions + limit { + counter = 'BRANCH' + minimum = 0.7 //TODO: change this value to 0.7 + } + } + rule { + element = 'CLASS' + excludes = jacocoExclusions + limit { + counter = 'LINE' + value = 'COVEREDRATIO' + minimum = 0.8 //TODO: change this value to 0.8 + } + } + } + dependsOn jacocoTestReport +} + +spotless { + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + + eclipse().configFile rootProject.file('.eclipseformat.xml') + } +} \ No newline at end of file diff --git a/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java new file mode 100644 index 0000000000..18d23eff0d --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/ConversationalMemoryHandler.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory; + +import java.util.List; + +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.conversation.Interaction.InteractionBuilder; + +/** + * Interface for handling all Conversational Memory operations + */ +public interface ConversationalMemoryHandler { + + /** + * Create a new conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(ActionListener listener); + + /** + * Create a new conversation + * @return ActionFuture for the conversationId of the new conversation + */ + public ActionFuture createConversation(); + + /** + * Create a new conversation + * @param name the name of the new conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(String name, ActionListener listener); + + /** + * Create a new conversation + * @param name the name of the new conversation + * @return ActionFuture for the conversationId of the new conversation + */ + public ActionFuture createConversation(String name); + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param listener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo, + ActionListener listener + ); + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used in this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @return ActionFuture for the interactionId of the new interaction + */ + public ActionFuture createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo + ); + + /** + * Adds an interaction to the index, updating the associated Conversational Metadata + * @param builder Interaction builder that creates the Interaction to be added. id should be null + * @param listener gets the interactionId of the newly created interaction + */ + public void createInteraction(InteractionBuilder builder, ActionListener listener); + + /** + * Adds an interaction to the index, updating the associated Conversational Metadata + * @param builder Interaction builder that creates the Interaction to be added. id should be null + * @return ActionFuture for the interactionId of the newly created interaction + */ + public ActionFuture createInteraction(InteractionBuilder builder); + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param conversationId the conversation whose interactions to get + * @param from where to start listing from + * @param maxResults how many interactions to get + * @param listener gets the list of interactions in this conversation, sorted by recency + */ + public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener); + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param conversationId the conversation whose interactions to get + * @param from where to start listing from + * @param maxResults how many interactions to get + * @return ActionFuture the list of interactions in this conversation, sorted by recency + */ + public ActionFuture> getInteractions(String conversationId, int from, int maxResults); + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param from where to start listing from + * @param maxResults how many conversations to list + * @param listener gets the list of all conversations, sorted by recency + */ + public void getConversations(int from, int maxResults, ActionListener> listener); + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param from where to start listing from + * @param maxResults how many conversations to list + * @return ActionFuture for the list of all conversations, sorted by recency + */ + public ActionFuture> getConversations(int from, int maxResults); + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param maxResults how many conversations to get + * @param listener receives the list of conversations, sorted by recency + */ + public void getConversations(int maxResults, ActionListener> listener); + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param maxResults how many conversations to list + * @return ActionFuture for the list of all conversations, sorted by recency + */ + public ActionFuture> getConversations(int maxResults); + + /** + * Delete a conversation and all of its interactions + * @param conversationId the id of the conversation to delete + * @param listener receives whether the conversationMeta object and all of its interactions were deleted. i.e. false => the ConvoMeta or a subset of its Interactions were not deleted + */ + public void deleteConversation(String conversationId, ActionListener listener); + + /** + * Delete a conversation and all of its interactions + * @param conversationId the id of the conversation to delete + * @return ActionFuture for whether the conversationMeta object and all of its interactions were deleted. i.e. false => the ConvoMeta or a subset of its Interactions were not deleted + */ + public ActionFuture deleteConversation(String conversationId); + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationAction.java new file mode 100644 index 0000000000..842f3c9f2b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for creating a new conversation in the index + */ +public class CreateConversationAction extends ActionType { + /** Instance of this */ + public static final CreateConversationAction INSTANCE = new CreateConversationAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/create"; + + private CreateConversationAction() { + super(NAME, CreateConversationResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java new file mode 100644 index 0000000000..e0a03f13eb --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.Getter; + +/** + * Action Request for creating a conversation + */ +public class CreateConversationRequest extends ActionRequest { + @Getter + private String name = null; + + /** + * Constructor + * @param in input stream to read from + * @throws IOException if something breaks + */ + public CreateConversationRequest(StreamInput in) throws IOException { + super(in); + this.name = in.readOptionalString(); + } + + /** + * Constructor + * @param name name of the conversation + */ + public CreateConversationRequest(String name) { + super(); + this.name = name; + } + + /** + * Constructor + * name will be null + */ + public CreateConversationRequest() {} + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalString(name); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + + return exception; + } + + /** + * Creates a CreateConversationRequest from a RestRequest + * @param restRequest a RestRequest for a CreateConversation + * @return a new CreateConversationRequest + * @throws IOException if something breaks + */ + public static CreateConversationRequest fromRestRequest(RestRequest restRequest) throws IOException { + if (!restRequest.hasContent()) { + return new CreateConversationRequest(); + } + Map body = restRequest.contentParser().mapStrings(); + if (body.containsKey(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)) { + return new CreateConversationRequest(body.get(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD)); + } else { + return new CreateConversationRequest(); + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java new file mode 100644 index 0000000000..79f6fb6bf0 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponse.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; + +import lombok.AllArgsConstructor; + +/** + * Action Response for CreateConversation + */ +@AllArgsConstructor +public class CreateConversationResponse extends ActionResponse implements ToXContentObject { + + String conversationId; + + /** + * Constructor + * @param in input stream to create this from + * @throws IOException if something breaks + */ + public CreateConversationResponse(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.conversationId); + } + + /** + * @return the unique id of the newly created conversation + */ + public String getId() { + return conversationId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.field(ActionConstants.CONVERSATION_ID_FIELD, this.conversationId); + builder.endObject(); + return builder; + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java new file mode 100644 index 0000000000..f6856b7c66 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportAction.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * The CreateConversationAction that actually does all of the work + */ +@Log4j2 +public class CreateConversationTransportAction extends HandledTransportAction { + + private ConversationalMemoryHandler cmHandler; + private Client client; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public CreateConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(CreateConversationAction.NAME, transportService, actionFilters, CreateConversationRequest::new); + this.cmHandler = cmHandler; + this.client = client; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + protected void doExecute(Task task, CreateConversationRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + String name = request.getName(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener.wrap(r -> { internalListener.onResponse(new CreateConversationResponse(r)); }, e -> { + log.error("Failed to create new conversation with name " + request.getName(), e); + internalListener.onFailure(e); + }); + + if (name == null) { + cmHandler.createConversation(al); + } else { + cmHandler.createConversation(name, al); + } + } catch (Exception e) { + log.error("Failed to create new conversation with name " + request.getName(), e); + actionListener.onFailure(e); + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionAction.java new file mode 100644 index 0000000000..0a8b06d4c2 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for adding and interaction to a conversation + */ +public class CreateInteractionAction extends ActionType { + /** Instance of this */ + public static CreateInteractionAction INSTANCE = new CreateInteractionAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/create"; + + private CreateInteractionAction() { + super(NAME, CreateInteractionResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java new file mode 100644 index 0000000000..52344b3792 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Request for create interaction + */ +@AllArgsConstructor +public class CreateInteractionRequest extends ActionRequest { + @Getter + private String conversationId; + @Getter + private String input; + @Getter + private String promptTemplate; + @Getter + private String response; + @Getter + private String origin; + @Getter + private String additionalInfo; + + /** + * Constructor + * @param in stream to read this request from + * @throws IOException if something breaks or there's no p.i.request in the stream + */ + public CreateInteractionRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.input = in.readString(); + this.promptTemplate = in.readString(); + this.response = in.readString(); + this.origin = in.readOptionalString(); + this.additionalInfo = in.readOptionalString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(conversationId); + out.writeString(input); + out.writeString(promptTemplate); + out.writeString(response); + out.writeOptionalString(origin); + out.writeOptionalString(additionalInfo); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.conversationId == null) { + exception = addValidationError("Interaction MUST belong to a conversation ID", exception); + } + return exception; + } + + /** + * Create a PutInteractionRequest from a RestRequest + * @param request a RestRequest for a put interaction op + * @return new PutInteractionRequest object + * @throws IOException if something goes wrong reading from request + */ + public static CreateInteractionRequest fromRestRequest(RestRequest request) throws IOException { + Map body = request.contentParser().mapStrings(); + String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); + String inp = body.get(ActionConstants.INPUT_FIELD); + String prmpt = body.get(ActionConstants.PROMPT_TEMPLATE_FIELD); + String rsp = body.get(ActionConstants.AI_RESPONSE_FIELD); + String ogn = body.get(ActionConstants.RESPONSE_ORIGIN_FIELD); + String addinf = body.get(ActionConstants.ADDITIONAL_INFO_FIELD); + return new CreateInteractionRequest(cid, inp, prmpt, rsp, ogn, addinf); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponse.java new file mode 100644 index 0000000000..ca55ae9c90 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponse.java @@ -0,0 +1,67 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; + +import lombok.AllArgsConstructor; + +/** + * Action Response for create interaction + */ +@AllArgsConstructor +public class CreateInteractionResponse extends ActionResponse implements ToXContentObject { + private String interactionId; + + /** + * Constructor + * @param in input stream to create this from + * @throws IOException if something breaks + */ + public CreateInteractionResponse(StreamInput in) throws IOException { + super(in); + this.interactionId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(this.interactionId); + } + + /** + * @return the id of the newly created interaction + */ + public String getId() { + return this.interactionId; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException { + builder.startObject(); + builder.field(ActionConstants.RESPONSE_INTERACTION_ID_FIELD, this.interactionId); + builder.endObject(); + return builder; + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java new file mode 100644 index 0000000000..2273cc32e8 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportAction.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * The put interaction action that does the work (of calling cmHandler) + */ +@Log4j2 +public class CreateInteractionTransportAction extends HandledTransportAction { + + private ConversationalMemoryHandler cmHandler; + private Client client; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for doing intra-cluster communication + * @param actionFilters for filtering actions + * @param cmHandler handler for conversational memory + * @param client client for general opensearch ops + * @param clusterService for some cluster ops + */ + @Inject + public CreateInteractionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(CreateInteractionAction.NAME, transportService, actionFilters, CreateInteractionRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + protected void doExecute(Task task, CreateInteractionRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + String cid = request.getConversationId(); + String inp = request.getInput(); + String rsp = request.getResponse(); + String ogn = request.getOrigin(); + String prompt = request.getPromptTemplate(); + String additionalInfo = request.getAdditionalInfo(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener al = ActionListener + .wrap(iid -> { internalListener.onResponse(new CreateInteractionResponse(iid)); }, e -> { + internalListener.onFailure(e); + }); + cmHandler.createInteraction(cid, inp, prompt, rsp, ogn, additionalInfo, al); + } catch (Exception e) { + log.error("Failed to create interaction for conversation " + cid, e); + actionListener.onFailure(e); + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationAction.java new file mode 100644 index 0000000000..d51011107b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for deleting a conversation from conversational memory + */ +public class DeleteConversationAction extends ActionType { + /** Instance of this */ + public static final DeleteConversationAction INSTANCE = new DeleteConversationAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/delete"; + + private DeleteConversationAction() { + super(NAME, DeleteConversationResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationRequest.java new file mode 100644 index 0000000000..d63ecc996f --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.AllArgsConstructor; + +/** + * Action Request for Delete Conversation + */ +@AllArgsConstructor +public class DeleteConversationRequest extends ActionRequest { + private String conversationId; + + /** + * Constructor + * @param in input stream, assumes one of these requests was written to it + * @throws IOException if something breaks + */ + public DeleteConversationRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(conversationId); + } + + /** + * Get the conversation id of the conversation to be deleted + * @return the id of the conversation to be deleted + */ + public String getId() { + return conversationId; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (conversationId == null) { + exception = addValidationError("conversation id must not be null", exception); + } + return exception; + } + + /** + * Create a new DeleteConversationRequest from a RestRequest + * @param request RestRequest representing a DeleteConversationRequest + * @return a new DeleteConversationRequest + * @throws IOException if something breaks + */ + public static DeleteConversationRequest fromRestRequest(RestRequest request) throws IOException { + String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); + return new DeleteConversationRequest(cid); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationResponse.java new file mode 100644 index 0000000000..f08b5dc05c --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationResponse.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; + +import lombok.AllArgsConstructor; + +/** + * Action Response for Delete Conversation Action + */ +@AllArgsConstructor +public class DeleteConversationResponse extends ActionResponse implements ToXContentObject { + private boolean success; + + /** + * Constructor + * @param in stream input. Assumes there was one of these written to the stream + * @throws IOException if something breaks + */ + public DeleteConversationResponse(StreamInput in) throws IOException { + super(in); + success = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeBoolean(success); + } + + /** + * Gets whether this delete action succeeded + * @return whether this deletion was successful + */ + public boolean wasSuccessful() { + return success; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContentObject.Params params) throws IOException { + builder.startObject(); + builder.field(ActionConstants.SUCCESS_FIELD, success); + builder.endObject(); + return builder; + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java new file mode 100644 index 0000000000..3b9a23a49b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportAction.java @@ -0,0 +1,97 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * Transport action that does all the work for DeleteConversation + */ +@Log4j2 +public class DeleteConversationTransportAction extends HandledTransportAction { + + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public DeleteConversationTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(DeleteConversationAction.NAME, transportService, actionFilters, DeleteConversationRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, DeleteConversationRequest request, ActionListener listener) { + if (!featureIsEnabled) { + listener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + String conversationId = request.getId(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> context.restore()); + ActionListener al = ActionListener.wrap(success -> { + DeleteConversationResponse response = new DeleteConversationResponse(success); + internalListener.onResponse(response); + }, e -> { internalListener.onFailure(e); }); + cmHandler.deleteConversation(conversationId, al); + } catch (Exception e) { + log.error("Failed to delete conversation " + conversationId, e); + listener.onFailure(e); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsAction.java new file mode 100644 index 0000000000..48e5c6db6c --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsAction.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action for listing all conversations in the index + */ +public class GetConversationsAction extends ActionType { + /** Instance of this */ + public static final GetConversationsAction INSTANCE = new GetConversationsAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/conversation/list"; + + private GetConversationsAction() { + super(NAME, GetConversationsResponse::new); + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsRequest.java new file mode 100644 index 0000000000..3adb068a7a --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsRequest.java @@ -0,0 +1,119 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.Getter; + +/** + * ActionRequest for list conversations action + */ +public class GetConversationsRequest extends ActionRequest { + @Getter + private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; + @Getter + private int from = 0; + + /** + * Constructor; returns from position 0 + * @param maxResults number of results to return + */ + public GetConversationsRequest(int maxResults) { + super(); + this.maxResults = maxResults; + } + + /** + * Constructor + * @param maxResults number of results to return + * @param from where to start from + */ + public GetConversationsRequest(int maxResults, int from) { + super(); + this.maxResults = maxResults; + this.from = from; + } + + /** + * Constructor; defaults to 10 results returned from position 0 + */ + public GetConversationsRequest() { + super(); + } + + /** + * Constructor + * @param in Input stream to read from. assumes there was a writeTo + * @throws IOException if I can't read + */ + public GetConversationsRequest(StreamInput in) throws IOException { + super(in); + this.maxResults = in.readInt(); + this.from = in.readInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeInt(maxResults); + out.writeInt(from); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (this.maxResults <= 0) { + exception = addValidationError("Can't list 0 or negative conversations", exception); + } + return exception; + } + + /** + * Creates a ListConversationsRequest from a RestRequest + * @param request a RestRequest for a ListConversations + * @return a new ListConversationsRequest + * @throws IOException if something breaks + */ + public static GetConversationsRequest fromRestRequest(RestRequest request) throws IOException { + if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { + int maxResults = request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD) + ? Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) + : ActionConstants.DEFAULT_MAX_RESULTS; + + int nextToken = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); + + return new GetConversationsRequest(maxResults, nextToken); + } else { + int maxResults = request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD) + ? Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) + : ActionConstants.DEFAULT_MAX_RESULTS; + + return new GetConversationsRequest(maxResults); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponse.java new file mode 100644 index 0000000000..b70ca62323 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponse.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.ConversationMeta; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Response for CreateConversation + */ +@AllArgsConstructor +public class GetConversationsResponse extends ActionResponse implements ToXContentObject { + @Getter + private List conversations; + @Getter + private int nextToken; + private boolean hasMoreTokens; + + /** + * Constructor + * @param in input stream to create this from + * @throws IOException if something breaks + */ + public GetConversationsResponse(StreamInput in) throws IOException { + super(in); + this.conversations = in.readList(ConversationMeta::fromStream); + this.nextToken = in.readInt(); + this.hasMoreTokens = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(conversations); + out.writeInt(nextToken); + out.writeBoolean(hasMoreTokens); + } + + /** + * are there more pages of results in this search + * @return whether there are more pages of results in this search + */ + public boolean hasMorePages() { + return hasMoreTokens; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.startArray(ActionConstants.RESPONSE_CONVERSATION_LIST_FIELD); + for (ConversationMeta conversation : conversations) { + conversation.toXContent(builder, params); + } + builder.endArray(); + if (hasMoreTokens) { + builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken); + } + builder.endObject(); + return builder; + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java new file mode 100644 index 0000000000..f515f0f50b --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportAction.java @@ -0,0 +1,104 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.util.List; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * ListConversationsAction that does the work of asking for them from the ConversationalMemoryHandler + */ +@Log4j2 +public class GetConversationsTransportAction extends HandledTransportAction { + + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetConversationsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetConversationsAction.NAME, transportService, actionFilters, GetConversationsRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetConversationsRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + int maxResults = request.getMaxResults(); + int from = request.getFrom(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener> al = ActionListener.wrap(conversations -> { + internalListener + .onResponse(new GetConversationsResponse(conversations, from + maxResults, conversations.size() == maxResults)); + }, e -> { + log.error("Failed to get conversations", e); + internalListener.onFailure(e); + }); + cmHandler.getConversations(from, maxResults, al); + } catch (Exception e) { + log.error("Failed to get conversations", e); + actionListener.onFailure(e); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java new file mode 100644 index 0000000000..024abe17ff --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsAction.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.action.ActionType; + +/** + * Action to return the interactions associated with a conversation + */ +public class GetInteractionsAction extends ActionType { + /** Instance of this */ + public static final GetInteractionsAction INSTANCE = new GetInteractionsAction(); + /** Name of this action */ + public static final String NAME = "cluster:admin/opensearch/ml/memory/interaction/get"; + + private GetInteractionsAction() { + super(NAME, GetInteractionsResponse::new); + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java new file mode 100644 index 0000000000..4554300f1c --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.opensearch.action.ValidateActions.addValidationError; + +import java.io.IOException; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; + +import lombok.Getter; + +/** + * ActionRequest for get interactions + */ +public class GetInteractionsRequest extends ActionRequest { + @Getter + private int maxResults = ActionConstants.DEFAULT_MAX_RESULTS; + @Getter + private int from = 0; + @Getter + private String conversationId; + + /** + * Constructor + * @param conversationId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + */ + public GetInteractionsRequest(String conversationId, int maxResults) { + this.conversationId = conversationId; + this.maxResults = maxResults; + } + + /** + * Constructor + * @param conversationId UID of the conversation to get interactions from + * @param maxResults number of interactions to retrieve + * @param from position of first interaction to retrieve + */ + public GetInteractionsRequest(String conversationId, int maxResults, int from) { + this.conversationId = conversationId; + this.maxResults = maxResults; + this.from = from; + } + + /** + * Constructor + * @param conversationId the UID of the conversation to get interactions from + */ + public GetInteractionsRequest(String conversationId) { + this.conversationId = conversationId; + } + + /** + * Constructor + * @param in streaminput to read this from. assumes there was a GetInteractionsRequest.writeTo + * @throws IOException if there wasn't a GIR in the stream + */ + public GetInteractionsRequest(StreamInput in) throws IOException { + super(in); + this.conversationId = in.readString(); + this.maxResults = in.readInt(); + this.from = in.readInt(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(conversationId); + out.writeInt(maxResults); + out.writeInt(from); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException exception = null; + if (conversationId == null) { + exception = addValidationError("Interactions must be retrieved from a conversation", exception); + } + if (maxResults <= 0) { + exception = addValidationError("The number of interactions to retrieve must be positive", exception); + } + if (from < 0) { + exception = addValidationError("The starting position must be nonnegative", exception); + } + return exception; + } + + /** + * Makes a GetInteractionsRequest out of a RestRequest + * @param request Rest Request representing a get interactions request + * @return a new GetInteractionsRequest + * @throws IOException if something goes wrong + */ + public static GetInteractionsRequest fromRestRequest(RestRequest request) throws IOException { + String cid = request.param(ActionConstants.CONVERSATION_ID_FIELD); + if (request.hasParam(ActionConstants.NEXT_TOKEN_FIELD)) { + int from = Integer.parseInt(request.param(ActionConstants.NEXT_TOKEN_FIELD)); + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetInteractionsRequest(cid, maxResults, from); + } else { + return new GetInteractionsRequest(cid, ActionConstants.DEFAULT_MAX_RESULTS, from); + } + } else { + if (request.hasParam(ActionConstants.REQUEST_MAX_RESULTS_FIELD)) { + int maxResults = Integer.parseInt(request.param(ActionConstants.REQUEST_MAX_RESULTS_FIELD)); + return new GetInteractionsRequest(cid, maxResults); + } else { + return new GetInteractionsRequest(cid); + } + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponse.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponse.java new file mode 100644 index 0000000000..86b02cbac9 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponse.java @@ -0,0 +1,87 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.Interaction; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +/** + * Action Response for get interactions + */ +@AllArgsConstructor +public class GetInteractionsResponse extends ActionResponse implements ToXContentObject { + @Getter + private List interactions; + @Getter + private int nextToken; + private boolean hasMoreTokens; + + /** + * Constructor + * @param in stream input; assumes GetInteractionsResponse.writeTo was called + * @throws IOException if theres not a G.I.R. in the stream + */ + public GetInteractionsResponse(StreamInput in) throws IOException { + super(in); + interactions = in.readList(Interaction::fromStream); + nextToken = in.readInt(); + hasMoreTokens = in.readBoolean(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeList(interactions); + out.writeInt(nextToken); + out.writeBoolean(hasMoreTokens); + } + + /** + * Are there more pages in this search results + * @return whether there are more pages in this search + */ + public boolean hasMorePages() { + return hasMoreTokens; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.startArray(ActionConstants.RESPONSE_INTERACTION_LIST_FIELD); + for (Interaction inter : interactions) { + inter.toXContent(builder, params); + } + builder.endArray(); + if (hasMoreTokens) { + builder.field(ActionConstants.NEXT_TOKEN_FIELD, nextToken); + } + builder.endObject(); + return builder; + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java new file mode 100644 index 0000000000..ee857a4174 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportAction.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.util.List; + +import org.opensearch.OpenSearchException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.tasks.Task; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +/** + * Get Interactions action that does the work of calling stuff + */ +@Log4j2 +public class GetInteractionsTransportAction extends HandledTransportAction { + + private Client client; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean featureIsEnabled; + + /** + * Constructor + * @param transportService for inter-node communications + * @param actionFilters for filtering actions + * @param cmHandler Handler for conversational memory operations + * @param client OS Client for dealing with OS + * @param clusterService for some cluster ops + */ + @Inject + public GetInteractionsTransportAction( + TransportService transportService, + ActionFilters actionFilters, + OpenSearchConversationalMemoryHandler cmHandler, + Client client, + ClusterService clusterService + ) { + super(GetInteractionsAction.NAME, transportService, actionFilters, GetInteractionsRequest::new); + this.client = client; + this.cmHandler = cmHandler; + this.featureIsEnabled = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED, it -> featureIsEnabled = it); + } + + @Override + public void doExecute(Task task, GetInteractionsRequest request, ActionListener actionListener) { + if (!featureIsEnabled) { + actionListener + .onFailure( + new OpenSearchException( + "The experimental Conversation Memory feature is not enabled. To enable, please update the setting " + + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + ) + ); + return; + } + int maxResults = request.getMaxResults(); + int from = request.getFrom(); + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(actionListener, () -> context.restore()); + ActionListener> al = ActionListener.wrap(interactions -> { + internalListener + .onResponse(new GetInteractionsResponse(interactions, from + maxResults, interactions.size() == maxResults)); + }, e -> { internalListener.onFailure(e); }); + cmHandler.getInteractions(request.getConversationId(), from, maxResults, al); + } catch (Exception e) { + log.error("Failed to get interactions for conversation " + request.getConversationId(), e); + actionListener.onFailure(e); + } + + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java new file mode 100644 index 0000000000..d7a4169fe7 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java @@ -0,0 +1,304 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import java.io.IOException; +import java.time.Instant; +import java.util.LinkedList; +import java.util.List; + +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.OpenSearchWrapperException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.DocWriteResponse.Result; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.search.SearchHit; +import org.opensearch.search.sort.SortOrder; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Class for handling the conversational metadata index + */ +@Log4j2 +@AllArgsConstructor +public class ConversationMetaIndex { + + private Client client; + private ClusterService clusterService; + private static final String indexName = ConversationalIndexConstants.META_INDEX_NAME; + + /** + * Creates the conversational meta index if it doesn't already exist + * @param listener listener to wait for this to finish + */ + public void initConversationMetaIndexIfAbsent(ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + log.debug("No conversational meta index found. Adding it"); + CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(createIndexResponse -> { + if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) { + log.info("created index [" + indexName + "]"); + internalListener.onResponse(true); + } else { + internalListener.onResponse(false); + } + }, e -> { + if (e instanceof ResourceAlreadyExistsException + || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { + internalListener.onResponse(true); + } else { + log.error("failed to create index [" + indexName + "]", e); + internalListener.onFailure(e); + } + }); + client.admin().indices().create(request, al); + } catch (Exception e) { + if (e instanceof ResourceAlreadyExistsException + || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { + listener.onResponse(true); + } else { + log.error("failed to create index [" + indexName + "]", e); + listener.onFailure(e); + } + } + } else { + listener.onResponse(true); + } + } + + /** + * Adds a new conversation with the specified name to the index + * @param name user-specified name of the conversation to be added + * @param listener listener to wait for this to finish + */ + public void createConversation(String name, ActionListener listener) { + initConversationMetaIndexIfAbsent(ActionListener.wrap(indexExists -> { + if (indexExists) { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + IndexRequest request = Requests + .indexRequest(indexName) + .source( + ConversationalIndexConstants.META_CREATED_FIELD, + Instant.now(), + ConversationalIndexConstants.META_NAME_FIELD, + name, + ConversationalIndexConstants.USER_FIELD, + userstr == null ? null : User.parse(userstr).getName() + ); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(resp -> { + if (resp.status() == RestStatus.CREATED) { + internalListener.onResponse(resp.getId()); + } else { + internalListener.onFailure(new IOException("failed to create conversation")); + } + }, e -> { + log.error("Failed to create conversation", e); + internalListener.onFailure(e); + }); + client.index(request, al); + } catch (Exception e) { + log.error("Failed to create conversation", e); + listener.onFailure(e); + } + } else { + listener.onFailure(new IOException("Failed to add conversation due to missing index")); + } + }, e -> { listener.onFailure(e); })); + } + + /** + * Adds a new conversation named "" + * @param listener listener to wait for this to finish + */ + public void createConversation(ActionListener listener) { + createConversation("", listener); + } + + /** + * list size conversations in the index + * @param from where to start listing from + * @param maxResults how many conversations to list + * @param listener gets the list of conversation metadata objects in the index + */ + public void getConversations(int from, int maxResults, ActionListener> listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener.onResponse(List.of()); + } + SearchRequest request = Requests.searchRequest(indexName); + String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + QueryBuilder queryBuilder; + if (userstr == null) + queryBuilder = new MatchAllQueryBuilder(); + else + queryBuilder = new TermQueryBuilder(ConversationalIndexConstants.USER_FIELD, User.parse(userstr).getName()); + request.source().query(queryBuilder); + request.source().from(from).size(maxResults); + request.source().sort(ConversationalIndexConstants.META_CREATED_FIELD, SortOrder.DESC); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(searchResponse -> { + List result = new LinkedList(); + for (SearchHit hit : searchResponse.getHits()) { + result.add(ConversationMeta.fromSearchHit(hit)); + } + internalListener.onResponse(result); + }, e -> { + log.error("Failed to retrieve conversations", e); + internalListener.onFailure(e); + }); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(refreshResponse -> { client.search(request, al); }, e -> { + log.error("Failed to retrieve conversations during refresh", e); + internalListener.onFailure(e); + })); + } catch (Exception e) { + log.error("Failed to retrieve conversations", e); + listener.onFailure(e); + } + } + + /** + * list size conversations in the index + * @param maxResults how many conversations to list + * @param listener gets the list of conversation metadata objects in the index + */ + public void getConversations(int maxResults, ActionListener> listener) { + getConversations(0, maxResults, listener); + } + + /** + * Deletes a conversation from the conversation metadata index + * @param conversationId id of the conversation to delete + * @param listener gets whether the deletion was successful + */ + public void deleteConversation(String conversationId, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener.onResponse(true); + } + DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + // When we get the delete response, do this: + ActionListener al = ActionListener.wrap(deleteResponse -> { + if (deleteResponse.getResult() == Result.DELETED) { + internalListener.onResponse(true); + } else if (deleteResponse.status() == RestStatus.NOT_FOUND) { + internalListener.onResponse(true); + } else { + internalListener.onResponse(false); + } + }, e -> { + log.error("Failure deleting conversation " + conversationId, e); + internalListener.onFailure(e); + }); + this.checkAccess(conversationId, ActionListener.wrap(access -> { + if (access) { + client.delete(delRequest, al); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { internalListener.onFailure(e); })); + } catch (Exception e) { + log.error("Failed deleting conversation with id=" + conversationId, e); + listener.onFailure(e); + } + } + + /** + * Checks whether the current requesting user has permission to see this conversation + * @param conversationId the conversation to check + * @param listener receives whether access should be granted + */ + public void checkAccess(String conversationId, ActionListener listener) { + // If the index doesn't exist, you have permission. Just won't get you anywhere + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener.onResponse(true); + return; + } + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + log.info("USERSTR: " + userstr); + // If security is off - User doesn't exist - you have permission + if (userstr == null || User.parse(userstr) == null) { + internalListener.onResponse(true); + return; + } + GetRequest getRequest = Requests.getRequest(indexName).id(conversationId); + ActionListener al = ActionListener.wrap(getResponse -> { + // If the conversation doesn't exist, fail + if (!(getResponse.isExists() && getResponse.getId().equals(conversationId))) { + throw new ResourceNotFoundException("Conversation [" + conversationId + "] not found"); + } + ConversationMeta conversation = ConversationMeta.fromMap(conversationId, getResponse.getSourceAsMap()); + String user = User.parse(userstr).getName(); + // If you're not the owner of this conversation, you do not have permission + if (!user.equals(conversation.getUser())) { + internalListener.onResponse(false); + return; + } + internalListener.onResponse(true); + }, e -> { internalListener.onFailure(e); }); + client.get(getRequest, al); + } catch (Exception e) { + listener.onFailure(e); + } + } +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java new file mode 100644 index 0000000000..a6714c63c3 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java @@ -0,0 +1,348 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import java.io.IOException; +import java.time.Instant; +import java.util.LinkedList; +import java.util.List; + +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.OpenSearchWrapperException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.delete.DeleteRequest; +import org.opensearch.action.index.IndexRequest; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.search.SearchHit; +import org.opensearch.search.sort.SortOrder; + +import com.google.common.annotations.VisibleForTesting; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Class for handling the interactions index + */ +@Log4j2 +@AllArgsConstructor +public class InteractionsIndex { + + private Client client; + private ClusterService clusterService; + private ConversationMetaIndex conversationMetaIndex; + private final String indexName = ConversationalIndexConstants.INTERACTIONS_INDEX_NAME; + // How big the steps should be when gathering *ALL* interactions in a conversation + private final int resultsAtATime = 300; + + /** + * 'PUT's the index in opensearch if it's not there already + * @param listener gets whether the index needed to be initialized. Throws error if it fails to init + */ + public void initInteractionsIndexIfAbsent(ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + log.debug("No interactions index found. Adding it"); + CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(r -> { + if (r.equals(new CreateIndexResponse(true, true, indexName))) { + log.info("created index [" + indexName + "]"); + internalListener.onResponse(true); + } else { + internalListener.onResponse(false); + } + }, e -> { + if (e instanceof ResourceAlreadyExistsException + || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { + internalListener.onResponse(true); + } else { + log.error("Failed to create index [" + indexName + "]", e); + internalListener.onFailure(e); + } + }); + client.admin().indices().create(request, al); + } catch (Exception e) { + if (e instanceof ResourceAlreadyExistsException + || (e instanceof OpenSearchWrapperException && e.getCause() instanceof ResourceAlreadyExistsException)) { + listener.onResponse(true); + } else { + log.error("Failed to create index [" + indexName + "]", e); + listener.onFailure(e); + } + } + } else { + listener.onResponse(true); + } + } + + /** + * Add an interaction to this index. Return the ID of the newly created interaction + * @param conversationId The id of the conversation this interaction belongs to + * @param input the user (human) input into this interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the GenAI response for this interaction + * @param origin the origin of the response for this interaction + * @param additionalInfo additional information used for constructing the LLM prompt + * @param timestamp when this interaction happened + * @param listener gets the id of the newly created interaction record + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo, + Instant timestamp, + ActionListener listener + ) { + initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> { + if (indexExists) { + this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> { + if (access) { + IndexRequest request = Requests + .indexRequest(indexName) + .source( + ConversationalIndexConstants.INTERACTIONS_ORIGIN_FIELD, + origin, + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, + conversationId, + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + input, + ConversationalIndexConstants.INTERACTIONS_PROMPT_TEMPLATE_FIELD, + promptTemplate, + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + response, + ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD, + additionalInfo, + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + timestamp + ); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(resp -> { + if (resp.status() == RestStatus.CREATED) { + internalListener.onResponse(resp.getId()); + } else { + internalListener.onFailure(new IOException("Failed to create interaction")); + } + }, e -> { internalListener.onFailure(e); }); + client.index(request, al); + } catch (Exception e) { + listener.onFailure(e); + } + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null + ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); })); + } else { + listener.onFailure(new IOException("no index to add conversation to")); + } + }, e -> { listener.onFailure(e); })); + } + + /** + * Add an interaction to this index, timestamped now. Return the id of the newly created interaction + * @param conversationId The id of the converation this interaction belongs to + * @param input the user (human) input into this interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the GenAI response for this interaction + * @param origin the name of the GenAI agent this interaction belongs to + * @param additionalInfo additional information used to construct the LLM prompt + * @param listener gets the id of the newly created interaction record + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo, + ActionListener listener + ) { + createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, Instant.now(), listener); + } + + /** + * Gets a list of interactions belonging to a conversation + * @param conversationId the conversation to read from + * @param from where to start in the reading + * @param maxResults how many interactions to return + * @param listener gets the list, sorted by recency, of interactions + */ + public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener.onResponse(List.of()); + return; + } + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + innerGetInteractions(conversationId, from, maxResults, listener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + } + + @VisibleForTesting + void innerGetInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { + SearchRequest request = Requests.searchRequest(indexName); + TermQueryBuilder builder = new TermQueryBuilder(ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, conversationId); + request.source().query(builder); + request.source().from(from).size(maxResults); + request.source().sort(ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, SortOrder.DESC); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener al = ActionListener.wrap(response -> { + List result = new LinkedList(); + for (SearchHit hit : response.getHits()) { + result.add(Interaction.fromSearchHit(hit)); + } + internalListener.onResponse(result); + }, e -> { internalListener.onFailure(e); }); + client + .admin() + .indices() + .refresh(Requests.refreshRequest(indexName), ActionListener.wrap(r -> { client.search(request, al); }, e -> { + internalListener.onFailure(e); + })); + } catch (Exception e) { + listener.onFailure(e); + } + } + + /** + * Gets all of the interactions in a conversation, regardless of conversation size + * @param conversationId conversation to get all interactions of + * @param maxResults how many interactions to get per search query + * @param listener receives the list of all interactions in the conversation + */ + @VisibleForTesting + void getAllInteractions(String conversationId, int maxResults, ActionListener> listener) { + ActionListener> al = nextGetListener(conversationId, 0, maxResults, listener, new LinkedList<>()); + innerGetInteractions(conversationId, 0, maxResults, al); + } + + /** + * Recursively builds the list of interactions for getAllInteractions by returning an + * ActionListener for handling the next search query + * @param conversationId conversation to get interactions from + * @param from where to start in this step + * @param maxResults how many to get in this step + * @param mainListener listener for the final result + * @param result partially built list of interactions + * @return an ActionListener to handle the next search query + */ + @VisibleForTesting + ActionListener> nextGetListener( + String conversationId, + int from, + int maxResults, + ActionListener> mainListener, + List result + ) { + if (maxResults < 1) { + mainListener.onFailure(new IllegalArgumentException("maxResults must be positive")); + return null; + } + return ActionListener.wrap(interactions -> { + result.addAll(interactions); + if (interactions.size() < maxResults) { + mainListener.onResponse(result); + } else { + ActionListener> al = nextGetListener(conversationId, from + maxResults, maxResults, mainListener, result); + innerGetInteractions(conversationId, from + maxResults, maxResults, al); + } + }, e -> { mainListener.onFailure(e); }); + } + + /** + * Deletes all interactions associated with a conversationId + * Note this uses a bulk delete request (and tries to delete an entire conversation) so it may be heavyweight + * @param conversationId the id of the conversation to delete from + * @param listener gets whether the deletion was successful + */ + public void deleteConversation(String conversationId, ActionListener listener) { + if (!clusterService.state().metadata().hasIndex(indexName)) { + listener.onResponse(true); + return; + } + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) { + ActionListener internalListener = ActionListener.runBefore(listener, () -> threadContext.restore()); + ActionListener> searchListener = ActionListener.wrap(interactions -> { + BulkRequest request = Requests.bulkRequest(); + for (Interaction interaction : interactions) { + DeleteRequest delRequest = Requests.deleteRequest(indexName).id(interaction.getId()); + request.add(delRequest); + } + client + .bulk(request, ActionListener.wrap(bulkResponse -> { internalListener.onResponse(!bulkResponse.hasFailures()); }, e -> { + internalListener.onFailure(e); + })); + }, e -> { internalListener.onFailure(e); }); + ActionListener accessListener = ActionListener.wrap(access -> { + if (access) { + getAllInteractions(conversationId, resultsAtATime, searchListener); + } else { + String userstr = client + .threadPool() + .getThreadContext() + .getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); + String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName(); + throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId); + } + }, e -> { listener.onFailure(e); }); + conversationMetaIndex.checkAccess(conversationId, accessListener); + } catch (Exception e) { + log.error("Failure while deleting interactions associated with conversation id=" + conversationId, e); + listener.onFailure(e); + } + } + +} diff --git a/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java new file mode 100644 index 0000000000..d2c70ff6e7 --- /dev/null +++ b/memory/src/main/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandler.java @@ -0,0 +1,280 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import java.time.Instant; +import java.util.List; + +import org.opensearch.action.StepListener; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.conversation.Interaction.InteractionBuilder; +import org.opensearch.ml.memory.ConversationalMemoryHandler; + +import com.google.common.annotations.VisibleForTesting; + +/** + * Class for handling all Conversational Memory operactions + */ +public class OpenSearchConversationalMemoryHandler implements ConversationalMemoryHandler { + + private ConversationMetaIndex conversationMetaIndex; + private InteractionsIndex interactionsIndex; + + /** + * Constructor + * @param client opensearch client to use for talking to OS + * @param clusterService ClusterService object for managing OS + */ + public OpenSearchConversationalMemoryHandler(Client client, ClusterService clusterService) { + this.conversationMetaIndex = new ConversationMetaIndex(client, clusterService); + this.interactionsIndex = new InteractionsIndex(client, clusterService, this.conversationMetaIndex); + } + + @VisibleForTesting + OpenSearchConversationalMemoryHandler(ConversationMetaIndex conversationMetaIndex, InteractionsIndex interactionsIndex) { + this.conversationMetaIndex = conversationMetaIndex; + this.interactionsIndex = interactionsIndex; + } + + /** + * Create a new conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(ActionListener listener) { + conversationMetaIndex.createConversation(listener); + } + + /** + * Create a new conversation + * @return ActionFuture for the conversationId of the new conversation + */ + public ActionFuture createConversation() { + PlainActionFuture fut = PlainActionFuture.newFuture(); + createConversation(fut); + return fut; + } + + /** + * Create a new conversation + * @param name the name of the new conversation + * @param listener listener to wait for this op to finish, gets unique id of new conversation + */ + public void createConversation(String name, ActionListener listener) { + conversationMetaIndex.createConversation(name, listener); + } + + /** + * Create a new conversation + * @param name the name of the new conversation + * @return ActionFuture for the conversationId of the new conversation + */ + public ActionFuture createConversation(String name) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + createConversation(name, fut); + return fut; + } + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used for this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @param listener gets the ID of the new interaction + */ + public void createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo, + ActionListener listener + ) { + Instant time = Instant.now(); + interactionsIndex.createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, time, listener); + } + + /** + * Adds an interaction to the conversation indicated, updating the conversational metadata + * @param conversationId the conversation to add the interaction to + * @param input the human input for the interaction + * @param promptTemplate the prompt template used in this interaction + * @param response the Gen AI response for this interaction + * @param origin the name of the GenAI agent in this interaction + * @param additionalInfo additional information used in constructing the LLM prompt + * @return ActionFuture for the interactionId of the new interaction + */ + public ActionFuture createInteraction( + String conversationId, + String input, + String promptTemplate, + String response, + String origin, + String additionalInfo + ) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + createInteraction(conversationId, input, promptTemplate, response, origin, additionalInfo, fut); + return fut; + } + + /** + * Adds an interaction to the index, updating the associated Conversational Metadata + * @param builder Interaction builder that creates the Interaction to be added. id should be null + * @param listener gets the interactionId of the newly created interaction + */ + public void createInteraction(InteractionBuilder builder, ActionListener listener) { + builder.createTime(Instant.now()); + Interaction interaction = builder.build(); + interactionsIndex + .createInteraction( + interaction.getConversationId(), + interaction.getInput(), + interaction.getPromptTemplate(), + interaction.getResponse(), + interaction.getOrigin(), + interaction.getAdditionalInfo(), + interaction.getCreateTime(), + listener + ); + } + + /** + * Adds an interaction to the index, updating the associated Conversational Metadata + * @param builder Interaction builder that creates the Interaction to be added. id should be null + * @return ActionFuture for the interactionId of the newly created interaction + */ + public ActionFuture createInteraction(InteractionBuilder builder) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + createInteraction(builder, fut); + return fut; + } + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param conversationId the conversation whose interactions to get + * @param from where to start listing from + * @param maxResults how many interactions to get + * @param listener gets the list of interactions in this conversation, sorted by recency + */ + public void getInteractions(String conversationId, int from, int maxResults, ActionListener> listener) { + interactionsIndex.getInteractions(conversationId, from, maxResults, listener); + } + + /** + * Get the interactions associate with this conversation, sorted by recency + * @param conversationId the conversation whose interactions to get + * @param from where to start listing from + * @param maxResults how many interactions to get + * @return ActionFuture the list of interactions in this conversation, sorted by recency + */ + public ActionFuture> getInteractions(String conversationId, int from, int maxResults) { + PlainActionFuture> fut = PlainActionFuture.newFuture(); + getInteractions(conversationId, from, maxResults, fut); + return fut; + } + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param from where to start listing from + * @param maxResults how many conversations to list + * @param listener gets the list of all conversations, sorted by recency + */ + public void getConversations(int from, int maxResults, ActionListener> listener) { + conversationMetaIndex.getConversations(from, maxResults, listener); + } + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param from where to start listing from + * @param maxResults how many conversations to list + * @return ActionFuture for the list of all conversations, sorted by recency + */ + public ActionFuture> getConversations(int from, int maxResults) { + PlainActionFuture> fut = PlainActionFuture.newFuture(); + getConversations(from, maxResults, fut); + return fut; + } + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param maxResults how many conversations to get + * @param listener receives the list of conversations, sorted by recency + */ + public void getConversations(int maxResults, ActionListener> listener) { + conversationMetaIndex.getConversations(maxResults, listener); + } + + /** + * Get all conversations (not the interactions in them, just the headers) + * @param maxResults how many conversations to list + * @return ActionFuture for the list of all conversations, sorted by recency + */ + public ActionFuture> getConversations(int maxResults) { + PlainActionFuture> fut = PlainActionFuture.newFuture(); + getConversations(maxResults, fut); + return fut; + } + + /** + * Delete a conversation and all of its interactions + * @param conversationId the id of the conversation to delete + * @param listener receives whether the conversationMeta object and all of its interactions were deleted. i.e. false => there's something still in an index somewhere + */ + public void deleteConversation(String conversationId, ActionListener listener) { + StepListener accessListener = new StepListener<>(); + conversationMetaIndex.checkAccess(conversationId, accessListener); + + accessListener.whenComplete(access -> { + if (access) { + StepListener metaDeleteListener = new StepListener<>(); + StepListener interactionsListener = new StepListener<>(); + + conversationMetaIndex.deleteConversation(conversationId, metaDeleteListener); + interactionsIndex.deleteConversation(conversationId, interactionsListener); + + metaDeleteListener.whenComplete(metaResult -> { + interactionsListener + .whenComplete(interactionResult -> { listener.onResponse(metaResult && interactionResult); }, listener::onFailure); + }, listener::onFailure); + } else { + listener.onResponse(false); + } + }, listener::onFailure); + } + + /** + * Delete a conversation and all of its interactions + * @param conversationId the id of the conversation to delete + * @return ActionFuture for whether the conversationMeta object and all of its interactions were deleted. i.e. false => there's something still in an index somewhere + */ + public ActionFuture deleteConversation(String conversationId) { + PlainActionFuture fut = PlainActionFuture.newFuture(); + deleteConversation(conversationId, fut); + return fut; + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java new file mode 100644 index 0000000000..7c842bdcac --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/ConversationalMemoryHandlerITTests.java @@ -0,0 +1,479 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory; + +import java.util.List; +import java.util.Stack; +import java.util.concurrent.CountDownLatch; +import java.util.function.Consumer; + +import org.junit.Before; +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.StepListener; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.CheckedConsumer; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class ConversationalMemoryHandlerITTests extends OpenSearchIntegTestCase { + + private Client client; + private ClusterService clusterService; + private ConversationalMemoryHandler cmHandler; + + @Before + private void setup() { + log.warn("started a test"); + client = client(); + clusterService = clusterService(); + cmHandler = new OpenSearchConversationalMemoryHandler(client, clusterService); + } + + private StoredContext setUser(String username) { + StoredContext stored = client + .threadPool() + .getThreadContext() + .newStoredContext(true, List.of(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT)); + ThreadContext context = client.threadPool().getThreadContext(); + context.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, username + "||"); + return stored; + } + + public void testCanStartConversations() { + log.warn("test can start conversations"); + CountDownLatch cdl = new CountDownLatch(3); + cmHandler.createConversation("test-1", new LatchedActionListener(ActionListener.wrap(cid0 -> { + cmHandler.createConversation("test-2", new LatchedActionListener(ActionListener.wrap(cid1 -> { + cmHandler.createConversation(new LatchedActionListener(ActionListener.wrap(cid2 -> { + assert (!cid0.equals(cid1) && !cid0.equals(cid2) && !cid1.equals(cid2)); + }, e -> { + cdl.countDown(); + cdl.countDown(); + log.error(e); + assert (false); + }), cdl)); + }, e -> { + cdl.countDown(); + cdl.countDown(); + log.error(e); + assert (false); + }), cdl)); + }, e -> { + cdl.countDown(); + cdl.countDown(); + log.error(e); + assert (false); + }), cdl)); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + log.warn("test can start conversations finished"); + } + + public void testCanAddNewInteractionsToConversation() { + log.warn("test can add new interactions"); + CountDownLatch cdl = new CountDownLatch(1); + StepListener cidListener = new StepListener<>(); + cmHandler.createConversation("test", cidListener); + + StepListener iid1Listener = new StepListener<>(); + cidListener.whenComplete(cid -> { + cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener iid2Listener = new StepListener<>(); + iid1Listener.whenComplete(iid -> { + cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); + }, e -> { + cdl.countDown(); + assert (false); + }); + + LatchedActionListener finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(iid2 -> { + assert (!iid2.equals(iid1Listener.result())); + }, e -> { assert (false); }), cdl); + iid2Listener.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + log.warn("test can add new interactions finished"); + } + + public void testCanGetInteractionsBackOut() { + log.warn("test can get interactions"); + CountDownLatch cdl = new CountDownLatch(1); + StepListener cidListener = new StepListener<>(); + cmHandler.createConversation("test", cidListener); + + StepListener iid1Listener = new StepListener<>(); + cidListener.whenComplete(cid -> { + cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1Listener); + }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener iid2Listener = new StepListener<>(); + iid1Listener.whenComplete(iid -> { + cmHandler.createInteraction(cidListener.result(), "test input1", "pt", "test response", "test origin", "meta", iid2Listener); + }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener> interactionsListener = new StepListener<>(); + iid2Listener.whenComplete(iid2 -> { cmHandler.getInteractions(cidListener.result(), 0, 2, interactionsListener); }, e -> { + cdl.countDown(); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(conversations -> { + List interactions = interactionsListener.result(); + String id1 = iid1Listener.result(); + String id2 = iid2Listener.result(); + String cid = cidListener.result(); + assert (interactions.size() == 2); + assert (interactions.get(0).getId().equals(id2)); + assert (interactions.get(1).getId().equals(id1)); + assert (conversations.size() == 1); + assert (conversations.get(0).getId().equals(cid)); + }, e -> { assert (false); }), cdl); + interactionsListener.whenComplete(r -> { cmHandler.getConversations(10, finishAndAssert); }, e -> { + cdl.countDown(); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + log.warn("test can get interactions finished"); + } + + public void testCanDeleteConversations() { + log.warn("test can delete conversations"); + CountDownLatch cdl = new CountDownLatch(1); + StepListener cid1 = new StepListener<>(); + cmHandler.createConversation("test", cid1); + + StepListener iid1 = new StepListener<>(); + cid1 + .whenComplete( + cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid1); }, + e -> { + cdl.countDown(); + assert (false); + } + ); + + StepListener iid2 = new StepListener<>(); + iid1 + .whenComplete( + iid -> { cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid2); }, + e -> { + cdl.countDown(); + assert (false); + } + ); + + StepListener cid2 = new StepListener<>(); + iid2.whenComplete(iid -> { cmHandler.createConversation(cid2); }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener iid3 = new StepListener<>(); + cid2 + .whenComplete( + cid -> { cmHandler.createInteraction(cid, "test input1", "pt", "test response", "test origin", "meta", iid3); }, + e -> { + cdl.countDown(); + assert (false); + } + ); + + StepListener del = new StepListener<>(); + iid3.whenComplete(iid -> { cmHandler.deleteConversation(cid1.result(), del); }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener> conversations = new StepListener<>(); + del.whenComplete(success -> { cmHandler.getConversations(10, conversations); }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener> inters1 = new StepListener<>(); + conversations.whenComplete(cons -> { cmHandler.getInteractions(cid1.result(), 0, 10, inters1); }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener> inters2 = new StepListener<>(); + inters1.whenComplete(ints -> { cmHandler.getInteractions(cid2.result(), 0, 10, inters2); }, e -> { + cdl.countDown(); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(r -> { + assert (del.result()); + assert (conversations.result().size() == 1); + assert (conversations.result().get(0).getId().equals(cid2.result())); + assert (inters1.result().size() == 0); + assert (inters2.result().size() == 1); + assert (inters2.result().get(0).getId().equals(iid3.result())); + }, e -> { assert (false); }), cdl); + inters2.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + log.warn("test can delete conversations finished"); + } + + public void testDifferentUsers_DifferentConversations() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + CheckedConsumer shouldHaveFailed = r -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + OpenSearchSecurityException e = new OpenSearchSecurityException("User was given inappropriate access controls"); + log.error(e); + throw e; + }; + CheckedConsumer shouldHaveFailedAsString = r -> { shouldHaveFailed.accept(r); }; + CheckedConsumer, Exception> shouldHaveFailedAsInterList = r -> { shouldHaveFailed.accept(r); }; + + final String user1 = "test-user1"; + final String user2 = "test-user2"; + + contextStack.push(setUser(user1)); + + StepListener cid1 = new StepListener<>(); + StepListener cid2 = new StepListener<>(); + StepListener cid3 = new StepListener<>(); + StepListener iid1 = new StepListener<>(); + StepListener iid2 = new StepListener<>(); + StepListener iid3 = new StepListener<>(); + StepListener iid4 = new StepListener<>(); + StepListener iid5 = new StepListener<>(); + StepListener> conversations1 = new StepListener<>(); + StepListener> conversations2 = new StepListener<>(); + StepListener> inter1 = new StepListener<>(); + StepListener> inter2 = new StepListener<>(); + StepListener> inter3 = new StepListener<>(); + StepListener> failInter1 = new StepListener<>(); + StepListener> failInter2 = new StepListener<>(); + StepListener> failInter3 = new StepListener<>(); + StepListener failiid1 = new StepListener<>(); + StepListener failiid2 = new StepListener<>(); + StepListener failiid3 = new StepListener<>(); + + cmHandler.createConversation("conversation1", cid1); + + cid1.whenComplete(cid -> { cmHandler.createConversation("conversation2", cid2); }, onFail); + + cid2 + .whenComplete( + cid -> { + cmHandler.createInteraction(cid1.result(), "test input1", "pt", "test response", "test origin", "meta", iid1); + }, + onFail + ); + + iid1 + .whenComplete( + iid -> { + cmHandler.createInteraction(cid1.result(), "test input2", "pt", "test response", "test origin", "meta", iid2); + }, + onFail + ); + + iid2 + .whenComplete( + iid -> { + cmHandler.createInteraction(cid2.result(), "test input3", "pt", "test response", "test origin", "meta", iid3); + }, + onFail + ); + + iid3.whenComplete(iid -> { + contextStack.push(setUser(user2)); + cmHandler.createConversation("conversation3", cid3); + }, onFail); + + cid3 + .whenComplete( + cid -> { + cmHandler.createInteraction(cid3.result(), "test input4", "pt", "test response", "test origin", "meta", iid4); + }, + onFail + ); + + iid4 + .whenComplete( + iid -> { + cmHandler.createInteraction(cid3.result(), "test input5", "pt", "test response", "test origin", "meta", iid5); + }, + onFail + ); + + iid5.whenComplete(iid -> { + cmHandler.createInteraction(cid1.result(), "test inputf1", "pt", "test response", "test origin", "meta", failiid1); + }, onFail); + + failiid1.whenComplete(shouldHaveFailedAsString, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { + cmHandler.createInteraction(cid1.result(), "test inputf2", "pt", "test response", "test origin", "meta", failiid2); + } else { + onFail.accept(e); + } + }); + + failiid2.whenComplete(shouldHaveFailedAsString, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { + cmHandler.getConversations(10, conversations2); + } else { + onFail.accept(e); + } + }); + + conversations2.whenComplete(conversations -> { + assert (conversations.size() == 1); + assert (conversations.get(0).getId().equals(cid3.result())); + cmHandler.getInteractions(cid3.result(), 0, 10, inter3); + }, onFail); + + inter3.whenComplete(inters -> { + assert (inters.size() == 2); + assert (inters.get(0).getId().equals(iid5.result())); + assert (inters.get(1).getId().equals(iid4.result())); + cmHandler.getInteractions(cid2.result(), 0, 10, failInter2); + }, onFail); + + failInter2.whenComplete(shouldHaveFailedAsInterList, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { + cmHandler.getInteractions(cid1.result(), 0, 10, failInter1); + } else { + onFail.accept(e); + } + }); + + failInter1.whenComplete(shouldHaveFailedAsInterList, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { + contextStack.pop().restore(); + cmHandler.getConversations(0, 10, conversations1); + } else { + onFail.accept(e); + } + }); + + conversations1.whenComplete(conversations -> { + assert (conversations.size() == 2); + assert (conversations.get(0).getId().equals(cid2.result())); + assert (conversations.get(1).getId().equals(cid1.result())); + cmHandler.getInteractions(cid1.result(), 0, 10, inter1); + }, onFail); + + inter1.whenComplete(inters -> { + assert (inters.size() == 2); + assert (inters.get(0).getId().equals(iid2.result())); + assert (inters.get(1).getId().equals(iid1.result())); + cmHandler.getInteractions(cid2.result(), 0, 10, inter2); + }, onFail); + + inter2.whenComplete(inters -> { + assert (inters.size() == 1); + assert (inters.get(0).getId().equals(iid3.result())); + cmHandler.getInteractions(cid3.result(), 0, 10, failInter3); + }, onFail); + + failInter3.whenComplete(shouldHaveFailedAsInterList, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user1 + "] does not have access to conversation ")) { + cmHandler.createInteraction(cid3.result(), "test inputf3", "pt", "test response", "test origin", "meta", failiid3); + } else { + onFail.accept(e); + } + }); + + failiid3.whenComplete(shouldHaveFailedAsString, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user1 + "] does not have access to conversation ")) { + contextStack.pop().restore(); + cdl.countDown(); + } else { + onFail.accept(e); + } + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + threadContext.restore(); + } catch (Exception e) { + log.error(e); + throw e; + } + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/ConversationalTests.java b/memory/src/test/java/org/opensearch/ml/memory/ConversationalTests.java new file mode 100644 index 0000000000..3929aac595 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/ConversationalTests.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory; + +import org.opensearch.test.OpenSearchTestCase; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class ConversationalTests extends OpenSearchTestCase { + + // Add unit tests for your plugin + public void testNothing() { + log.info("testing nothing"); + assert (true); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java new file mode 100644 index 0000000000..2975cd4c1d --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/ConversationActionTests.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import org.opensearch.test.OpenSearchTestCase; + +public class ConversationActionTests extends OpenSearchTestCase { + public void testActions() { + assert (CreateConversationAction.INSTANCE instanceof CreateConversationAction); + assert (DeleteConversationAction.INSTANCE instanceof DeleteConversationAction); + assert (GetConversationsAction.INSTANCE instanceof GetConversationsAction); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java new file mode 100644 index 0000000000..22b55bb7c2 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationRequestTests.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class CreateConversationRequestTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testConstructorsAndStreaming_Named() throws IOException { + CreateConversationRequest request = new CreateConversationRequest("test-name"); + assert (request.validate() == null); + assert (request.getName().equals("test-name")); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + CreateConversationRequest newRequest = new CreateConversationRequest(in); + assert (newRequest.getName().equals(request.getName())); + } + + public void testConstructorsAndStreaming_Unnamed() throws IOException { + CreateConversationRequest request = new CreateConversationRequest(); + assert (request.validate() == null); + assert (request.getName() == null); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + CreateConversationRequest newRequest = new CreateConversationRequest(in); + assert (newRequest.getName() == null); + } + + public void testEmptyRestRequest() throws IOException { + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + assert (request.getName() == null); + } + + public void testNamedRestRequest() throws IOException { + String name = "test-name"; + RestRequest req = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent(new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, name))), MediaTypeRegistry.JSON) + .build(); + CreateConversationRequest request = CreateConversationRequest.fromRestRequest(req); + assert (request.getName().equals(name)); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java new file mode 100644 index 0000000000..75f256ceca --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationResponseTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class CreateConversationResponseTests extends OpenSearchTestCase { + + public void testCreateConversationResponseStreaming() throws IOException { + CreateConversationResponse response = new CreateConversationResponse("test-id"); + assert (response.getId().equals("test-id")); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + CreateConversationResponse newResp = new CreateConversationResponse(in); + assert (newResp.getId().equals("test-id")); + } + + public void testToXContent() throws IOException { + CreateConversationResponse response = new CreateConversationResponse("createme"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + String expected = "{\"conversation_id\":\"createme\"}"; + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + assert (result.equals(expected)); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java new file mode 100644 index 0000000000..313071dc45 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateConversationTransportActionTests.java @@ -0,0 +1,165 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateConversationTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + CreateConversationRequest request; + CreateConversationTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new CreateConversationRequest("test"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testCreateConversation() { + log.info("testing create conversation transport"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse("testID"); + return null; + }).when(cmHandler).createConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("testID")); + } + + public void testCreateConversationWithNullName() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse("testID-2"); + return null; + }).when(cmHandler).createConversation(any(ActionListener.class)); + String nullstr = null; + this.request = new CreateConversationRequest(nullstr); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateConversationResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("testID-2")); + } + + public void testCreateConversationFails_thenFail() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onFailure(new Exception("Testing Error")); + return null; + }).when(cmHandler).createConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Testing Error")); + } + + public void testDoExecuteFails_thenFail() { + doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).createConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test doExecute Error")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new CreateConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java new file mode 100644 index 0000000000..cf027aef79 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionRequestTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class CreateInteractionRequestTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testConstructorsAndStreaming() throws IOException { + CreateInteractionRequest request = new CreateInteractionRequest("cid", "input", "pt", "response", "origin", "metadata"); + assert (request.validate() == null); + assert (request.getConversationId().equals("cid")); + assert (request.getInput().equals("input")); + assert (request.getResponse().equals("response")); + assert (request.getOrigin().equals("origin")); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + CreateInteractionRequest newReq = new CreateInteractionRequest(in); + assert (newReq.validate() == null); + assert (newReq.getConversationId().equals("cid")); + assert (newReq.getInput().equals("input")); + assert (newReq.getResponse().equals("response")); + assert (newReq.getOrigin().equals("origin")); + } + + public void testNullCID_thenFail() { + CreateInteractionRequest request = new CreateInteractionRequest(null, "input", "pt", "response", "origin", "metadata"); + assert (request.validate() != null); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("Interaction MUST belong to a conversation ID")); + } + + public void testFromRestRequest() throws IOException { + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.ADDITIONAL_INFO_FIELD, + "metadata" + ); + RestRequest rrequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + CreateInteractionRequest request = CreateInteractionRequest.fromRestRequest(rrequest); + assert (request.validate() == null); + assert (request.getConversationId().equals("cid")); + assert (request.getInput().equals("input")); + assert (request.getPromptTemplate().equals("pt")); + assert (request.getResponse().equals("response")); + assert (request.getOrigin().equals("origin")); + assert (request.getAdditionalInfo().equals("metadata")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java new file mode 100644 index 0000000000..939acc0435 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionResponseTests.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class CreateInteractionResponseTests extends OpenSearchTestCase { + + public void testCreateInteractionResponseStreaming() throws IOException { + CreateInteractionResponse response = new CreateInteractionResponse("test-iid"); + assert (response.getId().equals("test-iid")); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + CreateInteractionResponse newResp = new CreateInteractionResponse(in); + assert (newResp.getId().equals("test-iid")); + } + + public void testToXContent() throws IOException { + CreateInteractionResponse response = new CreateInteractionResponse("createme"); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + String expected = "{\"interaction_id\":\"createme\"}"; + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + assert (result.equals(expected)); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java new file mode 100644 index 0000000000..8321a0b65e --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/CreateInteractionTransportActionTests.java @@ -0,0 +1,156 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class CreateInteractionTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + CreateInteractionRequest request; + CreateInteractionTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new CreateInteractionRequest("test-cid", "input", "pt", "response", "origin", "metadata"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + } + + public void testCreateInteraction() { + log.info("testing create interaction transport"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(6); + listener.onResponse("testID"); + return null; + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getId().equals("testID")); + } + + public void testCreateInteractionFails_thenFail() { + log.info("testing create interaction transport"); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(6); + listener.onFailure(new Exception("Testing Failure")); + return null; + }).when(cmHandler).createInteraction(any(), any(), any(), any(), any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Testing Failure")); + } + + public void testDoExecuteFails_thenFail() { + log.info("testing create interaction transport"); + doThrow(new RuntimeException("Failure in doExecute")) + .when(cmHandler) + .createInteraction(any(), any(), any(), any(), any(), any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in doExecute")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new CreateInteractionTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationRequestTests.java new file mode 100644 index 0000000000..a456293630 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationRequestTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class DeleteConversationRequestTests extends OpenSearchTestCase { + + public void testDeleteConversationRequestStreaming() throws IOException { + DeleteConversationRequest request = new DeleteConversationRequest("test-id"); + assert (request.validate() == null); + assert (request.getId().equals("test-id")); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + DeleteConversationRequest newReq = new DeleteConversationRequest(in); + assert (newReq.validate() == null); + assert (newReq.getId().equals("test-id")); + } + + public void testNullIdIsInvalid() { + String nullId = null; + DeleteConversationRequest request = new DeleteConversationRequest(nullId); + assert (request.validate() != null); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("conversation id must not be null")); + } + + public void testFromRestRequest() throws IOException { + RestRequest rreq = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "deleteme")) + .build(); + DeleteConversationRequest req = DeleteConversationRequest.fromRestRequest(rreq); + assert (req.validate() == null); + assert (req.getId().equals("deleteme")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationResponseTests.java new file mode 100644 index 0000000000..50b6cd4afa --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationResponseTests.java @@ -0,0 +1,55 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class DeleteConversationResponseTests extends OpenSearchTestCase { + + public void testDeleteConversationResponseStreaming() throws IOException { + DeleteConversationResponse response = new DeleteConversationResponse(true); + assert (response.wasSuccessful()); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + DeleteConversationResponse newResp = new DeleteConversationResponse(in); + assert (newResp.wasSuccessful()); + } + + public void testToXContent() throws IOException { + DeleteConversationResponse response = new DeleteConversationResponse(false); + assert (!response.wasSuccessful()); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + String expected = "{\"success\":false}"; + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + assert (result.equals(expected)); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java new file mode 100644 index 0000000000..984b9a2fbf --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/DeleteConversationTransportActionTests.java @@ -0,0 +1,151 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class DeleteConversationTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + DeleteConversationRequest request; + DeleteConversationTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new DeleteConversationRequest("test"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testDeleteConversation() { + log.info("testing delete conversation transport"); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(cmHandler).deleteConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(DeleteConversationResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().wasSuccessful()); + } + + public void testDeleteFails_thenFail() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Fail Case")); + return null; + }).when(cmHandler).deleteConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Fail Case")); + } + + public void testdoExecuteFails_thenFail() { + doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).deleteConversation(any(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test doExecute Error")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new DeleteConversationTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsRequestTests.java new file mode 100644 index 0000000000..c3e7da58e7 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsRequestTests.java @@ -0,0 +1,84 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetConversationsRequestTests extends OpenSearchTestCase { + + public void testGetConversationsRequestAndStreaming() throws IOException { + GetConversationsRequest request = new GetConversationsRequest(); + assert (request.validate() == null); + assert (request.getFrom() == 0 && request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationsRequest newRequest = new GetConversationsRequest(in); + assert (newRequest.validate() == null); + assert (newRequest.getFrom() == request.getFrom() && newRequest.getMaxResults() == request.getMaxResults()); + } + + public void testVariousConstructors() { + GetConversationsRequest req1 = new GetConversationsRequest(2); + assert (req1.validate() == null); + assert (req1.getFrom() == 0 && req1.getMaxResults() == 2); + GetConversationsRequest req2 = new GetConversationsRequest(5, 2); + assert (req2.validate() == null); + assert (req2.getFrom() == 2 && req2.getMaxResults() == 5); + } + + public void testNegativeOrZeroMaxResults_thenFail() { + GetConversationsRequest req = new GetConversationsRequest(-3); + assert (req.validate() != null); + assert (req.validate().validationErrors().size() == 1); + assert (req.validate().validationErrors().get(0).equals("Can't list 0 or negative conversations")); + } + + public void testFromRestRequest() throws IOException { + Map maxResOnly = Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); + Map nextTokOnly = Map.of(ActionConstants.NEXT_TOKEN_FIELD, "6"); + Map bothFields = Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "2", ActionConstants.NEXT_TOKEN_FIELD, "7"); + RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); + RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); + RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build(); + GetConversationsRequest gcr1 = GetConversationsRequest.fromRestRequest(req1); + GetConversationsRequest gcr2 = GetConversationsRequest.fromRestRequest(req2); + GetConversationsRequest gcr3 = GetConversationsRequest.fromRestRequest(req3); + GetConversationsRequest gcr4 = GetConversationsRequest.fromRestRequest(req4); + + assert (gcr1.validate() == null && gcr2.validate() == null && gcr3.validate() == null && gcr4.validate() == null); + assert (gcr1.getFrom() == 0 && gcr2.getFrom() == 0 && gcr3.getFrom() == 6 && gcr4.getFrom() == 7); + assert (gcr1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gcr2.getMaxResults() == 4); + assert (gcr3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gcr4.getMaxResults() == 2); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java new file mode 100644 index 0000000000..e6ed013b7a --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsResponseTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.test.OpenSearchTestCase; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetConversationsResponseTests extends OpenSearchTestCase { + + List conversations; + + @Before + public void setup() { + conversations = List + .of( + new ConversationMeta("0", Instant.now(), "name0", "user0"), + new ConversationMeta("1", Instant.now(), "name1", "user0"), + new ConversationMeta("2", Instant.now(), "name2", "user2") + ); + } + + public void testGetConversationsResponseStreaming() throws IOException { + GetConversationsResponse response = new GetConversationsResponse(conversations, 2, true); + assert (response.hasMorePages()); + assert (response.getConversations().equals(conversations)); + assert (response.getNextToken() == 2); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetConversationsResponse newResp = new GetConversationsResponse(in); + assert (newResp.hasMorePages()); + assert (newResp.getConversations().equals(conversations)); + assert (newResp.getNextToken() == 2); + } + + public void testToXContent_MoreTokens() throws IOException { + GetConversationsResponse response = new GetConversationsResponse(conversations.subList(0, 1), 2, true); + ConversationMeta conversation = response.getConversations().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversations\":[{\"conversation_id\":\"0\",\"create_time\":\"" + + conversation.getCreatedTime() + + "\",\"name\":\"name0\",\"user\":\"user0\"}],\"next_token\":2}"; + log.info("FINDME"); + log.info(result); + log.info(expected); + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + log.info(ld.getDistance(result, expected)); + assert (ld.getDistance(result, expected) > 0.95); + } + + public void testToXContent_NoMoreTokens() throws IOException { + GetConversationsResponse response = new GetConversationsResponse(conversations.subList(0, 1), 2, false); + ConversationMeta conversation = response.getConversations().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"conversations\":[{\"conversation_id\":\"0\",\"create_time\":\"" + + conversation.getCreatedTime() + + "\",\"name\":\"name0\",\"user\":\"user0\"}]}"; + log.info("FINDME"); + log.info(result); + log.info(expected); + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + log.info(ld.getDistance(result, expected)); + assert (ld.getDistance(result, expected) > 0.95); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java new file mode 100644 index 0000000000..41c99bdc74 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetConversationsTransportActionTests.java @@ -0,0 +1,205 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetConversationsTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetConversationsRequest request; + GetConversationsTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetConversationsRequest(); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetConversations() { + log.info("testing get conversations transport"); + List testResult = List + .of( + new ConversationMeta("testcid1", Instant.now(), "", null), + new ConversationMeta("testcid2", Instant.now(), "testname", null) + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(testResult); + return null; + }).when(cmHandler).getConversations(anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationsResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversations().equals(testResult)); + assert (!argCaptor.getValue().hasMorePages()); + } + + public void testPagination() { + List testResult = List + .of( + new ConversationMeta("testcid1", Instant.now(), "", null), + new ConversationMeta("testcid2", Instant.now(), "testname", null), + new ConversationMeta("testcid3", Instant.now(), "testname", null) + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + int maxResults = invocation.getArgument(1); + if (maxResults <= 3) { + listener.onResponse(testResult.subList(0, maxResults)); + } else { + listener.onResponse(testResult); + } + return null; + }).when(cmHandler).getConversations(anyInt(), anyInt(), any()); + GetConversationsRequest r0 = new GetConversationsRequest(2); + action.doExecute(null, r0, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationsResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversations().equals(testResult.subList(0, 2))); + assert (argCaptor.getValue().hasMorePages()); + assert (argCaptor.getValue().getNextToken() == 2); + + @SuppressWarnings("unchecked") + ActionListener al1 = (ActionListener) Mockito.mock(ActionListener.class); + GetConversationsRequest r1 = new GetConversationsRequest(2, 2); + action.doExecute(null, r1, al1); + argCaptor = ArgumentCaptor.forClass(GetConversationsResponse.class); + verify(al1).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversations().equals(testResult.subList(0, 2))); + assert (argCaptor.getValue().hasMorePages()); + assert (argCaptor.getValue().getNextToken() == 4); + + @SuppressWarnings("unchecked") + ActionListener al2 = (ActionListener) Mockito.mock(ActionListener.class); + GetConversationsRequest r2 = new GetConversationsRequest(20, 4); + action.doExecute(null, r2, al2); + argCaptor = ArgumentCaptor.forClass(GetConversationsResponse.class); + verify(al2).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().getConversations().equals(testResult)); + assert (!argCaptor.getValue().hasMorePages()); + } + + public void testGetFails_thenFail() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(2); + al.onFailure(new Exception("Test Fail Case")); + return null; + }).when(cmHandler).getConversations(anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Fail Case")); + } + + public void testdoExecuteFails_thenFail() { + doThrow(new RuntimeException("Test doExecute Error")).when(cmHandler).getConversations(anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test doExecute Error")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetConversationsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java new file mode 100644 index 0000000000..e1428d87c3 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsRequestTests.java @@ -0,0 +1,136 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class GetInteractionsRequestTests extends OpenSearchTestCase { + + public void testConstructorsAndStreaming() throws IOException { + GetInteractionsRequest request = new GetInteractionsRequest("test-cid"); + assert (request.validate() == null); + assert (request.getConversationId().equals("test-cid")); + assert (request.getFrom() == 0); + assert (request.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + + GetInteractionsRequest req2 = new GetInteractionsRequest("test-cid2", 3); + assert (req2.validate() == null); + assert (req2.getConversationId().equals("test-cid2")); + assert (req2.getFrom() == 0); + assert (req2.getMaxResults() == 3); + + GetInteractionsRequest req3 = new GetInteractionsRequest("test-cid3", 4, 5); + assert (req3.validate() == null); + assert (req3.getConversationId().equals("test-cid3")); + assert (req3.getFrom() == 5); + assert (req3.getMaxResults() == 4); + + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + request.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionsRequest req4 = new GetInteractionsRequest(in); + assert (req4.validate() == null); + assert (req4.getConversationId().equals("test-cid")); + assert (req4.getFrom() == 0); + assert (req4.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + } + + public void testBadValues_thenFail() { + String nullstr = null; + GetInteractionsRequest request = new GetInteractionsRequest(nullstr); + assert (request.validate().validationErrors().get(0).equals("Interactions must be retrieved from a conversation")); + assert (request.validate().validationErrors().size() == 1); + + request = new GetInteractionsRequest("cid", -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The number of interactions to retrieve must be positive")); + + request = new GetInteractionsRequest("cid", 2, -2); + assert (request.validate().validationErrors().size() == 1); + assert (request.validate().validationErrors().get(0).equals("The starting position must be nonnegative")); + } + + public void testMultipleBadValues_thenFailMultipleWays() { + String nullstr = null; + GetInteractionsRequest request = new GetInteractionsRequest(nullstr, -2); + assert (request.validate().validationErrors().size() == 2); + assert (request.validate().validationErrors().get(0).equals("Interactions must be retrieved from a conversation")); + assert (request.validate().validationErrors().get(1).equals("The number of interactions to retrieve must be positive")); + + request = new GetInteractionsRequest(nullstr, 3, -2); + assert (request.validate().validationErrors().size() == 2); + assert (request.validate().validationErrors().get(0).equals("Interactions must be retrieved from a conversation")); + assert (request.validate().validationErrors().get(1).equals("The starting position must be nonnegative")); + + request = new GetInteractionsRequest("cid", -2, -2); + assert (request.validate().validationErrors().size() == 2); + assert (request.validate().validationErrors().get(0).equals("The number of interactions to retrieve must be positive")); + assert (request.validate().validationErrors().get(1).equals("The starting position must be nonnegative")); + + request = new GetInteractionsRequest(nullstr, -3, -4); + assert (request.validate().validationErrors().size() == 3); + assert (request.validate().validationErrors().get(0).equals("Interactions must be retrieved from a conversation")); + assert (request.validate().validationErrors().get(1).equals("The number of interactions to retrieve must be positive")); + assert (request.validate().validationErrors().get(2).equals("The starting position must be nonnegative")); + } + + public void testFromRestRequest() throws IOException { + Map basic = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid1"); + Map maxResOnly = Map + .of(ActionConstants.CONVERSATION_ID_FIELD, "cid2", ActionConstants.REQUEST_MAX_RESULTS_FIELD, "4"); + Map nextTokOnly = Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid3", ActionConstants.NEXT_TOKEN_FIELD, "6"); + Map bothFields = Map + .of( + ActionConstants.CONVERSATION_ID_FIELD, + "cid4", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest req1 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(basic).build(); + RestRequest req2 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(maxResOnly).build(); + RestRequest req3 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(nextTokOnly).build(); + RestRequest req4 = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(bothFields).build(); + GetInteractionsRequest gir1 = GetInteractionsRequest.fromRestRequest(req1); + GetInteractionsRequest gir2 = GetInteractionsRequest.fromRestRequest(req2); + GetInteractionsRequest gir3 = GetInteractionsRequest.fromRestRequest(req3); + GetInteractionsRequest gir4 = GetInteractionsRequest.fromRestRequest(req4); + + assert (gir1.validate() == null && gir2.validate() == null && gir3.validate() == null && gir4.validate() == null); + assert (gir1.getConversationId().equals("cid1") && gir2.getConversationId().equals("cid2")); + assert (gir3.getConversationId().equals("cid3") && gir4.getConversationId().equals("cid4")); + assert (gir1.getFrom() == 0 && gir2.getFrom() == 0 && gir3.getFrom() == 6 && gir4.getFrom() == 7); + assert (gir1.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir2.getMaxResults() == 4); + assert (gir3.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS && gir4.getMaxResults() == 2); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java new file mode 100644 index 0000000000..bbd17b2603 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsResponseTests.java @@ -0,0 +1,103 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; + +import org.apache.lucene.search.spell.LevenshteinDistance; +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetInteractionsResponseTests extends OpenSearchTestCase { + List interactions; + + @Before + public void setup() { + interactions = List + .of( + new Interaction("id0", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata"), + new Interaction("id1", Instant.now(), "cid", "input", "pt", "response", "origin", "mteadata"), + new Interaction("id2", Instant.now(), "cid", "input", "pt", "response", "origin", "metadata") + ); + } + + public void testGetInteractionsResponseStreaming() throws IOException { + GetInteractionsResponse response = new GetInteractionsResponse(interactions, 4, true); + assert (response.getInteractions().equals(interactions)); + assert (response.getNextToken() == 4); + assert (response.hasMorePages()); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + response.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + GetInteractionsResponse newResp = new GetInteractionsResponse(in); + assert (newResp.getInteractions().equals(interactions)); + assert (newResp.getNextToken() == 4); + assert (newResp.hasMorePages()); + } + + public void testToXContent_MoreTokens() throws IOException { + GetInteractionsResponse response = new GetInteractionsResponse(interactions.subList(0, 1), 2, true); + Interaction interaction = response.getInteractions().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":\"" + + interaction.getCreateTime() + + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":\"metadata\"}],\"next_token\":2}"; + log.info(result); + log.info(expected); + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + log.info(ld.getDistance(result, expected)); + assert (ld.getDistance(result, expected) > 0.95); + } + + public void testToXContent_NoMoreTokens() throws IOException { + GetInteractionsResponse response = new GetInteractionsResponse(interactions.subList(0, 1), 2, false); + Interaction interaction = response.getInteractions().get(0); + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + String result = BytesReference.bytes(builder).utf8ToString(); + String expected = "{\"interactions\":[{\"conversation_id\":\"cid\",\"interaction_id\":\"id0\",\"create_time\":\"" + + interaction.getCreateTime() + + "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":\"metadata\"}]}"; + log.info(result); + log.info(expected); + // Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness + LevenshteinDistance ld = new LevenshteinDistance(); + log.info(ld.getDistance(result, expected)); + assert (ld.getDistance(result, expected) > 0.95); + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java new file mode 100644 index 0000000000..a7a245b680 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/GetInteractionsTransportActionTests.java @@ -0,0 +1,196 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Set; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.ClusterSettings; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class GetInteractionsTransportActionTests extends OpenSearchTestCase { + + @Mock + ThreadPool threadPool; + + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + NamedXContentRegistry xContentRegistry; + + @Mock + TransportService transportService; + + @Mock + ActionFilters actionFilters; + + @Mock + ActionListener actionListener; + + @Mock + OpenSearchConversationalMemoryHandler cmHandler; + + GetInteractionsRequest request; + GetInteractionsTransportAction action; + ThreadContext threadContext; + + @Before + public void setup() throws IOException { + this.threadPool = Mockito.mock(ThreadPool.class); + this.client = Mockito.mock(Client.class); + this.clusterService = Mockito.mock(ClusterService.class); + this.xContentRegistry = Mockito.mock(NamedXContentRegistry.class); + this.transportService = Mockito.mock(TransportService.class); + this.actionFilters = Mockito.mock(ActionFilters.class); + @SuppressWarnings("unchecked") + ActionListener al = (ActionListener) Mockito.mock(ActionListener.class); + this.actionListener = al; + this.cmHandler = Mockito.mock(OpenSearchConversationalMemoryHandler.class); + + this.request = new GetInteractionsRequest("test-cid"); + + Settings settings = Settings.builder().put(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey(), true).build(); + this.threadContext = new ThreadContext(settings); + when(this.client.threadPool()).thenReturn(this.threadPool); + when(this.threadPool.getThreadContext()).thenReturn(this.threadContext); + when(this.clusterService.getSettings()).thenReturn(settings); + when(this.clusterService.getClusterSettings()) + .thenReturn(new ClusterSettings(settings, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + + this.action = spy(new GetInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + } + + public void testGetInteractions_noMorePages() { + log.info("test get interactions transport"); + Interaction testInteraction = new Interaction( + "test-iid", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testInteraction)); + return null; + }).when(cmHandler).getInteractions(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionsResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List interactions = argCaptor.getValue().getInteractions(); + assert (interactions.size() == 1); + Interaction interaction = interactions.get(0); + assert (interaction.equals(testInteraction)); + assert (!argCaptor.getValue().hasMorePages()); + } + + public void testGetInteractions_MorePages() { + log.info("test get interactions transport"); + Interaction testInteraction = new Interaction( + "test-iid", + Instant.now(), + "test-cid", + "test-input", + "pt", + "test-response", + "test-origin", + "metadata" + ); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onResponse(List.of(testInteraction)); + return null; + }).when(cmHandler).getInteractions(any(), anyInt(), anyInt(), any()); + GetInteractionsRequest shortPageRequest = new GetInteractionsRequest("test-cid", 1); + action.doExecute(null, shortPageRequest, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionsResponse.class); + verify(actionListener).onResponse(argCaptor.capture()); + List interactions = argCaptor.getValue().getInteractions(); + assert (interactions.size() == 1); + Interaction interaction = interactions.get(0); + assert (interaction.equals(testInteraction)); + assert (argCaptor.getValue().hasMorePages()); + } + + public void testGetInteractionsFails_thenFail() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + listener.onFailure(new Exception("Testing Failure")); + return null; + }).when(cmHandler).getInteractions(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Testing Failure")); + } + + public void testDoExecuteFails_thenFail() { + doThrow(new RuntimeException("Failure in doExecute")).when(cmHandler).getInteractions(any(), anyInt(), anyInt(), any()); + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in doExecute")); + } + + public void testFeatureDisabled_ThenFail() { + when(this.clusterService.getSettings()).thenReturn(Settings.EMPTY); + when(this.clusterService.getClusterSettings()).thenReturn(new ClusterSettings(Settings.EMPTY, Set.of(ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED))); + this.action = spy(new GetInteractionsTransportAction(transportService, actionFilters, cmHandler, client, clusterService)); + + action.doExecute(null, request, actionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().startsWith("The experimental Conversation Memory feature is not enabled.")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java new file mode 100644 index 0000000000..9002796bbc --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/action/conversation/InteractionActionTests.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.action.conversation; + +public class InteractionActionTests { + public void testActions() { + assert (CreateInteractionAction.INSTANCE instanceof CreateInteractionAction); + assert (GetInteractionsAction.INSTANCE instanceof GetInteractionsAction); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java new file mode 100644 index 0000000000..e1a0318758 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexITTests.java @@ -0,0 +1,418 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.Stack; +import java.util.concurrent.CountDownLatch; +import java.util.function.Consumer; + +import org.junit.Before; +import org.opensearch.OpenSearchSecurityException; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.StepListener; +import org.opensearch.client.Client; +import org.opensearch.client.Requests; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.common.util.concurrent.ThreadContext.StoredContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class ConversationMetaIndexITTests extends OpenSearchIntegTestCase { + + private ClusterService clusterService; + private Client client; + private ConversationMetaIndex index; + + private void refreshIndex() { + client.admin().indices().refresh(Requests.refreshRequest(ConversationalIndexConstants.META_INDEX_NAME)); + } + + private StoredContext setUser(String username) { + StoredContext stored = client + .threadPool() + .getThreadContext() + .newStoredContext(true, List.of(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT)); + ThreadContext context = client.threadPool().getThreadContext(); + context.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, username + "||"); + return stored; + } + + @Before + public void setup() { + log.info("Setting up test"); + this.client = client(); + this.clusterService = clusterService(); + this.index = new ConversationMetaIndex(client, clusterService); + } + + /** + * Can the index be initialized? + */ + public void testConversationMetaIndexCanBeInitialized() { + CountDownLatch cdl = new CountDownLatch(1); + index.initConversationMetaIndexIfAbsent(new LatchedActionListener(ActionListener.wrap(r -> { assert (r); }, e -> { + log.error(e); + assert (false); + }), cdl)); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + /** + * If the index tries to be initialized more than once does it break something + * Also make sure that only one initialization happens + */ + public void testConversationMetaIndexCanBeInitializedTwice() { + CountDownLatch cdl = new CountDownLatch(2); + index.initConversationMetaIndexIfAbsent(new LatchedActionListener(ActionListener.wrap(r -> { assert (r); }, e -> { + log.error(e); + assert (false); + }), cdl)); + index.initConversationMetaIndexIfAbsent(new LatchedActionListener(ActionListener.wrap(r -> { assert (r); }, e -> { + log.error(e); + assert (false); + }), cdl)); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + /** + * If the index tries to be initialized by different objects does it break anything + * Also make sure that only one initialization happens + */ + public void testConversationMetaIndexCanBeInitializedByDifferentObjects() { + CountDownLatch cdl = new CountDownLatch(2); + index.initConversationMetaIndexIfAbsent(new LatchedActionListener(ActionListener.wrap(r -> { assert (r); }, e -> { + log.error(e); + assert (false); + }), cdl)); + ConversationMetaIndex otherIndex = new ConversationMetaIndex(client, clusterService); + otherIndex.initConversationMetaIndexIfAbsent(new LatchedActionListener(ActionListener.wrap(r -> { assert (r); }, e -> { + log.error(e); + assert (false); + }), cdl)); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + /** + * Can I add a new conversation to the index without crashong? + */ + public void testCanAddNewConversation() { + CountDownLatch cdl = new CountDownLatch(1); + index + .createConversation(new LatchedActionListener(ActionListener.wrap(r -> { assert (r != null && r.length() > 0); }, e -> { + log.error(e); + assert (false); + }), cdl)); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + /** + * Are conversation ids unique? + */ + public void testConversationIDsAreUnique() { + int numTries = 100; + CountDownLatch cdl = new CountDownLatch(numTries); + Set seenIds = Collections.synchronizedSet(new HashSet(numTries)); + for (int i = 0; i < numTries; i++) { + index.createConversation(new LatchedActionListener(ActionListener.wrap(r -> { + assert (!seenIds.contains(r)); + seenIds.add(r); + }, e -> { + log.error(e); + assert (false); + }), cdl)); + } + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + /** + * If I add a conversation, that id shows up in the list of conversations + */ + public void testConversationsCanBeListed() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener addConversationListener = new StepListener<>(); + index.createConversation(addConversationListener); + + StepListener> listConversationListener = new StepListener<>(); + addConversationListener.whenComplete(cid -> { + refreshIndex(); + refreshIndex(); + index.getConversations(10, listConversationListener); + }, e -> { + cdl.countDown(); + log.error(e); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(conversations -> { + boolean foundConversation = false; + log.info("FINDME"); + log.info(addConversationListener.result()); + log.info(conversations); + for (ConversationMeta c : conversations) { + if (c.getId().equals(addConversationListener.result())) { + foundConversation = true; + } + } + assert (foundConversation); + }, e -> { log.error(e); }), cdl); + listConversationListener.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testConversationsCanBeListedPaginated() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener addConversationListener1 = new StepListener<>(); + index.createConversation(addConversationListener1); + + StepListener addConversationListener2 = new StepListener<>(); + addConversationListener1.whenComplete(cid -> { index.createConversation(addConversationListener2); }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener> listConversationListener1 = new StepListener<>(); + addConversationListener2.whenComplete(cid2 -> { index.getConversations(1, listConversationListener1); }, e -> { + cdl.countDown(); + assert (false); + }); + + StepListener> listConversationListener2 = new StepListener<>(); + listConversationListener1.whenComplete(conversations1 -> { index.getConversations(1, 1, listConversationListener2); }, e -> { + cdl.countDown(); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(conversations2 -> { + List conversations1 = listConversationListener1.result(); + String cid1 = addConversationListener1.result(); + String cid2 = addConversationListener2.result(); + if (!conversations1.get(0).getCreatedTime().equals(conversations2.get(0).getCreatedTime())) { + assert (conversations1.get(0).getId().equals(cid2)); + assert (conversations2.get(0).getId().equals(cid1)); + } + }, e -> { assert (false); }), cdl); + listConversationListener2.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + + } + + public void testConversationsCanBeDeleted() { + CountDownLatch cdl = new CountDownLatch(1); + StepListener addConversationListener = new StepListener<>(); + index.createConversation(addConversationListener); + + StepListener deleteConversationListener = new StepListener<>(); + addConversationListener.whenComplete(cid -> { index.deleteConversation(cid, deleteConversationListener); }, e -> { + cdl.countDown(); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(conversations -> { + assert (conversations.size() == 0); + }, e -> { + cdl.countDown(); + assert (false); + }), cdl); + deleteConversationListener.whenComplete(success -> { + if (success) { + index.getConversations(10, finishAndAssert); + } else { + cdl.countDown(); + assert (false); + } + }, e -> { + cdl.countDown(); + assert (false); + }); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testConversationsForDifferentUsersAreDifferent() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + + final String user1 = "test-user1"; + final String user2 = "test-user2"; + + StepListener cid1 = new StepListener<>(); + contextStack.push(setUser(user1)); + index.createConversation(cid1); + + StepListener cid2 = new StepListener<>(); + cid1.whenComplete(cid -> { index.createConversation(cid2); }, onFail); + + StepListener cid3 = new StepListener<>(); + cid2.whenComplete(cid -> { + contextStack.push(setUser(user2)); + index.createConversation(cid3); + }, onFail); + + StepListener> conversationsListener = new StepListener<>(); + cid3.whenComplete(cid -> { index.getConversations(10, conversationsListener); }, onFail); + + StepListener> originalConversationsListener = new StepListener<>(); + conversationsListener.whenComplete(conversations -> { + assert (conversations.size() == 1); + assert (conversations.get(0).getId().equals(cid3.result())); + assert (conversations.get(0).getUser().equals(user2)); + contextStack.pop().restore(); + index.getConversations(10, originalConversationsListener); + }, onFail); + + originalConversationsListener.whenComplete(conversations -> { + assert (conversations.size() == 2); + if (!conversations.get(0).getCreatedTime().equals(conversations.get(1).getCreatedTime())) { + assert (conversations.get(0).getId().equals(cid2.result())); + assert (conversations.get(1).getId().equals(cid1.result())); + } + assert (conversations.get(0).getUser().equals(user1)); + assert (conversations.get(1).getUser().equals(user1)); + contextStack.pop().restore(); + cdl.countDown(); + }, onFail); + + try { + cdl.await(); + threadContext.restore(); + } catch (InterruptedException e) { + log.error(e); + } + } catch (Exception e) { + log.error(e); + throw e; + } + } + + public void testDifferentUsersCannotTouchOthersConversations() { + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + CountDownLatch cdl = new CountDownLatch(1); + Stack contextStack = new Stack<>(); + Consumer onFail = e -> { + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + threadContext.restore(); + assert (false); + }; + + final String user1 = "user-1"; + final String user2 = "user-2"; + contextStack.push(setUser(user1)); + + StepListener cid1 = new StepListener<>(); + index.createConversation(cid1); + + StepListener delListener = new StepListener<>(); + cid1.whenComplete(cid -> { + contextStack.push(setUser(user2)); + index.deleteConversation(cid1.result(), delListener); + }, onFail); + + delListener.whenComplete(success -> { + Exception e = new OpenSearchSecurityException( + "Incorrect access was given to user [" + user2 + "] for conversation " + cid1.result() + ); + while (!contextStack.empty()) { + contextStack.pop().close(); + } + cdl.countDown(); + log.error(e); + assert (false); + }, e -> { + if (e instanceof OpenSearchSecurityException + && e.getMessage().startsWith("User [" + user2 + "] does not have access to conversation ")) { + contextStack.pop().restore(); + contextStack.pop().restore(); + cdl.countDown(); + } else { + onFail.accept(e); + } + }); + + try { + cdl.await(); + threadContext.restore(); + } catch (InterruptedException e) { + log.error(e); + } + } catch (Exception e) { + log.error(e); + throw e; + } + } + +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java new file mode 100644 index 0000000000..8d9667536b --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java @@ -0,0 +1,492 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.OpenSearchWrapperException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.ResourceNotFoundException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.refresh.RefreshResponse; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.SendRequestTransportException; + +public class ConversationMetaIndexTests extends OpenSearchTestCase { + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + Metadata metadata; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ThreadPool threadPool; + + ConversationMetaIndex conversationMetaIndex; + + @Before + public void setup() { + this.client = mock(Client.class); + this.clusterService = mock(ClusterService.class); + this.clusterState = mock(ClusterState.class); + this.metadata = mock(Metadata.class); + this.adminClient = mock(AdminClient.class); + this.indicesAdminClient = mock(IndicesAdminClient.class); + this.threadPool = mock(ThreadPool.class); + + doReturn(clusterState).when(clusterService).state(); + doReturn(metadata).when(clusterState).metadata(); + doReturn(adminClient).when(client).admin(); + doReturn(indicesAdminClient).when(adminClient).indices(); + doReturn(threadPool).when(client).threadPool(); + doReturn(new ThreadContext(Settings.EMPTY)).when(threadPool).getThreadContext(); + conversationMetaIndex = spy(new ConversationMetaIndex(client, clusterService)); + } + + private void setupDoesNotMakeIndex() { + doReturn(false).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(new CreateIndexResponse(false, false, "some-other-index-entirely")); + return null; + }).when(indicesAdminClient).create(any(), any()); + } + + private void setupRefreshSuccess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(mock(RefreshResponse.class)); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + } + + private void blanketGrantAccess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).checkAccess(any(), any()); + } + + private void setupUser(String user) { + String userstr = user == null ? "" : user + "||"; + doAnswer(invocation -> { + ThreadContext tc = new ThreadContext(Settings.EMPTY); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + return tc; + }).when(threadPool).getThreadContext(); + } + + public void testInit_DoesNotCreateIndex() { + setupDoesNotMakeIndex(); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(createIndexListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testInit_CreateIndexFails_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Error")); + return null; + }).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Error")); + } + + public void testInit_CreateIndexFails_WithWrapped_OtherException_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new SendRequestTransportException(null, "action", new Exception("some other exception"))); + return null; + }).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof OpenSearchWrapperException); + assert (argCaptor.getValue().getCause().getMessage().equals("some other exception")); + } + + public void testInit_ClientFails_WithResourceExists_ThenOK() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new ResourceAlreadyExistsException("Test index exists")).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(createIndexListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testInit_ClientFails_WithWrappedResourceExists_ThenOK() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new SendRequestTransportException(null, "action", new ResourceAlreadyExistsException("Test index exists"))) + .when(indicesAdminClient) + .create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(createIndexListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testInit_ClientFails_WithWrappedOtherException_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new SendRequestTransportException(null, "action", new Exception("Some other exception"))) + .when(indicesAdminClient) + .create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof OpenSearchWrapperException); + assert (argCaptor.getValue().getCause().getMessage().equals("Some other exception")); + } + + public void testInit_ClientFails_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Test Client Failure")).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + conversationMetaIndex.initConversationMetaIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Client Failure")); + } + + public void testCreate_DoesntMakeIndex_ThenFail() { + setupDoesNotMakeIndex(); + @SuppressWarnings("unchecked") + ActionListener createConversationListener = mock(ActionListener.class); + conversationMetaIndex.createConversation(createConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed to add conversation due to missing index")); + } + + public void testCreate_BadRestStatus_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + IndexResponse response = mock(IndexResponse.class); + doReturn(RestStatus.GONE).when(response).status(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createConversationListener = mock(ActionListener.class); + conversationMetaIndex.createConversation(createConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("failed to create conversation")); + } + + public void testCreate_InternalFailure_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Failure")); + return null; + }).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createConversationListener = mock(ActionListener.class); + conversationMetaIndex.createConversation(createConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Failure")); + } + + public void testCreate_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Test Failure")).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createConversationListener = mock(ActionListener.class); + conversationMetaIndex.createConversation(createConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Failure")); + } + + public void testCreate_InitFails_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Test Init Client Failure")).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createConversationListener = mock(ActionListener.class); + conversationMetaIndex.createConversation(createConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Init Client Failure")); + } + + public void testGet_NoIndex_ThenEmpty() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener> getConversationsListener = mock(ActionListener.class); + conversationMetaIndex.getConversations(10, getConversationsListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getConversationsListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 0); + } + + public void testGet_SearchFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupRefreshSuccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Exception")); + return null; + }).when(client).search(any(), any()); + @SuppressWarnings("unchecked") + ActionListener> getConversationsListener = mock(ActionListener.class); + conversationMetaIndex.getConversations(10, getConversationsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getConversationsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Exception")); + } + + public void testGet_RefreshFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Refresh Exception")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener> getConversationsListener = mock(ActionListener.class); + conversationMetaIndex.getConversations(10, getConversationsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getConversationsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Exception")); + } + + public void testGet_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Refresh Client Failure")).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener> getConversationsListener = mock(ActionListener.class); + conversationMetaIndex.getConversations(10, getConversationsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getConversationsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Refresh Client Failure")); + } + + public void testDelete_NoIndex_ThenReturnTrue() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteConversationListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testDelete_RestNotFoundStatus_ThenReturnTrue() { + doReturn(true).when(metadata).hasIndex(anyString()); + blanketGrantAccess(); + DeleteResponse response = mock(DeleteResponse.class); + doReturn(RestStatus.NOT_FOUND).when(response).status(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).delete(any(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteConversationListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testDelete_BadResponse_ThenReturnFalse() { + doReturn(true).when(metadata).hasIndex(anyString()); + blanketGrantAccess(); + DeleteResponse response = mock(DeleteResponse.class); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).delete(any(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteConversationListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testDelete_DeleteFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + blanketGrantAccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Fail in Delete")); + return null; + }).when(client).delete(any(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Fail in Delete")); + } + + public void testDelete_HighLevelFailure_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Check Fail")).when(conversationMetaIndex).checkAccess(any(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + conversationMetaIndex.deleteConversation("test-id", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Check Fail")); + } + + public void testCheckAccess_DoesNotExist_ThenFail() { + setupUser("user"); + doReturn(true).when(metadata).hasIndex(anyString()); + GetResponse response = mock(GetResponse.class); + doReturn(false).when(response).isExists(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof ResourceNotFoundException); + assert (argCaptor.getValue().getMessage().equals("Conversation [test id] not found")); + } + + public void testCheckAccess_WrongId_ThenFail() { + setupUser("user"); + doReturn(true).when(metadata).hasIndex(anyString()); + GetResponse response = mock(GetResponse.class); + doReturn(true).when(response).isExists(); + doReturn("wrong id").when(response).getId(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof ResourceNotFoundException); + assert (argCaptor.getValue().getMessage().equals("Conversation [test id] not found")); + } + + public void testCheckAccess_GetFails_ThenFail() { + setupUser("user"); + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Fail")); + return null; + }).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Fail")); + } + + public void testCheckAccess_ClientFails_ThenFail() { + setupUser("user"); + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Client Test Fail")).when(client).get(any(), any()); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(accessListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Test Fail")); + } + + public void testCheckAccess_EmptyStringUser_ThenReturnTrue() { + setupUser(null); + @SuppressWarnings("unchecked") + ActionListener accessListener = mock(ActionListener.class); + conversationMetaIndex.checkAccess("test id", accessListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(accessListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java new file mode 100644 index 0000000000..c23177bc2f --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexITTests.java @@ -0,0 +1,351 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import org.junit.Before; +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.StepListener; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchIntegTestCase; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 2) +public class InteractionsIndexITTests extends OpenSearchIntegTestCase { + + private Client client; + private ClusterService clusterService; + private InteractionsIndex index; + + @Before + public void setup() { + client = client(); + clusterService = clusterService(); + index = new InteractionsIndex(client, clusterService, new ConversationMetaIndex(client, clusterService)); + } + + /** + * Test the index intialization logic - can I create the index exactly once (while trying 3 times) + */ + public void testInteractionsIndexCanBeInitialized() { + log.info("testing index creation logic of the index object"); + CountDownLatch cdl = new CountDownLatch(3); + index.initInteractionsIndexIfAbsent(new LatchedActionListener<>(ActionListener.wrap(r -> { assert (r); }, e -> { + cdl.countDown(); + cdl.countDown(); + log.error(e); + assert (false); + }), cdl)); + index.initInteractionsIndexIfAbsent(new LatchedActionListener<>(ActionListener.wrap(r -> { assert (r); }, e -> { + cdl.countDown(); + cdl.countDown(); + log.error(e); + assert (false); + }), cdl)); + InteractionsIndex otherIndex = new InteractionsIndex(client, clusterService, new ConversationMetaIndex(client, clusterService)); + otherIndex.initInteractionsIndexIfAbsent(new LatchedActionListener<>(ActionListener.wrap(r -> { assert (r); }, e -> { + cdl.countDown(); + cdl.countDown(); + log.error(e); + assert (false); + }), cdl)); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + /** + * Make sure nothing breaks when I add an interaction, with and without timestamp, + * and that the ids are different + */ + public void testCanAddNewInteraction() { + CountDownLatch cdl = new CountDownLatch(2); + String[] ids = new String[2]; + index + .createInteraction( + "test", + "test input", + "pt", + "test response", + "test origin", + "metadata", + new LatchedActionListener<>(ActionListener.wrap(id -> { + ids[0] = id; + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }), cdl) + ); + + index + .createInteraction( + "test", + "test input", + "pt", + "test response", + "test origin", + "metadata", + new LatchedActionListener<>(ActionListener.wrap(id -> { + ids[1] = id; + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }), cdl) + ); + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + assert (!ids[0].equals(ids[1])); + } + + /** + * Make sure I can get interactions out related to a conversation + */ + public void testGetInteractions() { + final String conversation = "test-conversation"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener id1Listener = new StepListener<>(); + index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); + + StepListener id2Listener = new StepListener<>(); + id1Listener.whenComplete(id -> { + index + .createInteraction( + conversation, + "test input", + "pt", + "test response", + "test origin", + "metadata", + Instant.now().plus(3, ChronoUnit.MINUTES), + id2Listener + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener> getListener = new StepListener<>(); + id2Listener.whenComplete(r -> { index.getInteractions(conversation, 0, 2, getListener); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(interactions -> { + assert (interactions.size() == 2); + assert (interactions.get(0).getId().equals(id2Listener.result())); + assert (interactions.get(1).getId().equals(id1Listener.result())); + }, e -> { + log.error(e); + assert (false); + }), cdl); + getListener.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testGetInteractionPages() { + final String conversation = "test-conversation"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener id1Listener = new StepListener<>(); + index.createInteraction(conversation, "test input", "pt", "test response", "test origin", "metadata", id1Listener); + + StepListener id2Listener = new StepListener<>(); + id1Listener.whenComplete(id -> { + index + .createInteraction( + conversation, + "test input1", + "pt", + "test response", + "test origin", + "metadata", + Instant.now().plus(3, ChronoUnit.MINUTES), + id2Listener + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener id3Listener = new StepListener<>(); + id2Listener.whenComplete(id -> { + index + .createInteraction( + conversation, + "test input2", + "pt", + "test response", + "test origin", + "metadata", + Instant.now().plus(4, ChronoUnit.MINUTES), + id3Listener + ); + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener> getListener1 = new StepListener<>(); + id3Listener.whenComplete(r -> { index.getInteractions(conversation, 0, 2, getListener1); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener> getListener2 = new StepListener<>(); + getListener1.whenComplete(r -> { index.getInteractions(conversation, 2, 2, getListener2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(interactions2 -> { + List interactions1 = getListener1.result(); + String id1 = id1Listener.result(); + String id2 = id2Listener.result(); + String id3 = id3Listener.result(); + assert (interactions2.size() == 1); + assert (interactions1.size() == 2); + assert (interactions1.get(0).getId().equals(id3)); + assert (interactions1.get(1).getId().equals(id2)); + assert (interactions2.get(0).getId().equals(id1)); + }, e -> { + log.error(e); + assert (false); + }), cdl); + getListener2.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } + + public void testDeleteConversation() { + final String conversation1 = "conversation1"; + final String conversation2 = "conversation2"; + CountDownLatch cdl = new CountDownLatch(1); + StepListener iid1 = new StepListener<>(); + index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid1); + + StepListener iid2 = new StepListener<>(); + iid1 + .whenComplete( + r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid2); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener iid3 = new StepListener<>(); + iid2 + .whenComplete( + r -> { index.createInteraction(conversation2, "test input", "pt", "test response", "test origin", "metadata", iid3); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener iid4 = new StepListener<>(); + iid3 + .whenComplete( + r -> { index.createInteraction(conversation1, "test input", "pt", "test response", "test origin", "metadata", iid4); }, + e -> { + cdl.countDown(); + log.error(e); + assert (false); + } + ); + + StepListener deleteListener = new StepListener<>(); + iid4.whenComplete(r -> { index.deleteConversation(conversation1, deleteListener); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener> interactions1 = new StepListener<>(); + deleteListener.whenComplete(success -> { + if (success) { + index.getInteractions(conversation1, 0, 10, interactions1); + } else { + cdl.countDown(); + assert (false); + } + }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + StepListener> interactions2 = new StepListener<>(); + interactions1.whenComplete(interactions -> { index.getInteractions(conversation2, 0, 10, interactions2); }, e -> { + cdl.countDown(); + log.error(e); + assert (false); + }); + + LatchedActionListener> finishAndAssert = new LatchedActionListener<>(ActionListener.wrap(interactions -> { + assert (interactions.size() == 1); + assert (interactions.get(0).getId().equals(iid3.result())); + assert (interactions1.result().size() == 0); + }, e -> { + log.error(e); + assert (false); + }), cdl); + interactions2.whenComplete(finishAndAssert::onResponse, finishAndAssert::onFailure); + + try { + cdl.await(); + } catch (InterruptedException e) { + log.error(e); + } + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java new file mode 100644 index 0000000000..0e97c7e9f6 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/index/InteractionsIndexTests.java @@ -0,0 +1,585 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.time.Instant; +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.OpenSearchWrapperException; +import org.opensearch.ResourceAlreadyExistsException; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.admin.indices.refresh.RefreshResponse; +import org.opensearch.action.bulk.BulkResponse; +import org.opensearch.action.index.IndexResponse; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.AdminClient; +import org.opensearch.client.Client; +import org.opensearch.client.IndicesAdminClient; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.commons.ConfigConstants; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.SendRequestTransportException; + +public class InteractionsIndexTests extends OpenSearchTestCase { + @Mock + Client client; + + @Mock + ClusterService clusterService; + + @Mock + ClusterState clusterState; + + @Mock + Metadata metadata; + + @Mock + AdminClient adminClient; + + @Mock + IndicesAdminClient indicesAdminClient; + + @Mock + ThreadPool threadPool; + + @Mock + ConversationMetaIndex conversationMetaIndex; + + InteractionsIndex interactionsIndex; + + @Before + public void setup() { + this.client = mock(Client.class); + this.clusterService = mock(ClusterService.class); + this.clusterState = mock(ClusterState.class); + this.metadata = mock(Metadata.class); + this.adminClient = mock(AdminClient.class); + this.indicesAdminClient = mock(IndicesAdminClient.class); + this.threadPool = mock(ThreadPool.class); + this.conversationMetaIndex = mock(ConversationMetaIndex.class); + + doReturn(clusterState).when(clusterService).state(); + doReturn(metadata).when(clusterState).metadata(); + doReturn(adminClient).when(client).admin(); + doReturn(indicesAdminClient).when(adminClient).indices(); + doReturn(threadPool).when(client).threadPool(); + doReturn(new ThreadContext(Settings.EMPTY)).when(threadPool).getThreadContext(); + this.interactionsIndex = spy(new InteractionsIndex(client, clusterService, conversationMetaIndex)); + } + + private void setupDoesNotMakeIndex() { + doReturn(false).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(new CreateIndexResponse(false, false, "some-other-index-entirely")); + return null; + }).when(indicesAdminClient).create(any(), any()); + } + + private void setupGrantAccess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + } + + private void setupDenyAccess(String user) { + String userstr = user == null ? "" : user + "||"; + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + doAnswer(invocation -> { + ThreadContext tc = new ThreadContext(Settings.EMPTY); + tc.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, userstr); + return tc; + }).when(threadPool).getThreadContext(); + } + + private void setupRefreshSuccess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(mock(RefreshResponse.class)); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + } + + public void testInit_DoesNotCreateIndex_ThenReturnFalse() { + setupDoesNotMakeIndex(); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(createIndexListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testInit_CreateIndexFails_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Error")); + return null; + }).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Error")); + } + + public void testInit_CreateIndexFails_WithWrapped_OtherException_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new SendRequestTransportException(null, "action", new Exception("some other exception"))); + return null; + }).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof OpenSearchWrapperException); + assert (argCaptor.getValue().getCause().getMessage().equals("some other exception")); + } + + public void testInit_ClientFails_WithResourceExists_ThenOK() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new ResourceAlreadyExistsException("Test index exists")).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(createIndexListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testInit_ClientFails_WithWrappedResourceExists_ThenOK() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new SendRequestTransportException(null, "action", new ResourceAlreadyExistsException("Test index exists"))) + .when(indicesAdminClient) + .create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(createIndexListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testInit_ClientFails_WithWrappedOtherException_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new SendRequestTransportException(null, "action", new Exception("Some other exception"))) + .when(indicesAdminClient) + .create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof OpenSearchWrapperException); + assert (argCaptor.getValue().getCause().getMessage().equals("Some other exception")); + } + + public void testInit_ClientFails_ThenFail() { + doReturn(false).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Test Client Failure")).when(indicesAdminClient).create(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createIndexListener = mock(ActionListener.class); + interactionsIndex.initInteractionsIndexIfAbsent(createIndexListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createIndexListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Client Failure")); + } + + public void testCreate_NoIndex_ThenFail() { + setupDoesNotMakeIndex(); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("no index to add conversation to")); + } + + public void testCreate_BadRestStatus_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + IndexResponse response = mock(IndexResponse.class); + doReturn(RestStatus.GONE).when(response).status(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(response); + return null; + }).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed to create interaction")); + } + + public void testCreate_InternalFailure_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Test Failure")); + return null; + }).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Failure")); + } + + public void testCreate_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Client Failure")); + } + + public void testCreate_NoAccessNoUser_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess(null); + doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals("User [" + ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + "] does not have access to conversation cid")); + } + + public void testCreate_NoAccessWithUser_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("user"); + doThrow(new RuntimeException("Test Client Failure")).when(client).index(any(), any()); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation cid")); + } + + public void testCreate_CreateIndexFails_ThenFail() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(0); + al.onFailure(new Exception("Fail in Index Creation")); + return null; + }).when(interactionsIndex).initInteractionsIndexIfAbsent(any()); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + interactionsIndex.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta", createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(createInteractionListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Fail in Index Creation")); + } + + public void testGet_NoIndex_ThenEmpty() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getInteractions("cid", 0, 10, getInteractionsListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getInteractionsListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 0); + } + + public void testGet_SearchFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + setupRefreshSuccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failure in Search")); + return null; + }).when(client).search(any(), any()); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getInteractions("cid", 0, 10, getInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in Search")); + } + + public void testGet_RefreshFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failed to Refresh")); + return null; + }).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getInteractions("cid", 0, 10, getInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failed to Refresh")); + } + + public void testGet_ClientFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doThrow(new RuntimeException("Client Failure")).when(indicesAdminClient).refresh(any(), any()); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getInteractions("cid", 0, 10, getInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Client Failure")); + } + + public void testGet_NoAccessNoUser_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess(null); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getInteractions("cid", 0, 10, getInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals("User [" + ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + "] does not have access to conversation cid")); + } + + public void testGetAll_BadMaxResults_ThenFail() { + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.nextGetListener("cid", 0, 0, getInteractionsListener, List.of()); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("maxResults must be positive")); + } + + public void testGetAll_Recursion() { + List interactions = List + .of( + new Interaction("iid1", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), + new Interaction("iid2", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), + new Interaction("iid3", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta"), + new Interaction("iid4", Instant.now(), "cid", "inp", "pt", "rsp", "ogn", "meta") + ); + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(3); + al.onResponse(interactions.subList(0, 2)); + return null; + }).when(interactionsIndex).innerGetInteractions(anyString(), eq(0), anyInt(), any()); + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(3); + al.onResponse(interactions.subList(2, 4)); + return null; + }).when(interactionsIndex).innerGetInteractions(anyString(), eq(2), anyInt(), any()); + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(3); + al.onResponse(List.of()); + return null; + }).when(interactionsIndex).innerGetInteractions(anyString(), eq(4), anyInt(), any()); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getAllInteractions("cid", 2, getInteractionsListener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(getInteractionsListener, times(1)).onResponse(argCaptor.capture()); + List result = argCaptor.getValue(); + assert (result.size() == 4); + assert (result.get(0).getId().equals("iid1")); + assert (result.get(1).getId().equals("iid2")); + assert (result.get(2).getId().equals("iid3")); + assert (result.get(3).getId().equals("iid4")); + } + + public void testGetAll_GetFails_ThenFail() { + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(3); + al.onFailure(new Exception("Failure in Get")); + return null; + }).when(interactionsIndex).innerGetInteractions(anyString(), anyInt(), anyInt(), any()); + @SuppressWarnings("unchecked") + ActionListener> getInteractionsListener = mock(ActionListener.class); + interactionsIndex.getAllInteractions("cid", 2, getInteractionsListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(getInteractionsListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure in Get")); + } + + public void testDelete_NoIndex_ThenReturnTrue() { + doReturn(false).when(metadata).hasIndex(anyString()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteConversationListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue()); + } + + public void testDelete_BulkHasFailures_ReturnFalse() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + BulkResponse bulkResponse = mock(BulkResponse.class); + doReturn(true).when(bulkResponse).hasFailures(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(bulkResponse); + return null; + }).when(client).bulk(any(), any()); + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(2); + al.onResponse(List.of()); + return null; + }).when(interactionsIndex).getAllInteractions(anyString(), anyInt(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteConversationListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testDelete_BulkFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Failure during Bulk")); + return null; + }).when(client).bulk(any(), any()); + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(2); + al.onResponse(List.of()); + return null; + }).when(interactionsIndex).getAllInteractions(anyString(), anyInt(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure during Bulk")); + } + + public void testDelete_SearchFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupGrantAccess(); + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(2); + al.onFailure(new Exception("Failure during GetAllInteractions")); + return null; + }).when(interactionsIndex).getAllInteractions(anyString(), anyInt(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Failure during GetAllInteractions")); + } + + public void testDelete_NoAccessNoUser_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess(null); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor + .getValue() + .getMessage() + .equals("User [" + ActionConstants.DEFAULT_USERNAME_FOR_ERRORS + "] does not have access to conversation cid")); + } + + public void testDelete_NoAccessWithUser_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + setupDenyAccess("user"); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("User [user] does not have access to conversation cid")); + } + + public void testDelete_AccessFails_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onFailure(new Exception("Access Failure")); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Access Failure")); + } + + public void testDelete_MainFailure_ThenFail() { + doReturn(true).when(metadata).hasIndex(anyString()); + doThrow(new RuntimeException("Test Failure")).when(conversationMetaIndex).checkAccess(anyString(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteConversationListener = mock(ActionListener.class); + interactionsIndex.deleteConversation("cid", deleteConversationListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("Test Failure")); + } +} diff --git a/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java new file mode 100644 index 0000000000..e39513d2d8 --- /dev/null +++ b/memory/src/test/java/org/opensearch/ml/memory/index/OpenSearchConversationalMemoryHandlerTests.java @@ -0,0 +1,244 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.memory.index; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.conversation.ConversationMeta; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.conversation.Interaction.InteractionBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class OpenSearchConversationalMemoryHandlerTests extends OpenSearchTestCase { + + @Mock + ConversationMetaIndex conversationMetaIndex; + + @Mock + InteractionsIndex interactionsIndex; + + OpenSearchConversationalMemoryHandler cmHandler; + + @Before + public void setup() { + conversationMetaIndex = mock(ConversationMetaIndex.class); + interactionsIndex = mock(InteractionsIndex.class); + cmHandler = new OpenSearchConversationalMemoryHandler(conversationMetaIndex, interactionsIndex); + } + + public void testCreateConversation_NoName_FutureSuccess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(0); + al.onResponse("cid"); + return null; + }).when(conversationMetaIndex).createConversation(any(ActionListener.class)); + ActionFuture result = cmHandler.createConversation(); + assert (result.actionGet(200).equals("cid")); + } + + public void testCreateConversation_Named_FutureSucess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse("cid"); + return null; + }).when(conversationMetaIndex).createConversation(anyString(), any()); + ActionFuture result = cmHandler.createConversation("FutureSuccess"); + assert (result.actionGet(200).equals("cid")); + } + + public void testCreateInteraction_Future() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(7); + al.onResponse("iid"); + return null; + }) + .when(interactionsIndex) + .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + ActionFuture result = cmHandler.createInteraction("cid", "inp", "pt", "rsp", "ogn", "meta"); + assert (result.actionGet(200).equals("iid")); + } + + public void testCreateInteraction_FromBuilder_Success() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(7); + al.onResponse("iid"); + return null; + }) + .when(interactionsIndex) + .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + InteractionBuilder builder = Interaction + .builder() + .conversationId("cid") + .input("inp") + .origin("origin") + .response("rsp") + .promptTemplate("pt") + .additionalInfo("meta"); + @SuppressWarnings("unchecked") + ActionListener createInteractionListener = mock(ActionListener.class); + cmHandler.createInteraction(builder, createInteractionListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(String.class); + verify(createInteractionListener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().equals("iid")); + } + + public void testCreateInteraction_FromBuilder_Future() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(7); + al.onResponse("iid"); + return null; + }) + .when(interactionsIndex) + .createInteraction(anyString(), anyString(), anyString(), anyString(), anyString(), anyString(), any(), any()); + InteractionBuilder builder = Interaction + .builder() + .origin("ogn") + .conversationId("cid") + .input("inp") + .response("rsp") + .promptTemplate("pt") + .additionalInfo("meta"); + ActionFuture result = cmHandler.createInteraction(builder); + assert (result.actionGet(200).equals("iid")); + } + + public void testGetInteractions_Future() { + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(3); + al.onResponse(List.of()); + return null; + }).when(interactionsIndex).getInteractions(anyString(), anyInt(), anyInt(), any()); + ActionFuture> result = cmHandler.getInteractions("cid", 0, 10); + assert (result.actionGet(200).size() == 0); + } + + public void testGetConversations_Future() { + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(1); + al.onResponse(List.of()); + return null; + }).when(conversationMetaIndex).getConversations(anyInt(), any()); + ActionFuture> result = cmHandler.getConversations(10); + assert (result.actionGet(200).size() == 0); + } + + public void testGetConversations_Page_Future() { + doAnswer(invocation -> { + ActionListener> al = invocation.getArgument(2); + al.onResponse(List.of()); + return null; + }).when(conversationMetaIndex).getConversations(anyInt(), anyInt(), any()); + ActionFuture> result = cmHandler.getConversations(30, 10); + assert (result.actionGet(200).size() == 0); + } + + public void testDelete_NoAccess() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteListener = mock(ActionListener.class); + cmHandler.deleteConversation("cid", deleteListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testDelete_ConversationMetaDeleteFalse_ThenFalse() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(conversationMetaIndex).deleteConversation(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(interactionsIndex).deleteConversation(anyString(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteListener = mock(ActionListener.class); + cmHandler.deleteConversation("cid", deleteListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testDelete_InteractionsDeleteFalse_ThenFalse() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).deleteConversation(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(false); + return null; + }).when(interactionsIndex).deleteConversation(anyString(), any()); + @SuppressWarnings("unchecked") + ActionListener deleteListener = mock(ActionListener.class); + cmHandler.deleteConversation("cid", deleteListener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Boolean.class); + verify(deleteListener, times(1)).onResponse(argCaptor.capture()); + assert (!argCaptor.getValue()); + } + + public void testDelete_AsFuture() { + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).checkAccess(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(conversationMetaIndex).deleteConversation(anyString(), any()); + doAnswer(invocation -> { + ActionListener al = invocation.getArgument(1); + al.onResponse(true); + return null; + }).when(interactionsIndex).deleteConversation(anyString(), any()); + ActionFuture result = cmHandler.deleteConversation("cid"); + assert (result.actionGet(200)); + } +} diff --git a/memory/src/yamlRestTest/java/org/opensearch/ml/conversational/ConversationalClientYamlTestSuiteIT.java b/memory/src/yamlRestTest/java/org/opensearch/ml/conversational/ConversationalClientYamlTestSuiteIT.java new file mode 100644 index 0000000000..70007dbe0c --- /dev/null +++ b/memory/src/yamlRestTest/java/org/opensearch/ml/conversational/ConversationalClientYamlTestSuiteIT.java @@ -0,0 +1,25 @@ +/* + * + * * Copyright OpenSearch Contributors + * * SPDX-License-Identifier: Apache-2.0 + * + */ +package org.opensearch.ml.conversational; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; +import org.opensearch.test.rest.yaml.ClientYamlTestCandidate; +import org.opensearch.test.rest.yaml.OpenSearchClientYamlSuiteTestCase; + + +public class ConversationalClientYamlTestSuiteIT extends OpenSearchClientYamlSuiteTestCase { + + public ConversationalClientYamlTestSuiteIT(@Name("yaml") ClientYamlTestCandidate testCandidate) { + super(testCandidate); + } + + @ParametersFactory + public static Iterable parameters() throws Exception { + return OpenSearchClientYamlSuiteTestCase.createParameters(); + } +} diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_create_conversation.json b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_create_conversation.json new file mode 100644 index 0000000000..d4fa5fd14a --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_create_conversation.json @@ -0,0 +1,19 @@ +{ + "_plugins.conversational_create_conversation": { + "stability": "stable", + "url": { + "paths": [ + { + "path": "/_plugins/ml/conversational/memory", + "methods": ["POST"] + } + ] + }, + "body": { + "name": { + "type": "string", + "description": "[optional] name of the conversation to be created" + } + } + } +} \ No newline at end of file diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_create_interaction.json b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_create_interaction.json new file mode 100644 index 0000000000..6754a9de26 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_create_interaction.json @@ -0,0 +1,44 @@ +{ + "_plugins.conversational_create_interaction": { + "stability": "stable", + "url": { + "paths": [ + { + "path": "/_plugins/ml/conversational/memory/{conversationId}", + "methods": ["POST"], + "parts": { + "conversationId": { + "type": "string", + "description": "ID of conversation to add this interaction to" + } + } + } + ] + }, + "params": { + "input": { + "type": "string", + "description": "human input in the interaction" + }, + "prompt": { + "type": "string", + "description": "prompting around the input" + }, + "response": { + "type": "string", + "description": "AI response from the input" + }, + "agent": { + "type": "string", + "description": "GenAI Agent used for this interaction" + }, + "attributes": { + "type": "string", + "description": "arbitrary XContent string of additional data associated with this interaction" + } + }, + "body": { + "description": "the interaction" + } + } +} \ No newline at end of file diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_delete_conversation.json b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_delete_conversation.json new file mode 100644 index 0000000000..1f350d07a9 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_delete_conversation.json @@ -0,0 +1,19 @@ +{ + "_plugins.conversational_delete_conversation": { + "stability": "stable", + "url": { + "paths": [ + { + "path": "/_plugins/ml/conversational/memory/{conversationId}", + "methods": ["DELETE"], + "parts": { + "conversationId": { + "type": "string", + "description": "id of the conversation to be deleted" + } + } + } + ] + } + } +} \ No newline at end of file diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_get_conversations.json b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_get_conversations.json new file mode 100644 index 0000000000..16a28af294 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_get_conversations.json @@ -0,0 +1,23 @@ +{ + "_plugins.conversational_get_conversations": { + "stability": "stable", + "url": { + "paths": [ + { + "path": "/_plugins/ml/conversational/memory", + "methods": ["GET"] + } + ] + }, + "params": { + "maxResults": { + "type": "number", + "description": "[optional] number of results to return (defaults to 10)" + }, + "nextToken": { + "type": "number", + "description": "[optional] token pointing to the next page of results" + } + } + } +} \ No newline at end of file diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_get_interactions.json b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_get_interactions.json new file mode 100644 index 0000000000..bb5cf610ad --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/api/_plugins.conversational_get_interactions.json @@ -0,0 +1,29 @@ +{ + "_plugins.conversational_get_interactions": { + "stability": "stable", + "url": { + "paths": [ + { + "path": "/_plugins/ml/conversational/memory/{conversationId}", + "methods": ["GET"], + "parts": { + "conversationId": { + "type": "string", + "description": "ID of conversation to get interactions from" + } + } + } + ] + }, + "params": { + "maxResults": { + "type": "number", + "description": "[optional] number of results to return (defaults to 10)" + }, + "nextToken": { + "type": "number", + "description": "[optional] token pointing to the next page of results" + } + } + } +} \ No newline at end of file diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml b/memory/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml new file mode 100644 index 0000000000..34404f8afe --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/test/10_basic.yml @@ -0,0 +1,8 @@ +"Test that the plugin is loaded in OpenSearch": + - do: + cat.plugins: + local: true + h: component + + - match: + $body: /^conversational\n$/ diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/test/20_create_convo.yml b/memory/src/yamlRestTest/resources/rest-api-spec/test/20_create_convo.yml new file mode 100644 index 0000000000..b6373050a4 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/test/20_create_convo.yml @@ -0,0 +1,17 @@ +--- +"Test creating a conversation": + - do: + _plugins.conversational_create_conversation: + body: null + + - match: + $body.conversationId: /^.{10,}$/ +--- +"Test creating a conversation with a name": + - do: + _plugins.conversational_create_conversation: + body: + name: Test + + - match: + $body.conversationId: /^.{10,}$/ diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/test/30_list_convos.yml b/memory/src/yamlRestTest/resources/rest-api-spec/test/30_list_convos.yml new file mode 100644 index 0000000000..1552d42d3c --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/test/30_list_convos.yml @@ -0,0 +1,72 @@ +--- +"Test adding and getting back conversations": + + - do: + _plugins.conversational_create_conversation: + body: null + + - do: + _plugins.conversational_create_conversation: + body: + name: TEST + + - do: + _plugins.conversational_get_conversations: + params: null + + - match: + $body.conversations.0.conversationId: /.{10,}/ + + - match: + $body.conversations.1.conversationId: /.{10,}/ + + - match: + $body.conversations.0.numInteractions: 0 + + - match: + $body.conversations.1.numInteractions: 0 + + - match: + $body.conversations.0.name: TEST + + - match: + $body.conversations.1.name: /^$/ + +--- +"Test paginations": + - do: + _plugins.conversational_create_conversation: + body: + name: C1 + + - do: + _plugins.conversational_create_conversation: + body: + name: C2 + + - do: + _plugins.conversational_create_conversation: + body: + name: C3 + + - do: + _plugins.conversational_get_conversations: + maxResults: 2 + + - match: + $body.conversations.0.name: C3 + + - match: + $body.conversations.1.name: C2 + + - match: + $body.nextToken: 2 + + - do: + _plugins.conversational_get_conversations: + maxResults: 2 + nextToken: 2 + + - match: + $body.conversations.0.name: C1 + diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/test/40_put_interaction.yml b/memory/src/yamlRestTest/resources/rest-api-spec/test/40_put_interaction.yml new file mode 100644 index 0000000000..13328bcb85 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/test/40_put_interaction.yml @@ -0,0 +1,13 @@ +"Test adding an interaction": + - do: + _plugins.conversational_create_interaction: + conversationId: test-cid + body: + input: test-input + prompt: test-prompt + response: test-response + agent: test-agent + attributes: test-attributes + + - match: + $body.interactionId: /.{10,}/ diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/test/50_get_interactions.yml b/memory/src/yamlRestTest/resources/rest-api-spec/test/50_get_interactions.yml new file mode 100644 index 0000000000..7dda8cf0e9 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/test/50_get_interactions.yml @@ -0,0 +1,88 @@ +--- +"Test adding and getting back interactions": + + - do: + _plugins.conversational_create_interaction: + conversationId: test-cid + body: + input: test-input + prompt: test-prompt + response: test-response + agent: test-agent + attributes: test-attributes + + - do: + _plugins.conversational_create_interaction: + conversationId: test-cid + body: + input: test-input1 + prompt: test-prompt1 + response: test-response1 + agent: test-agent1 + attributes: test-attributes1 + + - do: + _plugins.conversational_get_interactions: + conversationId: test-cid + + - match: + $body.interactions.0.input: test-input1 + + - match: + $body.interactions.1.input: test-input + +--- +"Test adding interactions, paginated": + + - do: + _plugins.conversational_create_interaction: + conversationId: test-cid + body: + input: test-input1 + prompt: test-prompt + response: test-response + agent: test-agent + attributes: test-attributes + + - do: + _plugins.conversational_create_interaction: + conversationId: test-cid + body: + input: test-input2 + prompt: test-prompt + response: test-response + agent: test-agent + attributes: test-attributes + + - do: + _plugins.conversational_create_interaction: + conversationId: test-cid + body: + input: test-input3 + prompt: test-prompt + response: test-response + agent: test-agent + attributes: test-attributes + + - do: + _plugins.conversational_get_interactions: + conversationId: test-cid + maxResults: 2 + + - match: + $body.interactions.0.input: test-input3 + + - match: + $body.interactions.1.input: test-input2 + + - match: + $body.nextToken: 2 + + - do: + _plugins.conversational_get_interactions: + conversationId: test-cid + maxResults: 2 + nextToken: 2 + + - match: + $body.interactions.0.input: test-input1 diff --git a/memory/src/yamlRestTest/resources/rest-api-spec/test/60_delete_convo.yml b/memory/src/yamlRestTest/resources/rest-api-spec/test/60_delete_convo.yml new file mode 100644 index 0000000000..9584db8fe8 --- /dev/null +++ b/memory/src/yamlRestTest/resources/rest-api-spec/test/60_delete_convo.yml @@ -0,0 +1,9 @@ +--- +"Test deleting a conversation that doesn't exist": + + - do: + _plugins.conversational_delete_conversation: + conversationId: test + + - match: + $body.success: true diff --git a/plugin/build.gradle b/plugin/build.gradle index ec49b23fce..5629801080 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -44,6 +44,8 @@ opensearchplugin { dependencies { implementation project(':opensearch-ml-common') implementation project(':opensearch-ml-algorithms') + implementation project(':opensearch-ml-search-processors') + implementation project(':opensearch-ml-memory') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" implementation "org.opensearch.client:opensearch-rest-client:${opensearch_version}" diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 6ee5c42bd4..ee2e4ffe38 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -11,7 +11,9 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -124,6 +126,18 @@ import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.indices.MLIndicesHandler; import org.opensearch.ml.indices.MLInputDatasetHandler; +import org.opensearch.ml.memory.ConversationalMemoryHandler; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionTransportAction; +import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; +import org.opensearch.ml.memory.action.conversation.DeleteConversationTransportAction; +import org.opensearch.ml.memory.action.conversation.GetConversationsAction; +import org.opensearch.ml.memory.action.conversation.GetConversationsTransportAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsTransportAction; +import org.opensearch.ml.memory.index.OpenSearchConversationalMemoryHandler; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; import org.opensearch.ml.rest.RestMLCreateConnectorAction; @@ -151,6 +165,11 @@ import org.opensearch.ml.rest.RestMLUndeployModelAction; import org.opensearch.ml.rest.RestMLUpdateModelGroupAction; import org.opensearch.ml.rest.RestMLUploadModelChunkAction; +import org.opensearch.ml.rest.RestMemoryCreateConversationAction; +import org.opensearch.ml.rest.RestMemoryCreateInteractionAction; +import org.opensearch.ml.rest.RestMemoryDeleteConversationAction; +import org.opensearch.ml.rest.RestMemoryGetConversationsAction; +import org.opensearch.ml.rest.RestMemoryGetInteractionsAction; import org.opensearch.ml.settings.MLCommonsSettings; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.stats.MLClusterLevelStat; @@ -170,10 +189,19 @@ import org.opensearch.monitor.os.OsService; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPipelinePlugin; +import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; @@ -183,7 +211,7 @@ import lombok.SneakyThrows; -public class MachineLearningPlugin extends Plugin implements ActionPlugin { +public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin { public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons."; public static final String GENERAL_THREAD_POOL = "opensearch_ml_general"; public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute"; @@ -224,6 +252,10 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private ConversationalMemoryHandler cmHandler; + + private volatile boolean ragSearchPipelineEnabled; + @Override public List> getActions() { return ImmutableList @@ -256,7 +288,13 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { new ActionHandler<>(MLCreateConnectorAction.INSTANCE, TransportCreateConnectorAction.class), new ActionHandler<>(MLConnectorGetAction.INSTANCE, GetConnectorTransportAction.class), new ActionHandler<>(MLConnectorDeleteAction.INSTANCE, DeleteConnectorTransportAction.class), - new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class) + new ActionHandler<>(MLConnectorSearchAction.INSTANCE, SearchConnectorTransportAction.class), + + new ActionHandler<>(CreateConversationAction.INSTANCE, CreateConversationTransportAction.class), + new ActionHandler<>(GetConversationsAction.INSTANCE, GetConversationsTransportAction.class), + new ActionHandler<>(CreateInteractionAction.INSTANCE, CreateInteractionTransportAction.class), + new ActionHandler<>(GetInteractionsAction.INSTANCE, GetInteractionsTransportAction.class), + new ActionHandler<>(DeleteConversationAction.INSTANCE, DeleteConversationTransportAction.class) ); } @@ -289,6 +327,7 @@ public Collection createComponents( mlEngine = new MLEngine(dataPath, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings); + cmHandler = new OpenSearchConversationalMemoryHandler(client, clusterService); JvmService jvmService = new JvmService(environment.settings()); OsService osService = new OsService(environment.settings()); @@ -425,6 +464,12 @@ public Collection createComponents( encryptor ); + // TODO move this into MLFeatureEnabledSetting + // search processor factories below will get BooleanSupplier that supplies the current value being updated through this. + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); + return ImmutableList .of( encryptor, @@ -450,7 +495,8 @@ public Collection createComponents( mlCommonsClusterEventListener, clusterManagerEventListener, mlCircuitBreakerService, - mlModelAutoRedeployer + mlModelAutoRedeployer, + cmHandler ); } @@ -493,6 +539,12 @@ public List getRestHandlers( RestMLGetConnectorAction restMLGetConnectorAction = new RestMLGetConnectorAction(); RestMLDeleteConnectorAction restMLDeleteConnectorAction = new RestMLDeleteConnectorAction(); RestMLSearchConnectorAction restMLSearchConnectorAction = new RestMLSearchConnectorAction(); + + RestMemoryCreateConversationAction restCreateConversationAction = new RestMemoryCreateConversationAction(); + RestMemoryGetConversationsAction restListConversationsAction = new RestMemoryGetConversationsAction(); + RestMemoryCreateInteractionAction restCreateInteractionAction = new RestMemoryCreateInteractionAction(); + RestMemoryGetInteractionsAction restListInteractionsAction = new RestMemoryGetInteractionsAction(); + RestMemoryDeleteConversationAction restDeleteConversationAction = new RestMemoryDeleteConversationAction(); return ImmutableList .of( restMLStatsAction, @@ -519,7 +571,12 @@ public List getRestHandlers( restMLCreateConnectorAction, restMLGetConnectorAction, restMLDeleteConnectorAction, - restMLSearchConnectorAction + restMLSearchConnectorAction, + restCreateConversationAction, + restListConversationsAction, + restCreateInteractionAction, + restListInteractionsAction, + restDeleteConversationAction ); } @@ -625,8 +682,58 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, MLCommonsSettings.ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES, MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES, - MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED + MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED, + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, + MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED ); return settings; } + + /** + * + * Search processors for Retrieval Augmented Generation + * + */ + + @Override + public List> getSearchExts() { + List> searchExts = new ArrayList<>(); + + searchExts + .add( + new SearchPlugin.SearchExtSpec<>( + GenerativeQAParamExtBuilder.PARAMETER_NAME, + input -> new GenerativeQAParamExtBuilder(input), + parser -> GenerativeQAParamExtBuilder.parse(parser) + ) + ); + + return searchExts; + } + + @Override + public Map> getRequestProcessors(Parameters parameters) { + Map> requestProcessors = new HashMap<>(); + + requestProcessors + .put( + GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, + new GenerativeQARequestProcessor.Factory(() -> this.ragSearchPipelineEnabled) + ); + + return requestProcessors; + } + + @Override + public Map> getResponseProcessors(Parameters parameters) { + Map> responseProcessors = new HashMap<>(); + + responseProcessors + .put( + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + new GenerativeQAResponseProcessor.Factory(this.client, () -> this.ragSearchPipelineEnabled) + ); + + return responseProcessors; + } } diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryCreateConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryCreateConversationAction.java new file mode 100644 index 0000000000..3d91095e96 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryCreateConversationAction.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +/** + * Rest Action for creating a conversation + */ +public class RestMemoryCreateConversationAction extends BaseRestHandler { + private final static String CREATE_CONVERSATION_NAME = "conversational_create_conversation"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.POST, ActionConstants.CREATE_CONVERSATION_REST_PATH)); + } + + @Override + public String getName() { + return CREATE_CONVERSATION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + CreateConversationRequest ccRequest = CreateConversationRequest.fromRestRequest(request); + return channel -> client.execute(CreateConversationAction.INSTANCE, ccRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryCreateInteractionAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryCreateInteractionAction.java new file mode 100644 index 0000000000..f06a18b547 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryCreateInteractionAction.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +/** + * Rest action for adding a new interaction to a conversation + */ +public class RestMemoryCreateInteractionAction extends BaseRestHandler { + private final static String CREATE_INTERACTION_NAME = "conversational_create_interaction"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.POST, ActionConstants.CREATE_INTERACTION_REST_PATH)); + } + + @Override + public String getName() { + return CREATE_INTERACTION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + CreateInteractionRequest piRequest = CreateInteractionRequest.fromRestRequest(request); + return channel -> client.execute(CreateInteractionAction.INSTANCE, piRequest, new RestToXContentListener<>(channel)); + } + +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryDeleteConversationAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryDeleteConversationAction.java new file mode 100644 index 0000000000..23dd617cf2 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryDeleteConversationAction.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; +import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +/** + * Rest Action for deleting a conversation + */ +public class RestMemoryDeleteConversationAction extends BaseRestHandler { + private final static String DELETE_CONVERSATION_NAME = "conversational_delete_conversation"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.DELETE, ActionConstants.DELETE_CONVERSATION_REST_PATH)); + } + + @Override + public String getName() { + return DELETE_CONVERSATION_NAME; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + DeleteConversationRequest dcRequest = DeleteConversationRequest.fromRestRequest(request); + return channel -> client.execute(DeleteConversationAction.INSTANCE, dcRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationsAction.java new file mode 100644 index 0000000000..8d8ab99aec --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetConversationsAction.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationsAction; +import org.opensearch.ml.memory.action.conversation.GetConversationsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +/** + * Rest Handler for list conversations + */ +public class RestMemoryGetConversationsAction extends BaseRestHandler { + private final static String GET_CONVERSATIONS_NAME = "conversational_get_conversations"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATIONS_REST_PATH)); + } + + @Override + public String getName() { + return GET_CONVERSATIONS_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetConversationsRequest lcRequest = GetConversationsRequest.fromRestRequest(request); + return channel -> client.execute(GetConversationsAction.INSTANCE, lcRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionsAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionsAction.java new file mode 100644 index 0000000000..99b7e02fc3 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMemoryGetInteractionsAction.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.List; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +/** + * Rest Handler for get Interactions + */ +public class RestMemoryGetInteractionsAction extends BaseRestHandler { + private final static String GET_INTERACTIONS_NAME = "conversational_get_interactions"; + + @Override + public List routes() { + return List.of(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTIONS_REST_PATH)); + } + + @Override + public String getName() { + return GET_INTERACTIONS_NAME; + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + GetInteractionsRequest giRequest = GetInteractionsRequest.fromRestRequest(request); + return channel -> client.execute(GetInteractionsAction.INSTANCE, giRequest, new RestToXContentListener<>(channel)); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index a89c00350c..4d62bb504f 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -9,6 +9,8 @@ import java.util.function.Function; import org.opensearch.common.settings.Setting; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; import com.google.common.collect.ImmutableList; @@ -165,4 +167,10 @@ private MLCommonsSettings() {} Setting.Property.NodeScope, Setting.Property.Dynamic ); + + public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED; + + // Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference. + public static final Setting ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED = + GenerativeQAProcessorConstants.RAG_PIPELINE_FEATURE_ENABLED; } diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java new file mode 100644 index 0000000000..d1d51de4fb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.plugin; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.plugins.SearchPipelinePlugin; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; + +public class MachineLearningPluginTests { + + MachineLearningPlugin plugin = new MachineLearningPlugin(); + + @Test + public void testGetSearchExts() { + List> searchExts = plugin.getSearchExts(); + assertEquals(1, searchExts.size()); + SearchPlugin.SearchExtSpec spec = searchExts.get(0); + assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName()); + } + + @Test + public void testGetRequestProcessors() { + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map requestProcessors = plugin.getRequestProcessors(parameters); + assertEquals(1, requestProcessors.size()); + assertTrue( + requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory + ); + } + + @Test + public void testGetResponseProcessors() { + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map responseProcessors = plugin.getResponseProcessors(parameters); + assertEquals(1, responseProcessors.size()); + assertTrue( + responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory + ); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java new file mode 100644 index 0000000000..26f1486f3f --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionIT.java @@ -0,0 +1,78 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryCreateConversationActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testCreateConversation() throws IOException { + Response response = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("conversation_id")); + } + + public void testCreateConversationNamed() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_CONVERSATION_REST_PATH, + null, + gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "name")), + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("conversation_id")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionTests.java new file mode 100644 index 0000000000..958776c466 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateConversationActionTests.java @@ -0,0 +1,80 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class RestMemoryCreateConversationActionTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testBasics() { + RestMemoryCreateConversationAction action = new RestMemoryCreateConversationAction(); + assert (action.getName().equals("conversational_create_conversation")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.POST, ActionConstants.CREATE_CONVERSATION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryCreateConversationAction action = new RestMemoryCreateConversationAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withContent( + new BytesArray(gson.toJson(Map.of(ActionConstants.REQUEST_CONVERSATION_NAME_FIELD, "test-name"))), + MediaTypeRegistry.JSON + ) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(CreateConversationRequest.class); + verify(client, times(1)).execute(eq(CreateConversationAction.INSTANCE), argumentCaptor.capture(), any()); + assert (argumentCaptor.getValue().getName().equals("test-name")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java new file mode 100644 index 0000000000..63f7093b6e --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionIT.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryCreateInteractionActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testCreateInteraction() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String id = (String) ccmap.get("conversation_id"); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response response = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", id), + null, + gson.toJson(params), + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("interaction_id")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java new file mode 100644 index 0000000000..ced83f730a --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryCreateInteractionActionTests.java @@ -0,0 +1,97 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +import com.google.gson.Gson; + +public class RestMemoryCreateInteractionActionTests extends OpenSearchTestCase { + + Gson gson; + + @Before + public void setup() { + gson = new Gson(); + } + + public void testBasics() { + RestMemoryCreateInteractionAction action = new RestMemoryCreateInteractionAction(); + assert (action.getName().equals("conversational_create_interaction")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.POST, ActionConstants.CREATE_INTERACTION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "pt", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.ADDITIONAL_INFO_FIELD, + "metadata" + ); + RestMemoryCreateInteractionAction action = new RestMemoryCreateInteractionAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "cid")) + .withContent(new BytesArray(gson.toJson(params)), MediaTypeRegistry.JSON) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(CreateInteractionRequest.class); + verify(client, times(1)).execute(eq(CreateInteractionAction.INSTANCE), argCaptor.capture(), any()); + CreateInteractionRequest req = argCaptor.getValue(); + assert (req.getConversationId().equals("cid")); + assert (req.getInput().equals("input")); + assert (req.getPromptTemplate().equals("pt")); + assert (req.getResponse().equals("response")); + assert (req.getOrigin().equals("origin")); + assert (req.getAdditionalInfo().equals("metadata")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java new file mode 100644 index 0000000000..3bdb7ed213 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionIT.java @@ -0,0 +1,177 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryDeleteConversationActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testDeleteConversation_ThatExists() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String id = (String) ccmap.get("conversation_id"); + + Response response = TestHelper + .makeRequest( + client(), + "DELETE", + ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{conversation_id}", id), + null, + "", + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("success")); + assert ((Boolean) map.get("success")); + } + + public void testDeleteConversation_ThatDoesNotExist() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "DELETE", + ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{conversation_id}", "happybirthday"), + null, + "", + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("success")); + assert ((Boolean) map.get("success")); + } + + public void testDeleteConversation_WithInteractions() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response ciresponse = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (ciresponse != null); + assert (TestHelper.restStatus(ciresponse) == RestStatus.OK); + HttpEntity cihttpEntity = ciresponse.getEntity(); + String cientityString = TestHelper.httpEntityToString(cihttpEntity); + Map cimap = gson.fromJson(cientityString, Map.class); + assert (cimap.containsKey("interaction_id")); + String iid = (String) cimap.get("interaction_id"); + + Response dcresponse = TestHelper + .makeRequest( + client(), + "DELETE", + ActionConstants.DELETE_CONVERSATION_REST_PATH.replace("{conversation_id}", cid), + null, + "", + null + ); + assert (dcresponse != null); + assert (TestHelper.restStatus(dcresponse) == RestStatus.OK); + HttpEntity dchttpEntity = dcresponse.getEntity(); + String dcentityString = TestHelper.httpEntityToString(dchttpEntity); + Map dcmap = gson.fromJson(dcentityString, Map.class); + assert (dcmap.containsKey("success")); + assert ((Boolean) dcmap.get("success")); + + Response gcresponse = TestHelper.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATIONS_REST_PATH, null, "", null); + assert (gcresponse != null); + assert (TestHelper.restStatus(gcresponse) == RestStatus.OK); + HttpEntity gchttpEntity = gcresponse.getEntity(); + String gcentityString = TestHelper.httpEntityToString(gchttpEntity); + Map gcmap = gson.fromJson(gcentityString, Map.class); + assert (gcmap.containsKey("conversations")); + assert (!gcmap.containsKey("next_token")); + assert (((ArrayList) gcmap.get("conversations")).size() == 0); + + Response giresponse = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + assert (giresponse != null); + assert (TestHelper.restStatus(giresponse) == RestStatus.OK); + HttpEntity gihttpEntity = giresponse.getEntity(); + String gientityString = TestHelper.httpEntityToString(gihttpEntity); + Map gimap = gson.fromJson(gientityString, Map.class); + assert (gimap.containsKey("interactions")); + assert (!gimap.containsKey("next_token")); + assert (((ArrayList) gimap.get("interactions")).size() == 0); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionTests.java new file mode 100644 index 0000000000..c8e733b3a2 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryDeleteConversationActionTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.DeleteConversationAction; +import org.opensearch.ml.memory.action.conversation.DeleteConversationRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryDeleteConversationActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryDeleteConversationAction action = new RestMemoryDeleteConversationAction(); + assert (action.getName().equals("conversational_delete_conversation")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.DELETE, ActionConstants.DELETE_CONVERSATION_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryDeleteConversationAction action = new RestMemoryDeleteConversationAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) + .withParams(Map.of(ActionConstants.CONVERSATION_ID_FIELD, "deleteme")) + .build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(DeleteConversationRequest.class); + verify(client, times(1)).execute(eq(DeleteConversationAction.INSTANCE), argCaptor.capture(), any()); + assert (argCaptor.getValue().getId().equals("deleteme")); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java new file mode 100644 index 0000000000..4c95296b08 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionIT.java @@ -0,0 +1,188 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetConversationsActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testNoConversations_EmptyList() throws IOException { + Response response = TestHelper.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATIONS_REST_PATH, null, "", null); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("conversations")); + assert (!map.containsKey("next_token")); + assert (((ArrayList) map.get("conversations")).size() == 0); + } + + public void testGetConversations_LastPage() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String id = (String) ccmap.get("conversation_id"); + + Response response = TestHelper.makeRequest(client(), "GET", ActionConstants.GET_CONVERSATIONS_REST_PATH, null, "", null); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("conversations")); + assert (!map.containsKey("next_token")); + @SuppressWarnings("unchecked") + ArrayList conversations = (ArrayList) map.get("conversations"); + assert (conversations.size() == 1); + assert (conversations.get(0).containsKey("conversation_id")); + assert (((String) conversations.get(0).get("conversation_id")).equals(id)); + } + + public void testConversations_MorePages() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String id = (String) ccmap.get("conversation_id"); + + Response response = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_CONVERSATIONS_REST_PATH, + Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1"), + "", + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("conversations")); + assert (map.containsKey("next_token")); + @SuppressWarnings("unchecked") + ArrayList conversations = (ArrayList) map.get("conversations"); + assert (conversations.size() == 1); + assert (conversations.get(0).containsKey("conversation_id")); + assert (((String) conversations.get(0).get("conversation_id")).equals(id)); + assert (((Double) map.get("next_token")).intValue() == 1); + } + + public void testGetConversations_nextPage() throws IOException { + Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse1 != null); + assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK); + HttpEntity cchttpEntity1 = ccresponse1.getEntity(); + String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1); + Map ccmap1 = gson.fromJson(ccentityString1, Map.class); + assert (ccmap1.containsKey("conversation_id")); + String id1 = (String) ccmap1.get("conversation_id"); + + Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse2 != null); + assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK); + HttpEntity cchttpEntity2 = ccresponse2.getEntity(); + String ccentityString2 = TestHelper.httpEntityToString(cchttpEntity2); + Map ccmap2 = gson.fromJson(ccentityString2, Map.class); + assert (ccmap2.containsKey("conversation_id")); + String id2 = (String) ccmap2.get("conversation_id"); + + Response response1 = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_CONVERSATIONS_REST_PATH, + Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1"), + "", + null + ); + assert (response1 != null); + assert (TestHelper.restStatus(response1) == RestStatus.OK); + HttpEntity httpEntity1 = response1.getEntity(); + String entityString1 = TestHelper.httpEntityToString(httpEntity1); + Map map1 = gson.fromJson(entityString1, Map.class); + assert (map1.containsKey("conversations")); + assert (map1.containsKey("next_token")); + @SuppressWarnings("unchecked") + ArrayList conversations1 = (ArrayList) map1.get("conversations"); + assert (conversations1.size() == 1); + assert (conversations1.get(0).containsKey("conversation_id")); + assert (((String) conversations1.get(0).get("conversation_id")).equals(id2)); + assert (((Double) map1.get("next_token")).intValue() == 1); + + Response response = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_CONVERSATIONS_REST_PATH, + Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1", ActionConstants.NEXT_TOKEN_FIELD, "1"), + "", + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("conversations")); + assert (map.containsKey("next_token")); + @SuppressWarnings("unchecked") + ArrayList conversations = (ArrayList) map.get("conversations"); + assert (conversations.size() == 1); + assert (conversations.get(0).containsKey("conversation_id")); + assert (((String) conversations.get(0).get("conversation_id")).equals(id1)); + assert (((Double) map.get("next_token")).intValue() == 2); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionTests.java new file mode 100644 index 0000000000..a3f2043d62 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetConversationsActionTests.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetConversationsAction; +import org.opensearch.ml.memory.action.conversation.GetConversationsRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetConversationsActionTests extends OpenSearchTestCase { + public void testBasics() { + RestMemoryGetConversationsAction action = new RestMemoryGetConversationsAction(); + assert (action.getName().equals("conversational_get_conversations")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_CONVERSATIONS_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetConversationsAction action = new RestMemoryGetConversationsAction(); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetConversationsRequest.class); + verify(client, times(1)).execute(eq(GetConversationsAction.INSTANCE), argCaptor.capture(), any()); + GetConversationsRequest gcreq = argCaptor.getValue(); + assert (gcreq.getFrom() == 0); + assert (gcreq.getMaxResults() == ActionConstants.DEFAULT_MAX_RESULTS); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java new file mode 100644 index 0000000000..3272d4a991 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionIT.java @@ -0,0 +1,313 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; + +import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.settings.MLCommonsSettings; +import org.opensearch.ml.utils.TestHelper; + +import com.google.common.collect.ImmutableList; + +public class RestMemoryGetInteractionsActionIT extends MLCommonsRestTestCase { + + @Before + public void setupFeatureSettings() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"" + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey() + "\":true}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, response.getStatusLine().getStatusCode()); + } + + public void testGetInteractions_NoConversation() throws IOException { + Response response = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", "coffee"), + null, + "", + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("interactions")); + assert (!map.containsKey("next_token")); + assert (((ArrayList) map.get("interactions")).size() == 0); + } + + public void testGetInteractions_NoInteractions() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Response response = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("interactions")); + assert (!map.containsKey("next_token")); + assert (((ArrayList) map.get("interactions")).size() == 0); + } + + public void testGetInteractions_LastPage() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response response = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("interaction_id")); + String iid = (String) map.get("interaction_id"); + + Response response1 = TestHelper + .makeRequest(client(), "GET", ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), null, "", null); + assert (response1 != null); + assert (TestHelper.restStatus(response1) == RestStatus.OK); + HttpEntity httpEntity1 = response1.getEntity(); + String entityString1 = TestHelper.httpEntityToString(httpEntity1); + Map map1 = gson.fromJson(entityString1, Map.class); + assert (map1.containsKey("interactions")); + assert (!map1.containsKey("next_token")); + assert (((ArrayList) map1.get("interactions")).size() == 1); + @SuppressWarnings("unchecked") + ArrayList interactions = (ArrayList) map1.get("interactions"); + assert (((String) interactions.get(0).get("interaction_id")).equals(iid)); + } + + public void testGetInteractions_MorePages() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response response = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("interaction_id")); + String iid = (String) map.get("interaction_id"); + + Response response1 = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1"), + "", + null + ); + assert (response1 != null); + assert (TestHelper.restStatus(response1) == RestStatus.OK); + HttpEntity httpEntity1 = response1.getEntity(); + String entityString1 = TestHelper.httpEntityToString(httpEntity1); + Map map1 = gson.fromJson(entityString1, Map.class); + assert (map1.containsKey("interactions")); + assert (map1.containsKey("next_token")); + assert (((ArrayList) map1.get("interactions")).size() == 1); + @SuppressWarnings("unchecked") + ArrayList interactions = (ArrayList) map1.get("interactions"); + assert (((String) interactions.get(0).get("interaction_id")).equals(iid)); + assert (((Double) map1.get("next_token")).intValue() == 1); + } + + public void testGetInteractions_NextPage() throws IOException { + Response ccresponse = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null); + assert (ccresponse != null); + assert (TestHelper.restStatus(ccresponse) == RestStatus.OK); + HttpEntity cchttpEntity = ccresponse.getEntity(); + String ccentityString = TestHelper.httpEntityToString(cchttpEntity); + Map ccmap = gson.fromJson(ccentityString, Map.class); + assert (ccmap.containsKey("conversation_id")); + String cid = (String) ccmap.get("conversation_id"); + + Map params = Map + .of( + ActionConstants.INPUT_FIELD, + "input", + ActionConstants.AI_RESPONSE_FIELD, + "response", + ActionConstants.RESPONSE_ORIGIN_FIELD, + "origin", + ActionConstants.PROMPT_TEMPLATE_FIELD, + "promtp template", + ActionConstants.ADDITIONAL_INFO_FIELD, + "some metadata" + ); + Response response = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (response != null); + assert (TestHelper.restStatus(response) == RestStatus.OK); + HttpEntity httpEntity = response.getEntity(); + String entityString = TestHelper.httpEntityToString(httpEntity); + Map map = gson.fromJson(entityString, Map.class); + assert (map.containsKey("interaction_id")); + String iid = (String) map.get("interaction_id"); + + Response response2 = TestHelper + .makeRequest( + client(), + "POST", + ActionConstants.CREATE_INTERACTION_REST_PATH.replace("{conversation_id}", cid), + null, + gson.toJson(params), + null + ); + assert (response2 != null); + assert (TestHelper.restStatus(response2) == RestStatus.OK); + HttpEntity httpEntity2 = response2.getEntity(); + String entityString2 = TestHelper.httpEntityToString(httpEntity2); + Map map2 = gson.fromJson(entityString2, Map.class); + assert (map2.containsKey("interaction_id")); + String iid2 = (String) map2.get("interaction_id"); + + Response response1 = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1"), + "", + null + ); + assert (response1 != null); + assert (TestHelper.restStatus(response1) == RestStatus.OK); + HttpEntity httpEntity1 = response1.getEntity(); + String entityString1 = TestHelper.httpEntityToString(httpEntity1); + Map map1 = gson.fromJson(entityString1, Map.class); + assert (map1.containsKey("interactions")); + assert (map1.containsKey("next_token")); + assert (((ArrayList) map1.get("interactions")).size() == 1); + @SuppressWarnings("unchecked") + ArrayList interactions = (ArrayList) map1.get("interactions"); + assert (((String) interactions.get(0).get("interaction_id")).equals(iid2)); + assert (((Double) map1.get("next_token")).intValue() == 1); + + Response response3 = TestHelper + .makeRequest( + client(), + "GET", + ActionConstants.GET_INTERACTIONS_REST_PATH.replace("{conversation_id}", cid), + Map.of(ActionConstants.REQUEST_MAX_RESULTS_FIELD, "1", ActionConstants.NEXT_TOKEN_FIELD, "1"), + "", + null + ); + assert (response3 != null); + assert (TestHelper.restStatus(response3) == RestStatus.OK); + HttpEntity httpEntity3 = response3.getEntity(); + String entityString3 = TestHelper.httpEntityToString(httpEntity3); + Map map3 = gson.fromJson(entityString3, Map.class); + assert (map3.containsKey("interactions")); + assert (map3.containsKey("next_token")); + assert (((ArrayList) map3.get("interactions")).size() == 1); + @SuppressWarnings("unchecked") + ArrayList interactions3 = (ArrayList) map3.get("interactions"); + assert (((String) interactions3.get(0).get("interaction_id")).equals(iid)); + assert (((Double) map3.get("next_token")).intValue() == 2); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionTests.java new file mode 100644 index 0000000000..0882b8eccb --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMemoryGetInteractionsActionTests.java @@ -0,0 +1,75 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.rest; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.List; +import java.util.Map; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.ml.common.conversation.ActionConstants; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestHandler.Route; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; + +public class RestMemoryGetInteractionsActionTests extends OpenSearchTestCase { + + public void testBasics() { + RestMemoryGetInteractionsAction action = new RestMemoryGetInteractionsAction(); + assert (action.getName().equals("conversational_get_interactions")); + List routes = action.routes(); + assert (routes.size() == 1); + assert (routes.get(0).equals(new Route(RestRequest.Method.GET, ActionConstants.GET_INTERACTIONS_REST_PATH))); + } + + public void testPrepareRequest() throws Exception { + RestMemoryGetInteractionsAction action = new RestMemoryGetInteractionsAction(); + Map params = Map + .of( + ActionConstants.CONVERSATION_ID_FIELD, + "cid", + ActionConstants.REQUEST_MAX_RESULTS_FIELD, + "2", + ActionConstants.NEXT_TOKEN_FIELD, + "7" + ); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withParams(params).build(); + + NodeClient client = mock(NodeClient.class); + RestChannel channel = mock(RestChannel.class); + action.handleRequest(request, channel, client); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), argCaptor.capture(), any()); + GetInteractionsRequest req = argCaptor.getValue(); + assert (req.getConversationId().equals("cid")); + assert (req.getFrom() == 7); + assert (req.getMaxResults() == 2); + } +} diff --git a/search-processors/README.md b/search-processors/README.md new file mode 100644 index 0000000000..2b3dc6ed52 --- /dev/null +++ b/search-processors/README.md @@ -0,0 +1,95 @@ +# conversational-search-processors +OpenSearch search processors providing conversational search capabilities +======= +# Plugin for Conversations Using Search Processors in OpenSearch +This repo is a WIP plugin for handling conversations in OpenSearch ([Per this RFC](https://github.com/opensearch-project/ml-commons/issues/1150)). + +Conversational Retrieval Augmented Generation (RAG) is implemented via Search processors that combine user questions and OpenSearch query results as input to an LLM, e.g. OpenAI, and return answers. + +## Creating a search pipeline with the GenerativeQAResponseProcessor + +``` +PUT /_search/pipeline/ +{ + "response_processors": [ + { + "retrieval_augmented_generation": { + "tag": , + "description": , + "model_id": "", + "context_field_list": [] (e.g. ["text"]) + } + } + ] +} +``` + +The 'model_id' parameter here needs to refer to a model of type REMOTE that has an HttpConnector instance associated with it. + +## Making a search request against an index using the above processor +``` +GET //_search\?search_pipeline\= +{ + "_source": ["title", "text"], + "query" : { + "neural": { + "text_vector": { + "query_text": , + "k": (e.g. 10), + "model_id": + } + } + }, + "ext": { + "generative_qa_parameters": { + "llm_model": (e.g. "gpt-3.5-turbo"), + "llm_question": + } + } +} +``` + +## Retrieval Augmented Generation response +``` +{ + "took": 3, + "timed_out": false, + "_shards": { + "total": 3, + "successful": 3, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 110, + "relation": "eq" + }, + "max_score": 0.55129033, + "hits": [ + { + "_index": "...", + "_id": "...", + "_score": 0.55129033, + "_source": { + "text": "...", + "title": "..." + } + }, + { + ... + } + ... + { + ... + } + ] + }, // end of hits + "ext": { + "retrieval_augmented_generation": { + "answer": "..." + } + } +} +``` +The RAG answer is returned as an "ext" to SearchResponse following the "hits" array. diff --git a/search-processors/build.gradle b/search-processors/build.gradle new file mode 100644 index 0000000000..3d5903a0d3 --- /dev/null +++ b/search-processors/build.gradle @@ -0,0 +1,75 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'java' + id 'jacoco' + id "io.freefair.lombok" +} + +repositories { + mavenCentral() + mavenLocal() +} + +dependencies { + + compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' + implementation 'org.apache.commons:commons-lang3:3.12.0' + //implementation project(':opensearch-ml-client') + implementation project(':opensearch-ml-common') + implementation project(':opensearch-ml-memory') + implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}" + // https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5 + implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1' + implementation("com.google.guava:guava:32.0.1-jre") + implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' + testImplementation "org.opensearch.test:framework:${opensearch_version}" +} + +test { + include '**/*Tests.class' + systemProperty 'tests.security.manager', 'false' +} + +jacocoTestReport { + dependsOn /*integTest,*/ test + reports { + xml.required = true + html.required = true + } +} + +jacocoTestCoverageVerification { + violationRules { + rule { + limit { + counter = 'LINE' + minimum = 0.8 + } + limit { + counter = 'BRANCH' + minimum = 0.8 + } + } + } + dependsOn jacocoTestReport +} + +check.dependsOn jacocoTestCoverageVerification diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java new file mode 100644 index 0000000000..5b6e8159b8 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java @@ -0,0 +1,43 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.common.settings.Setting; + +public class GenerativeQAProcessorConstants { + + // Identifier for the generative QA request processor + public static final String REQUEST_PROCESSOR_TYPE = "question_rewrite"; + + // Identifier for the generative QA response processor + public static final String RESPONSE_PROCESSOR_TYPE = "retrieval_augmented_generation"; + + // The model_id of the model registered and deployed in OpenSearch. + public static final String CONFIG_NAME_MODEL_ID = "model_id"; + + // The name of the model supported by an LLM, e.g. "gpt-3.5" in OpenAI. + public static final String CONFIG_NAME_LLM_MODEL = "llm_model"; + + // The field in search results that contain the context to be sent to the LLM. + public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list"; + + public static final Setting RAG_PIPELINE_FEATURE_ENABLED = Setting + .boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled."; +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java new file mode 100644 index 0000000000..b0a741575c --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java @@ -0,0 +1,92 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; + +import java.util.Map; +import java.util.function.BooleanSupplier; + +/** + * Defines the request processor for generative QA search pipelines. + */ +public class GenerativeQARequestProcessor extends AbstractProcessor implements SearchRequestProcessor { + + private String modelId; + private final BooleanSupplier featureFlagSupplier; + + protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId, BooleanSupplier supplier) { + super(tag, description, ignoreFailure); + this.modelId = modelId; + this.featureFlagSupplier = supplier; + } + + @Override + public SearchRequest processRequest(SearchRequest request) throws Exception { + + // TODO Use chat history to rephrase the question with full conversation context. + + if (!featureFlagSupplier.getAsBoolean()) { + throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); + } + + return request; + } + + @Override + public String getType() { + return GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE; + } + + public static final class Factory implements Processor.Factory { + + private final BooleanSupplier featureFlagSupplier; + + public Factory(BooleanSupplier supplier) { + this.featureFlagSupplier = supplier; + } + + @Override + public SearchRequestProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws Exception { + if (featureFlagSupplier.getAsBoolean()) { + return new GenerativeQARequestProcessor(tag, description, ignoreFailure, + ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID + ), + this.featureFlagSupplier + ); + } else { + throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); + } + } + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java new file mode 100644 index 0000000000..720d880628 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -0,0 +1,207 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.search.SearchHit; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; +import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.BooleanSupplier; + +import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; + +/** + * Defines the response processor for generative QA search pipelines. + * + */ +@Log4j2 +public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor { + + private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10; + + // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. + + private final String llmModel; + private final List contextFields; + + @Setter + private ConversationalMemoryClient memoryClient; + + @Getter + @Setter + // Mainly for unit testing purpose + private Llm llm; + + private final BooleanSupplier featureFlagSupplier; + + protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, + Llm llm, String llmModel, List contextFields, BooleanSupplier supplier) { + super(tag, description, ignoreFailure); + this.llmModel = llmModel; + this.contextFields = contextFields; + this.llm = llm; + this.memoryClient = new ConversationalMemoryClient(client); + this.featureFlagSupplier = supplier; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + + log.info("Entering processResponse."); + + if (!this.featureFlagSupplier.getAsBoolean()) { + throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); + } + + GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); + String llmQuestion = params.getLlmQuestion(); + String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); + String conversationId = params.getConversationId(); + log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId); + List chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW); + List searchResults = getSearchResults(response); + ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults)); + String answer = (String) output.getAnswers().get(0); + + String interactionId = null; + if (conversationId != null) { + interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer, + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults)); + } + + return insertAnswer(response, answer, interactionId); + } + + @Override + public String getType() { + return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; + } + + private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) { + + // TODO return the interaction id in the response. + + return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), + response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters()); + } + + private List getSearchResults(SearchResponse response) { + List searchResults = new ArrayList<>(); + for (SearchHit hit : response.getHits().getHits()) { + Map docSourceMap = hit.getSourceAsMap(); + for (String contextField : contextFields) { + Object context = docSourceMap.get(contextField); + if (context == null) { + log.error("Context " + contextField + " not found in search hit " + hit); + // TODO throw a more meaningful error here? + throw new RuntimeException(); + } + searchResults.add(context.toString()); + } + } + return searchResults; + } + + private static String jsonArrayToString(List listOfStrings) { + JsonArray array = new JsonArray(listOfStrings.size()); + listOfStrings.forEach(array::add); + return array.toString(); + } + + public static final class Factory implements Processor.Factory { + + private final Client client; + private final BooleanSupplier featureFlagSupplier; + + public Factory(Client client, BooleanSupplier supplier) { + this.client = client; + this.featureFlagSupplier = supplier; + } + + @Override + public SearchResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws Exception { + if (this.featureFlagSupplier.getAsBoolean()) { + String modelId = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID + ); + String llmModel = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL + ); + List contextFields = ConfigurationUtils.readList(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + config, + GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST + ); + if (contextFields.isEmpty()) { + throw newConfigurationException(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, + tag, + GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, + "required property can't be empty." + ); + } + log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields); + return new GenerativeQAResponseProcessor(client, + tag, + description, + ignoreFailure, + ModelLocator.getLlm(modelId, client), + llmModel, + contextFields, + featureFlagSupplier + ); + } else { + throw new MLException(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); + } + } + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java new file mode 100644 index 0000000000..2a22902c9a --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; + +/** + * This is an extension of SearchResponse that adds LLM-generated answers to search responses in a dedicated "ext" section. + * + * TODO Add ExtBuilders to SearchResponse and get rid of this class. + */ +public class GenerativeSearchResponse extends SearchResponse { + + private static final String EXT_SECTION_NAME = "ext"; + private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer"; + + private final String answer; + + public GenerativeSearchResponse( + String answer, + SearchResponseSections internalResponse, + String scrollId, + int totalShards, + int successfulShards, + int skippedShards, + long tookInMillis, + ShardSearchFailure[] shardFailures, + Clusters clusters + ) { + super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); + this.answer = answer; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + innerToXContent(builder, params); + /* start of ext */ builder.startObject(EXT_SECTION_NAME); + /* start of our stuff */ builder.startObject(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE); + /* body of our stuff */ builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer); + /* end of our stuff */ builder.endObject(); + /* end of ext */ builder.endObject(); + builder.endObject(); + return builder; + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java new file mode 100644 index 0000000000..84a32b2368 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import com.google.common.base.Preconditions; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; + +import java.util.ArrayList; +import java.util.List; + +/** + * An OpenSearch client wrapper for conversational memory related calls. + */ +@Log4j2 +@AllArgsConstructor +public class ConversationalMemoryClient { + + private final static Logger logger = LogManager.getLogger(); + + private Client client; + + public String createConversation(String name) { + + CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(); + log.info("createConversation: id: {}", response.getId()); + return response.getId(); + } + + public String createInteraction(String conversationId, String input, String promptTemplate, String response, String origin, String additionalInfo) { + Preconditions.checkNotNull(conversationId); + Preconditions.checkNotNull(input); + Preconditions.checkNotNull(response); + CreateInteractionResponse res = client.execute(CreateInteractionAction.INSTANCE, + new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet(); + log.info("createInteraction: interactionId: {}", res.getId()); + return res.getId(); + } + + public List getInteractions(String conversationId, int lastN) { + + Preconditions.checkArgument(lastN > 0, "lastN must be at least 1."); + + log.info("In getInteractions, conversationId {}, lastN {}", conversationId, lastN); + + List interactions = new ArrayList<>(); + int from = 0; + boolean allInteractionsFetched = false; + int maxResults = lastN; + do { + GetInteractionsResponse response = + client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(); + List list = response.getInteractions(); + if (list != null && !CollectionUtils.isEmpty(list)) { + interactions.addAll(list); + from += list.size(); + maxResults -= list.size(); + log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults); + } else if (response.hasMorePages()) { + // If we didn't get any results back, we ignore this flag and break out of the loop + // to avoid an infinite loop. + // But in the future, we may support this mode, e.g. DynamoDB. + break; + } + log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults); + allInteractionsFetched = !response.hasMorePages(); + } while (from < lastN && !allInteractionsFetched); + + return interactions; + } + + +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java new file mode 100644 index 0000000000..265c20a76d --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import java.util.function.Function; + +/** + * An internal facing ML client adapted from org.opensearch.ml.client.MachineLearningNodeClient. + */ +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@RequiredArgsConstructor +public class MachineLearningInternalClient { + + Client client; + + public ActionFuture predict(String modelId, MLInput mlInput) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + predict(modelId, mlInput, actionFuture); + return actionFuture; + } + + @VisibleForTesting + void predict(String modelId, MLInput mlInput, ActionListener listener) { + validateMLInput(mlInput, true); + + MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder() + .mlInput(mlInput) + .modelId(modelId) + .dispatchTask(true) + .build(); + client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener)); + } + + private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { + ActionListener internalListener = ActionListener.wrap(predictionResponse -> { + listener.onResponse(predictionResponse.getOutput()); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res); + return predictionResponse; + }); + return actionListener; + } + + private ActionListener wrapActionListener(final ActionListener listener, final Function recreate) { + ActionListener actionListener = ActionListener.wrap(r-> { + listener.onResponse(recreate.apply(r));; + }, e->{ + listener.onFailure(e); + }); + return actionListener; + } + + private void validateMLInput(MLInput mlInput, boolean requireInput) { + if (mlInput == null) { + throw new IllegalArgumentException("ML Input can't be null"); + } + if(requireInput && mlInput.getInputDataset() == null) { + throw new IllegalArgumentException("input data set can't be null"); + } + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java new file mode 100644 index 0000000000..8a6ee8cc65 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * This is the extension builder for generative QA search pipelines. + */ +@NoArgsConstructor +public class GenerativeQAParamExtBuilder extends SearchExtBuilder { + + // The name of the "ext" section containing Generative QA parameters. + public static final String PARAMETER_NAME = "generative_qa_parameters"; + + @Setter + @Getter + private GenerativeQAParameters params; + + public GenerativeQAParamExtBuilder(StreamInput input) throws IOException { + this.params = new GenerativeQAParameters(input); + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.params); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof GenerativeQAParamExtBuilder)) { + return false; + } + + return this.params.equals(((GenerativeQAParamExtBuilder) obj).getParams()); + } + + @Override + public String getWriteableName() { + return PARAMETER_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.params.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(this.params); + } + + public static GenerativeQAParamExtBuilder parse(XContentParser parser) throws IOException { + GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); + GenerativeQAParameters params = GenerativeQAParameters.parse(parser); + builder.setParams(params); + return builder; + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java new file mode 100644 index 0000000000..52da6daa02 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; + +import java.util.Optional; + +/** + * Utility class for extracting generative QA search pipeline parameters from search requests. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class GenerativeQAParamUtil { + + public static GenerativeQAParameters getGenerativeQAParameters(SearchRequest request) { + GenerativeQAParamExtBuilder builder = null; + if (request.source() != null && request.source().ext() != null && !request.source().ext().isEmpty()) { + Optional b = request.source().ext().stream().filter(bldr -> GenerativeQAParamExtBuilder.PARAMETER_NAME.equals(bldr.getWriteableName())).findFirst(); + if (b.isPresent()) { + builder = (GenerativeQAParamExtBuilder) b.get(); + } + } + + GenerativeQAParameters params = null; + if (builder != null) { + params = builder.getParams(); + } + + return params; + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java new file mode 100644 index 0000000000..04d2b53674 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -0,0 +1,110 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import com.google.common.base.Preconditions; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +/** + * Defines parameters for generative QA search pipelines. + * + */ +@AllArgsConstructor +@NoArgsConstructor +public class GenerativeQAParameters implements Writeable, ToXContentObject { + + private static final ObjectParser PARSER; + + private static final ParseField CONVERSATION_ID = new ParseField("conversation_id"); + private static final ParseField LLM_MODEL = new ParseField("llm_model"); + private static final ParseField LLM_QUESTION = new ParseField("llm_question"); + + static { + PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new); + PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); + PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); + PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); + } + + @Setter + @Getter + private String conversationId; + + @Setter + @Getter + private String llmModel; + + @Setter + @Getter + private String llmQuestion; + + public GenerativeQAParameters(StreamInput input) throws IOException { + this.conversationId = input.readOptionalString(); + this.llmModel = input.readOptionalString(); + this.llmQuestion = input.readString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return xContentBuilder.field(CONVERSATION_ID.getPreferredName(), this.conversationId) + .field(LLM_MODEL.getPreferredName(), this.llmModel) + .field(LLM_QUESTION.getPreferredName(), this.llmQuestion); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(conversationId); + out.writeOptionalString(llmModel); + + Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); + out.writeString(llmQuestion); + } + + public static GenerativeQAParameters parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + GenerativeQAParameters other = (GenerativeQAParameters) o; + return Objects.equals(this.conversationId, other.getConversationId()) + && Objects.equals(this.llmModel, other.getLlmModel()) + && Objects.equals(this.llmQuestion, other.getLlmQuestion()); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java new file mode 100644 index 0000000000..faf80b9d7a --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.conversation.Interaction; + +import java.util.List; + +/** + * Input for LLMs via HttpConnector + */ +@Log4j2 +@Getter +@Setter +@AllArgsConstructor +public class ChatCompletionInput { + + private String model; + private String question; + private List chatHistory; + private List contexts; +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java new file mode 100644 index 0000000000..b9bc891a7a --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +import java.util.List; + +/** + * Output from LLMs via HttpConnector + */ +@Log4j2 +@Getter +@Setter +@AllArgsConstructor +public class ChatCompletionOutput { + + private List answers; +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java new file mode 100644 index 0000000000..58a3cad64c --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -0,0 +1,100 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkNotNull; + +/** + * Wrapper for talking to LLMs via OpenSearch HttpConnector. + */ +@Log4j2 +public class DefaultLlmImpl implements Llm { + + private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model"; + private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages"; + private static final String CONNECTOR_OUTPUT_CHOICES = "choices"; + private static final String CONNECTOR_OUTPUT_MESSAGE = "message"; + private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; + private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; + + private final String openSearchModelId; + + private MachineLearningInternalClient mlClient; + + public DefaultLlmImpl(String openSearchModelId, Client client) { + checkNotNull(openSearchModelId); + this.openSearchModelId = openSearchModelId; + this.mlClient = new MachineLearningInternalClient(client); + } + + @VisibleForTesting + void setMlClient(MachineLearningInternalClient mlClient) { + this.mlClient = mlClient; + } + + /** + * Use ChatCompletion API to generate an answer. + * + * @param chatCompletionInput + * @return + */ + @Override + public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) { + + Map inputParameters = new HashMap<>(); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); + String messages = PromptUtil.getChatCompletionPrompt(chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts()); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); + log.info("Messages to LLM: {}", messages); + MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); + ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); + ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(); + + // Response from a remote model + Map dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + log.info("dataAsMap: {}", dataAsMap.toString()); + + // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. + + List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + Map firstChoiceMap = (Map) choices.get(0); + log.info("Choices: {}", firstChoiceMap.toString()); + Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); + log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + + return new ChatCompletionOutput(List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT))); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java new file mode 100644 index 0000000000..e850561066 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -0,0 +1,26 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +/** + * Capabilities of large language models, e.g. completion, embeddings, etc. + */ +public interface Llm { + + ChatCompletionOutput doChatCompletion(ChatCompletionInput input); +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java new file mode 100644 index 0000000000..5d007420f7 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.ml.common.conversation.Interaction; + +import java.util.List; + +/** + * Helper class for creating inputs and outputs for different implementations of LLMs. + */ +public class LlmIOUtil { + + public static ChatCompletionInput createChatCompletionInput(String llmModel, String question, List chatHistory, List contexts) { + + // TODO pick the right subclass based on the modelId. + + return new ChatCompletionInput(llmModel, question, chatHistory, contexts); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java new file mode 100644 index 0000000000..1b43574374 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.client.Client; + +/** + * Helper class for wiring LLMs based on the model ID. + * + * TODO Should we extend this use case beyond HttpConnectors/Remote Inference? + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class ModelLocator { + + public static Llm getLlm(String modelId, Client client) { + return new DefaultLlmImpl(modelId, client); + } + +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java new file mode 100644 index 0000000000..10e5a924c6 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -0,0 +1,161 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.prompt; + +import com.google.common.annotations.VisibleForTesting; +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; +import lombok.AccessLevel; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.ml.common.conversation.Interaction; + +import java.util.ArrayList; +import java.util.List; + +/** + * A utility class for producing prompts for LLMs. + * + * TODO Should prompt engineering llm-specific? + * + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class PromptUtil { + + public static final String DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE = + "Generate a concise and informative answer in less than 100 words for the given question, taking into context: " + + "- An enumerated list of search results" + + "- A rephrase of the question that was used to generate the search results" + + "- The conversation history" + + "Cite search results using [${number}] notation." + + "Do not repeat yourself, and NEVER repeat anything in the chat history." + + "If there are any necessary steps or procedures in your answer, enumerate them."; + + private static final String roleUser = "user"; + + public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory) { + return null; + } + + public static String getChatCompletionPrompt(String question, List chatHistory, List contexts) { + return buildMessageParameter(question, chatHistory, contexts); + } + + enum ChatRole { + USER("user"), + ASSISTANT("assistant"), + SYSTEM("system"); + + // TODO Add "function" + + @Getter + private String name; + + ChatRole(String name) { + this.name = name; + } + } + + @VisibleForTesting + static String buildMessageParameter(String question, List chatHistory, List contexts) { + + // TODO better prompt template management is needed here. + + JsonArray messageArray = new JsonArray(); + messageArray.add(new Message(ChatRole.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); + for (String result : contexts) { + messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT: " + result).toJson()); + } + if (!chatHistory.isEmpty()) { + Messages.fromInteractions(chatHistory).getMessages().forEach(m -> messageArray.add(m.toJson())); + } + messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson()); + messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson()); + + return messageArray.toString(); + } + + private static Gson gson = new Gson(); + + @Getter + static class Messages { + + @Getter + private List messages = new ArrayList<>(); + //private JsonArray jsonArray = new JsonArray(); + + public Messages(final List messages) { + addMessages(messages); + } + + public void addMessages(List messages) { + this.messages.addAll(messages); + } + + public static Messages fromInteractions(final List interactions) { + List messages = new ArrayList<>(); + + for (Interaction interaction : interactions) { + messages.add(new Message(ChatRole.USER, interaction.getInput())); + messages.add(new Message(ChatRole.ASSISTANT, interaction.getResponse())); + } + + return new Messages(messages); + } + } + + static class Message { + + private final static String MESSAGE_FIELD_ROLE = "role"; + private final static String MESSAGE_FIELD_CONTENT = "content"; + + @Getter + private ChatRole chatRole; + @Getter + private String content; + + private JsonObject json; + + public Message() { + json = new JsonObject(); + } + + public Message(ChatRole chatRole, String content) { + this(); + setChatRole(chatRole); + setContent(content); + } + + public void setChatRole(ChatRole chatRole) { + json.remove(MESSAGE_FIELD_ROLE); + json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName())); + } + public void setContent(String content) { + this.content = StringEscapeUtils.escapeJson(content); + json.remove(MESSAGE_FIELD_CONTENT); + json.add(MESSAGE_FIELD_CONTENT, new JsonPrimitive(this.content)); + } + + public JsonObject toJson() { + return json; + } + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java new file mode 100644 index 0000000000..cbd5122371 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java @@ -0,0 +1,39 @@ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GenerativeQAParamUtilTests extends OpenSearchTestCase { + + public void testGenerativeQAParametersMissingParams() { + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + assertNull(actual); + } + + public void testMisc() { + SearchRequest request = new SearchRequest(); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + request.source(new SearchSourceBuilder()); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + request.source(new SearchSourceBuilder().ext(List.of())); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + + SearchExtBuilder extBuilder = mock(SearchExtBuilder.class); + when(extBuilder.getWriteableName()).thenReturn("foo"); + request.source(new SearchSourceBuilder().ext(List.of(extBuilder))); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java new file mode 100644 index 0000000000..a83ccd1767 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java @@ -0,0 +1,98 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; +import java.util.function.BooleanSupplier; + +import static org.mockito.Mockito.mock; + +public class GenerativeQARequestProcessorTests extends OpenSearchTestCase { + + private BooleanSupplier alwaysOn = () -> true; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + public void testProcessorFactory() throws Exception { + + Map config = new HashMap<>(); + config.put("model_id", "foo"); + SearchRequestProcessor processor = + new GenerativeQARequestProcessor.Factory(alwaysOn).create(null, "tag", "desc", true, config, null); + assertTrue(processor instanceof GenerativeQARequestProcessor); + } + + public void testProcessRequest() throws Exception { + GenerativeQARequestProcessor processor = new GenerativeQARequestProcessor("tag", "desc", false, "foo", alwaysOn); + SearchRequest request = new SearchRequest(); + SearchRequest processed = processor.processRequest(request); + assertEquals(request, processed); + } + + public void testGetType() { + GenerativeQARequestProcessor processor = new GenerativeQARequestProcessor("tag", "desc", false, "foo", alwaysOn); + assertEquals(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, processor.getType()); + } + + public void testProcessorFactoryFeatureFlagDisabled() throws Exception { + + exceptionRule.expect(MLException.class); + exceptionRule.expectMessage(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); + Map config = new HashMap<>(); + config.put("model_id", "foo"); + Processor processor = + new GenerativeQARequestProcessor.Factory(()->false).create(null, "tag", "desc", true, config, null); + } + + // Only to be used for the following test case. + private boolean featureFlag001 = false; + public void testProcessorFeatureFlagOffOnOff() throws Exception { + Map config = new HashMap<>(); + config.put("model_id", "foo"); + Processor.Factory factory = new GenerativeQARequestProcessor.Factory(()->featureFlag001); + boolean firstExceptionThrown = false; + try { + factory.create(null, "tag", "desc", true, config, null); + } catch (MLException e) { + assertEquals(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG, e.getMessage()); + firstExceptionThrown = true; + } + assertTrue(firstExceptionThrown); + featureFlag001 = true; + GenerativeQARequestProcessor processor = (GenerativeQARequestProcessor) factory.create(null, "tag", "desc", true, config, null); + featureFlag001 = false; + boolean secondExceptionThrown = false; + try { + processor.processRequest(mock(SearchRequest.class)); + } catch (MLException e) { + assertEquals(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG, e.getMessage()); + secondExceptionThrown = true; + } + assertTrue(secondExceptionThrown); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java new file mode 100644 index 0000000000..af8b1d9929 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -0,0 +1,282 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.junit.Rule; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.Client; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; +import org.opensearch.test.OpenSearchTestCase; + +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BooleanSupplier; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class GenerativeQAResponseProcessorTests extends OpenSearchTestCase { + + private BooleanSupplier alwaysOn = () -> true; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + public void testProcessorFactoryRemoteModel() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "xyz"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) + .create(null, "tag", "desc", true, config, null); + assertNotNull(processor); + } + + public void testGetType() { + Client client = mock(Client.class); + Llm llm = mock(Llm.class); + GenerativeQAResponseProcessor processor = new GenerativeQAResponseProcessor(client, null, null, false, llm, "foo", List.of("text"), alwaysOn); + assertEquals(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, processor.getType()); + } + + public void testProcessResponseNoSearchHits() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) + .create(null, "tag", "desc", true, config, null); + + SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); + boolean errorThrown = false; + try { + SearchResponse res = processor.processResponse(request, response); + } catch (Exception e) { + errorThrown = true; + } + assertTrue(errorThrown); + } + + public void testProcessResponse() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) + .create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())).thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); + SearchResponse res = processor.processResponse(request, response); + verify(llm).doChatCompletion(captor.capture()); + ChatCompletionInput input = captor.getValue(); + assertTrue(input instanceof ChatCompletionInput); + List passages = ((ChatCompletionInput) input).getContexts(); + assertEquals("passage0", passages.get(0)); + assertEquals("passage1", passages.get(1)); + assertTrue(res instanceof GenerativeSearchResponse); + } + + public void testProcessResponseMissingContextField() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client, alwaysOn) + .create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())).thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + //.field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + boolean exceptionThrown = false; + + try { + SearchResponse res = processor.processResponse(request, response); + } catch (Exception e) { + exceptionThrown = true; + } + + assertTrue(exceptionThrown); + } + + public void testProcessorFactoryFeatureDisabled() throws Exception { + + exceptionRule.expect(MLException.class); + exceptionRule.expectMessage(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG); + + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "xyz"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + Processor processor = + new GenerativeQAResponseProcessor.Factory(client, () -> false) + .create(null, "tag", "desc", true, config, null); + } + + // Use this only for the following test case. + private boolean featureEnabled001; + public void testProcessorFeatureOffOnOff() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "xyz"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + featureEnabled001 = false; + BooleanSupplier supplier = () -> featureEnabled001; + Processor.Factory factory = new GenerativeQAResponseProcessor.Factory(client, supplier); + GenerativeQAResponseProcessor processor; + boolean firstExceptionThrown = false; + try { + processor = (GenerativeQAResponseProcessor) factory.create(null, "tag", "desc", true, config, null); + } catch (MLException e) { + assertEquals(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG, e.getMessage()); + firstExceptionThrown = true; + } + assertTrue(firstExceptionThrown); + featureEnabled001 = true; + processor = (GenerativeQAResponseProcessor) factory.create(null, "tag", "desc", true, config, null); + + featureEnabled001 = false; + boolean secondExceptionThrown = false; + try { + processor.processResponse(mock(SearchRequest.class), mock(SearchResponse.class)); + } catch (MLException e) { + assertEquals(GenerativeQAProcessorConstants.FEATURE_NOT_ENABLED_ERROR_MSG, e.getMessage()); + secondExceptionThrown = true; + } + assertTrue(secondExceptionThrown); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java new file mode 100644 index 0000000000..cead38b0a0 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentGenerator; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.io.OutputStream; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GenerativeSearchResponseTests extends OpenSearchTestCase { + + public void testToXContent() throws IOException { + String answer = "answer"; + SearchResponseSections internal = new SearchResponseSections(new SearchHits(new SearchHit[0], null, 0), null, null, false, false, null, 0); + GenerativeSearchResponse searchResponse = new GenerativeSearchResponse(answer, internal, null, 0, 0, 0, 0, new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java new file mode 100644 index 0000000000..67038d93cd --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.IntStream; + +import static org.mockito.Mockito.*; + +public class ConversationalMemoryClientTests extends OpenSearchTestCase { + + public void testCreateConversation() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateConversationRequest.class); + String conversationId = UUID.randomUUID().toString(); + CreateConversationResponse response = new CreateConversationResponse(conversationId); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response); + when(client.execute(eq(CreateConversationAction.INSTANCE), any())).thenReturn(future); + String name = "foo"; + String actual = memoryClient.createConversation(name); + verify(client, times(1)).execute(eq(CreateConversationAction.INSTANCE), captor.capture()); + assertEquals(name, captor.getValue().getName()); + assertEquals(conversationId, actual); + } + + public void testGetInteractionsNoPagination() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + int lastN = 5; + String conversationId = UUID.randomUUID().toString(); + List interactions = new ArrayList<>(); + IntStream.range(0, lastN).forEach(i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response = new GetInteractionsResponse(interactions, lastN, false); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + + List actual = memoryClient.getInteractions(conversationId, lastN); + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); + GetInteractionsRequest actualRequest = captor.getValue(); + assertEquals(lastN, actual.size()); + assertEquals(conversationId, actualRequest.getConversationId()); + assertEquals(lastN, actualRequest.getMaxResults()); + assertEquals(0, actualRequest.getFrom()); + } + + public void testGetInteractionsWithPagination() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + int lastN = 5; + String conversationId = UUID.randomUUID().toString(); + List firstPage = new ArrayList<>(); + IntStream.range(0, lastN).forEach(i -> firstPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response1 = new GetInteractionsResponse(firstPage, lastN, true); + List secondPage = new ArrayList<>(); + IntStream.range(0, lastN).forEach(i -> secondPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response2 = new GetInteractionsResponse(secondPage, lastN, false); + ActionFuture future1 = mock(ActionFuture.class); + when(future1.actionGet()).thenReturn(response1); + ActionFuture future2 = mock(ActionFuture.class); + when(future2.actionGet()).thenReturn(response2); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future1).thenReturn(future2); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + + List actual = memoryClient.getInteractions(conversationId, 2*lastN); + // Called twice + verify(client, times(2)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); + List actualRequests = captor.getAllValues(); + assertEquals(2*lastN, actual.size()); + assertEquals(conversationId, actualRequests.get(0).getConversationId()); + assertEquals(2*lastN, actualRequests.get(0).getMaxResults()); + assertEquals(0, actualRequests.get(0).getFrom()); + assertEquals(lastN, actualRequests.get(1).getFrom()); + } + + public void testGetInteractionsNoMoreResults() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + int lastN = 5; + int found = lastN - 1; + String conversationId = UUID.randomUUID().toString(); + List interactions = new ArrayList<>(); + // Return fewer results than requested + IntStream.range(0, found).forEach(i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response = new GetInteractionsResponse(interactions, found, false); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + + List actual = memoryClient.getInteractions(conversationId, lastN); + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); + GetInteractionsRequest actualRequest = captor.getValue(); + assertEquals(found, actual.size()); + assertEquals(conversationId, actualRequest.getConversationId()); + assertEquals(lastN, actualRequest.getMaxResults()); + assertEquals(0, actualRequest.getFrom()); + } + + public void testAvoidInfiniteLoop() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + GetInteractionsResponse response1 = new GetInteractionsResponse(null, 0, true); + GetInteractionsResponse response2 = new GetInteractionsResponse(List.of(), 0, true); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response1).thenReturn(response2); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + List actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + } + + public void testNoResults() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + GetInteractionsResponse response1 = new GetInteractionsResponse(null, 0, true); + GetInteractionsResponse response2 = new GetInteractionsResponse(List.of(), 0, false); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response1).thenReturn(response2); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + List actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + } + + public void testCreateInteraction() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + String id = UUID.randomUUID().toString(); + CreateInteractionResponse res = new CreateInteractionResponse(id); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(res); + when(client.execute(eq(CreateInteractionAction.INSTANCE), any())).thenReturn(future); + String actual = memoryClient.createInteraction("cid", "input", "prompt", "answer", "origin", "hits"); + assertEquals(id, actual); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java new file mode 100644 index 0000000000..ce921bac89 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class MachineLearningInternalClientTests { + @Mock(answer = RETURNS_DEEP_STUBS) + NodeClient client; + + @Mock + MLInputDataset input; + + @Mock + DataFrame output; + + @Mock + ActionListener dataFrameActionListener; + + @InjectMocks + MachineLearningInternalClient machineLearningInternalClient; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void predict() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLPredictionOutput predictionOutput = MLPredictionOutput.builder() + .status("Success") + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder() + .output(predictionOutput) + .build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); + MLInput mlInput = MLInput.builder() + .algorithm(FunctionName.KMEANS) + .inputDataset(input) + .build(); + machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); + + verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any()); + verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); + assertEquals(output, ((MLPredictionOutput)dataFrameArgumentCaptor.getValue()).getPredictionResult()); + } + + @Test + public void predict_Exception_WithNullAlgorithm() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("algorithm can't be null"); + MLInput mlInput = MLInput.builder() + .inputDataset(input) + .build(); + machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); + } + + @Test + public void predict_Exception_WithNullDataSet() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("input data set can't be null"); + MLInput mlInput = MLInput.builder() + .algorithm(FunctionName.KMEANS) + .build(); + machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java new file mode 100644 index 0000000000..b05b52062c --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.EOFException; +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { + + public void testCtor() throws IOException { + GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); + GenerativeQAParameters parameters = new GenerativeQAParameters(); + builder.setParams(parameters); + assertEquals(parameters, builder.getParams()); + + GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(new StreamInput() { + @Override + public byte readByte() throws IOException { + return 0; + } + + @Override + public void readBytes(byte[] b, int offset, int len) throws IOException { + + } + + @Override + public void close() throws IOException { + + } + + @Override + public int available() throws IOException { + return 0; + } + + @Override + protected void ensureCanReadBytes(int length) throws EOFException { + + } + + @Override + public int read() throws IOException { + return 0; + } + }); + + assertNotNull(builder1); + } + + public void testMiscMethods() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d"); + GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(); + GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder(); + builder1.setParams(param1); + builder2.setParams(param2); + assertEquals(builder1, builder1); + assertNotEquals(builder1, param1); + assertNotEquals(builder1, builder2); + assertNotEquals(builder1.hashCode(), builder2.hashCode()); + + StreamOutput so = mock(StreamOutput.class); + builder1.writeTo(so); + verify(so, times(2)).writeOptionalString(any()); + verify(so, times(1)).writeString(any()); + } + + public void testParse() throws IOException { + XContentParser xcParser = mock(XContentParser.class); + when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT); + GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser); + assertNotNull(builder); + assertNotNull(builder.getParams()); + } + + public void testXContentRoundTrip() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentParser parser = createParser(xContentType.xContent(), serialized); + GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); + } + + public void testStreamRoundTrip() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(param1); + BytesStreamOutput bso = new BytesStreamOutput(); + extBuilder.writeTo(bso); + GenerativeQAParamExtBuilder deserialized = new GenerativeQAParamExtBuilder(bso.bytes().streamInput()); + assertEquals(extBuilder, deserialized); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java new file mode 100644 index 0000000000..c6cf3e9399 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java @@ -0,0 +1,35 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class GenerativeQAParamUtilTests extends OpenSearchTestCase { + + public void testGenerativeQAParametersMissingParams() { + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + assertNull(actual); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java new file mode 100644 index 0000000000..b2f9d9dc2f --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -0,0 +1,125 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentGenerator; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public class GenerativeQAParametersTests extends OpenSearchTestCase { + + public void testGenerativeQAParameters() { + GenerativeQAParameters params = new GenerativeQAParameters("conversation_id", "llm_model", "llm_question"); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + assertEquals(params, actual); + } + + static class DummyStreamOutput extends StreamOutput { + + List list = new ArrayList<>(); + + @Override + public void writeString(String str) { + list.add(str); + } + + public List getList() { + return list; + } + + @Override + public void writeByte(byte b) throws IOException { + + } + + @Override + public void writeBytes(byte[] b, int offset, int length) throws IOException { + + } + + @Override + public void flush() throws IOException { + + } + + @Override + public void close() throws IOException { + + } + + @Override + public void reset() throws IOException { + + } + } + public void testWriteTo() throws IOException { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + StreamOutput output = new DummyStreamOutput(); + parameters.writeTo(output); + List actual = ((DummyStreamOutput) output).getList(); + assertEquals(3, actual.size()); + assertEquals(conversationId, actual.get(0)); + assertEquals(llmModel, actual.get(1)); + assertEquals(llmQuestion, actual.get(2)); + } + + public void testMisc() { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + assertNotEquals(parameters, null); + assertNotEquals(parameters, "foo"); + assertEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, llmQuestion)); + assertNotEquals(parameters, new GenerativeQAParameters("", llmModel, llmQuestion)); + assertNotEquals(parameters, new GenerativeQAParameters(conversationId, "", llmQuestion)); + assertNotEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, "")); + } + + public void testToXConent() throws IOException { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + assertNotNull(parameters.toXContent(builder, null)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java new file mode 100644 index 0000000000..925b84b8b1 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -0,0 +1,56 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +import java.time.Instant; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class ChatCompletionInputTests extends OpenSearchTestCase { + + public void testCtor() { + String model = "model"; + String question = "question"; + + ChatCompletionInput input = new ChatCompletionInput(model, question, Collections.emptyList(), Collections.emptyList()); + + assertNotNull(input); + } + + public void testGettersSetters() { + String model = "model"; + String question = "question"; + List history = List.of(Interaction.fromMap("1", + Map.of( + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "convo1", + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "hello"))); + List contexts = List.of("result1", "result2"); + ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts); + assertEquals(model, input.getModel()); + assertEquals(question, input.getQuestion()); + assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); + assertEquals(contexts.get(0), input.getContexts().get(0)); + assertEquals(contexts.get(1), input.getContexts().get(1)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java new file mode 100644 index 0000000000..c3f6c68688 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class ChatCompletionOutputTests extends OpenSearchTestCase { + + public void testCtor() { + ChatCompletionOutput output = new ChatCompletionOutput(List.of("answer")); + assertNotNull(output); + } + + public void testGettersSetters() { + String answer = "answer"; + ChatCompletionOutput output = new ChatCompletionOutput(List.of(answer)); + assertEquals(answer, (String) output.getAnswers().get(0)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java new file mode 100644 index 0000000000..0aba017245 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.client.Client; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public class DefaultLlmImplTests extends OpenSearchTestCase { + + @Mock + Client client; + + public void testBuildMessageParameter() { + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + String question = "Who am I"; + List contexts = new ArrayList<>(); + contexts.add("context 1"); + contexts.add("context 2"); + List chatHistory = List.of(Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer1")), + Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer2"))); + String parameter = PromptUtil.getChatCompletionPrompt(question, chatHistory, contexts); + Map parameters = Map.of("model", "foo", "messages", parameter); + assertTrue(isJson(parameter)); + } + + public void testChatCompletionApi() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map messageMap = Map.of("role", "agent", "content", "answer"); + Map dataAsMap = Map.of("choices", List.of(Map.of("message", messageMap))); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertEquals("answer", (String) output.getAnswers().get(0)); + } + + private boolean isJson(String Json) { + try { + new JSONObject(Json); + } catch (JSONException ex) { + try { + new JSONArray(Json); + } catch (JSONException ex1) { + return false; + } + } + return true; + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java new file mode 100644 index 0000000000..5d8395126b --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; + +public class LlmIOUtilTests extends OpenSearchTestCase { + + public void testCtor() { + assertNotNull(new LlmIOUtil()); + } + + public void testChatCompletionInput() { + ChatCompletionInput input = LlmIOUtil.createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); + assertTrue(input instanceof ChatCompletionInput); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java new file mode 100644 index 0000000000..dcf3d223fb --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.client.Client; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.Mockito.mock; + +public class ModelLocatorTests extends OpenSearchTestCase { + + public void testGetRemoteLlm() { + Client client = mock(Client.class); + Llm llm = ModelLocator.getLlm("xyz", client); + assertTrue(llm instanceof DefaultLlmImpl); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java new file mode 100644 index 0000000000..dd3fed1c9d --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.prompt; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.test.OpenSearchTestCase; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class PromptUtilTests extends OpenSearchTestCase { + + public void testPromptUtilStaticMethods() { + assertNull(PromptUtil.getQuestionRephrasingPrompt("question", Collections.emptyList())); + } + + public void testBuildMessageParameter() { + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = List.of(Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer1")), + Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer2"))); + contexts.add("context 1"); + contexts.add("context 2"); + String parameter = PromptUtil.buildMessageParameter(question, chatHistory, contexts); + Map parameters = Map.of("model", "foo", "messages", parameter); + assertTrue(isJson(parameter)); + } + + private boolean isJson(String Json) { + try { + new JSONObject(Json); + } catch (JSONException ex) { + try { + new JSONArray(Json); + } catch (JSONException ex1) { + return false; + } + } + return true; + } +} diff --git a/settings.gradle b/settings.gradle index b69d65c8b6..b6d0b19113 100644 --- a/settings.gradle +++ b/settings.gradle @@ -13,4 +13,7 @@ include 'plugin' project(":plugin").name = rootProject.name + "-plugin" include 'ml-algorithms' project(":ml-algorithms").name = rootProject.name + "-algorithms" - +include 'search-processors' +project(":search-processors").name = rootProject.name + "-search-processors" +include 'memory' +project(":memory").name = rootProject.name + "-memory"