From 0c3999337fa8c1a7b6934d1541a1364e3f21ddb9 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Fri, 1 Sep 2023 12:24:01 -0700 Subject: [PATCH] Use Search Pipeline processors, Remote Inference and HttpConnector to enable Retrieval Augmented Generation (RAG) (#1195) * Use Search Pipeline processors, Remote Inference and HttpConnector to enable Retrieval Augmented Generation (RAG) (https://github.com/opensearch-project/ml-commons/issues/1150) Signed-off-by: Austin Lee * Address test coverage. Signed-off-by: Austin Lee * Fix/update imports due to changes coming from core. Signed-off-by: Austin Lee * Update license header. Signed-off-by: Austin Lee * Address comments. Signed-off-by: Austin Lee * Use List for context fields so we can pull contexts from multiple fields when constructing contexts for LLMs. Signed-off-by: Austin Lee * Address review comments. Signed-off-by: Austin Lee * Fix spotless issue. Signed-off-by: Austin Lee * Update README. Signed-off-by: Austin Lee * Fix ml-client shadowJar implicit dependency issue. Signed-off-by: Austin Lee * Add a wrapper client for ML predict. Signed-off-by: Austin Lee * Add tests for the internal ML client. Signed-off-by: Austin Lee --------- Signed-off-by: Austin Lee Signed-off-by: Austin Lee --- plugin/build.gradle | 1 + .../ml/plugin/MachineLearningPlugin.java | 33 +++- search-processors/README.md | 95 +++++++++++ search-processors/build.gradle | 74 +++++++++ .../GenerativeQAProcessorConstants.java | 36 +++++ .../GenerativeQARequestProcessor.java | 69 ++++++++ .../GenerativeQAResponseProcessor.java | 145 +++++++++++++++++ .../generative/GenerativeSearchResponse.java | 66 ++++++++ .../client/MachineLearningInternalClient.java | 80 +++++++++ .../ext/GenerativeQAParamExtBuilder.java | 88 ++++++++++ .../generative/ext/GenerativeQAParamUtil.java | 49 ++++++ .../ext/GenerativeQAParameters.java | 110 +++++++++++++ .../generative/llm/ChatCompletionInput.java | 40 +++++ .../generative/llm/ChatCompletionOutput.java | 37 +++++ .../generative/llm/DefaultLlmImpl.java | 99 ++++++++++++ .../questionanswering/generative/llm/Llm.java | 26 +++ .../generative/llm/LlmIOUtil.java | 33 ++++ .../generative/llm/ModelLocator.java | 36 +++++ .../generative/prompt/PromptUtil.java | 74 +++++++++ .../GenerativeQARequestProcessorTests.java | 48 ++++++ .../GenerativeQAResponseProcessorTests.java | 153 ++++++++++++++++++ .../MachineLearningInternalClientTests.java | 105 ++++++++++++ .../ext/GenerativeQAParamExtBuilderTests.java | 81 ++++++++++ .../ext/GenerativeQAParamUtilTests.java | 41 +++++ .../ext/GenerativeQAParametersTests.java | 95 +++++++++++ .../llm/ChatCompletionInputTests.java | 48 ++++++ .../llm/ChatCompletionOutputTests.java | 36 +++++ .../generative/llm/DefaultLlmImplTests.java | 99 ++++++++++++ .../generative/llm/LlmIOUtilTests.java | 30 ++++ .../generative/llm/ModelLocatorTests.java | 32 ++++ .../generative/prompt/PromptUtilTests.java | 61 +++++++ settings.gradle | 4 + 32 files changed, 2023 insertions(+), 1 deletion(-) create mode 100644 search-processors/README.md create mode 100644 search-processors/build.gradle create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java create mode 100644 search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java create mode 100644 search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java diff --git a/plugin/build.gradle b/plugin/build.gradle index cd6e9cbd3f..8b885a7f96 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -44,6 +44,7 @@ opensearchplugin { dependencies { implementation project(':opensearch-ml-common') implementation project(':opensearch-ml-algorithms') + implementation project(':opensearch-ml-search-processors') implementation project(':opensearch-ml-memory') implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index cbaedb45fd..70b827261c 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -184,10 +184,19 @@ import org.opensearch.monitor.os.OsService; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; +import org.opensearch.plugins.SearchPipelinePlugin; +import org.opensearch.plugins.SearchPlugin; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; import org.opensearch.script.ScriptService; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; import org.opensearch.threadpool.ThreadPool; @@ -197,7 +206,7 @@ import lombok.SneakyThrows; -public class MachineLearningPlugin extends Plugin implements ActionPlugin { +public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin { public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons."; public static final String GENERAL_THREAD_POOL = "opensearch_ml_general"; public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute"; @@ -649,4 +658,26 @@ public List> getSettings() { ); return settings; } + + @Override + public List> getSearchExts() { + return List + .of( + new SearchPlugin.SearchExtSpec<>( + GenerativeQAParamExtBuilder.PARAMETER_NAME, + input -> new GenerativeQAParamExtBuilder(input), + parser -> GenerativeQAParamExtBuilder.parse(parser) + ) + ); + } + + @Override + public Map> getRequestProcessors(Parameters parameters) { + return Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory()); + } + + @Override + public Map> getResponseProcessors(Parameters parameters) { + return Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client)); + } } diff --git a/search-processors/README.md b/search-processors/README.md new file mode 100644 index 0000000000..2b3dc6ed52 --- /dev/null +++ b/search-processors/README.md @@ -0,0 +1,95 @@ +# conversational-search-processors +OpenSearch search processors providing conversational search capabilities +======= +# Plugin for Conversations Using Search Processors in OpenSearch +This repo is a WIP plugin for handling conversations in OpenSearch ([Per this RFC](https://github.com/opensearch-project/ml-commons/issues/1150)). + +Conversational Retrieval Augmented Generation (RAG) is implemented via Search processors that combine user questions and OpenSearch query results as input to an LLM, e.g. OpenAI, and return answers. + +## Creating a search pipeline with the GenerativeQAResponseProcessor + +``` +PUT /_search/pipeline/ +{ + "response_processors": [ + { + "retrieval_augmented_generation": { + "tag": , + "description": , + "model_id": "", + "context_field_list": [] (e.g. ["text"]) + } + } + ] +} +``` + +The 'model_id' parameter here needs to refer to a model of type REMOTE that has an HttpConnector instance associated with it. + +## Making a search request against an index using the above processor +``` +GET //_search\?search_pipeline\= +{ + "_source": ["title", "text"], + "query" : { + "neural": { + "text_vector": { + "query_text": , + "k": (e.g. 10), + "model_id": + } + } + }, + "ext": { + "generative_qa_parameters": { + "llm_model": (e.g. "gpt-3.5-turbo"), + "llm_question": + } + } +} +``` + +## Retrieval Augmented Generation response +``` +{ + "took": 3, + "timed_out": false, + "_shards": { + "total": 3, + "successful": 3, + "skipped": 0, + "failed": 0 + }, + "hits": { + "total": { + "value": 110, + "relation": "eq" + }, + "max_score": 0.55129033, + "hits": [ + { + "_index": "...", + "_id": "...", + "_score": 0.55129033, + "_source": { + "text": "...", + "title": "..." + } + }, + { + ... + } + ... + { + ... + } + ] + }, // end of hits + "ext": { + "retrieval_augmented_generation": { + "answer": "..." + } + } +} +``` +The RAG answer is returned as an "ext" to SearchResponse following the "hits" array. diff --git a/search-processors/build.gradle b/search-processors/build.gradle new file mode 100644 index 0000000000..4371f37e15 --- /dev/null +++ b/search-processors/build.gradle @@ -0,0 +1,74 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'java' + id 'jacoco' + id "io.freefair.lombok" +} + +repositories { + mavenCentral() + mavenLocal() +} + +dependencies { + + compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + implementation 'org.apache.commons:commons-lang3:3.12.0' + //implementation project(':opensearch-ml-client') + implementation project(':opensearch-ml-common') + implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}" + // https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5 + implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1' + implementation("com.google.guava:guava:32.0.1-jre") + implementation group: 'org.json', name: 'json', version: '20230227' + implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' + testImplementation "org.opensearch.test:framework:${opensearch_version}" +} + +test { + include '**/*Tests.class' + systemProperty 'tests.security.manager', 'false' +} + +jacocoTestReport { + dependsOn /*integTest,*/ test + reports { + xml.required = true + html.required = true + } +} + +jacocoTestCoverageVerification { + violationRules { + rule { + limit { + counter = 'LINE' + minimum = 0.65 //TODO: increase coverage to 0.90 + } + limit { + counter = 'BRANCH' + minimum = 0.55 //TODO: increase coverage to 0.85 + } + } + } + dependsOn jacocoTestReport +} + +check.dependsOn jacocoTestCoverageVerification +//jacocoTestCoverageVerification.dependsOn jacocoTestReport diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java new file mode 100644 index 0000000000..e04131afb5 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +public class GenerativeQAProcessorConstants { + + // Identifier for the generative QA request processor + public static final String REQUEST_PROCESSOR_TYPE = "question_rewrite"; + + // Identifier for the generative QA response processor + public static final String RESPONSE_PROCESSOR_TYPE = "retrieval_augmented_generation"; + + // The model_id of the model registered and deployed in OpenSearch. + public static final String CONFIG_NAME_MODEL_ID = "model_id"; + + // The name of the model supported by an LLM, e.g. "gpt-3.5" in OpenAI. + public static final String CONFIG_NAME_LLM_MODEL = "llm_model"; + + // The field in search results that contain the context to be sent to the LLM. + public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list"; +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java new file mode 100644 index 0000000000..ef03b2326f --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessor.java @@ -0,0 +1,69 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; + +import java.util.Map; + +/** + * Defines the request processor for generative QA search pipelines. + */ +public class GenerativeQARequestProcessor extends AbstractProcessor implements SearchRequestProcessor { + + private String modelId; + + protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId) { + super(tag, description, ignoreFailure); + this.modelId = modelId; + } + + @Override + public SearchRequest processRequest(SearchRequest request) throws Exception { + + // TODO Use chat history to rephrase the question with full conversation context. + + return request; + } + + @Override + public String getType() { + return GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE; + } + + public static final class Factory implements Processor.Factory { + + @Override + public SearchRequestProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws Exception { + return new GenerativeQARequestProcessor(tag, description, ignoreFailure, + ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, tag, config, GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID) + ); + } + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java new file mode 100644 index 0000000000..60c746f1d6 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -0,0 +1,145 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Client; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; +import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; + +/** + * Defines the response processor for generative QA search pipelines. + * + */ +@Log4j2 +public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor { + + // TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. + + private final String llmModel; + private final List contextFields; + + @Getter + @Setter + // Mainly for unit testing purpose + private Llm llm; + + protected GenerativeQAResponseProcessor(Client client, String tag, String description, boolean ignoreFailure, + Llm llm, String llmModel, List contextFields) { + super(tag, description, ignoreFailure); + this.llmModel = llmModel; + this.contextFields = contextFields; + this.llm = llm; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { + + log.info("Entering processResponse."); + + List chatHistory = getChatHistory(request); + GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); + String llmQuestion = params.getLlmQuestion(); + String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); + String conversationId = params.getConversationId(); + log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId); + + ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, getSearchResults(response))); + + return insertAnswer(response, (String) output.getAnswers().get(0)); + } + + @Override + public String getType() { + return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; + } + + private SearchResponse insertAnswer(SearchResponse response, String answer) { + return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), + response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters()); + } + + // TODO Integrate with Conversational Memory + private List getChatHistory(SearchRequest request) { + return new ArrayList<>(); + } + + private List getSearchResults(SearchResponse response) { + List searchResults = new ArrayList<>(); + for (SearchHit hit : response.getHits().getHits()) { + Map docSourceMap = hit.getSourceAsMap(); + for (String contextField : contextFields) { + Object context = docSourceMap.get(contextField); + if (context == null) { + log.error("Context " + contextField + " not found in search hit " + hit); + // TODO throw a more meaningful error here? + throw new RuntimeException(); + } + searchResults.add(context.toString()); + } + } + return searchResults; + } + + public static final class Factory implements Processor.Factory { + + private final Client client; + + public Factory(Client client) { + this.client = client; + } + + @Override + public SearchResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) throws Exception { + String modelId = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, tag, config, GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID); + String llmModel = ConfigurationUtils.readOptionalStringProperty(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, tag, config, GenerativeQAProcessorConstants.CONFIG_NAME_LLM_MODEL); + List contextFields = ConfigurationUtils.readList(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, tag, config, GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST); + if (contextFields.isEmpty()) { + throw newConfigurationException(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, tag, GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, "required property can't be empty."); + } + log.info("model_id {}, llm_model {}, context_field_list {}", modelId, llmModel, contextFields); + return new GenerativeQAResponseProcessor(client, tag, description, ignoreFailure, ModelLocator.getLlm(modelId, client), llmModel, contextFields); + } + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java new file mode 100644 index 0000000000..2a22902c9a --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeSearchResponse.java @@ -0,0 +1,66 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; + +/** + * This is an extension of SearchResponse that adds LLM-generated answers to search responses in a dedicated "ext" section. + * + * TODO Add ExtBuilders to SearchResponse and get rid of this class. + */ +public class GenerativeSearchResponse extends SearchResponse { + + private static final String EXT_SECTION_NAME = "ext"; + private static final String GENERATIVE_QA_ANSWER_FIELD_NAME = "answer"; + + private final String answer; + + public GenerativeSearchResponse( + String answer, + SearchResponseSections internalResponse, + String scrollId, + int totalShards, + int successfulShards, + int skippedShards, + long tookInMillis, + ShardSearchFailure[] shardFailures, + Clusters clusters + ) { + super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); + this.answer = answer; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + innerToXContent(builder, params); + /* start of ext */ builder.startObject(EXT_SECTION_NAME); + /* start of our stuff */ builder.startObject(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE); + /* body of our stuff */ builder.field(GENERATIVE_QA_ANSWER_FIELD_NAME, this.answer); + /* end of our stuff */ builder.endObject(); + /* end of ext */ builder.endObject(); + builder.endObject(); + return builder; + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java new file mode 100644 index 0000000000..265c20a76d --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClient.java @@ -0,0 +1,80 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.RequiredArgsConstructor; +import lombok.experimental.FieldDefaults; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import java.util.function.Function; + +/** + * An internal facing ML client adapted from org.opensearch.ml.client.MachineLearningNodeClient. + */ +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +@RequiredArgsConstructor +public class MachineLearningInternalClient { + + Client client; + + public ActionFuture predict(String modelId, MLInput mlInput) { + PlainActionFuture actionFuture = PlainActionFuture.newFuture(); + predict(modelId, mlInput, actionFuture); + return actionFuture; + } + + @VisibleForTesting + void predict(String modelId, MLInput mlInput, ActionListener listener) { + validateMLInput(mlInput, true); + + MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder() + .mlInput(mlInput) + .modelId(modelId) + .dispatchTask(true) + .build(); + client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener)); + } + + private ActionListener getMlPredictionTaskResponseActionListener(ActionListener listener) { + ActionListener internalListener = ActionListener.wrap(predictionResponse -> { + listener.onResponse(predictionResponse.getOutput()); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLTaskResponse predictionResponse = MLTaskResponse.fromActionResponse(res); + return predictionResponse; + }); + return actionListener; + } + + private ActionListener wrapActionListener(final ActionListener listener, final Function recreate) { + ActionListener actionListener = ActionListener.wrap(r-> { + listener.onResponse(recreate.apply(r));; + }, e->{ + listener.onFailure(e); + }); + return actionListener; + } + + private void validateMLInput(MLInput mlInput, boolean requireInput) { + if (mlInput == null) { + throw new IllegalArgumentException("ML Input can't be null"); + } + if(requireInput && mlInput.getInputDataset() == null) { + throw new IllegalArgumentException("input data set can't be null"); + } + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java new file mode 100644 index 0000000000..fc4d5a2222 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilder.java @@ -0,0 +1,88 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; + +import java.io.IOException; +import java.util.Objects; + +/** + * This is the extension builder for generative QA search pipelines. + */ +@NoArgsConstructor +public class GenerativeQAParamExtBuilder extends SearchExtBuilder { + + // The name of the "ext" section containing Generative QA parameters. + public static final String PARAMETER_NAME = "generative_qa_parameters"; + + @Setter + @Getter + private GenerativeQAParameters params; + + public GenerativeQAParamExtBuilder(StreamInput input) throws IOException { + this.params = new GenerativeQAParameters(input); + } + + @Override + public int hashCode() { + return Objects.hash(this.getClass(), this.params); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof GenerativeQAParamExtBuilder)) { + return false; + } + + return Objects.equals(this.getParams(), ((GenerativeQAParamExtBuilder) obj).getParams()); + } + + @Override + public String getWriteableName() { + return PARAMETER_NAME; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + this.params.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.value(params); + } + + public static GenerativeQAParamExtBuilder parse(XContentParser parser) throws IOException { + GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); + GenerativeQAParameters params = GenerativeQAParameters.parse(parser); + builder.setParams(params); + return builder; + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java new file mode 100644 index 0000000000..52da6daa02 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtil.java @@ -0,0 +1,49 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.SearchExtBuilder; + +import java.util.Optional; + +/** + * Utility class for extracting generative QA search pipeline parameters from search requests. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class GenerativeQAParamUtil { + + public static GenerativeQAParameters getGenerativeQAParameters(SearchRequest request) { + GenerativeQAParamExtBuilder builder = null; + if (request.source() != null && request.source().ext() != null && !request.source().ext().isEmpty()) { + Optional b = request.source().ext().stream().filter(bldr -> GenerativeQAParamExtBuilder.PARAMETER_NAME.equals(bldr.getWriteableName())).findFirst(); + if (b.isPresent()) { + builder = (GenerativeQAParamExtBuilder) b.get(); + } + } + + GenerativeQAParameters params = null; + if (builder != null) { + params = builder.getParams(); + } + + return params; + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java new file mode 100644 index 0000000000..04d2b53674 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParameters.java @@ -0,0 +1,110 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import com.google.common.base.Preconditions; +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.common.io.stream.Writeable; +import org.opensearch.core.ParseField; +import org.opensearch.core.xcontent.ObjectParser; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +/** + * Defines parameters for generative QA search pipelines. + * + */ +@AllArgsConstructor +@NoArgsConstructor +public class GenerativeQAParameters implements Writeable, ToXContentObject { + + private static final ObjectParser PARSER; + + private static final ParseField CONVERSATION_ID = new ParseField("conversation_id"); + private static final ParseField LLM_MODEL = new ParseField("llm_model"); + private static final ParseField LLM_QUESTION = new ParseField("llm_question"); + + static { + PARSER = new ObjectParser<>("generative_qa_parameters", GenerativeQAParameters::new); + PARSER.declareString(GenerativeQAParameters::setConversationId, CONVERSATION_ID); + PARSER.declareString(GenerativeQAParameters::setLlmModel, LLM_MODEL); + PARSER.declareString(GenerativeQAParameters::setLlmQuestion, LLM_QUESTION); + } + + @Setter + @Getter + private String conversationId; + + @Setter + @Getter + private String llmModel; + + @Setter + @Getter + private String llmQuestion; + + public GenerativeQAParameters(StreamInput input) throws IOException { + this.conversationId = input.readOptionalString(); + this.llmModel = input.readOptionalString(); + this.llmQuestion = input.readString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + return xContentBuilder.field(CONVERSATION_ID.getPreferredName(), this.conversationId) + .field(LLM_MODEL.getPreferredName(), this.llmModel) + .field(LLM_QUESTION.getPreferredName(), this.llmQuestion); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(conversationId); + out.writeOptionalString(llmModel); + + Preconditions.checkNotNull(llmQuestion, "llm_question must not be null."); + out.writeString(llmQuestion); + } + + public static GenerativeQAParameters parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + GenerativeQAParameters other = (GenerativeQAParameters) o; + return Objects.equals(this.conversationId, other.getConversationId()) + && Objects.equals(this.llmModel, other.getLlmModel()) + && Objects.equals(this.llmQuestion, other.getLlmQuestion()); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java new file mode 100644 index 0000000000..b1ea2f9706 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -0,0 +1,40 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +import java.util.List; + +/** + * Input for LLMs via HttpConnector + */ +@Log4j2 +@Getter +@Setter +@AllArgsConstructor +public class ChatCompletionInput { + + private String model; + private String question; + private List chatHistory; + private List contexts; +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java new file mode 100644 index 0000000000..b9bc891a7a --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutput.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.extern.log4j.Log4j2; + +import java.util.List; + +/** + * Output from LLMs via HttpConnector + */ +@Log4j2 +@Getter +@Setter +@AllArgsConstructor +public class ChatCompletionOutput { + + private List answers; +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java new file mode 100644 index 0000000000..456faafe1c --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -0,0 +1,99 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import com.google.common.annotations.VisibleForTesting; +import lombok.extern.log4j.Log4j2; +import org.opensearch.client.Client; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkNotNull; + +/** + * Wrapper for talking to LLMs via OpenSearch HttpConnector. + */ +@Log4j2 +public class DefaultLlmImpl implements Llm { + + private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model"; + private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages"; + private static final String CONNECTOR_OUTPUT_CHOICES = "choices"; + private static final String CONNECTOR_OUTPUT_MESSAGE = "message"; + private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; + private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; + + private final String openSearchModelId; + + private MachineLearningInternalClient mlClient; + + public DefaultLlmImpl(String openSearchModelId, Client client) { + checkNotNull(openSearchModelId); + this.openSearchModelId = openSearchModelId; + this.mlClient = new MachineLearningInternalClient(client); + + } + + @VisibleForTesting + void setMlClient(MachineLearningInternalClient mlClient) { + this.mlClient = mlClient; + } + + /** + * Use ChatCompletion API to generate an answer. + * + * @param chatCompletionInput + * @return + */ + @Override + public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) { + + Map inputParameters = new HashMap<>(); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, PromptUtil.getChatCompletionPrompt(chatCompletionInput.getQuestion(), chatCompletionInput.getChatHistory(), chatCompletionInput.getContexts())); + MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); + ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); + ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(); + + // Response from the (remote) model + Map dataAsMap = modelOutput.getMlModelOutputs().get(0).getMlModelTensors().get(0).getDataAsMap(); + log.info("dataAsMap: {}", dataAsMap.toString()); + + // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. + + List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + Map firstChoiceMap = (Map) choices.get(0); + log.info("Choices: {}", firstChoiceMap.toString()); + Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); + log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + + return new ChatCompletionOutput(List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT))); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java new file mode 100644 index 0000000000..e850561066 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -0,0 +1,26 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +/** + * Capabilities of large language models, e.g. completion, embeddings, etc. + */ +public interface Llm { + + ChatCompletionOutput doChatCompletion(ChatCompletionInput input); +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java new file mode 100644 index 0000000000..badb1920bd --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -0,0 +1,33 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import java.util.List; + +/** + * Helper class for creating inputs and outputs for different implementations of LLMs. + */ +public class LlmIOUtil { + + public static ChatCompletionInput createChatCompletionInput(String llmModel, String question, List chatHistory, List contexts) { + + // TODO pick the right subclass based on the modelId. + + return new ChatCompletionInput(llmModel, question, chatHistory, contexts); + } +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java new file mode 100644 index 0000000000..1b43574374 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.opensearch.client.Client; + +/** + * Helper class for wiring LLMs based on the model ID. + * + * TODO Should we extend this use case beyond HttpConnectors/Remote Inference? + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class ModelLocator { + + public static Llm getLlm(String modelId, Client client) { + return new DefaultLlmImpl(modelId, client); + } + +} diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java new file mode 100644 index 0000000000..f7b5049bd9 --- /dev/null +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.prompt; + +import com.google.common.annotations.VisibleForTesting; +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.commons.text.StringEscapeUtils; + +import java.util.List; +import java.util.Locale; + +/** + * TODO Should prompt engineering llm-specific? + * + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class PromptUtil { + + private static final String roleUser = "user"; + + public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory) { + return null; + } + + public static String getChatCompletionPrompt(String question, List chatHistory, List contexts) { + return buildMessageParameter(question, chatHistory, contexts); + } + + @VisibleForTesting + static String buildMessageParameter(String question, List chatHistory, List contexts) { + // TODO better prompt template management is needed here. + String instructions = "Generate a concise and informative answer in less than 100 words for the given question, taking into context: " + + "- An enumerated list of search results" + + "- A rephrase of the question that was used to generate the search results" + + "- The conversation history" + + "Cite search results using [${number}] notation." + + "Do not repeat yourself, and NEVER repeat anything in the chat history." + + "If there are any necessary steps or procedures in your answer, enumerate them."; + StringBuffer sb = new StringBuffer(); + sb.append("[\n"); + sb.append(formatMessage(roleUser, instructions)); + sb.append(",\n"); + for (String result : contexts) { + sb.append(formatMessage(roleUser, "SEARCH RESULTS: " + result)); + sb.append(",\n"); + } + sb.append(formatMessage(roleUser, "QUESTION: " + question)); + sb.append(",\n"); + sb.append(formatMessage(roleUser, "ANSWER:")); + sb.append("\n"); + sb.append("]"); + return sb.toString(); + } + + private static String formatMessage(String role, String content) { + return String.format(Locale.ROOT, "{\"role\": \"%s\", \"content\": \"%s\"}", role, StringEscapeUtils.escapeJson(content)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java new file mode 100644 index 0000000000..c5dcb40266 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQARequestProcessorTests.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.client.Client; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class GenerativeQARequestProcessorTests extends OpenSearchTestCase { + + public void testProcessorFactory() throws Exception { + + Map config = new HashMap<>(); + config.put("model_id", "foo"); + SearchRequestProcessor processor = + new GenerativeQARequestProcessor.Factory().create(null, "tag", "desc", true, config, null); + assertTrue(processor instanceof GenerativeQARequestProcessor); + } + + public void testProcessRequest() throws Exception { + GenerativeQARequestProcessor processor = new GenerativeQARequestProcessor("tag", "desc", false, "foo"); + SearchRequest request = new SearchRequest(); + SearchRequest processed = processor.processRequest(request); + assertEquals(request, processed); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java new file mode 100644 index 0000000000..02ba81af06 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -0,0 +1,153 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative; + +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.client.Client; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class GenerativeQAResponseProcessorTests extends OpenSearchTestCase { + + public void testProcessorFactoryRemoteModel() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "xyz"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client) + .create(null, "tag", "desc", true, config, null); + assertNotNull(processor); + } + + public void testProcessResponseNoSearchHits() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client) + .create(null, "tag", "desc", true, config, null); + + SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); + boolean errorThrown = false; + try { + SearchResponse res = processor.processResponse(request, response); + } catch (Exception e) { + errorThrown = true; + } + assertTrue(errorThrown); + } + + public void testProcessResponse() throws Exception { + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(client) + .create(null, "tag", "desc", true, config, null); + + SearchRequest request = new SearchRequest(); // mock(SearchRequest.class); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); // mock(SearchSourceBuilder.class); + GenerativeQAParameters params = new GenerativeQAParameters("12345", "llm_model", "You are kind."); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + ChatCompletionOutput output = mock(ChatCompletionOutput.class); + when(llm.doChatCompletion(any())).thenReturn(output); + when(output.getAnswers()).thenReturn(List.of("foo")); + processor.setLlm(llm); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ChatCompletionInput.class); + SearchResponse res = processor.processResponse(request, response); + verify(llm).doChatCompletion(captor.capture()); + ChatCompletionInput input = captor.getValue(); + assertTrue(input instanceof ChatCompletionInput); + List passages = ((ChatCompletionInput) input).getContexts(); + assertEquals("passage0", passages.get(0)); + assertEquals("passage1", passages.get(1)); + assertTrue(res instanceof GenerativeSearchResponse); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java new file mode 100644 index 0000000000..ce921bac89 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/client/MachineLearningInternalClientTests.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.searchpipelines.questionanswering.generative.client; + +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.MLInputDataset; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.MLPredictionOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Answers.RETURNS_DEEP_STUBS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class MachineLearningInternalClientTests { + @Mock(answer = RETURNS_DEEP_STUBS) + NodeClient client; + + @Mock + MLInputDataset input; + + @Mock + DataFrame output; + + @Mock + ActionListener dataFrameActionListener; + + @InjectMocks + MachineLearningInternalClient machineLearningInternalClient; + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + public void predict() { + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLPredictionOutput predictionOutput = MLPredictionOutput.builder() + .status("Success") + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder() + .output(predictionOutput) + .build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ArgumentCaptor dataFrameArgumentCaptor = ArgumentCaptor.forClass(MLOutput.class); + MLInput mlInput = MLInput.builder() + .algorithm(FunctionName.KMEANS) + .inputDataset(input) + .build(); + machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); + + verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), isA(MLPredictionTaskRequest.class), any()); + verify(dataFrameActionListener).onResponse(dataFrameArgumentCaptor.capture()); + assertEquals(output, ((MLPredictionOutput)dataFrameArgumentCaptor.getValue()).getPredictionResult()); + } + + @Test + public void predict_Exception_WithNullAlgorithm() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("algorithm can't be null"); + MLInput mlInput = MLInput.builder() + .inputDataset(input) + .build(); + machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); + } + + @Test + public void predict_Exception_WithNullDataSet() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("input data set can't be null"); + MLInput mlInput = MLInput.builder() + .algorithm(FunctionName.KMEANS) + .build(); + machineLearningInternalClient.predict(null, mlInput, dataFrameActionListener); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java new file mode 100644 index 0000000000..df00113f58 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamExtBuilderTests.java @@ -0,0 +1,81 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.EOFException; +import java.io.IOException; + +public class GenerativeQAParamExtBuilderTests extends OpenSearchTestCase { + + public void testCtor() throws IOException { + GenerativeQAParamExtBuilder builder = new GenerativeQAParamExtBuilder(); + GenerativeQAParameters parameters = new GenerativeQAParameters(); + builder.setParams(parameters); + assertEquals(parameters, builder.getParams()); + + GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(new StreamInput() { + @Override + public byte readByte() throws IOException { + return 0; + } + + @Override + public void readBytes(byte[] b, int offset, int len) throws IOException { + + } + + @Override + public void close() throws IOException { + + } + + @Override + public int available() throws IOException { + return 0; + } + + @Override + protected void ensureCanReadBytes(int length) throws EOFException { + + } + + @Override + public int read() throws IOException { + return 0; + } + }); + + assertNotNull(builder1); + } + + public void testMiscMethods() { + GenerativeQAParameters param1 = new GenerativeQAParameters("a", "b", "c"); + GenerativeQAParameters param2 = new GenerativeQAParameters("a", "b", "d"); + GenerativeQAParamExtBuilder builder1 = new GenerativeQAParamExtBuilder(); + GenerativeQAParamExtBuilder builder2 = new GenerativeQAParamExtBuilder(); + builder1.setParams(param1); + builder2.setParams(param2); + assertNotEquals(builder1, builder2); + assertNotEquals(builder1.hashCode(), builder2.hashCode()); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java new file mode 100644 index 0000000000..9811d5e768 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParamUtilTests.java @@ -0,0 +1,41 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +public class GenerativeQAParamUtilTests extends OpenSearchTestCase { + + public void testGenerativeQAParametersMissingParams() { + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + assertNull(actual); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java new file mode 100644 index 0000000000..a18af80200 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/ext/GenerativeQAParametersTests.java @@ -0,0 +1,95 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.ext; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; +import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class GenerativeQAParametersTests extends OpenSearchTestCase { + + public void testGenerativeQAParameters() { + GenerativeQAParameters params = new GenerativeQAParameters("conversation_id", "llm_model", "llm_question"); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(params); + SearchSourceBuilder srcBulder = SearchSourceBuilder.searchSource().ext(List.of(extBuilder)); + SearchRequest request = new SearchRequest("my_index").source(srcBulder); + GenerativeQAParameters actual = GenerativeQAParamUtil.getGenerativeQAParameters(request); + assertEquals(params, actual); + } + + static class DummyStreamOutput extends StreamOutput { + + List list = new ArrayList<>(); + + @Override + public void writeString(String str) { + list.add(str); + } + + public List getList() { + return list; + } + + @Override + public void writeByte(byte b) throws IOException { + + } + + @Override + public void writeBytes(byte[] b, int offset, int length) throws IOException { + + } + + @Override + public void flush() throws IOException { + + } + + @Override + public void close() throws IOException { + + } + + @Override + public void reset() throws IOException { + + } + } + public void testWriteTo() throws IOException { + String conversationId = "a"; + String llmModel = "b"; + String llmQuestion = "c"; + GenerativeQAParameters parameters = new GenerativeQAParameters(conversationId, llmModel, llmQuestion); + StreamOutput output = new DummyStreamOutput(); + parameters.writeTo(output); + List actual = ((DummyStreamOutput) output).getList(); + assertEquals(3, actual.size()); + assertEquals(conversationId, actual.get(0)); + assertEquals(llmModel, actual.get(1)); + assertEquals(llmQuestion, actual.get(2)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java new file mode 100644 index 0000000000..4c404cd4b5 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.List; + +public class ChatCompletionInputTests extends OpenSearchTestCase { + + public void testCtor() { + String model = "model"; + String question = "question"; + + ChatCompletionInput input = new ChatCompletionInput(model, question, Collections.emptyList(), Collections.emptyList()); + + assertNotNull(input); + } + + public void testGettersSetters() { + String model = "model"; + String question = "question"; + List history = List.of("hello"); + List contexts = List.of("result1", "result2"); + ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts); + assertEquals(model, input.getModel()); + assertEquals(question, input.getQuestion()); + assertEquals(history.get(0), input.getChatHistory().get(0)); + assertEquals(contexts.get(0), input.getContexts().get(0)); + assertEquals(contexts.get(1), input.getContexts().get(1)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java new file mode 100644 index 0000000000..c3f6c68688 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionOutputTests.java @@ -0,0 +1,36 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +public class ChatCompletionOutputTests extends OpenSearchTestCase { + + public void testCtor() { + ChatCompletionOutput output = new ChatCompletionOutput(List.of("answer")); + assertNotNull(output); + } + + public void testGettersSetters() { + String answer = "answer"; + ChatCompletionOutput output = new ChatCompletionOutput(List.of(answer)); + assertEquals(answer, (String) output.getAnswers().get(0)); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java new file mode 100644 index 0000000000..b5a5421fa4 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -0,0 +1,99 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.opensearch.common.action.ActionFuture; +import org.opensearch.client.Client; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; +import org.opensearch.ml.common.output.model.MLResultDataType; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; +import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.*; + +public class DefaultLlmImplTests extends OpenSearchTestCase { + + @Mock + Client client; + + public void testBuildMessageParameter() { + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = new ArrayList<>(); + contexts.add("context 1"); + contexts.add("context 2"); + chatHistory.add("message 1"); + chatHistory.add("message 2"); + String parameter = PromptUtil.getChatCompletionPrompt(question, chatHistory, contexts); + //System.out.println(parameter); + Map parameters = Map.of("model", "foo", "messages", parameter); + assertTrue(isJson(parameter)); + } + + public void testChatCompletionApi() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map messageMap = Map.of("role", "agent", "content", "answer"); + Map dataAsMap = Map.of("choices", List.of(Map.of("message", messageMap))); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet()).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertEquals("answer", (String) output.getAnswers().get(0)); + } + + private boolean isJson(String Json) { + try { + new JSONObject(Json); + } catch (JSONException ex) { + try { + new JSONArray(Json); + } catch (JSONException ex1) { + return false; + } + } + return true; + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java new file mode 100644 index 0000000000..bf2ec3bf7c --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -0,0 +1,30 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; + +public class LlmIOUtilTests extends OpenSearchTestCase { + + public void testChatCompletionInput() { + ChatCompletionInput input = LlmIOUtil.createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList()); + assertTrue(input instanceof ChatCompletionInput); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java new file mode 100644 index 0000000000..dcf3d223fb --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ModelLocatorTests.java @@ -0,0 +1,32 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.llm; + +import org.opensearch.client.Client; +import org.opensearch.test.OpenSearchTestCase; + +import static org.mockito.Mockito.mock; + +public class ModelLocatorTests extends OpenSearchTestCase { + + public void testGetRemoteLlm() { + Client client = mock(Client.class); + Llm llm = ModelLocator.getLlm("xyz", client); + assertTrue(llm instanceof DefaultLlmImpl); + } +} diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java new file mode 100644 index 0000000000..2fdbfc0fe1 --- /dev/null +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -0,0 +1,61 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.searchpipelines.questionanswering.generative.prompt; + +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class PromptUtilTests extends OpenSearchTestCase { + + public void testPromptUtilStaticMethods() { + assertNull(PromptUtil.getQuestionRephrasingPrompt("question", Collections.emptyList())); + } + + public void testBuildMessageParameter() { + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = new ArrayList<>(); + contexts.add("context 1"); + contexts.add("context 2"); + chatHistory.add("message 1"); + chatHistory.add("message 2"); + String parameter = PromptUtil.buildMessageParameter(question, chatHistory, contexts); + Map parameters = Map.of("model", "foo", "messages", parameter); + assertTrue(isJson(parameter)); + } + + private boolean isJson(String Json) { + try { + new JSONObject(Json); + } catch (JSONException ex) { + try { + new JSONArray(Json); + } catch (JSONException ex1) { + return false; + } + } + return true; + } +} diff --git a/settings.gradle b/settings.gradle index cb60211c8c..bf697450c1 100644 --- a/settings.gradle +++ b/settings.gradle @@ -13,5 +13,9 @@ include 'plugin' project(":plugin").name = rootProject.name + "-plugin" include 'ml-algorithms' project(":ml-algorithms").name = rootProject.name + "-algorithms" +include 'search-processors' +project(":search-processors").name = rootProject.name + "-search-processors" +include 'conversational-memory' +project(":conversational-memory").name = rootProject.name + "-memory" include 'memory' project(":memory").name = rootProject.name + "-memory"