Skip to content

Commit

Permalink
[Feature] Add Retrieval Augmented Generation search processors (#1275)
Browse files Browse the repository at this point in the history
* Put RAG pipeline behind a feature flag.

Signed-off-by: Austin Lee <[email protected]>

* Add support for chat history in RAG using the Conversational Memory API

Signed-off-by: Austin Lee <[email protected]>

* Fix spotless

Signed-off-by: Austin Lee <[email protected]>

* Fix RAG feature flag enablement.

Signed-off-by: Austin Lee <[email protected]>

* Address review comments and suggestions.

Signed-off-by: Austin Lee <[email protected]>

* Address comments.

Signed-off-by: Austin Lee <[email protected]>

* Add unit tests for MachineLearningPlugin

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Sep 5, 2023
1 parent 0c39993 commit 180c791
Show file tree
Hide file tree
Showing 24 changed files with 882 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -247,6 +249,12 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin, Searc

private ConversationalMemoryHandler cmHandler;

private volatile boolean ragSearchPipelineEnabled;

public MachineLearningPlugin(Settings settings) {
this.ragSearchPipelineEnabled = MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED.get(settings);
}

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
return ImmutableList
Expand Down Expand Up @@ -449,6 +457,11 @@ public Collection<Object> createComponents(
encryptor
);

// TODO move this into MLFeatureEnabledSetting
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> ragSearchPipelineEnabled = it);

return ImmutableList
.of(
encryptor,
Expand Down Expand Up @@ -654,30 +667,56 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_MODEL_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED,
MLCommonsSettings.ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX,
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED
MLCommonsSettings.ML_COMMONS_MEMORY_FEATURE_ENABLED,
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED
);
return settings;
}

/**
*
* Search processors for Retrieval Augmented Generation
*
*/

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List
.of(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.PARAMETER_NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
List<SearchPlugin.SearchExtSpec<?>> searchExts = new ArrayList<>();

if (ragSearchPipelineEnabled) {
searchExts
.add(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.PARAMETER_NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
}

return searchExts;
}

@Override
public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory());
Map<String, Processor.Factory<SearchRequestProcessor>> requestProcessors = new HashMap<>();

if (ragSearchPipelineEnabled) {
requestProcessors.put(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory());
}

return requestProcessors;
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
Map<String, Processor.Factory<SearchResponseProcessor>> responseProcessors = new HashMap<>();

if (ragSearchPipelineEnabled) {
responseProcessors
.put(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
}

return responseProcessors;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,7 @@ private MLCommonsSettings() {}
);

public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = ConversationalIndexConstants.ML_COMMONS_MEMORY_FEATURE_ENABLED;
// Feature flag for enabling search processors for Retrieval Augmented Generation using OpenSearch and Remote Inference.
public static final Setting<Boolean> ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED = Setting
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.ml.plugin;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

import java.util.List;
import java.util.Map;

import org.junit.Test;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;

public class MachineLearningPluginTests {

@Test
public void testGetSearchExtsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(0, searchExts.size());
}

@Test
public void testGetSearchExtsFeatureDisabledExplicit() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "false").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(0, searchExts.size());
}

@Test
public void testGetRequestProcessorsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(0, requestProcessors.size());
}

@Test
public void testGetResponseProcessorsFeatureDisabled() {
Settings settings = Settings.builder().build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(0, responseProcessors.size());
}

@Test
public void testGetSearchExts() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts();
assertEquals(1, searchExts.size());
SearchPlugin.SearchExtSpec<?> spec = searchExts.get(0);
assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName());
}

@Test
public void testGetRequestProcessors() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters);
assertEquals(1, requestProcessors.size());
assertTrue(
requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory
);
}

@Test
public void testGetResponseProcessors() {
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build();
MachineLearningPlugin plugin = new MachineLearningPlugin(settings);
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class);
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters);
assertEquals(1, responseProcessors.size());
assertTrue(
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory
);
}
}
7 changes: 4 additions & 3 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ repositories {
dependencies {

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
implementation 'org.apache.commons:commons-lang3:3.12.0'
//implementation project(':opensearch-ml-client')
implementation project(':opensearch-ml-common')
implementation project(':opensearch-ml-memory')
implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}"
// https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1'
Expand Down Expand Up @@ -59,16 +61,15 @@ jacocoTestCoverageVerification {
rule {
limit {
counter = 'LINE'
minimum = 0.65 //TODO: increase coverage to 0.90
minimum = 0.8
}
limit {
counter = 'BRANCH'
minimum = 0.55 //TODO: increase coverage to 0.85
minimum = 0.8
}
}
}
dependsOn jacocoTestReport
}

check.dependsOn jacocoTestCoverageVerification
//jacocoTestCoverageVerification.dependsOn jacocoTestReport
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,31 @@
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import com.google.gson.Gson;
import com.google.gson.JsonArray;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.Client;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters;
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil;
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator;
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -48,11 +54,16 @@
@Log4j2
public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor {

private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10;

// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM.

private final String llmModel;
private final List<String> contextFields;

@Setter
private ConversationalMemoryClient memoryClient;

@Getter
@Setter
// Mainly for unit testing purpose
Expand All @@ -64,40 +75,46 @@ protected GenerativeQAResponseProcessor(Client client, String tag, String descri
this.llmModel = llmModel;
this.contextFields = contextFields;
this.llm = llm;
this.memoryClient = new ConversationalMemoryClient(client);
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {

log.info("Entering processResponse.");

List<String> chatHistory = getChatHistory(request);
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
String llmQuestion = params.getLlmQuestion();
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel();
String conversationId = params.getConversationId();
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId);
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW);
List<String> searchResults = getSearchResults(response);
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults));
String answer = (String) output.getAnswers().get(0);

String interactionId = null;
if (conversationId != null) {
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer,
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults));
}

ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, getSearchResults(response)));

return insertAnswer(response, (String) output.getAnswers().get(0));
return insertAnswer(response, answer, interactionId);
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE;
}

private SearchResponse insertAnswer(SearchResponse response, String answer) {
private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) {

// TODO return the interaction id in the response.

return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(),
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters());
}

// TODO Integrate with Conversational Memory
private List<String> getChatHistory(SearchRequest request) {
return new ArrayList<>();
}

private List<String> getSearchResults(SearchResponse response) {
List<String> searchResults = new ArrayList<>();
for (SearchHit hit : response.getHits().getHits()) {
Expand All @@ -115,6 +132,12 @@ private List<String> getSearchResults(SearchResponse response) {
return searchResults;
}

private static String jsonArrayToString(List<String> listOfStrings) {
JsonArray array = new JsonArray(listOfStrings.size());
listOfStrings.forEach(array::add);
return array.toString();
}

public static final class Factory implements Processor.Factory<SearchResponseProcessor> {

private final Client client;
Expand Down
Loading

0 comments on commit 180c791

Please sign in to comment.