Skip to content

Commit

Permalink
Use Search Pipeline processors, Remote Inference and HttpConnector to…
Browse files Browse the repository at this point in the history
… enable Retrieval Augmented Generation (RAG) (#1195)

* Use Search Pipeline processors, Remote Inference and HttpConnector to
enable Retrieval Augmented Generation (RAG) (#1150)

Signed-off-by: Austin Lee <[email protected]>

* Address test coverage.

Signed-off-by: Austin Lee <[email protected]>

* Fix/update imports due to changes coming from core.

Signed-off-by: Austin Lee <[email protected]>

* Update license header.

Signed-off-by: Austin Lee <[email protected]>

* Address comments.

Signed-off-by: Austin Lee <[email protected]>

* Use List for context fields so we can pull contexts from multiple fields when constructing contexts for LLMs.

Signed-off-by: Austin Lee <[email protected]>

* Address review comments.

Signed-off-by: Austin Lee <[email protected]>

* Fix spotless issue.

Signed-off-by: Austin Lee <[email protected]>

* Update README.

Signed-off-by: Austin Lee <[email protected]>

* Fix ml-client shadowJar implicit dependency issue.

Signed-off-by: Austin Lee <[email protected]>

* Add a wrapper client for ML predict.

Signed-off-by: Austin Lee <[email protected]>

* Add tests for the internal ML client.

Signed-off-by: Austin Lee <[email protected]>

---------

Signed-off-by: Austin Lee <[email protected]>
Signed-off-by: Austin Lee <[email protected]>
  • Loading branch information
austintlee authored Sep 1, 2023
1 parent da7bdc9 commit 0c39993
Show file tree
Hide file tree
Showing 32 changed files with 2,023 additions and 1 deletion.
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ opensearchplugin {
dependencies {
implementation project(':opensearch-ml-common')
implementation project(':opensearch-ml-algorithms')
implementation project(':opensearch-ml-search-processors')
implementation project(':opensearch-ml-memory')

implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,19 @@
import org.opensearch.monitor.os.OsService;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPipelinePlugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.rest.RestController;
import org.opensearch.rest.RestHandler;
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQARequestProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAResponseProcessor;
import org.opensearch.searchpipelines.questionanswering.generative.ext.GenerativeQAParamExtBuilder;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -197,7 +206,7 @@

import lombok.SneakyThrows;

public class MachineLearningPlugin extends Plugin implements ActionPlugin {
public class MachineLearningPlugin extends Plugin implements ActionPlugin, SearchPlugin, SearchPipelinePlugin {
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
public static final String GENERAL_THREAD_POOL = "opensearch_ml_general";
public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute";
Expand Down Expand Up @@ -649,4 +658,26 @@ public List<Setting<?>> getSettings() {
);
return settings;
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List
.of(
new SearchPlugin.SearchExtSpec<>(
GenerativeQAParamExtBuilder.PARAMETER_NAME,
input -> new GenerativeQAParamExtBuilder(input),
parser -> GenerativeQAParamExtBuilder.parse(parser)
)
);
}

@Override
public Map<String, Processor.Factory<SearchRequestProcessor>> getRequestProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, new GenerativeQARequestProcessor.Factory());
}

@Override
public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProcessors(Parameters parameters) {
return Map.of(GenerativeQAProcessorConstants.RESPONSE_PROCESSOR_TYPE, new GenerativeQAResponseProcessor.Factory(this.client));
}
}
95 changes: 95 additions & 0 deletions search-processors/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# conversational-search-processors
OpenSearch search processors providing conversational search capabilities
=======
# Plugin for Conversations Using Search Processors in OpenSearch
This repo is a WIP plugin for handling conversations in OpenSearch ([Per this RFC](https://github.com/opensearch-project/ml-commons/issues/1150)).

Conversational Retrieval Augmented Generation (RAG) is implemented via Search processors that combine user questions and OpenSearch query results as input to an LLM, e.g. OpenAI, and return answers.

## Creating a search pipeline with the GenerativeQAResponseProcessor

```
PUT /_search/pipeline/<search pipeline name>
{
"response_processors": [
{
"retrieval_augmented_generation": {
"tag": <tag>,
"description": <description>,
"model_id": "<model_id>",
"context_field_list": [<field>] (e.g. ["text"])
}
}
]
}
```

The 'model_id' parameter here needs to refer to a model of type REMOTE that has an HttpConnector instance associated with it.

## Making a search request against an index using the above processor
```
GET /<index>/_search\?search_pipeline\=<search pipeline name>
{
"_source": ["title", "text"],
"query" : {
"neural": {
"text_vector": {
"query_text": <query string>,
"k": <integer> (e.g. 10),
"model_id": <model_id>
}
}
},
"ext": {
"generative_qa_parameters": {
"llm_model": <LLM model> (e.g. "gpt-3.5-turbo"),
"llm_question": <question string>
}
}
}
```

## Retrieval Augmented Generation response
```
{
"took": 3,
"timed_out": false,
"_shards": {
"total": 3,
"successful": 3,
"skipped": 0,
"failed": 0
},
"hits": {
"total": {
"value": 110,
"relation": "eq"
},
"max_score": 0.55129033,
"hits": [
{
"_index": "...",
"_id": "...",
"_score": 0.55129033,
"_source": {
"text": "...",
"title": "..."
}
},
{
...
}
...
{
...
}
]
}, // end of hits
"ext": {
"retrieval_augmented_generation": {
"answer": "..."
}
}
}
```
The RAG answer is returned as an "ext" to SearchResponse following the "hits" array.
74 changes: 74 additions & 0 deletions search-processors/build.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
plugins {
id 'java'
id 'jacoco'
id "io.freefair.lombok"
}

repositories {
mavenCentral()
mavenLocal()
}

dependencies {

compileOnly group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}"
implementation 'org.apache.commons:commons-lang3:3.12.0'
//implementation project(':opensearch-ml-client')
implementation project(':opensearch-ml-common')
implementation group: 'org.opensearch', name: 'common-utils', version: "${common_utils_version}"
// https://mvnrepository.com/artifact/org.apache.httpcomponents.core5/httpcore5
implementation group: 'org.apache.httpcomponents.core5', name: 'httpcore5', version: '5.2.1'
implementation("com.google.guava:guava:32.0.1-jre")
implementation group: 'org.json', name: 'json', version: '20230227'
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
testImplementation "org.opensearch.test:framework:${opensearch_version}"
}

test {
include '**/*Tests.class'
systemProperty 'tests.security.manager', 'false'
}

jacocoTestReport {
dependsOn /*integTest,*/ test
reports {
xml.required = true
html.required = true
}
}

jacocoTestCoverageVerification {
violationRules {
rule {
limit {
counter = 'LINE'
minimum = 0.65 //TODO: increase coverage to 0.90
}
limit {
counter = 'BRANCH'
minimum = 0.55 //TODO: increase coverage to 0.85
}
}
}
dependsOn jacocoTestReport
}

check.dependsOn jacocoTestCoverageVerification
//jacocoTestCoverageVerification.dependsOn jacocoTestReport
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.searchpipelines.questionanswering.generative;

public class GenerativeQAProcessorConstants {

// Identifier for the generative QA request processor
public static final String REQUEST_PROCESSOR_TYPE = "question_rewrite";

// Identifier for the generative QA response processor
public static final String RESPONSE_PROCESSOR_TYPE = "retrieval_augmented_generation";

// The model_id of the model registered and deployed in OpenSearch.
public static final String CONFIG_NAME_MODEL_ID = "model_id";

// The name of the model supported by an LLM, e.g. "gpt-3.5" in OpenAI.
public static final String CONFIG_NAME_LLM_MODEL = "llm_model";

// The field in search results that contain the context to be sent to the LLM.
public static final String CONFIG_NAME_CONTEXT_FIELD_LIST = "context_field_list";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2023 Aryn
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.opensearch.searchpipelines.questionanswering.generative;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

import java.util.Map;

/**
* Defines the request processor for generative QA search pipelines.
*/
public class GenerativeQARequestProcessor extends AbstractProcessor implements SearchRequestProcessor {

private String modelId;

protected GenerativeQARequestProcessor(String tag, String description, boolean ignoreFailure, String modelId) {
super(tag, description, ignoreFailure);
this.modelId = modelId;
}

@Override
public SearchRequest processRequest(SearchRequest request) throws Exception {

// TODO Use chat history to rephrase the question with full conversation context.

return request;
}

@Override
public String getType() {
return GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE;
}

public static final class Factory implements Processor.Factory<SearchRequestProcessor> {

@Override
public SearchRequestProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
return new GenerativeQARequestProcessor(tag, description, ignoreFailure,
ConfigurationUtils.readStringProperty(GenerativeQAProcessorConstants.REQUEST_PROCESSOR_TYPE, tag, config, GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID)
);
}
}
}
Loading

0 comments on commit 0c39993

Please sign in to comment.