Skip to content

Commit

Permalink
Add streaming API for Bedrock Anthropics
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkozminski committed Feb 23, 2024
1 parent b363150 commit 6d20783
Show file tree
Hide file tree
Showing 6 changed files with 292 additions and 64 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;

@Getter
@SuperBuilder
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel {
@Builder.Default
private final BedrockAnthropicChatModel.Types model = BedrockAnthropicChatModel.Types.AnthropicClaudeV2;

@Override
protected String getModelId() {
return model.getValue();
}

@Getter
/**
* Bedrock Anthropic model ids
*/
public enum Types {
AnthropicClaudeV2("anthropic.claude-v2"),
AnthropicClaudeV2_1("anthropic.claude-v2:1");

private final String value;

Types(String modelID) {
this.value = modelID;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
Expand All @@ -30,47 +31,14 @@
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> implements ChatLanguageModel {
private static final String HUMAN_PROMPT = "Human:";
private static final String ASSISTANT_PROMPT = "Assistant:";

@Builder.Default
private final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
private final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
private final Integer maxRetries = 5;
@Builder.Default
private final Region region = Region.US_EAST_1;
@Builder.Default
private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
private final int maxTokens = 300;
@Builder.Default
private final float temperature = 1;
@Builder.Default
private final float topP = 0.999f;
@Builder.Default
private final String[] stopSequences = new String[]{};
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> extends AbstractSharedBedrockChatModel implements ChatLanguageModel {
@Getter(lazy = true)
private final BedrockRuntimeClient client = initClient();

@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {

final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
final String body = convertMessagesToAwsBody(messages);

InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries);
final String response = invokeModelResponse.body().asUtf8String();
Expand All @@ -81,26 +49,6 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
result.getFinishReason());
}

/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

/**
* Get request parameters
Expand All @@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) {
*/
protected abstract Map<String, Object> getRequestParameters(final String prompt);

/**
* Get model id
*
* @return model id
*/
protected abstract String getModelId();


/**
* Get response class type
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Json;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

/**
* Bedrock Streaming chat model
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel {
@Getter
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient();

class StreamingResponse {
public String completion;
}

@Override
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) {
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage(userMessage));
generate(messages, handler);
}

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages)))
.modelId(getModelId())
.contentType("application/json")
.accept("application/json")
.build();

AtomicReference<String> finalCompletion = new AtomicReference<>("");

InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
.onChunk(chunk -> {
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class);
finalCompletion.set(finalCompletion.get() + sr.completion);
handler.onNext(sr.completion);
})
.build();

InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor)))
.onComplete(() -> {
handler.onComplete(Response.from(new AiMessage(finalCompletion.get())));
})
.onError(handler::onError)
.build();
asyncClient.invokeModelWithResponseStream(request, h).join();

}

/**
* Initialize async bedrock client
*
* @return async bedrock client
*/
private BedrockRuntimeAsyncClient initAsyncClient() {
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(region)
.credentialsProvider(credentialsProvider)
.build();
return client;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.internal.Json;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.stream.Collectors.joining;

@Getter
@SuperBuilder
public abstract class AbstractSharedBedrockChatModel {
protected static final String HUMAN_PROMPT = "Human:";
protected static final String ASSISTANT_PROMPT = "Assistant:";
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";

@Builder.Default
protected final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
protected final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
protected final Integer maxRetries = 5;
@Builder.Default
protected final Region region = Region.US_EAST_1;
@Builder.Default
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
protected final int maxTokens = 300;
@Builder.Default
protected final float temperature = 1;
@Builder.Default
protected final float topP = 0.999f;
@Builder.Default
protected final String[] stopSequences = new String[]{};
@Builder.Default
protected final int topK = 250;
@Builder.Default
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;


/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
return body;
}

protected Map<String, Object> getRequestParameters(String prompt) {
final Map<String, Object> parameters = new HashMap<>(7);

parameters.put("prompt", prompt);
parameters.put("max_tokens_to_sample", getMaxTokens());
parameters.put("temperature", getTemperature());
parameters.put("top_k", topK);
parameters.put("top_p", getTopP());
parameters.put("stop_sequences", getStopSequences());
parameters.put("anthropic_version", anthropicVersion);

return parameters;
}

/**
* Get model id
*
* @return model id
*/
protected abstract String getModelId();

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
public class BedrockChatModelIT {

@Test
@Disabled("To run this test, you must have provide your own access key, secret, region")
// @Disabled("To run this test, you must have provide your own access key, secret, region")
void testBedrockAnthropicChatModel() {

BedrockAnthropicChatModel bedrockChatModel = BedrockAnthropicChatModel
Expand All @@ -37,7 +37,7 @@ void testBedrockAnthropicChatModel() {
}

@Test
@Disabled("To run this test, you must have provide your own access key, secret, region")
// @Disabled("To run this test, you must have provide your own access key, secret, region")
void testBedrockTitanChatModel() {

BedrockTitanChatModel bedrockChatModel = BedrockTitanChatModel
Expand Down
Loading

0 comments on commit 6d20783

Please sign in to comment.