From 1baa0804924e69b4205b057d6461e2e2395f28c2 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Fri, 1 Sep 2023 12:32:51 -0700 Subject: [PATCH 1/7] Put RAG pipeline behind a feature flag. Signed-off-by: Austin Lee --- .../ml/plugin/MachineLearningPlugin.java | 34 ++++++++++++++++--- .../ml/settings/MLCommonsSettings.java | 3 ++ settings.gradle | 2 -- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 70b827261c..21ba7e9dca 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -7,9 +7,11 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED; import java.nio.file.Path; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -247,6 +249,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc private ConversationalMemoryHandler cmHandler; + private volatile boolean ragSearchPipelineEnabled; + @Override public List> getActions() { return ImmutableList @@ -338,6 +342,11 @@ public Collection createComponents( stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); + // TODO move this into MLFeatureEnabledSetting + this.ragSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(clusterService.getSettings()); + clusterService.getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); + mlIndicesHandler = new MLIndicesHandler(clusterService, client); mlTaskManager = new MLTaskManager(client, threadPool, mlIndicesHandler); modelHelper = new ModelHelper(mlEngine); @@ -654,30 +663,45 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED, MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, - MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED + MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, + ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED ); return settings; } + /** + * + * Search processors for Retrieval Augmented Generation + * + */ + @Override public List> getSearchExts() { - return List + return ragSearchPipelineEnabled ? List .of( new SearchPlugin.SearchExtSpec<>( GenerativeQAParamExtBuilder.PARAMETER_NAME, input -> new GenerativeQAParamExtBuilder(input), parser -> GenerativeQAParamExtBuilder.parse(parser) ) - ); + ) + // Feature not enabled + : Collections.emptyList(); } @Override public Map> getRequestProcessors(Parameters parameters) { - return Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()); + return ragSearchPipelineEnabled ? + Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()) + // Feature not enabled + : Collections.emptyMap(); } @Override public Map> getResponseProcessors(Parameters parameters) { - return Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)); + return ragSearchPipelineEnabled ? + Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)) + // Feature not enabled + : Collections.emptyMap(); } } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index dc9c209535..de03e16eb9 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -130,4 +130,7 @@ private MLCommonsSettings() {} ); public static final Setting ML_COMMONS_MEMORY_FEATURE_ENABLED = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED; + // Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference. + public static final Setting ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED = Setting + .boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/settings.gradle b/settings.gradle index bf697450c1..b6d0b19113 100644 --- a/settings.gradle +++ b/settings.gradle @@ -15,7 +15,5 @@ include 'ml-algorithms' project(":ml-algorithms").name = rootProject.name + "-algorithms" include 'search-processors' project(":search-processors").name = rootProject.name + "-search-processors" -include 'conversational-memory' -project(":conversational-memory").name = rootProject.name + "-memory" include 'memory' project(":memory").name = rootProject.name + "-memory" From 1eb94445fe40fd81f6ac4c6acf288cf7563bc5e9 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 15 Aug 2023 09:45:45 -0700 Subject: [PATCH 2/7] Add support for chat history in RAG using the Conversational Memory API Signed-off-by: Austin Lee --- search-processors/build.gradle | 7 +- .../GenerativeQAResponseProcessor.java | 43 ++++- .../client/ConversationalMemoryClient.java | 99 ++++++++++ .../ext/GenerativeQAParamExtBuilder.java | 4 +- .../generative/llm/ChatCompletionInput.java | 3 +- .../generative/llm/DefaultLlmImpl.java | 7 +- .../generative/llm/LlmIOUtil.java | 4 +- .../generative/prompt/PromptUtil.java | 137 +++++++++++--- .../GenerativeQAParamUtilTests.java | 39 ++++ .../GenerativeQARequestProcessorTests.java | 9 +- .../GenerativeQAResponseProcessorTests.java | 76 +++++++- .../GenerativeSearchResponseTests.java | 53 ++++++ .../ConversationalMemoryClientTests.java | 174 ++++++++++++++++++ .../ext/GenerativeQAParamExtBuilderTests.java | 52 +++++- .../ext/GenerativeQAParamUtilTests.java | 6 - .../ext/GenerativeQAParametersTests.java | 36 +++- .../llm/ChatCompletionInputTests.java | 12 +- .../generative/llm/DefaultLlmImplTests.java | 15 +- .../generative/llm/LlmIOUtilTests.java | 4 + .../generative/prompt/PromptUtilTests.java | 14 +- 20 files changed, 721 insertions(+), 73 deletions(-) create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java diff --git a/search-processors/build.gradle b/search-processors/build.gradle index 4371f37e15..3d5903a0d3 100644 --- a/search-processors/build.gradle +++ b/search-processors/build.gradle @@ -29,9 +29,11 @@ repositories { dependencies { compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' implementation 'org.apache.commons:commons-lang3:3.12.0' //implementation project(':opensearch-ml-client') implementation project(':opensearch-ml-common') + implementation project(':opensearch-ml-memory') implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}" // https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5 implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1' @@ -59,11 +61,11 @@ jacocoTestCoverageVerification { rule { limit { counter = 'LINE' - minimum = 0.65 //TODO: increase coverage to 0.90 + minimum = 0.8 } limit { counter = 'BRANCH' - minimum = 0.55 //TODO: increase coverage to 0.85 + minimum = 0.8 } } } @@ -71,4 +73,3 @@ jacocoTestCoverageVerification { } check.dependsOn jacocoTestCoverageVerification -//jacocoTestCoverageVerification.dependsOn jacocoTestReport diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 60c746f1d6..ea0fe8cdd1 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -17,6 +17,8 @@ */ package org.opensearch.searchpipelines.questionanswering.generative; +import com.google.gson.Gson; +import com.google.gson.JsonArray; import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; @@ -24,18 +26,22 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.search.SearchHit; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -48,11 +54,16 @@ @Log4j2 public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor { + private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10; + // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. private final String llmModel; private final List contextFields; + @Setter + private ConversationalMemoryClient memoryClient; + @Getter @Setter // Mainly for unit testing purpose @@ -64,6 +75,7 @@ protected GenerativeQAResponseProcessor(Client client, String tag, String descri this.llmModel = llmModel; this.contextFields = contextFields; this.llm = llm; + this.memoryClient = new ConversationalMemoryClient(client); } @Override @@ -71,16 +83,23 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp log.info("Entering processResponse."); - List chatHistory = getChatHistory(request); GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); String llmQuestion = params.getLlmQuestion(); String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); String conversationId = params.getConversationId(); log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId); + List chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW); + List searchResults = getSearchResults(response); + ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults)); + String answer = (String) output.getAnswers().get(0); + + String interactionId = null; + if (conversationId != null) { + interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer, + GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults)); + } - ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, getSearchResults(response))); - - return insertAnswer(response, (String) output.getAnswers().get(0)); + return insertAnswer(response, answer, interactionId); } @Override @@ -88,16 +107,14 @@ public String getType() { return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; } - private SearchResponse insertAnswer(SearchResponse response, String answer) { + private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) { + + // TODO return the interaction id in the response. + return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters()); } - // TODO Integrate with Conversational Memory - private List getChatHistory(SearchRequest request) { - return new ArrayList<>(); - } - private List getSearchResults(SearchResponse response) { List searchResults = new ArrayList<>(); for (SearchHit hit : response.getHits().getHits()) { @@ -115,6 +132,12 @@ private List getSearchResults(SearchResponse response) { return searchResults; } + private static String jsonArrayToString(List listOfStrings) { + JsonArray array = new JsonArray(listOfStrings.size()); + listOfStrings.forEach(array::add); + return array.toString(); + } + public static final class Factory implements Processor.Factory { private final Client client; diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java new file mode 100644 index 0000000000..724eef0b48 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -0,0 +1,99 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import com.google.common.base.Preconditions; +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionRequest; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; + +import java.util.ArrayList; +import java.util.List; + +/** + * An OpenSearch client wrapper for conversational memory related calls. + */ +@Log4j2 +@AllArgsConstructor +public class ConversationalMemoryClient { + + private final static Logger logger = LogManager.getLogger(); + + private Client client; + + public String createConversation(String name) { + + CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(); + + return response.getId(); + } + + public String createInteraction(String conversationId, String input, String promptTemplate, String response, String origin, String additionalInfo) { + Preconditions.checkNotNull(conversationId); + Preconditions.checkNotNull(input); + Preconditions.checkNotNull(response); + CreateInteractionResponse res = client.execute(CreateInteractionAction.INSTANCE, + new CreateInteractionRequest(conversationId, input, promptTemplate, response, origin, additionalInfo)).actionGet(); + log.info("createInteraction: interactionId: {}", res.getId()); + return res.getId(); + } + + public List getInteractions(String conversationId, int lastN) { + + log.info("In getInteractions, conversationId {}, lastN {}", conversationId, lastN); + + List interactions = new ArrayList<>(); + int from = 0; + boolean done = false; + int maxResults = lastN; + do { + GetInteractionsResponse response = + client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(); + List list = response.getInteractions(); + if (list != null && !list.isEmpty()) { + interactions.addAll(list); + from += list.size(); + maxResults -= list.size(); + log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults); + } else if (response.hasMorePages()) { + // If we didn't get any results back, we ignore this flag and break out of the loop + // to avoid an infinite loop. + // But in the future, we may support this mode, e.g. DynamoDB. + break; + } + log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults); + done = !response.hasMorePages(); + } while (from < lastN && !done); + + return interactions; + } + + +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java index fc4d5a2222..8a6ee8cc65 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java @@ -61,7 +61,7 @@ public boolean equals(Object obj) { return false; } - return Objects.equals(this.getParams(), ((GenerativeQAParamExtBuilder) obj).getParams()); + return this.params.equals(((GenerativeQAParamExtBuilder) obj).getParams()); } @Override @@ -76,7 +76,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return builder.value(params); + return builder.value(this.params); } public static GenerativeQAParamExtBuilder parse(XContentParser parser) throws IOException { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index b1ea2f9706..faf80b9d7a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -21,6 +21,7 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; +import org.opensearch.ml.common.conversation.Interaction; import java.util.List; @@ -35,6 +36,6 @@ public class ChatCompletionInput { private String model; private String question; - private List chatHistory; + private List chatHistory; private List contexts; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index 456faafe1c..58a3cad64c 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -57,7 +57,6 @@ public DefaultLlmImpl(String openSearchModelId, Client client) { checkNotNull(openSearchModelId); this.openSearchModelId = openSearchModelId; this.mlClient = new MachineLearningInternalClient(client); - } @VisibleForTesting @@ -76,13 +75,15 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI Map inputParameters = new HashMap<>(); inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, PromptUtil.getChatCompletionPrompt(chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts())); + String messages = PromptUtil.getChatCompletionPrompt(chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts()); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); + log.info("Messages to LLM: {}", messages); MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(); - // Response from the (remote) model + // Response from a remote model Map dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); log.info("dataAsMap: {}", dataAsMap.toString()); diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index badb1920bd..5d007420f7 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -17,6 +17,8 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import org.opensearch.ml.common.conversation.Interaction; + import java.util.List; /** @@ -24,7 +26,7 @@ */ public class LlmIOUtil { - public static ChatCompletionInput createChatCompletionInput(String llmModel, String question, List chatHistory, List contexts) { + public static ChatCompletionInput createChatCompletionInput(String llmModel, String question, List chatHistory, List contexts) { // TODO pick the right subclass based on the modelId. diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index f7b5049bd9..45b72a41b8 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -18,57 +18,144 @@ package org.opensearch.searchpipelines.questionanswering.generative.prompt; import com.google.common.annotations.VisibleForTesting; +import com.google.gson.Gson; +import com.google.gson.JsonArray; +import com.google.gson.JsonObject; +import com.google.gson.JsonPrimitive; import lombok.AccessLevel; +import lombok.Getter; import lombok.NoArgsConstructor; import org.apache.commons.text.StringEscapeUtils; +import org.opensearch.ml.common.conversation.Interaction; +import java.util.ArrayList; import java.util.List; -import java.util.Locale; /** + * A utility class for producing prompts for LLMs. + * * TODO Should prompt engineering llm-specific? * */ @NoArgsConstructor(access = AccessLevel.PRIVATE) public class PromptUtil { + public static final String DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE = + "Generate a concise and informative answer in less than 100 words for the given question, taking into context: " + + "- An enumerated list of search results" + + "- A rephrase of the question that was used to generate the search results" + + "- The conversation history" + + "Cite search results using [${number}] notation." + + "Do not repeat yourself, and NEVER repeat anything in the chat history." + + "If there are any necessary steps or procedures in your answer, enumerate them."; + private static final String roleUser = "user"; - public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory) { + public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory) { return null; } - public static String getChatCompletionPrompt(String question, List chatHistory, List contexts) { + public static String getChatCompletionPrompt(String question, List chatHistory, List contexts) { return buildMessageParameter(question, chatHistory, contexts); } + enum Role { + USER("user"), + ASSISTANT("assistant"), + SYSTEM("system"); + + // TODO Add "function" + + @Getter + private String name; + + Role(String name) { + this.name = name; + } + } + @VisibleForTesting - static String buildMessageParameter(String question, List chatHistory, List contexts) { + static String buildMessageParameter(String question, List chatHistory, List contexts) { + // TODO better prompt template management is needed here. - String instructions = "Generate a concise and informative answer in less than 100 words for the given question, taking into context: " - + "- An enumerated list of search results" - + "- A rephrase of the question that was used to generate the search results" - + "- The conversation history" - + "Cite search results using [${number}] notation." - + "Do not repeat yourself, and NEVER repeat anything in the chat history." - + "If there are any necessary steps or procedures in your answer, enumerate them."; - StringBuffer sb = new StringBuffer(); - sb.append("[\n"); - sb.append(formatMessage(roleUser, instructions)); - sb.append(",\n"); + + JsonArray messageArray = new JsonArray(); + messageArray.add(new Message(Role.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); for (String result : contexts) { - sb.append(formatMessage(roleUser, "SEARCH RESULTS: " + result)); - sb.append(",\n"); + messageArray.add(new Message(Role.USER, "SEARCH RESULT: " + result).toJson()); + } + if (!chatHistory.isEmpty()) { + Messages.fromInteractions(chatHistory).getMessages().forEach(m -> messageArray.add(m.toJson())); + } + messageArray.add(new Message(Role.USER, "QUESTION: " + question).toJson()); + messageArray.add(new Message(Role.USER, "ANSWER:").toJson()); + + return messageArray.toString(); + } + + private static Gson gson = new Gson(); + + @Getter + static class Messages { + + @Getter + private List messages = new ArrayList<>(); + //private JsonArray jsonArray = new JsonArray(); + + public Messages(final List messages) { + addMessages(messages); + } + + public void addMessages(List messages) { + this.messages.addAll(messages); + } + + public static Messages fromInteractions(final List interactions) { + List messages = new ArrayList<>(); + + for (Interaction interaction : interactions) { + messages.add(new Message(Role.USER, interaction.getInput())); + messages.add(new Message(Role.ASSISTANT, interaction.getResponse())); + } + + return new Messages(messages); } - sb.append(formatMessage(roleUser, "QUESTION: " + question)); - sb.append(",\n"); - sb.append(formatMessage(roleUser, "ANSWER:")); - sb.append("\n"); - sb.append("]"); - return sb.toString(); } - private static String formatMessage(String role, String content) { - return String.format(Locale.ROOT, "{\"role\": \"%s\", \"content\": \"%s\"}", role, StringEscapeUtils.escapeJson(content)); + static class Message { + + private final static String MESSAGE_FIELD_ROLE = "role"; + private final static String MESSAGE_FIELD_CONTENT = "content"; + + @Getter + private Role role; + @Getter + private String content; + + private JsonObject json; + + public Message() { + json = new JsonObject(); + } + + public Message(Role role, String content) { + this(); + setRole(role); + setContent(content); + } + + public void setRole(Role role) { + json.remove(MESSAGE_FIELD_ROLE); + json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(role.getName())); + } + public void setContent(String content) { + this.content = StringEscapeUtils.escapeJson(content); + json.remove(MESSAGE_FIELD_CONTENT); + json.add(MESSAGE_FIELD_CONTENT, new JsonPrimitive(this.content)); + } + + public JsonObject toJson() { + return json; + } } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java new file mode 100644 index 0000000000..cbd5122371 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAParamUtilTests.java @@ -0,0 +1,39 @@ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GenerativeQAParamUtilTests extends OpenSearchTestCase { + + public void testGenerativeQAParametersMissingParams() { + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + assertNull(actual); + } + + public void testMisc() { + SearchRequest request = new SearchRequest(); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + request.source(new SearchSourceBuilder()); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + request.source(new SearchSourceBuilder().ext(List.of())); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + + SearchExtBuilder extBuilder = mock(SearchExtBuilder.class); + when(extBuilder.getWriteableName()).thenReturn("foo"); + request.source(new SearchSourceBuilder().ext(List.of(extBuilder))); + assertNull(GenerativeQAParamUtil.getGenerativeQAParameters(request)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java index c5dcb40266..cdfce4421f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java @@ -18,16 +18,12 @@ package org.opensearch.searchpipelines.questionanswering.generative; import org.opensearch.action.search.SearchRequest; -import org.opensearch.client.Client; import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.test.OpenSearchTestCase; -import java.util.Collections; import java.util.HashMap; import java.util.Map; -import static org.mockito.Mockito.mock; - public class GenerativeQARequestProcessorTests extends OpenSearchTestCase { public void testProcessorFactory() throws Exception { @@ -45,4 +41,9 @@ public void testProcessRequest() throws Exception { SearchRequest processed = processor.processRequest(request); assertEquals(request, processed); } + + public void testGetType() { + GenerativeQARequestProcessor processor = new GenerativeQARequestProcessor("tag", "desc", false, "foo"); + assertEquals(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, processor.getType()); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index 02ba81af06..98ec14d59e 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -25,21 +25,25 @@ import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; -import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; import org.opensearch.test.OpenSearchTestCase; +import java.time.Instant; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -57,6 +61,13 @@ public void testProcessorFactoryRemoteModel() throws Exception { assertNotNull(processor); } + public void testGetType() { + Client client = mock(Client.class); + Llm llm = mock(Llm.class); + GenerativeQAResponseProcessor processor = new GenerativeQAResponseProcessor(client, null, null, false, llm, "foo", List.of("text")); + assertEquals(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, processor.getType()); + } + public void testProcessResponseNoSearchHits() throws Exception { Client client = mock(Client.class); Map config = new HashMap<>(); @@ -109,8 +120,12 @@ public void testProcessResponse() throws Exception { GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client) .create(null, "tag", "desc", true, config, null); - SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); - SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())).thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); extBuilder.setParams(params); @@ -150,4 +165,59 @@ public void testProcessResponse() throws Exception { assertEquals("passage1", passages.get(1)); assertTrue(res instanceof GenerativeSearchResponse); } + + public void testProcessResponseMissingContextField() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client) + .create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + when(memoryClient.getInteractions(any(), anyInt())).thenReturn(List.of(new Interaction("0", Instant.now(), "1", "question", "", "answer", "foo", "{}"))); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + //.field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + boolean exceptionThrown = false; + + try { + SearchResponse res = processor.processResponse(request, response); + } catch (Exception e) { + exceptionThrown = true; + } + + assertTrue(exceptionThrown); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java new file mode 100644 index 0000000000..cead38b0a0 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponseTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentGenerator; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.io.OutputStream; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class GenerativeSearchResponseTests extends OpenSearchTestCase { + + public void testToXContent() throws IOException { + String answer = "answer"; + SearchResponseSections internal = new SearchResponseSections(new SearchHits(new SearchHit[0], null, 0), null, null, false, false, null, 0); + GenerativeSearchResponse searchResponse = new GenerativeSearchResponse(answer, internal, null, 0, 0, 0, 0, new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java new file mode 100644 index 0000000000..67038d93cd --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClientTests.java @@ -0,0 +1,174 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import org.mockito.ArgumentCaptor; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.ml.common.conversation.Interaction; +import org.opensearch.ml.memory.action.conversation.CreateConversationAction; +import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; +import org.opensearch.ml.memory.action.conversation.CreateConversationResponse; +import org.opensearch.ml.memory.action.conversation.CreateInteractionAction; +import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse; +import org.opensearch.ml.memory.action.conversation.GetInteractionsAction; +import org.opensearch.ml.memory.action.conversation.GetInteractionsRequest; +import org.opensearch.ml.memory.action.conversation.GetInteractionsResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.stream.IntStream; + +import static org.mockito.Mockito.*; + +public class ConversationalMemoryClientTests extends OpenSearchTestCase { + + public void testCreateConversation() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + ArgumentCaptor captor = ArgumentCaptor.forClass(CreateConversationRequest.class); + String conversationId = UUID.randomUUID().toString(); + CreateConversationResponse response = new CreateConversationResponse(conversationId); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response); + when(client.execute(eq(CreateConversationAction.INSTANCE), any())).thenReturn(future); + String name = "foo"; + String actual = memoryClient.createConversation(name); + verify(client, times(1)).execute(eq(CreateConversationAction.INSTANCE), captor.capture()); + assertEquals(name, captor.getValue().getName()); + assertEquals(conversationId, actual); + } + + public void testGetInteractionsNoPagination() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + int lastN = 5; + String conversationId = UUID.randomUUID().toString(); + List interactions = new ArrayList<>(); + IntStream.range(0, lastN).forEach(i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response = new GetInteractionsResponse(interactions, lastN, false); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + + List actual = memoryClient.getInteractions(conversationId, lastN); + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); + GetInteractionsRequest actualRequest = captor.getValue(); + assertEquals(lastN, actual.size()); + assertEquals(conversationId, actualRequest.getConversationId()); + assertEquals(lastN, actualRequest.getMaxResults()); + assertEquals(0, actualRequest.getFrom()); + } + + public void testGetInteractionsWithPagination() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + int lastN = 5; + String conversationId = UUID.randomUUID().toString(); + List firstPage = new ArrayList<>(); + IntStream.range(0, lastN).forEach(i -> firstPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response1 = new GetInteractionsResponse(firstPage, lastN, true); + List secondPage = new ArrayList<>(); + IntStream.range(0, lastN).forEach(i -> secondPage.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response2 = new GetInteractionsResponse(secondPage, lastN, false); + ActionFuture future1 = mock(ActionFuture.class); + when(future1.actionGet()).thenReturn(response1); + ActionFuture future2 = mock(ActionFuture.class); + when(future2.actionGet()).thenReturn(response2); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future1).thenReturn(future2); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + + List actual = memoryClient.getInteractions(conversationId, 2*lastN); + // Called twice + verify(client, times(2)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); + List actualRequests = captor.getAllValues(); + assertEquals(2*lastN, actual.size()); + assertEquals(conversationId, actualRequests.get(0).getConversationId()); + assertEquals(2*lastN, actualRequests.get(0).getMaxResults()); + assertEquals(0, actualRequests.get(0).getFrom()); + assertEquals(lastN, actualRequests.get(1).getFrom()); + } + + public void testGetInteractionsNoMoreResults() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + int lastN = 5; + int found = lastN - 1; + String conversationId = UUID.randomUUID().toString(); + List interactions = new ArrayList<>(); + // Return fewer results than requested + IntStream.range(0, found).forEach(i -> interactions.add(new Interaction(Integer.toString(i), Instant.now(), conversationId, "foo", "bar", "x", "y", null))); + GetInteractionsResponse response = new GetInteractionsResponse(interactions, found, false); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + ArgumentCaptor captor = ArgumentCaptor.forClass(GetInteractionsRequest.class); + + List actual = memoryClient.getInteractions(conversationId, lastN); + verify(client, times(1)).execute(eq(GetInteractionsAction.INSTANCE), captor.capture()); + GetInteractionsRequest actualRequest = captor.getValue(); + assertEquals(found, actual.size()); + assertEquals(conversationId, actualRequest.getConversationId()); + assertEquals(lastN, actualRequest.getMaxResults()); + assertEquals(0, actualRequest.getFrom()); + } + + public void testAvoidInfiniteLoop() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + GetInteractionsResponse response1 = new GetInteractionsResponse(null, 0, true); + GetInteractionsResponse response2 = new GetInteractionsResponse(List.of(), 0, true); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response1).thenReturn(response2); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + List actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + } + + public void testNoResults() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + GetInteractionsResponse response1 = new GetInteractionsResponse(null, 0, true); + GetInteractionsResponse response2 = new GetInteractionsResponse(List.of(), 0, false); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(response1).thenReturn(response2); + when(client.execute(eq(GetInteractionsAction.INSTANCE), any())).thenReturn(future); + List actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + actual = memoryClient.getInteractions("1", 10); + assertTrue(actual.isEmpty()); + } + + public void testCreateInteraction() { + Client client = mock(Client.class); + ConversationalMemoryClient memoryClient = new ConversationalMemoryClient(client); + String id = UUID.randomUUID().toString(); + CreateInteractionResponse res = new CreateInteractionResponse(id); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(res); + when(client.execute(eq(CreateInteractionAction.INSTANCE), any())).thenReturn(future); + String actual = memoryClient.createInteraction("cid", "input", "prompt", "answer", "origin", "hits"); + assertEquals(id, actual); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java index df00113f58..b05b52062c 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -17,14 +17,24 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.ext; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.StreamInput; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentHelper; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.test.OpenSearchTestCase; import java.io.EOFException; import java.io.IOException; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { public void testCtor() throws IOException { @@ -68,14 +78,50 @@ public int read() throws IOException { assertNotNull(builder1); } - public void testMiscMethods() { + public void testMiscMethods() throws IOException { GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d"); GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(); GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder(); builder1.setParams(param1); builder2.setParams(param2); + assertEquals(builder1, builder1); + assertNotEquals(builder1, param1); assertNotEquals(builder1, builder2); assertNotEquals(builder1.hashCode(), builder2.hashCode()); + + StreamOutput so = mock(StreamOutput.class); + builder1.writeTo(so); + verify(so, times(2)).writeOptionalString(any()); + verify(so, times(1)).writeString(any()); + } + + public void testParse() throws IOException { + XContentParser xcParser = mock(XContentParser.class); + when(xcParser.nextToken()).thenReturn(XContentParser.Token.START_OBJECT).thenReturn(XContentParser.Token.END_OBJECT); + GenerativeQAParamExtBuilder builder = GenerativeQAParamExtBuilder.parse(xcParser); + assertNotNull(builder); + assertNotNull(builder.getParams()); + } + + public void testXContentRoundTrip() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(param1); + XContentType xContentType = randomFrom(XContentType.values()); + BytesReference serialized = XContentHelper.toXContent(extBuilder, xContentType, true); + XContentParser parser = createParser(xContentType.xContent(), serialized); + GenerativeQAParamExtBuilder deserialized = GenerativeQAParamExtBuilder.parse(parser); + assertEquals(extBuilder, deserialized); + } + + public void testStreamRoundTrip() throws IOException { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(param1); + BytesStreamOutput bso = new BytesStreamOutput(); + extBuilder.writeTo(bso); + GenerativeQAParamExtBuilder deserialized = new GenerativeQAParamExtBuilder(bso.bytes().streamInput()); + assertEquals(extBuilder, deserialized); } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java index 9811d5e768..c6cf3e9399 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java @@ -19,16 +19,10 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; import org.opensearch.test.OpenSearchTestCase; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNull; - public class GenerativeQAParamUtilTests extends OpenSearchTestCase { public void testGenerativeQAParametersMissingParams() { diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java index a18af80200..b2f9d9dc2f 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -19,16 +19,20 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentGenerator; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; -import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.io.OutputStream; import java.util.ArrayList; import java.util.List; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + public class GenerativeQAParametersTests extends OpenSearchTestCase { public void testGenerativeQAParameters() { @@ -92,4 +96,30 @@ public void testWriteTo() throws IOException { assertEquals(llmModel, actual.get(1)); assertEquals(llmQuestion, actual.get(2)); } + + public void testMisc() { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + assertNotEquals(parameters, null); + assertNotEquals(parameters, "foo"); + assertEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, llmQuestion)); + assertNotEquals(parameters, new GenerativeQAParameters("", llmModel, llmQuestion)); + assertNotEquals(parameters, new GenerativeQAParameters(conversationId, "", llmQuestion)); + assertNotEquals(parameters, new GenerativeQAParameters(conversationId, llmModel, "")); + } + + public void testToXConent() throws IOException { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + assertNotNull(parameters.toXContent(builder, null)); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index 4c404cd4b5..925b84b8b1 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -17,10 +17,14 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.test.OpenSearchTestCase; +import java.time.Instant; import java.util.Collections; import java.util.List; +import java.util.Map; public class ChatCompletionInputTests extends OpenSearchTestCase { @@ -36,12 +40,16 @@ public void testCtor() { public void testGettersSetters() { String model = "model"; String question = "question"; - List history = List.of("hello"); + List history = List.of(Interaction.fromMap("1", + Map.of( + ConversationalIndexConstants.INTERACTIONS_CONVERSATION_ID_FIELD, "convo1", + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "hello"))); List contexts = List.of("result1", "result2"); ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts); assertEquals(model, input.getModel()); assertEquals(question, input.getQuestion()); - assertEquals(history.get(0), input.getChatHistory().get(0)); + assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); assertEquals(contexts.get(0), input.getContexts().get(0)); assertEquals(contexts.get(1), input.getContexts().get(1)); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index b5a5421fa4..0aba017245 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -24,6 +24,8 @@ import org.mockito.Mock; import org.opensearch.common.action.ActionFuture; import org.opensearch.client.Client; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; @@ -35,6 +37,7 @@ import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; import org.opensearch.test.OpenSearchTestCase; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -52,13 +55,17 @@ public void testBuildMessageParameter() { DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); String question = "Who am I"; List contexts = new ArrayList<>(); - List chatHistory = new ArrayList<>(); contexts.add("context 1"); contexts.add("context 2"); - chatHistory.add("message 1"); - chatHistory.add("message 2"); + List chatHistory = List.of(Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer1")), + Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer2"))); String parameter = PromptUtil.getChatCompletionPrompt(question, chatHistory, contexts); - //System.out.println(parameter); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java index bf2ec3bf7c..5d8395126b 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -23,6 +23,10 @@ public class LlmIOUtilTests extends OpenSearchTestCase { + public void testCtor() { + assertNotNull(new LlmIOUtil()); + } + public void testChatCompletionInput() { ChatCompletionInput input = LlmIOUtil.createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); assertTrue(input instanceof ChatCompletionInput); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java index 2fdbfc0fe1..dd3fed1c9d 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -20,8 +20,11 @@ import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import org.opensearch.ml.common.conversation.ConversationalIndexConstants; +import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.test.OpenSearchTestCase; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -36,11 +39,16 @@ public void testPromptUtilStaticMethods() { public void testBuildMessageParameter() { String question = "Who am I"; List contexts = new ArrayList<>(); - List chatHistory = new ArrayList<>(); + List chatHistory = List.of(Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer1")), + Interaction.fromMap("convo1", Map.of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, "answer2"))); contexts.add("context 1"); contexts.add("context 2"); - chatHistory.add("message 1"); - chatHistory.add("message 2"); String parameter = PromptUtil.buildMessageParameter(question, chatHistory, contexts); Map parameters = Map.of("model", "foo", "messages", parameter); assertTrue(isJson(parameter)); From 8f4d90dc0809d68200a802f1ac8862be4cfd0d04 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Fri, 1 Sep 2023 14:15:33 -0700 Subject: [PATCH 3/7] Fix spotless Signed-off-by: Austin Lee --- .../ml/plugin/MachineLearningPlugin.java | 38 ++++++++++--------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 21ba7e9dca..259f8568c4 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -344,7 +344,8 @@ public Collection createComponents( // TODO move this into MLFeatureEnabledSetting this.ragSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(clusterService.getSettings()); - clusterService.getClusterSettings() + clusterService + .getClusterSettings() .addSettingsUpdateConsumer(ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); mlIndicesHandler = new MLIndicesHandler(clusterService, client); @@ -677,31 +678,32 @@ public List> getSettings() { @Override public List> getSearchExts() { - return ragSearchPipelineEnabled ? List - .of( - new SearchPlugin.SearchExtSpec<>( - GenerativeQAParamExtBuilder.PARAMETER_NAME, - input -> new GenerativeQAParamExtBuilder(input), - parser -> GenerativeQAParamExtBuilder.parse(parser) + return ragSearchPipelineEnabled + ? List + .of( + new SearchPlugin.SearchExtSpec<>( + GenerativeQAParamExtBuilder.PARAMETER_NAME, + input -> new GenerativeQAParamExtBuilder(input), + parser -> GenerativeQAParamExtBuilder.parse(parser) + ) ) - ) - // Feature not enabled - : Collections.emptyList(); + // Feature not enabled + : Collections.emptyList(); } @Override public Map> getRequestProcessors(Parameters parameters) { - return ragSearchPipelineEnabled ? - Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()) - // Feature not enabled - : Collections.emptyMap(); + return ragSearchPipelineEnabled + ? Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()) + // Feature not enabled + : Collections.emptyMap(); } @Override public Map> getResponseProcessors(Parameters parameters) { - return ragSearchPipelineEnabled ? - Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)) - // Feature not enabled - : Collections.emptyMap(); + return ragSearchPipelineEnabled + ? Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)) + // Feature not enabled + : Collections.emptyMap(); } } From a1df6a193a9b05a2b672432327bdac409438c616 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 4 Sep 2023 08:53:52 -0700 Subject: [PATCH 4/7] Fix RAG feature flag enablement. Signed-off-by: Austin Lee --- .../ml/plugin/MachineLearningPlugin.java | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 259f8568c4..a7c6cf1d65 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -7,7 +7,6 @@ import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; -import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED; import java.nio.file.Path; import java.util.Collection; @@ -251,6 +250,10 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc private volatile boolean ragSearchPipelineEnabled; + public MachineLearningPlugin(Settings settings) { + this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings); + } + @Override public List> getActions() { return ImmutableList @@ -342,12 +345,6 @@ public Collection createComponents( stats.put(MLNodeLevelStat.ML_NODE_TOTAL_CIRCUIT_BREAKER_TRIGGER_COUNT, new MLStat<>(false, new CounterSupplier())); this.mlStats = new MLStats(stats); - // TODO move this into MLFeatureEnabledSetting - this.ragSearchPipelineEnabled = ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(clusterService.getSettings()); - clusterService - .getClusterSettings() - .addSettingsUpdateConsumer(ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); - mlIndicesHandler = new MLIndicesHandler(clusterService, client); mlTaskManager = new MLTaskManager(client, threadPool, mlIndicesHandler); modelHelper = new ModelHelper(mlEngine); @@ -459,6 +456,11 @@ public Collection createComponents( encryptor ); + // TODO move this into MLFeatureEnabledSetting + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it); + return ImmutableList .of( encryptor, @@ -665,7 +667,7 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED, MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX, MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED, - ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED + MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED ); return settings; } From e076c92736c7aeca2719308aba17cf1b918e757b Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 5 Sep 2023 12:34:54 -0700 Subject: [PATCH 5/7] Address review comments and suggestions. Signed-off-by: Austin Lee --- .../generative/client/ConversationalMemoryClient.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index 724eef0b48..d95461bcd2 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -23,6 +23,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.ml.memory.action.conversation.CreateConversationAction; import org.opensearch.ml.memory.action.conversation.CreateConversationRequest; @@ -51,7 +52,7 @@ public class ConversationalMemoryClient { public String createConversation(String name) { CreateConversationResponse response = client.execute(CreateConversationAction.INSTANCE, new CreateConversationRequest(name)).actionGet(); - + log.info("createConversation: id: {}", response.getId()); return response.getId(); } @@ -67,6 +68,8 @@ public String createInteraction(String conversationId, String input, String prom public List getInteractions(String conversationId, int lastN) { + Preconditions.checkArgument(lastN > 0, "lastN must be at least 1."); + log.info("In getInteractions, conversationId {}, lastN {}", conversationId, lastN); List interactions = new ArrayList<>(); @@ -77,7 +80,7 @@ public List getInteractions(String conversationId, int lastN) { GetInteractionsResponse response = client.execute(GetInteractionsAction.INSTANCE, new GetInteractionsRequest(conversationId, maxResults, from)).actionGet(); List list = response.getInteractions(); - if (list != null && !list.isEmpty()) { + if (list != null && !CollectionUtils.isEmpty(list)) { interactions.addAll(list); from += list.size(); maxResults -= list.size(); From 8cbac4cd86d795d8ad3e0ea1bcab0cf258d8f55e Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 5 Sep 2023 14:21:53 -0700 Subject: [PATCH 6/7] Address comments. Signed-off-by: Austin Lee --- .../ml/plugin/MachineLearningPlugin.java | 41 ++++++++++++------- .../client/ConversationalMemoryClient.java | 6 +-- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index a7c6cf1d65..5587099797 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -9,8 +9,9 @@ import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX; import java.nio.file.Path; +import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -680,32 +681,42 @@ public List> getSettings() { @Override public List> getSearchExts() { - return ragSearchPipelineEnabled - ? List - .of( + List> searchExts = new ArrayList<>(); + + if (ragSearchPipelineEnabled) { + searchExts + .add( new SearchPlugin.SearchExtSpec<>( GenerativeQAParamExtBuilder.PARAMETER_NAME, input -> new GenerativeQAParamExtBuilder(input), parser -> GenerativeQAParamExtBuilder.parse(parser) ) - ) - // Feature not enabled - : Collections.emptyList(); + ); + } + + return searchExts; } @Override public Map> getRequestProcessors(Parameters parameters) { - return ragSearchPipelineEnabled - ? Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()) - // Feature not enabled - : Collections.emptyMap(); + Map> requestProcessors = new HashMap<>(); + + if (ragSearchPipelineEnabled) { + requestProcessors.put(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()); + } + + return requestProcessors; } @Override public Map> getResponseProcessors(Parameters parameters) { - return ragSearchPipelineEnabled - ? Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)) - // Feature not enabled - : Collections.emptyMap(); + Map> responseProcessors = new HashMap<>(); + + if (ragSearchPipelineEnabled) { + responseProcessors + .put(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)); + } + + return responseProcessors; } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java index d95461bcd2..84a32b2368 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/ConversationalMemoryClient.java @@ -74,7 +74,7 @@ public List getInteractions(String conversationId, int lastN) { List interactions = new ArrayList<>(); int from = 0; - boolean done = false; + boolean allInteractionsFetched = false; int maxResults = lastN; do { GetInteractionsResponse response = @@ -92,8 +92,8 @@ public List getInteractions(String conversationId, int lastN) { break; } log.info("Interactions: {}, from: {}, maxResults: {}", interactions, from, maxResults); - done = !response.hasMorePages(); - } while (from < lastN && !done); + allInteractionsFetched = !response.hasMorePages(); + } while (from < lastN && !allInteractionsFetched); return interactions; } From 739d96a8b558234f573413b8e03a0cb86ef3314f Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Tue, 5 Sep 2023 15:34:20 -0700 Subject: [PATCH 7/7] Add unit tests for MachineLearningPlugin Signed-off-by: Austin Lee --- .../ml/plugin/MachineLearningPluginTests.java | 105 ++++++++++++++++++ .../generative/prompt/PromptUtil.java | 26 ++--- 2 files changed, 118 insertions(+), 13 deletions(-) create mode 100644 plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java new file mode 100644 index 0000000000..b810f7c439 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -0,0 +1,105 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.ml.plugin; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +import java.util.List; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.common.settings.Settings; +import org.opensearch.plugins.SearchPipelinePlugin; +import org.opensearch.plugins.SearchPlugin; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; + +public class MachineLearningPluginTests { + + @Test + public void testGetSearchExtsFeatureDisabled() { + Settings settings = Settings.builder().build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + List> searchExts = plugin.getSearchExts(); + assertEquals(0, searchExts.size()); + } + + @Test + public void testGetSearchExtsFeatureDisabledExplicit() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "false").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + List> searchExts = plugin.getSearchExts(); + assertEquals(0, searchExts.size()); + } + + @Test + public void testGetRequestProcessorsFeatureDisabled() { + Settings settings = Settings.builder().build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map requestProcessors = plugin.getRequestProcessors(parameters); + assertEquals(0, requestProcessors.size()); + } + + @Test + public void testGetResponseProcessorsFeatureDisabled() { + Settings settings = Settings.builder().build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map responseProcessors = plugin.getResponseProcessors(parameters); + assertEquals(0, responseProcessors.size()); + } + + @Test + public void testGetSearchExts() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + List> searchExts = plugin.getSearchExts(); + assertEquals(1, searchExts.size()); + SearchPlugin.SearchExtSpec spec = searchExts.get(0); + assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName()); + } + + @Test + public void testGetRequestProcessors() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map requestProcessors = plugin.getRequestProcessors(parameters); + assertEquals(1, requestProcessors.size()); + assertTrue( + requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory + ); + } + + @Test + public void testGetResponseProcessors() { + Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); + MachineLearningPlugin plugin = new MachineLearningPlugin(settings); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map responseProcessors = plugin.getResponseProcessors(parameters); + assertEquals(1, responseProcessors.size()); + assertTrue( + responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory + ); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 45b72a41b8..10e5a924c6 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -59,7 +59,7 @@ public static String getChatCompletionPrompt(String question, List return buildMessageParameter(question, chatHistory, contexts); } - enum Role { + enum ChatRole { USER("user"), ASSISTANT("assistant"), SYSTEM("system"); @@ -69,7 +69,7 @@ enum Role { @Getter private String name; - Role(String name) { + ChatRole(String name) { this.name = name; } } @@ -80,15 +80,15 @@ static String buildMessageParameter(String question, List chatHisto // TODO better prompt template management is needed here. JsonArray messageArray = new JsonArray(); - messageArray.add(new Message(Role.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); + messageArray.add(new Message(ChatRole.USER, DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE).toJson()); for (String result : contexts) { - messageArray.add(new Message(Role.USER, "SEARCH RESULT: " + result).toJson()); + messageArray.add(new Message(ChatRole.USER, "SEARCH RESULT: " + result).toJson()); } if (!chatHistory.isEmpty()) { Messages.fromInteractions(chatHistory).getMessages().forEach(m -> messageArray.add(m.toJson())); } - messageArray.add(new Message(Role.USER, "QUESTION: " + question).toJson()); - messageArray.add(new Message(Role.USER, "ANSWER:").toJson()); + messageArray.add(new Message(ChatRole.USER, "QUESTION: " + question).toJson()); + messageArray.add(new Message(ChatRole.USER, "ANSWER:").toJson()); return messageArray.toString(); } @@ -114,8 +114,8 @@ public static Messages fromInteractions(final List interactions) { List messages = new ArrayList<>(); for (Interaction interaction : interactions) { - messages.add(new Message(Role.USER, interaction.getInput())); - messages.add(new Message(Role.ASSISTANT, interaction.getResponse())); + messages.add(new Message(ChatRole.USER, interaction.getInput())); + messages.add(new Message(ChatRole.ASSISTANT, interaction.getResponse())); } return new Messages(messages); @@ -128,7 +128,7 @@ static class Message { private final static String MESSAGE_FIELD_CONTENT = "content"; @Getter - private Role role; + private ChatRole chatRole; @Getter private String content; @@ -138,15 +138,15 @@ public Message() { json = new JsonObject(); } - public Message(Role role, String content) { + public Message(ChatRole chatRole, String content) { this(); - setRole(role); + setChatRole(chatRole); setContent(content); } - public void setRole(Role role) { + public void setChatRole(ChatRole chatRole) { json.remove(MESSAGE_FIELD_ROLE); - json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(role.getName())); + json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName())); } public void setContent(String content) { this.content = StringEscapeUtils.escapeJson(content);