Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix prompt passing for Bedrock by passing a single string prompt for … #1490

Merged
merged 4 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions search-processors/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ GET /<index>/_search\?search_pipeline\=<search pipeline name>
}
```

To use this with Bedrock models, use "bedrock/" as a prefix for the "llm_model" parameters, e.g. "bedrock/anthropic".
austintlee marked this conversation as resolved.
Show resolved Hide resolved

The latest RAG processor has been tested with OpenAI's GPT 3.5 and 4 models and Bedrock's Anthropic Claude (v2) model only.

## Retrieval Augmented Generation response
```
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ public class ChatCompletionInput {
private int timeoutInSeconds;
private String systemPrompt;
private String userInstructions;
private Llm.ModelProvider modelProvider;
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public DefaultLlmImpl(String openSearchModelId, Client client) {
}

@VisibleForTesting
void setMlClient(MachineLearningInternalClient mlClient) {
protected void setMlClient(MachineLearningInternalClient mlClient) {
this.mlClient = mlClient;
}

Expand All @@ -76,19 +76,7 @@ void setMlClient(MachineLearningInternalClient mlClient) {
@Override
public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionInput) {

Map<String, String> inputParameters = new HashMap<>();
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
log.info("Messages to LLM: {}", messages);
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(inputParameters).build();
MLInputDataset dataset = RemoteInferenceInputDataSet.builder().parameters(getInputParameters(chatCompletionInput)).build();
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataset).build();
ActionFuture<MLOutput> future = mlClient.predict(this.openSearchModelId, mlInput);
ModelTensorOutput modelOutput = (ModelTensorOutput) future.actionGet(chatCompletionInput.getTimeoutInSeconds() * 1000);
Expand All @@ -99,19 +87,83 @@ public ChatCompletionOutput doChatCompletion(ChatCompletionInput chatCompletionI

// TODO dataAsMap can be null or can contain information such as throttling. Handle non-happy cases.

List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
return buildChatCompletionOutput(chatCompletionInput.getModelProvider(), dataAsMap);
}

protected Map<String, String> getInputParameters(ChatCompletionInput chatCompletionInput) {
Map<String, String> inputParameters = new HashMap<>();

if (chatCompletionInput.getModelProvider() == ModelProvider.OPENAI) {
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MODEL, chatCompletionInput.getModel());
String messages = PromptUtil
.getChatCompletionPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
);
inputParameters.put(CONNECTOR_INPUT_PARAMETER_MESSAGES, messages);
log.info("Messages to LLM: {}", messages);
} else if (chatCompletionInput.getModelProvider() == ModelProvider.BEDROCK) {
inputParameters
.put(
"inputs",
PromptUtil
.buildSingleStringPrompt(
chatCompletionInput.getSystemPrompt(),
chatCompletionInput.getUserInstructions(),
chatCompletionInput.getQuestion(),
chatCompletionInput.getChatHistory(),
chatCompletionInput.getContexts()
)
);
} else {
throw new IllegalArgumentException("Unknown/unsupported model provider: " + chatCompletionInput.getModelProvider());
}

log.info("LLM input parameters: {}", inputParameters.toString());
return inputParameters;
}

protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider, Map<String, ?> dataAsMap) {

List<Object> answers = null;
List<String> errors = null;
if (choices == null) {
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));

if (provider == ModelProvider.OPENAI) {
List choices = (List) dataAsMap.get(CONNECTOR_OUTPUT_CHOICES);
if (choices == null) {
Map error = (Map) dataAsMap.get(CONNECTOR_OUTPUT_ERROR);
errors = List.of((String) error.get(CONNECTOR_OUTPUT_MESSAGE));
} else {
Map firstChoiceMap = (Map) choices.get(0);
log.info("Choices: {}", firstChoiceMap.toString());
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
log
.info(
"role: {}, content: {}",
message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE),
message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT)
);
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
}
} else if (provider == ModelProvider.BEDROCK) {
String response = (String) dataAsMap.get("completion");
if (response != null) {
answers = List.of(response);
} else {
Map error = (Map) dataAsMap.get("error");
if (error != null) {
errors = List.of((String) error.get("message"));
} else {
errors = List.of("Unknown error or response.");
}
}
} else {
Map firstChoiceMap = (Map) choices.get(0);
log.info("Choices: {}", firstChoiceMap.toString());
Map message = (Map) firstChoiceMap.get(CONNECTOR_OUTPUT_MESSAGE);
log.info("role: {}, content: {}", message.get(CONNECTOR_OUTPUT_MESSAGE_ROLE), message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
throw new IllegalArgumentException("Unknown/unsupported model provider: " + provider);
}

return new ChatCompletionOutput(answers, errors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,11 @@
*/
public interface Llm {

// TODO Ensure the current implementation works with all models supported by Bedrock.
enum ModelProvider {
OPENAI,
BEDROCK
austintlee marked this conversation as resolved.
Show resolved Hide resolved
}

ChatCompletionOutput doChatCompletion(ChatCompletionInput input);
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
*/
public class LlmIOUtil {

private static final String BEDROCK_PROVIDER_PREFIX = "bedrock/";

public static ChatCompletionInput createChatCompletionInput(
String llmModel,
String question,
Expand Down Expand Up @@ -57,7 +59,19 @@ public static ChatCompletionInput createChatCompletionInput(
List<String> contexts,
int timeoutInSeconds
) {

return new ChatCompletionInput(llmModel, question, chatHistory, contexts, timeoutInSeconds, systemPrompt, userInstructions);
Llm.ModelProvider provider = Llm.ModelProvider.OPENAI;
austintlee marked this conversation as resolved.
Show resolved Hide resolved
if (llmModel != null && llmModel.startsWith(BEDROCK_PROVIDER_PREFIX)) {
provider = Llm.ModelProvider.BEDROCK;
}
return new ChatCompletionInput(
llmModel,
question,
chatHistory,
contexts,
timeoutInSeconds,
systemPrompt,
userInstructions,
provider
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;

import org.apache.commons.text.StringEscapeUtils;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -54,6 +55,8 @@ public class PromptUtil {

private static final String roleUser = "user";

private static final String NEWLINE = "\\n";

public static String getQuestionRephrasingPrompt(String originalQuestion, List<Interaction> chatHistory) {
return null;
}
Expand All @@ -62,6 +65,8 @@ public static String getChatCompletionPrompt(String question, List<Interaction>
return getChatCompletionPrompt(DEFAULT_SYSTEM_PROMPT, null, question, chatHistory, contexts);
}

// TODO Currently, this is OpenAI specific. Change this to indicate as such or address it as part of
// future prompt template management work.
public static String getChatCompletionPrompt(
String systemPrompt,
String userInstructions,
Expand All @@ -87,6 +92,48 @@ enum ChatRole {
}
}

public static String buildSingleStringPrompt(
austintlee marked this conversation as resolved.
Show resolved Hide resolved
String systemPrompt,
String userInstructions,
String question,
List<Interaction> chatHistory,
List<String> contexts
) {
if (Strings.isNullOrEmpty(systemPrompt) && Strings.isNullOrEmpty(userInstructions)) {
systemPrompt = DEFAULT_SYSTEM_PROMPT;
}

StringBuilder bldr = new StringBuilder();

if (!Strings.isNullOrEmpty(systemPrompt)) {
bldr.append(systemPrompt);
bldr.append(NEWLINE);
}
if (!Strings.isNullOrEmpty(userInstructions)) {
bldr.append(userInstructions);
bldr.append(NEWLINE);
}

for (int i = 0; i < contexts.size(); i++) {
bldr.append("SEARCH RESULT " + (i + 1) + ": " + contexts.get(i));
bldr.append(NEWLINE);
}
if (!chatHistory.isEmpty()) {
// The oldest interaction first
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
Collections.reverse(messages);
messages.forEach(m -> {
bldr.append(m.toString());
bldr.append(NEWLINE);
});

}
bldr.append("QUESTION: " + question);
bldr.append(NEWLINE);

return bldr.toString();
}

@VisibleForTesting
static String buildMessageParameter(
String systemPrompt,
Expand All @@ -110,7 +157,6 @@ static String buildMessageParameter(
}
if (!chatHistory.isEmpty()) {
// The oldest interaction first
// Collections.reverse(chatHistory);
List<Message> messages = Messages.fromInteractions(chatHistory).getMessages();
Collections.reverse(messages);
messages.forEach(m -> messageArray.add(m.toJson()));
Expand Down Expand Up @@ -163,6 +209,8 @@ public static Messages fromInteractions(final List<Interaction> interactions) {
}
}

// TODO This is OpenAI specific. Either change this to OpenAiMessage or have it handle
// vendor specific messages.
static class Message {

private final static String MESSAGE_FIELD_ROLE = "role";
Expand All @@ -186,6 +234,7 @@ public Message(ChatRole chatRole, String content) {
}

public void setChatRole(ChatRole chatRole) {
this.chatRole = chatRole;
json.remove(MESSAGE_FIELD_ROLE);
json.add(MESSAGE_FIELD_ROLE, new JsonPrimitive(chatRole.getName()));
}
Expand All @@ -199,5 +248,10 @@ public void setContent(String content) {
public JsonObject toJson() {
return json;
}

@Override
public String toString() {
return String.format(Locale.ROOT, "%s: %s", chatRole.getName(), content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public void testCtor() {
Collections.emptyList(),
0,
systemPrompt,
userInstructions
userInstructions,
Llm.ModelProvider.OPENAI
);

assertNotNull(input);
Expand Down Expand Up @@ -70,7 +71,16 @@ public void testGettersSetters() {
)
);
List<String> contexts = List.of("result1", "result2");
ChatCompletionInput input = new ChatCompletionInput(model, question, history, contexts, 0, systemPrompt, userInstructions);
ChatCompletionInput input = new ChatCompletionInput(
model,
question,
history,
contexts,
0,
systemPrompt,
userInstructions,
Llm.ModelProvider.OPENAI
);
assertEquals(model, input.getModel());
assertEquals(question, input.getQuestion());
assertEquals(history.get(0).getConversationId(), input.getChatHistory().get(0).getConversationId());
Expand Down
Loading
Loading