From 89f9b850525c8f5708e4ea13b4b168757df67408 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 11 Oct 2023 16:25:08 -0700 Subject: [PATCH] =?UTF-8?q?Fix=20prompt=20passing=20for=20Bedrock=20by=20p?= =?UTF-8?q?assing=20a=20single=20string=20prompt=20for=20=E2=80=A6=20(#149?= =?UTF-8?q?0)=20(#1497)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix prompt passing for Bedrock by passing a single string prompt for Bedrock models. (https://github.com/opensearch-project/ml-commons/issues/1476) Signed-off-by: Austin Lee * Add unit tests, apply Spotless. Signed-off-by: Austin Lee * Check if systemPrompt is null. Signed-off-by: Austin Lee * Address review comments. Signed-off-by: Austin Lee --------- Signed-off-by: Austin Lee (cherry picked from commit e18f2499582483e0a88f8336c6e078d41143c2f4) Co-authored-by: Austin Lee --- search-processors/README.md | 4 + .../generative/llm/ChatCompletionInput.java | 1 + .../generative/llm/DefaultLlmImpl.java | 98 ++++++++++++++----- .../questionanswering/generative/llm/Llm.java | 6 ++ .../generative/llm/LlmIOUtil.java | 18 +++- .../generative/prompt/PromptUtil.java | 56 ++++++++++- .../llm/ChatCompletionInputTests.java | 14 ++- .../generative/llm/DefaultLlmImplTests.java | 68 ++++++++++++- .../generative/llm/LlmIOUtilTests.java | 6 ++ .../generative/prompt/PromptUtilTests.java | 40 ++++++++ 10 files changed, 281 insertions(+), 30 deletions(-) diff --git a/search-processors/README.md b/search-processors/README.md index 2b3dc6ed52..030691855b 100644 --- a/search-processors/README.md +++ b/search-processors/README.md @@ -49,6 +49,10 @@ GET //_search\?search_pipeline\= } ``` +To use this with Bedrock models, use "bedrock/" as a prefix for the "llm_model" parameters, e.g. "bedrock/anthropic". + +The latest RAG processor has been tested with OpenAI's GPT 3.5 and 4 models and Bedrock's Anthropic Claude (v2) model only. + ## Retrieval Augmented Generation response ``` { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index 85e1173875..3b9c829706 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -42,4 +42,5 @@ public class ChatCompletionInput { private int timeoutInSeconds; private String systemPrompt; private String userInstructions; + private Llm.ModelProvider modelProvider; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index beef67b9e9..9fbb96a1b7 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) { } @VisibleForTesting - void setMlClient(MachineLearningInternalClient mlClient) { + protected void setMlClient(MachineLearningInternalClient mlClient) { this.mlClient = mlClient; } @@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) { @Override public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) { - Map inputParameters = new HashMap<>(); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); - String messages = PromptUtil - .getChatCompletionPrompt( - chatCompletionInput.getSystemPrompt(), - chatCompletionInput.getUserInstructions(), - chatCompletionInput.getQuestion(), - chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts() - ); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); - log.info("Messages to LLM: {}", messages); - MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); + MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000); @@ -99,19 +87,83 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. - List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap); + } + + protected Map getInputParameters(ChatCompletionInput chatCompletionInput) { + Map inputParameters = new HashMap<>(); + + if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) { + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); + String messages = PromptUtil + .getChatCompletionPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); + log.info("Messages to LLM: {}", messages); + } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) { + inputParameters + .put( + "inputs", + PromptUtil + .buildSingleStringPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ) + ); + } else { + throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider()); + } + + log.info("LLM input parameters: {}", inputParameters.toString()); + return inputParameters; + } + + protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map dataAsMap) { + List answers = null; List errors = null; - if (choices == null) { - Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); - errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + + if (provider == ModelProvider.OPENAI) { + List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + if (choices == null) { + Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); + errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + } else { + Map firstChoiceMap = (Map) choices.get(0); + log.info("Choices: {}", firstChoiceMap.toString()); + Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); + log + .info( + "role: {}, content: {}", + message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), + message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT) + ); + answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + } + } else if (provider == ModelProvider.BEDROCK) { + String response = (String) dataAsMap.get("completion"); + if (response != null) { + answers = List.of(response); + } else { + Map error = (Map) dataAsMap.get("error"); + if (error != null) { + errors = List.of((String) error.get("message")); + } else { + errors = List.of("Unknown error or response."); + } + } } else { - Map firstChoiceMap = (Map) choices.get(0); - log.info("Choices: {}", firstChoiceMap.toString()); - Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); - log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); - answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider); } + return new ChatCompletionOutput(answers, errors); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index e850561066..be5efdb294 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -22,5 +22,11 @@ */ public interface Llm { + // TODO Ensure the current implementation works with all models supported by Bedrock. + enum ModelProvider { + OPENAI, + BEDROCK + } + ChatCompletionOutput doChatCompletion(ChatCompletionInput input); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index fb95ed63bf..0a84858557 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -27,6 +27,8 @@ */ public class LlmIOUtil { + private static final String BEDROCK_PROVIDER_PREFIX = "bedrock/"; + public static ChatCompletionInput createChatCompletionInput( String llmModel, String question, @@ -57,7 +59,19 @@ public static ChatCompletionInput createChatCompletionInput( List contexts, int timeoutInSeconds ) { - - return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions); + Llm.ModelProvider provider = Llm.ModelProvider.OPENAI; + if (llmModel != null && llmModel.startsWith(BEDROCK_PROVIDER_PREFIX)) { + provider = Llm.ModelProvider.BEDROCK; + } + return new ChatCompletionInput( + llmModel, + question, + chatHistory, + contexts, + timeoutInSeconds, + systemPrompt, + userInstructions, + provider + ); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 9c57ffbf0f..3a8a21614e 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Locale; import org.apache.commons.text.StringEscapeUtils; import org.opensearch.core.common.Strings; @@ -54,6 +55,8 @@ public class PromptUtil { private static final String roleUser = "user"; + private static final String NEWLINE = "\\n"; + public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory) { return null; } @@ -62,6 +65,8 @@ public static String getChatCompletionPrompt(String question, List return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); } + // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of + // future prompt template management work. public static String getChatCompletionPrompt( String systemPrompt, String userInstructions, @@ -87,6 +92,48 @@ enum ChatRole { } } + public static String buildSingleStringPrompt( + String systemPrompt, + String userInstructions, + String question, + List chatHistory, + List contexts + ) { + if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) { + systemPrompt = DEFAULT_SYSTEM_PROMPT; + } + + StringBuilder bldr = new StringBuilder(); + + if (!Strings.isNullOrEmpty(systemPrompt)) { + bldr.append(systemPrompt); + bldr.append(NEWLINE); + } + if (!Strings.isNullOrEmpty(userInstructions)) { + bldr.append(userInstructions); + bldr.append(NEWLINE); + } + + for (int i = 0; i < contexts.size(); i++) { + bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); + bldr.append(NEWLINE); + } + if (!chatHistory.isEmpty()) { + // The oldest interaction first + List messages = Messages.fromInteractions(chatHistory).getMessages(); + Collections.reverse(messages); + messages.forEach(m -> { + bldr.append(m.toString()); + bldr.append(NEWLINE); + }); + + } + bldr.append("QUESTION: " + question); + bldr.append(NEWLINE); + + return bldr.toString(); + } + @VisibleForTesting static String buildMessageParameter( String systemPrompt, @@ -110,7 +157,6 @@ static String buildMessageParameter( } if (!chatHistory.isEmpty()) { // The oldest interaction first - // Collections.reverse(chatHistory); List messages = Messages.fromInteractions(chatHistory).getMessages(); Collections.reverse(messages); messages.forEach(m -> messageArray.add(m.toJson())); @@ -163,6 +209,8 @@ public static Messages fromInteractions(final List interactions) { } } + // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle + // vendor specific messages. static class Message { private final static String MESSAGE_FIELD_ROLE = "role"; @@ -186,6 +234,7 @@ public Message(ChatRole chatRole, String content) { } public void setChatRole(ChatRole chatRole) { + this.chatRole = chatRole; json.remove(MESSAGE_FIELD_ROLE); json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName())); } @@ -199,5 +248,10 @@ public void setContent(String content) { public JsonObject toJson() { return json; } + + @Override + public String toString() { + return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content); + } } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index 0e34dd0bf1..3be8ba8c49 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -41,7 +41,8 @@ public void testCtor() { Collections.emptyList(), 0, systemPrompt, - userInstructions + userInstructions, + Llm.ModelProvider.OPENAI ); assertNotNull(input); @@ -70,7 +71,16 @@ public void testGettersSetters() { ) ); List contexts = List.of("result1", "result2"); - ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions); + ChatCompletionInput input = new ChatCompletionInput( + model, + question, + history, + contexts, + 0, + systemPrompt, + userInstructions, + Llm.ModelProvider.OPENAI + ); assertEquals(model, input.getModel()); assertEquals(question, input.getQuestion()); assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 218bd65ec9..ae402e0ee3 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -111,7 +111,38 @@ public void testChatCompletionApi() throws Exception { Collections.emptyList(), 0, "prompt", - "instructions" + "instructions", + Llm.ModelProvider.OPENAI + ); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertEquals("answer", (String) output.getAnswers().get(0)); + } + + public void testChatCompletionApiForBedrock() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map messageMap = Map.of("role", "agent", "content", "answer"); + Map dataAsMap = Map.of("completion", "answer"); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput( + "bedrock/model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture()); @@ -141,7 +172,40 @@ public void testChatCompletionThrowingError() throws Exception { Collections.emptyList(), 0, "prompt", - "instructions" + "instructions", + Llm.ModelProvider.OPENAI + ); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, (String) output.getErrors().get(0)); + } + + public void testChatCompletionBedrockThrowingError() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + String errorMessage = "throttled"; + Map messageMap = Map.of("message", errorMessage); + Map dataAsMap = Map.of("error", messageMap); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java index 41d44f18ca..a2a34db27b 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -32,4 +32,10 @@ public void testChatCompletionInput() { .createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList(), 0); assertTrue(input instanceof ChatCompletionInput); } + + public void testChatCompletionInputForBedrock() { + ChatCompletionInput input = LlmIOUtil + .createChatCompletionInput("bedrock/model", "question", Collections.emptyList(), Collections.emptyList(), 0); + assertTrue(input instanceof ChatCompletionInput); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java index 583ab17149..a3aedf4e5d 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -77,6 +77,46 @@ public void testBuildMessageParameter() { assertTrue(isJson(parameter)); } + public void testBuildBedrockInputParameter() { + String systemPrompt = "You are the best."; + String userInstructions = null; + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); + contexts.add("context 1"); + contexts.add("context 2"); + String parameter = PromptUtil.buildSingleStringPrompt(systemPrompt, userInstructions, question, chatHistory, contexts); + assertTrue(parameter.contains(systemPrompt)); + } + private boolean isJson(String Json) { try { new JSONObject(Json);