diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java new file mode 100644 index 00000000000..16a2cb1f800 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java @@ -0,0 +1,33 @@ +package dev.langchain4j.model.bedrock; + +import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.SuperBuilder; + +@Getter +@SuperBuilder +public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { + @Builder.Default + private final BedrockAnthropicChatModel.Types model = BedrockAnthropicChatModel.Types.AnthropicClaudeV2; + + @Override + protected String getModelId() { + return model.getValue(); + } + + @Getter + /** + * Bedrock Anthropic model ids + */ + public enum Types { + AnthropicClaudeV2("anthropic.claude-v2"), + AnthropicClaudeV2_1("anthropic.claude-v2:1"); + + private final String value; + + Types(String modelID) { + this.value = modelID; + } + } +} diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java index 71a552fbb7a..77fba2d57df 100644 --- a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockChatModel.java @@ -13,6 +13,7 @@ import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest; import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse; @@ -30,47 +31,14 @@ */ @Getter @SuperBuilder -public abstract class AbstractBedrockChatModel implements ChatLanguageModel { - private static final String HUMAN_PROMPT = "Human:"; - private static final String ASSISTANT_PROMPT = "Assistant:"; - - @Builder.Default - private final String humanPrompt = HUMAN_PROMPT; - @Builder.Default - private final String assistantPrompt = ASSISTANT_PROMPT; - @Builder.Default - private final Integer maxRetries = 5; - @Builder.Default - private final Region region = Region.US_EAST_1; - @Builder.Default - private final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); - @Builder.Default - private final int maxTokens = 300; - @Builder.Default - private final float temperature = 1; - @Builder.Default - private final float topP = 0.999f; - @Builder.Default - private final String[] stopSequences = new String[]{}; +public abstract class AbstractBedrockChatModel extends AbstractSharedBedrockChatModel implements ChatLanguageModel { @Getter(lazy = true) private final BedrockRuntimeClient client = initClient(); @Override public Response generate(List messages) { - final String context = messages.stream() - .filter(message -> message.type() == ChatMessageType.SYSTEM) - .map(ChatMessage::text) - .collect(joining("\n")); - - final String userMessages = messages.stream() - .filter(message -> message.type() != ChatMessageType.SYSTEM) - .map(this::chatMessageToString) - .collect(joining("\n")); - - final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); - final Map requestParameters = getRequestParameters(prompt); - final String body = Json.toJson(requestParameters); + final String body = convertMessagesToAwsBody(messages); InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries); final String response = invokeModelResponse.body().asUtf8String(); @@ -81,26 +49,6 @@ public Response generate(List messages) { result.getFinishReason()); } - /** - * Convert chat message to string - * - * @param message chat message - * @return string - */ - protected String chatMessageToString(ChatMessage message) { - switch (message.type()) { - case SYSTEM: - return message.text(); - case USER: - return humanPrompt + " " + message.text(); - case AI: - return assistantPrompt + " " + message.text(); - case TOOL_EXECUTION_RESULT: - throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); - } - - throw new IllegalArgumentException("Unknown message type: " + message.type()); - } /** * Get request parameters @@ -110,13 +58,6 @@ protected String chatMessageToString(ChatMessage message) { */ protected abstract Map getRequestParameters(final String prompt); - /** - * Get model id - * - * @return model id - */ - protected abstract String getModelId(); - /** * Get response class type diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java new file mode 100644 index 00000000000..5470a9c23d2 --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java @@ -0,0 +1,87 @@ +package dev.langchain4j.model.bedrock.internal; + +import dev.langchain4j.agent.tool.ToolSpecification; +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.UserMessage; +import dev.langchain4j.internal.Json; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.chat.StreamingChatLanguageModel; +import dev.langchain4j.model.output.Response; +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Bedrock Streaming chat model + */ +@Getter +@SuperBuilder +public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { + @Getter + private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); + + class StreamingResponse { + public String completion; + } + + @Override + public void generate(String userMessage, StreamingResponseHandler handler) { + List messages = new ArrayList<>(); + messages.add(new UserMessage(userMessage)); + generate(messages, handler); + } + + @Override + public void generate(List messages, StreamingResponseHandler handler) { + InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() + .body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) + .modelId(getModelId()) + .contentType("application/json") + .accept("application/json") + .build(); + + AtomicReference finalCompletion = new AtomicReference<>(""); + + InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() + .onChunk(chunk -> { + StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); + finalCompletion.set(finalCompletion.get() + sr.completion); + handler.onNext(sr.completion); + }) + .build(); + + InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() + .onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) + .onComplete(() -> { + handler.onComplete(Response.from(new AiMessage(finalCompletion.get()))); + }) + .onError(handler::onError) + .build(); + asyncClient.invokeModelWithResponseStream(request, h).join(); + + } + + /** + * Initialize async bedrock client + * + * @return async bedrock client + */ + private BedrockRuntimeAsyncClient initAsyncClient() { + BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() + .region(region) + .credentialsProvider(credentialsProvider) + .build(); + return client; + } + + + +} diff --git a/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java new file mode 100644 index 00000000000..737ecae129b --- /dev/null +++ b/langchain4j-bedrock/src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java @@ -0,0 +1,110 @@ +package dev.langchain4j.model.bedrock.internal; + +import dev.langchain4j.data.message.ChatMessage; +import dev.langchain4j.data.message.ChatMessageType; +import dev.langchain4j.internal.Json; +import lombok.Builder; +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static java.util.stream.Collectors.joining; + +@Getter +@SuperBuilder +public abstract class AbstractSharedBedrockChatModel { + protected static final String HUMAN_PROMPT = "Human:"; + protected static final String ASSISTANT_PROMPT = "Assistant:"; + protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; + + @Builder.Default + protected final String humanPrompt = HUMAN_PROMPT; + @Builder.Default + protected final String assistantPrompt = ASSISTANT_PROMPT; + @Builder.Default + protected final Integer maxRetries = 5; + @Builder.Default + protected final Region region = Region.US_EAST_1; + @Builder.Default + protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); + @Builder.Default + protected final int maxTokens = 300; + @Builder.Default + protected final float temperature = 1; + @Builder.Default + protected final float topP = 0.999f; + @Builder.Default + protected final String[] stopSequences = new String[]{}; + @Builder.Default + protected final int topK = 250; + @Builder.Default + protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; + + + /** + * Convert chat message to string + * + * @param message chat message + * @return string + */ + protected String chatMessageToString(ChatMessage message) { + switch (message.type()) { + case SYSTEM: + return message.text(); + case USER: + return humanPrompt + " " + message.text(); + case AI: + return assistantPrompt + " " + message.text(); + case TOOL_EXECUTION_RESULT: + throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); + } + + throw new IllegalArgumentException("Unknown message type: " + message.type()); + } + + protected String convertMessagesToAwsBody(List messages) { + final String context = messages.stream() + .filter(message -> message.type() == ChatMessageType.SYSTEM) + .map(ChatMessage::text) + .collect(joining("\n")); + + final String userMessages = messages.stream() + .filter(message -> message.type() != ChatMessageType.SYSTEM) + .map(this::chatMessageToString) + .collect(joining("\n")); + + final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); + final Map requestParameters = getRequestParameters(prompt); + final String body = Json.toJson(requestParameters); + return body; + } + + protected Map getRequestParameters(String prompt) { + final Map parameters = new HashMap<>(7); + + parameters.put("prompt", prompt); + parameters.put("max_tokens_to_sample", getMaxTokens()); + parameters.put("temperature", getTemperature()); + parameters.put("top_k", topK); + parameters.put("top_p", getTopP()); + parameters.put("stop_sequences", getStopSequences()); + parameters.put("anthropic_version", anthropicVersion); + + return parameters; + } + + /** + * Get model id + * + * @return model id + */ + protected abstract String getModelId(); + +} diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java index 9319e8928d0..334e3b4667b 100644 --- a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java +++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockChatModelIT.java @@ -14,7 +14,7 @@ public class BedrockChatModelIT { @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") +// @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockAnthropicChatModel() { BedrockAnthropicChatModel bedrockChatModel = BedrockAnthropicChatModel @@ -37,7 +37,7 @@ void testBedrockAnthropicChatModel() { } @Test - @Disabled("To run this test, you must have provide your own access key, secret, region") +// @Disabled("To run this test, you must have provide your own access key, secret, region") void testBedrockTitanChatModel() { BedrockTitanChatModel bedrockChatModel = BedrockTitanChatModel diff --git a/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java new file mode 100644 index 00000000000..a65ab2a2c6c --- /dev/null +++ b/langchain4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java @@ -0,0 +1,57 @@ +package dev.langchain4j.model.bedrock; + +import dev.langchain4j.data.message.AiMessage; +import dev.langchain4j.model.StreamingResponseHandler; +import dev.langchain4j.model.output.Response; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.regions.Region; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeoutException; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; + +public class BedrockStreamingChatModelIT { + @Test + @Disabled("To run this test, you must have provide your own access key, secret, region") + void testBedrockAnthropicStreamingChatModel() throws ExecutionException, InterruptedException, TimeoutException { + BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel + .builder() + .temperature(0.50f) + .maxTokens(300) + .region(Region.US_EAST_1) + .maxRetries(1) + .build(); + + CompletableFuture futureAnswer = new CompletableFuture<>(); + CompletableFuture> futureResponse = new CompletableFuture<>(); + bedrockChatModel.generate("What's the capital of Poland?", new StreamingResponseHandler() { + private final StringBuilder answerBuilder = new StringBuilder(); + @Override + public void onNext(String token) { + answerBuilder.append(token); + } + + @Override + public void onComplete(Response response) { + futureAnswer.complete(answerBuilder.toString()); + futureResponse.complete(response); + } + + @Override + public void onError(Throwable throwable) { + System.out.println(throwable); + } + + }); + String answer = futureAnswer.get(30, SECONDS); + Response response = futureResponse.get(30, SECONDS); + + assertThat(answer).contains("Warsaw"); + assertThat(response.content().text()).contains("Warsaw"); + } + +}