Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Retrieval Augmented Generation search processors #1275

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -247,6 +249,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc

private ConversationalMemoryHandler cmHandler;

private volatile boolean ragSearchPipelineEnabled;

public MachineLearningPlugin(Settings settings) {
this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
Expand Down Expand Up @@ -449,6 +457,11 @@ public Collection<Object> createComponents(
encryptor
);

// TODO move this into MLFeatureEnabledSetting
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it);

return ImmutableList
.of(
encryptor,
Expand Down Expand Up @@ -654,30 +667,56 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX,
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED
);
return settings;
}

/**
*
* Search processors for Retrieval Augmented Generation
*
*/

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List
.of(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.PARAMETER_NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
List<SearchPlugin.SearchExtSpec<?>> searchExts = new ArrayList<>();

if (ragSearchPipelineEnabled) {
searchExts
.add(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.PARAMETER_NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
}

return searchExts;
}

@Override
public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory());
Map<String, Processor.Factory<SearchRequestProcessor>> requestProcessors = new HashMap<>();

if (ragSearchPipelineEnabled) {
requestProcessors.put(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory());
}

return requestProcessors;
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
Map<String, Processor.Factory<SearchResponseProcessor>> responseProcessors = new HashMap<>();

if (ragSearchPipelineEnabled) {
responseProcessors
.put(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
}

return responseProcessors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,7 @@ private MLCommonsSettings() {}
);

public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED;
// Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference.
public static final Setting<Boolean> ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED = Setting
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious to know what would actually happen if someone were to enable the ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED flag but not enable the ML_COMMONS_MEMORY_FEATURE_ENABLED flag, or vice versa?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pipeline, no Memory: you can use the pipeline without memory anyway. Just a straightforward RAG pipeline without any conversation history. Trying to use memory anyway with that disabled will throw the feature flag error.
Memory, no Pipeline: you can use this memory implementation for your crazy langchain app that lives in come other ecosystem entirely.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Memory" is the most fundamental piece and it can be used independently by any application or a REST client. You CAN use RAG without memory as well if you are not passing a conversationId. I think we need to make sure we document these use cases and scenarios as part of the 2.10.0 release.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HenryL27 could you please explain Trying to use memory anyway with that disabled will throw the feature flag error. --> I assume customer will not see this error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhrubo-os this is the error we throw:

new OpenSearchException(
                        "The experimental Conversation Memory feature is not enabled. To enable, please update the setting "
                            + ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED.getKey()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I'm trying to understand it step by step:

  1. I enabled ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED
  2. I asked a question in the RAG pipeline now. Can I continue my Q/A interaction with LLM now? Or it will show me this OpenSearchException to enable the other feature flag?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't. You will only get the error if you pass it a conversionId.

.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.plugin;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

import java.util.List;
import java.util.Map;

import org.junit.Test;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;

public class MachineLearningPluginTests {

@Test
public void testGetSearchExtsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(0, searchExts.size());
}

@Test
public void testGetSearchExtsFeatureDisabledExplicit() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "false").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(0, searchExts.size());
}

@Test
public void testGetRequestProcessorsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(0, requestProcessors.size());
}

@Test
public void testGetResponseProcessorsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(0, responseProcessors.size());
}

@Test
public void testGetSearchExts() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(1, searchExts.size());
SearchPlugin.SearchExtSpec<?> spec = searchExts.get(0);
assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName());
}

@Test
public void testGetRequestProcessors() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(1, requestProcessors.size());
assertTrue(
requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory
);
}

@Test
public void testGetResponseProcessors() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(1, responseProcessors.size());
assertTrue(
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory
);
}
}
7 changes: 4 additions & 3 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ repositories {
dependencies {

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation 'org.apache.commons:commons-lang3:3.12.0'
//implementation project(':opensearch-ml-client')
implementation project(':opensearch-ml-common')
implementation project(':opensearch-ml-memory')
implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}"
// https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1'
Expand Down Expand Up @@ -59,16 +61,15 @@ jacocoTestCoverageVerification {
rule {
limit {
counter = 'LINE'
minimum = 0.65 //TODO: increase coverage to 0.90
minimum = 0.8
}
limit {
counter = 'BRANCH'
minimum = 0.55 //TODO: increase coverage to 0.85
minimum = 0.8
}
}
}
dependsOn jacocoTestReport
}

check.dependsOn jacocoTestCoverageVerification
//jacocoTestCoverageVerification.dependsOn jacocoTestReport
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,31 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -48,11 +54,16 @@
@Log4j2
public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {

private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;

@Setter
private ConversationalMemoryClient memoryClient;

@Getter
@Setter
// Mainly for unit testing purpose
Expand All @@ -64,40 +75,46 @@ protected GenerativeQAResponseProcessor(Client client, String tag, String descri
this.llmModel = llmModel;
this.contextFields = contextFields;
this.llm = llm;
this.memoryClient = new ConversationalMemoryClient(client);
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {

log.info("Entering processResponse.");

List<String> chatHistory = getChatHistory(request);
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
String conversationId = params.getConversationId();
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW);
List<String> searchResults = getSearchResults(response);
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults));
String answer = (String) output.getAnswers().get(0);

String interactionId = null;
if (conversationId != null) {
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults));
}

ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, getSearchResults(response)));

return insertAnswer(response, (String) output.getAnswers().get(0));
return insertAnswer(response, answer, interactionId);
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer) {
private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) {

// TODO return the interaction id in the response.

return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters());
}

// TODO Integrate with Conversational Memory
private List<String> getChatHistory(SearchRequest request) {
return new ArrayList<>();
}

private List<String> getSearchResults(SearchResponse response) {
List<String> searchResults = new ArrayList<>();
for (SearchHit hit : response.getHits().getHits()) {
Expand All @@ -115,6 +132,12 @@ private List<String> getSearchResults(SearchResponse response) {
return searchResults;
}

private static String jsonArrayToString(List<String> listOfStrings) {
JsonArray array = new JsonArray(listOfStrings.size());
listOfStrings.forEach(array::add);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this line?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to construct a single string out of an array as a helper method to store the search results as a "additionalInfo" string in memory.

return array.toString();
}

public static final class Factory implements Processor.Factory<SearchResponseProcessor> {

private final Client client;
Expand Down
Loading