From 91582daedc74387047e08ab5eef2e00e80aa0417 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Wed, 11 Oct 2023 13:39:01 -0700 Subject: [PATCH 1/4] Fix prompt passing for Bedrock by passing a single string prompt for Bedrock models. (https://github.com/opensearch-project/ml-commons/issues/1476) Signed-off-by: Austin Lee --- .../generative/llm/ChatCompletionInput.java | 3 + .../generative/llm/DefaultLlmImpl.java | 80 +++++++++++++------ .../questionanswering/generative/llm/Llm.java | 5 ++ .../generative/llm/LlmIOUtil.java | 17 +++- .../generative/prompt/PromptUtil.java | 50 ++++++++++++ .../llm/ChatCompletionInputTests.java | 5 +- .../generative/llm/DefaultLlmImplTests.java | 6 +- 7 files changed, 137 insertions(+), 29 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index 85e1173875..61e7ecae76 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -18,7 +18,9 @@ package org.opensearch.searchpipelines.questionanswering.generative.llm; import java.util.List; +import java.util.Map; +import lombok.Builder; import org.opensearch.ml.common.conversation.Interaction; import lombok.AllArgsConstructor; @@ -42,4 +44,5 @@ public class ChatCompletionInput { private int timeoutInSeconds; private String systemPrompt; private String userInstructions; + private Llm.ModelProvider modelProvider; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index beef67b9e9..45b46f2a1a 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) { } @VisibleForTesting - void setMlClient(MachineLearningInternalClient mlClient) { + protected void setMlClient(MachineLearningInternalClient mlClient) { this.mlClient = mlClient; } @@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) { @Override public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) { - Map inputParameters = new HashMap<>(); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); - String messages = PromptUtil - .getChatCompletionPrompt( - chatCompletionInput.getSystemPrompt(), - chatCompletionInput.getUserInstructions(), - chatCompletionInput.getQuestion(), - chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts() - ); - inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); - log.info("Messages to LLM: {}", messages); - MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build(); + MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build(); ActionFuture future = mlClient.predict(this.openSearchModelId, mlInput); ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000); @@ -99,19 +87,65 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI // TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases. - List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap); + } + + protected Map getInputParameters(ChatCompletionInput chatCompletionInput) { + Map inputParameters = new HashMap<>(); + + if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) { + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); + String messages = PromptUtil.getChatCompletionPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ); + inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); + log.info("Messages to LLM: {}", messages); + } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) { + inputParameters.put("inputs", PromptUtil.buildSingleStringPrompt(chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts())); + } else { + throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider()); + } + + log.info("LLM input parameters: {}", inputParameters.toString()); + return inputParameters; + } + + protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map dataAsMap) { + List answers = null; List errors = null; - if (choices == null) { - Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); - errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + + if (provider == ModelProvider.OPENAI) { + List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES); + if (choices == null) { + Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR); + errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE)); + } else { + Map firstChoiceMap = (Map) choices.get(0); + log.info("Choices: {}", firstChoiceMap.toString()); + Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); + log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + } + } else if (provider == ModelProvider.BEDROCK) { + String response = (String) dataAsMap.get("completion"); + if (response != null) { + answers = List.of(response); + } else { + // Error + } } else { - Map firstChoiceMap = (Map) choices.get(0); - log.info("Choices: {}", firstChoiceMap.toString()); - Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); - log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); - answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider); } + return new ChatCompletionOutput(answers, errors); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index e850561066..faf136d550 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -22,5 +22,10 @@ */ public interface Llm { + enum ModelProvider { + OPENAI, + BEDROCK + } + ChatCompletionOutput doChatCompletion(ChatCompletionInput input); } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index fb95ed63bf..b8fcf48096 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -17,7 +17,9 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; @@ -27,6 +29,14 @@ */ public class LlmIOUtil { + private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model"; + private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages"; + private static final String CONNECTOR_OUTPUT_CHOICES = "choices"; + private static final String CONNECTOR_OUTPUT_MESSAGE = "message"; + private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; + private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; + private static final String CONNECTOR_OUTPUT_ERROR = "error"; + public static ChatCompletionInput createChatCompletionInput( String llmModel, String question, @@ -57,7 +67,10 @@ public static ChatCompletionInput createChatCompletionInput( List contexts, int timeoutInSeconds ) { - - return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions); + Llm.ModelProvider provider = Llm.ModelProvider.OPENAI; + if (llmModel != null && llmModel.startsWith("bedrock/")) { + provider = Llm.ModelProvider.BEDROCK; + } + return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions, provider); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index 9c57ffbf0f..f38e694c78 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Locale; import org.apache.commons.text.StringEscapeUtils; import org.opensearch.core.common.Strings; @@ -62,6 +63,8 @@ public static String getChatCompletionPrompt(String question, List return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); } + // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of + // future prompt template management work. public static String getChatCompletionPrompt( String systemPrompt, String userInstructions, @@ -87,6 +90,46 @@ enum ChatRole { } } + static final String NEWLINE = "\\n"; + + public static String buildSingleStringPrompt ( + String systemPrompt, + String userInstructions, + String question, + List chatHistory, + List contexts + ) { + if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) { + systemPrompt = DEFAULT_SYSTEM_PROMPT; + } + + StringBuilder bldr = new StringBuilder(); + bldr.append(systemPrompt); + bldr.append(NEWLINE); + bldr.append(userInstructions); + bldr.append(NEWLINE); + + for (int i = 0; i < contexts.size(); i++) { + bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); + bldr.append(NEWLINE); + } + if (!chatHistory.isEmpty()) { + // The oldest interaction first + // Collections.reverse(chatHistory); + List messages = Messages.fromInteractions(chatHistory).getMessages(); + Collections.reverse(messages); + messages.forEach(m -> { + bldr.append(m.toString()); + bldr.append(NEWLINE); + }); + + } + bldr.append("QUESTION: " + question); + bldr.append(NEWLINE); + + return bldr.toString(); + } + @VisibleForTesting static String buildMessageParameter( String systemPrompt, @@ -163,6 +206,8 @@ public static Messages fromInteractions(final List interactions) { } } + // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle + // vendor specific messages. static class Message { private final static String MESSAGE_FIELD_ROLE = "role"; @@ -199,5 +244,10 @@ public void setContent(String content) { public JsonObject toJson() { return json; } + + @Override + public String toString() { + return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content); + } } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index 0e34dd0bf1..403291f27c 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -41,7 +41,8 @@ public void testCtor() { Collections.emptyList(), 0, systemPrompt, - userInstructions + userInstructions, + Llm.ModelProvider.OPENAI ); assertNotNull(input); @@ -70,7 +71,7 @@ public void testGettersSetters() { ) ); List contexts = List.of("result1", "result2"); - ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions); + ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions, Llm.ModelProvider.OPENAI); assertEquals(model, input.getModel()); assertEquals(question, input.getQuestion()); assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 218bd65ec9..551a0e68bc 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -111,7 +111,8 @@ public void testChatCompletionApi() throws Exception { Collections.emptyList(), 0, "prompt", - "instructions" + "instructions", + Llm.ModelProvider.OPENAI ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture()); @@ -141,7 +142,8 @@ public void testChatCompletionThrowingError() throws Exception { Collections.emptyList(), 0, "prompt", - "instructions" + "instructions", + Llm.ModelProvider.OPENAI ); ChatCompletionOutput output = connector.doChatCompletion(input); verify(mlClient, times(1)).predict(any(), captor.capture()); From e0b85ba08991c27b45a381a1549abd3f0834ebeb Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Wed, 11 Oct 2023 14:38:16 -0700 Subject: [PATCH 2/4] Add unit tests, apply Spotless. Signed-off-by: Austin Lee --- search-processors/README.md | 2 + .../generative/llm/ChatCompletionInput.java | 2 - .../generative/llm/DefaultLlmImpl.java | 48 +++++++++----- .../generative/llm/LlmIOUtil.java | 23 +++---- .../generative/prompt/PromptUtil.java | 55 ++++++++-------- .../llm/ChatCompletionInputTests.java | 11 +++- .../generative/llm/DefaultLlmImplTests.java | 62 +++++++++++++++++++ .../generative/llm/LlmIOUtilTests.java | 6 ++ .../generative/prompt/PromptUtilTests.java | 40 ++++++++++++ 9 files changed, 193 insertions(+), 56 deletions(-) diff --git a/search-processors/README.md b/search-processors/README.md index 2b3dc6ed52..7bb8572b41 100644 --- a/search-processors/README.md +++ b/search-processors/README.md @@ -49,6 +49,8 @@ GET //_search\?search_pipeline\= } ``` +To use this with Bedrock models, use "bedrock/" as a prefix for the "llm_model" parameters, e.g. "bedrock/anthropic". + ## Retrieval Augmented Generation response ``` { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java index 61e7ecae76..3b9c829706 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInput.java @@ -18,9 +18,7 @@ package org.opensearch.searchpipelines.questionanswering.generative.llm; import java.util.List; -import java.util.Map; -import lombok.Builder; import org.opensearch.ml.common.conversation.Interaction; import lombok.AllArgsConstructor; diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java index 45b46f2a1a..9fbb96a1b7 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java @@ -95,27 +95,35 @@ protected Map getInputParameters(ChatCompletionInput chatComplet if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) { inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel()); - String messages = PromptUtil.getChatCompletionPrompt( - chatCompletionInput.getSystemPrompt(), - chatCompletionInput.getUserInstructions(), - chatCompletionInput.getQuestion(), - chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts() - ); + String messages = PromptUtil + .getChatCompletionPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ); inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages); log.info("Messages to LLM: {}", messages); } else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) { - inputParameters.put("inputs", PromptUtil.buildSingleStringPrompt(chatCompletionInput.getSystemPrompt(), - chatCompletionInput.getUserInstructions(), - chatCompletionInput.getQuestion(), - chatCompletionInput.getChatHistory(), - chatCompletionInput.getContexts())); + inputParameters + .put( + "inputs", + PromptUtil + .buildSingleStringPrompt( + chatCompletionInput.getSystemPrompt(), + chatCompletionInput.getUserInstructions(), + chatCompletionInput.getQuestion(), + chatCompletionInput.getChatHistory(), + chatCompletionInput.getContexts() + ) + ); } else { throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider()); } log.info("LLM input parameters: {}", inputParameters.toString()); - return inputParameters; + return inputParameters; } protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map dataAsMap) { @@ -132,7 +140,12 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map firstChoiceMap = (Map) choices.get(0); log.info("Choices: {}", firstChoiceMap.toString()); Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE); - log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); + log + .info( + "role: {}, content: {}", + message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), + message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT) + ); answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)); } } else if (provider == ModelProvider.BEDROCK) { @@ -140,7 +153,12 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, if (response != null) { answers = List.of(response); } else { - // Error + Map error = (Map) dataAsMap.get("error"); + if (error != null) { + errors = List.of((String) error.get("message")); + } else { + errors = List.of("Unknown error or response."); + } } } else { throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider); diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java index b8fcf48096..0a84858557 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtil.java @@ -17,9 +17,7 @@ */ package org.opensearch.searchpipelines.questionanswering.generative.llm; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.opensearch.ml.common.conversation.Interaction; import org.opensearch.searchpipelines.questionanswering.generative.prompt.PromptUtil; @@ -29,13 +27,7 @@ */ public class LlmIOUtil { - private static final String CONNECTOR_INPUT_PARAMETER_MODEL = "model"; - private static final String CONNECTOR_INPUT_PARAMETER_MESSAGES = "messages"; - private static final String CONNECTOR_OUTPUT_CHOICES = "choices"; - private static final String CONNECTOR_OUTPUT_MESSAGE = "message"; - private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role"; - private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content"; - private static final String CONNECTOR_OUTPUT_ERROR = "error"; + private static final String BEDROCK_PROVIDER_PREFIX = "bedrock/"; public static ChatCompletionInput createChatCompletionInput( String llmModel, @@ -68,9 +60,18 @@ public static ChatCompletionInput createChatCompletionInput( int timeoutInSeconds ) { Llm.ModelProvider provider = Llm.ModelProvider.OPENAI; - if (llmModel != null && llmModel.startsWith("bedrock/")) { + if (llmModel != null && llmModel.startsWith(BEDROCK_PROVIDER_PREFIX)) { provider = Llm.ModelProvider.BEDROCK; } - return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions, provider); + return new ChatCompletionInput( + llmModel, + question, + chatHistory, + contexts, + timeoutInSeconds, + systemPrompt, + userInstructions, + provider + ); } } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index f38e694c78..c494e17b89 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -55,6 +55,8 @@ public class PromptUtil { private static final String roleUser = "user"; + private static final String NEWLINE = "\\n"; + public static String getQuestionRephrasingPrompt(String originalQuestion, List chatHistory) { return null; } @@ -63,8 +65,8 @@ public static String getChatCompletionPrompt(String question, List return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts); } - // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of - // future prompt template management work. + // TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of + // future prompt template management work. public static String getChatCompletionPrompt( String systemPrompt, String userInstructions, @@ -90,9 +92,7 @@ enum ChatRole { } } - static final String NEWLINE = "\\n"; - - public static String buildSingleStringPrompt ( + public static String buildSingleStringPrompt( String systemPrompt, String userInstructions, String question, @@ -103,31 +103,32 @@ public static String buildSingleStringPrompt ( systemPrompt = DEFAULT_SYSTEM_PROMPT; } - StringBuilder bldr = new StringBuilder(); - bldr.append(systemPrompt); - bldr.append(NEWLINE); + StringBuilder bldr = new StringBuilder(); + bldr.append(systemPrompt); + bldr.append(NEWLINE); + if (!Strings.isNullOrEmpty(userInstructions)) { bldr.append(userInstructions); bldr.append(NEWLINE); + } - for (int i = 0; i < contexts.size(); i++) { - bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); + for (int i = 0; i < contexts.size(); i++) { + bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i)); + bldr.append(NEWLINE); + } + if (!chatHistory.isEmpty()) { + // The oldest interaction first + List messages = Messages.fromInteractions(chatHistory).getMessages(); + Collections.reverse(messages); + messages.forEach(m -> { + bldr.append(m.toString()); bldr.append(NEWLINE); - } - if (!chatHistory.isEmpty()) { - // The oldest interaction first - // Collections.reverse(chatHistory); - List messages = Messages.fromInteractions(chatHistory).getMessages(); - Collections.reverse(messages); - messages.forEach(m -> { - bldr.append(m.toString()); - bldr.append(NEWLINE); - }); + }); - } - bldr.append("QUESTION: " + question); - bldr.append(NEWLINE); + } + bldr.append("QUESTION: " + question); + bldr.append(NEWLINE); - return bldr.toString(); + return bldr.toString(); } @VisibleForTesting @@ -153,7 +154,6 @@ static String buildMessageParameter( } if (!chatHistory.isEmpty()) { // The oldest interaction first - // Collections.reverse(chatHistory); List messages = Messages.fromInteractions(chatHistory).getMessages(); Collections.reverse(messages); messages.forEach(m -> messageArray.add(m.toJson())); @@ -206,8 +206,8 @@ public static Messages fromInteractions(final List interactions) { } } - // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle - // vendor specific messages. + // TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle + // vendor specific messages. static class Message { private final static String MESSAGE_FIELD_ROLE = "role"; @@ -231,6 +231,7 @@ public Message(ChatRole chatRole, String content) { } public void setChatRole(ChatRole chatRole) { + this.chatRole = chatRole; json.remove(MESSAGE_FIELD_ROLE); json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName())); } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java index 403291f27c..3be8ba8c49 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/ChatCompletionInputTests.java @@ -71,7 +71,16 @@ public void testGettersSetters() { ) ); List contexts = List.of("result1", "result2"); - ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions, Llm.ModelProvider.OPENAI); + ChatCompletionInput input = new ChatCompletionInput( + model, + question, + history, + contexts, + 0, + systemPrompt, + userInstructions, + Llm.ModelProvider.OPENAI + ); assertEquals(model, input.getModel()); assertEquals(question, input.getQuestion()); assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId()); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java index 551a0e68bc..ae402e0ee3 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java @@ -121,6 +121,36 @@ public void testChatCompletionApi() throws Exception { assertEquals("answer", (String) output.getAnswers().get(0)); } + public void testChatCompletionApiForBedrock() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + Map messageMap = Map.of("role", "agent", "content", "answer"); + Map dataAsMap = Map.of("completion", "answer"); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput( + "bedrock/model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK + ); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertEquals("answer", (String) output.getAnswers().get(0)); + } + public void testChatCompletionThrowingError() throws Exception { MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); @@ -153,6 +183,38 @@ public void testChatCompletionThrowingError() throws Exception { assertEquals(errorMessage, (String) output.getErrors().get(0)); } + public void testChatCompletionBedrockThrowingError() throws Exception { + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); + ArgumentCaptor captor = ArgumentCaptor.forClass(MLInput.class); + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client); + connector.setMlClient(mlClient); + + String errorMessage = "throttled"; + Map messageMap = Map.of("message", errorMessage); + Map dataAsMap = Map.of("error", messageMap); + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap); + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); + ActionFuture future = mock(ActionFuture.class); + when(future.actionGet(anyLong())).thenReturn(mlOutput); + when(mlClient.predict(any(), any())).thenReturn(future); + ChatCompletionInput input = new ChatCompletionInput( + "model", + "question", + Collections.emptyList(), + Collections.emptyList(), + 0, + "prompt", + "instructions", + Llm.ModelProvider.BEDROCK + ); + ChatCompletionOutput output = connector.doChatCompletion(input); + verify(mlClient, times(1)).predict(any(), captor.capture()); + MLInput mlInput = captor.getValue(); + assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet); + assertTrue(output.isErrorOccurred()); + assertEquals(errorMessage, (String) output.getErrors().get(0)); + } + private boolean isJson(String Json) { try { new JSONObject(Json); diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java index 41d44f18ca..a2a34db27b 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/LlmIOUtilTests.java @@ -32,4 +32,10 @@ public void testChatCompletionInput() { .createChatCompletionInput("model", "question", Collections.emptyList(), Collections.emptyList(), 0); assertTrue(input instanceof ChatCompletionInput); } + + public void testChatCompletionInputForBedrock() { + ChatCompletionInput input = LlmIOUtil + .createChatCompletionInput("bedrock/model", "question", Collections.emptyList(), Collections.emptyList(), 0); + assertTrue(input instanceof ChatCompletionInput); + } } diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java index 583ab17149..a3aedf4e5d 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtilTests.java @@ -77,6 +77,46 @@ public void testBuildMessageParameter() { assertTrue(isJson(parameter)); } + public void testBuildBedrockInputParameter() { + String systemPrompt = "You are the best."; + String userInstructions = null; + String question = "Who am I"; + List contexts = new ArrayList<>(); + List chatHistory = List + .of( + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 1", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer1" + ) + ), + Interaction + .fromMap( + "convo1", + Map + .of( + ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD, + Instant.now().toString(), + ConversationalIndexConstants.INTERACTIONS_INPUT_FIELD, + "message 2", + ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD, + "answer2" + ) + ) + ); + contexts.add("context 1"); + contexts.add("context 2"); + String parameter = PromptUtil.buildSingleStringPrompt(systemPrompt, userInstructions, question, chatHistory, contexts); + assertTrue(parameter.contains(systemPrompt)); + } + private boolean isJson(String Json) { try { new JSONObject(Json); From bdebe6592782514077d1083f16a14197ea9e503c Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Wed, 11 Oct 2023 14:52:12 -0700 Subject: [PATCH 3/4] Check if systemPrompt is null. Signed-off-by: Austin Lee --- .../questionanswering/generative/prompt/PromptUtil.java | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java index c494e17b89..3a8a21614e 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/prompt/PromptUtil.java @@ -104,8 +104,11 @@ public static String buildSingleStringPrompt( } StringBuilder bldr = new StringBuilder(); - bldr.append(systemPrompt); - bldr.append(NEWLINE); + + if (!Strings.isNullOrEmpty(systemPrompt)) { + bldr.append(systemPrompt); + bldr.append(NEWLINE); + } if (!Strings.isNullOrEmpty(userInstructions)) { bldr.append(userInstructions); bldr.append(NEWLINE); From a598e5ea0badf584af698f3b5c3b1ca61db76cf3 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Wed, 11 Oct 2023 16:10:40 -0700 Subject: [PATCH 4/4] Address review comments. Signed-off-by: Austin Lee --- search-processors/README.md | 2 ++ .../searchpipelines/questionanswering/generative/llm/Llm.java | 1 + 2 files changed, 3 insertions(+) diff --git a/search-processors/README.md b/search-processors/README.md index 7bb8572b41..030691855b 100644 --- a/search-processors/README.md +++ b/search-processors/README.md @@ -51,6 +51,8 @@ GET //_search\?search_pipeline\= To use this with Bedrock models, use "bedrock/" as a prefix for the "llm_model" parameters, e.g. "bedrock/anthropic". +The latest RAG processor has been tested with OpenAI's GPT 3.5 and 4 models and Bedrock's Anthropic Claude (v2) model only. + ## Retrieval Augmented Generation response ``` { diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java index faf136d550..be5efdb294 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/Llm.java @@ -22,6 +22,7 @@ */ public interface Llm { + // TODO Ensure the current implementation works with all models supported by Bedrock. enum ModelProvider { OPENAI, BEDROCK