-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add Retrieval Augmented Generation search processors #1275
Changes from all commits
1baa080
1eb9444
8f4d90d
a1df6a1
e076c92
8cbac4c
739d96a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
/* | ||
* Copyright 2023 Aryn | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.opensearch.ml.plugin; | ||
|
||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertTrue; | ||
import static org.mockito.Mockito.mock; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import org.junit.Test; | ||
import org.opensearch.common.settings.Settings; | ||
import org.opensearch.plugins.SearchPipelinePlugin; | ||
import org.opensearch.plugins.SearchPlugin; | ||
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants; | ||
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor; | ||
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor; | ||
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder; | ||
|
||
public class MachineLearningPluginTests { | ||
|
||
@Test | ||
public void testGetSearchExtsFeatureDisabled() { | ||
Settings settings = Settings.builder().build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts(); | ||
assertEquals(0, searchExts.size()); | ||
} | ||
|
||
@Test | ||
public void testGetSearchExtsFeatureDisabledExplicit() { | ||
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "false").build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts(); | ||
assertEquals(0, searchExts.size()); | ||
} | ||
|
||
@Test | ||
public void testGetRequestProcessorsFeatureDisabled() { | ||
Settings settings = Settings.builder().build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); | ||
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters); | ||
assertEquals(0, requestProcessors.size()); | ||
} | ||
|
||
@Test | ||
public void testGetResponseProcessorsFeatureDisabled() { | ||
Settings settings = Settings.builder().build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); | ||
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters); | ||
assertEquals(0, responseProcessors.size()); | ||
} | ||
|
||
@Test | ||
public void testGetSearchExts() { | ||
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
List<SearchPlugin.SearchExtSpec<?>> searchExts = plugin.getSearchExts(); | ||
assertEquals(1, searchExts.size()); | ||
SearchPlugin.SearchExtSpec<?> spec = searchExts.get(0); | ||
assertEquals(GenerativeQAParamExtBuilder.PARAMETER_NAME, spec.getName().getPreferredName()); | ||
} | ||
|
||
@Test | ||
public void testGetRequestProcessors() { | ||
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); | ||
Map<String, ?> requestProcessors = plugin.getRequestProcessors(parameters); | ||
assertEquals(1, requestProcessors.size()); | ||
assertTrue( | ||
requestProcessors.get(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE) instanceof GenerativeQARequestProcessor.Factory | ||
); | ||
} | ||
|
||
@Test | ||
public void testGetResponseProcessors() { | ||
Settings settings = Settings.builder().put("plugins.ml_commons.rag_pipeline_feature_enabled", "true").build(); | ||
MachineLearningPlugin plugin = new MachineLearningPlugin(settings); | ||
SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); | ||
Map<String, ?> responseProcessors = plugin.getResponseProcessors(parameters); | ||
assertEquals(1, responseProcessors.size()); | ||
assertTrue( | ||
responseProcessors.get(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE) instanceof GenerativeQAResponseProcessor.Factory | ||
); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,25 +17,31 @@ | |
*/ | ||
package org.opensearch.searchpipelines.questionanswering.generative; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.JsonArray; | ||
import lombok.Getter; | ||
import lombok.Setter; | ||
import lombok.extern.log4j.Log4j2; | ||
import org.opensearch.action.search.SearchRequest; | ||
import org.opensearch.action.search.SearchResponse; | ||
import org.opensearch.client.Client; | ||
import org.opensearch.ingest.ConfigurationUtils; | ||
import org.opensearch.ml.common.conversation.Interaction; | ||
import org.opensearch.search.SearchHit; | ||
import org.opensearch.search.pipeline.AbstractProcessor; | ||
import org.opensearch.search.pipeline.Processor; | ||
import org.opensearch.search.pipeline.SearchResponseProcessor; | ||
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamUtil; | ||
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParameters; | ||
import org.opensearch.searchpipelines.questionanswering.generative.client.ConversationalMemoryClient; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.LlmIOUtil; | ||
import org.opensearch.searchpipelines.questionanswering.generative.llm.ModelLocator; | ||
import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
|
@@ -48,11 +54,16 @@ | |
@Log4j2 | ||
public class GenerativeQAResponseProcessor extends AbstractProcessor implements SearchResponseProcessor { | ||
|
||
private static final int DEFAULT_CHAT_HISTORY_WINDOW = 10; | ||
|
||
// TODO Add "interaction_count". This is how far back in chat history we want to go back when calling LLM. | ||
|
||
private final String llmModel; | ||
private final List<String> contextFields; | ||
|
||
@Setter | ||
private ConversationalMemoryClient memoryClient; | ||
|
||
@Getter | ||
@Setter | ||
// Mainly for unit testing purpose | ||
|
@@ -64,40 +75,46 @@ protected GenerativeQAResponseProcessor(Client client, String tag, String descri | |
this.llmModel = llmModel; | ||
this.contextFields = contextFields; | ||
this.llm = llm; | ||
this.memoryClient = new ConversationalMemoryClient(client); | ||
} | ||
|
||
@Override | ||
public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception { | ||
|
||
log.info("Entering processResponse."); | ||
|
||
List<String> chatHistory = getChatHistory(request); | ||
GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); | ||
String llmQuestion = params.getLlmQuestion(); | ||
String llmModel = params.getLlmModel() == null ? this.llmModel : params.getLlmModel(); | ||
String conversationId = params.getConversationId(); | ||
log.info("LLM question: {}, LLM model {}, conversation id: {}", llmQuestion, llmModel, conversationId); | ||
List<Interaction> chatHistory = (conversationId == null) ? Collections.emptyList() : memoryClient.getInteractions(conversationId, DEFAULT_CHAT_HISTORY_WINDOW); | ||
List<String> searchResults = getSearchResults(response); | ||
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, searchResults)); | ||
String answer = (String) output.getAnswers().get(0); | ||
|
||
String interactionId = null; | ||
if (conversationId != null) { | ||
interactionId = memoryClient.createInteraction(conversationId, llmQuestion, PromptUtil.DEFAULT_CHAT_COMPLETION_PROMPT_TEMPLATE, answer, | ||
GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, jsonArrayToString(searchResults)); | ||
} | ||
|
||
ChatCompletionOutput output = llm.doChatCompletion(LlmIOUtil.createChatCompletionInput(llmModel, llmQuestion, chatHistory, getSearchResults(response))); | ||
|
||
return insertAnswer(response, (String) output.getAnswers().get(0)); | ||
return insertAnswer(response, answer, interactionId); | ||
} | ||
|
||
@Override | ||
public String getType() { | ||
return GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE; | ||
} | ||
|
||
private SearchResponse insertAnswer(SearchResponse response, String answer) { | ||
private SearchResponse insertAnswer(SearchResponse response, String answer, String interactionId) { | ||
|
||
// TODO return the interaction id in the response. | ||
|
||
return new GenerativeSearchResponse(answer, response.getInternalResponse(), response.getScrollId(), response.getTotalShards(), response.getSuccessfulShards(), | ||
response.getSkippedShards(), response.getSuccessfulShards(), response.getShardFailures(), response.getClusters()); | ||
} | ||
|
||
// TODO Integrate with Conversational Memory | ||
private List<String> getChatHistory(SearchRequest request) { | ||
return new ArrayList<>(); | ||
} | ||
|
||
private List<String> getSearchResults(SearchResponse response) { | ||
List<String> searchResults = new ArrayList<>(); | ||
for (SearchHit hit : response.getHits().getHits()) { | ||
|
@@ -115,6 +132,12 @@ private List<String> getSearchResults(SearchResponse response) { | |
return searchResults; | ||
} | ||
|
||
private static String jsonArrayToString(List<String> listOfStrings) { | ||
JsonArray array = new JsonArray(listOfStrings.size()); | ||
listOfStrings.forEach(array::add); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this line? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used to construct a single string out of an array as a helper method to store the search results as a "additionalInfo" string in memory. |
||
return array.toString(); | ||
} | ||
|
||
public static final class Factory implements Processor.Factory<SearchResponseProcessor> { | ||
|
||
private final Client client; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm curious to know what would actually happen if someone were to enable the
ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED
flag but not enable theML_COMMONS_MEMORY_FEATURE_ENABLED
flag, or vice versa?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pipeline, no Memory: you can use the pipeline without memory anyway. Just a straightforward RAG pipeline without any conversation history. Trying to use memory anyway with that disabled will throw the feature flag error.
Memory, no Pipeline: you can use this memory implementation for your crazy langchain app that lives in come other ecosystem entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Memory" is the most fundamental piece and it can be used independently by any application or a REST client. You CAN use RAG without memory as well if you are not passing a conversationId. I think we need to make sure we document these use cases and scenarios as part of the 2.10.0 release.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@HenryL27 could you please explain
Trying to use memory anyway with that disabled will throw the feature flag error.
--> I assume customer will not see this error?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dhrubo-os this is the error we throw:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. I'm trying to understand it step by step:
ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED
OpenSearchException
to enable the other feature flag?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't. You will only get the error if you pass it a conversionId.