Skip to content

Commit

Permalink
early prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
michalkozminski committed Feb 13, 2024
1 parent b363150 commit 21452f3
Show file tree
Hide file tree
Showing 6 changed files with 308 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel;
import lombok.Getter;
import lombok.experimental.SuperBuilder;

@Getter
@SuperBuilder
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel {
@Getter
public enum Types {
AnthropicClaudeInstantV1("anthropic.claude-instant-v1"),
AnthropicClaudeV1("anthropic.claude-v1"),
AnthropicClaudeV2("anthropic.claude-v2"),
AnthropicClaudeV2_1("anthropic.claude-v2:1");

private final String value;

Types(String modelID) {
this.value = modelID;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;
Expand All @@ -30,7 +31,7 @@
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> implements ChatLanguageModel {
public abstract class AbstractBedrockChatModel<T extends BedrockChatModelResponse> extends AbstractSharedBedrockChatModel implements ChatLanguageModel {
private static final String HUMAN_PROMPT = "Human:";
private static final String ASSISTANT_PROMPT = "Assistant:";

Expand Down Expand Up @@ -58,19 +59,7 @@ public abstract class AbstractBedrockChatModel<T extends BedrockChatModelRespons
@Override
public Response<AiMessage> generate(List<ChatMessage> messages) {

final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
final String body = convertMessagesToAwsBody(messages);

InvokeModelResponse invokeModelResponse = withRetry(() -> invoke(body), maxRetries);
final String response = invokeModelResponse.body().asUtf8String();
Expand All @@ -81,26 +70,8 @@ public Response<AiMessage> generate(List<ChatMessage> messages) {
result.getFinishReason());
}

/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}



/**
* Get request parameters
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Json;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.Response;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;

import static java.util.stream.Collectors.joining;

/**
* Bedrock chat model
*/
@Getter
@SuperBuilder
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel {
@Getter
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient();

class StreamingResponse {
public String completion;
}


@Override
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) {
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage(userMessage));

InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages)))
.modelId(getModelId())
.contentType("application/json")
.accept("application/json")
.build();

InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder()
.onChunk(chunk -> {
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class);
System.out.println("\n\nChunk:");
System.out.println(sr.completion);

handler.onNext(sr.completion);
})
.build();

InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder()
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor)))
.onComplete(() -> {
handler.onComplete(Response.from(new AiMessage()));
System.out.println("\n\nComplete");
})
.onError(e -> System.out.println("\n\nError: " + e.getMessage()))
.build();
asyncClient.invokeModelWithResponseStream(request, h).join();

}

@Override
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) {
}

@Override
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) {
StreamingChatLanguageModel.super.generate(messages, toolSpecifications, handler);
}

@Override
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) {
StreamingChatLanguageModel.super.generate(messages, toolSpecification, handler);
}

/**
* Initialize bedrock client
*
* @return bedrock client
*/
private BedrockRuntimeAsyncClient initAsyncClient() {
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder()
.region(region)
.credentialsProvider(credentialsProvider)
.build();
return client;
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package dev.langchain4j.model.bedrock.internal;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.internal.Json;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.SuperBuilder;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static java.util.stream.Collectors.joining;

@Getter
@SuperBuilder
public class AbstractSharedBedrockChatModel {
protected static final String HUMAN_PROMPT = "Human:";
protected static final String ASSISTANT_PROMPT = "Assistant:";
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31";

protected String getModelId() {
return "anthropic.claude-v2";
}

@Builder.Default
protected final String humanPrompt = HUMAN_PROMPT;
@Builder.Default
protected final String assistantPrompt = ASSISTANT_PROMPT;
@Builder.Default
protected final Integer maxRetries = 5;
@Builder.Default
protected final Region region = Region.US_EAST_1;
@Builder.Default
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build();
@Builder.Default
protected final int maxTokens = 300;
@Builder.Default
protected final float temperature = 1;
@Builder.Default
protected final float topP = 0.999f;
@Builder.Default
protected final String[] stopSequences = new String[]{};
@Builder.Default
protected final int topK = 250;
@Builder.Default
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION;


/**
* Convert chat message to string
*
* @param message chat message
* @return string
*/
protected String chatMessageToString(ChatMessage message) {
switch (message.type()) {
case SYSTEM:
return message.text();
case USER:
return humanPrompt + " " + message.text();
case AI:
return assistantPrompt + " " + message.text();
case TOOL_EXECUTION_RESULT:
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models");
}

throw new IllegalArgumentException("Unknown message type: " + message.type());
}

protected String convertMessagesToAwsBody(List<ChatMessage> messages) {
final String context = messages.stream()
.filter(message -> message.type() == ChatMessageType.SYSTEM)
.map(ChatMessage::text)
.collect(joining("\n"));

final String userMessages = messages.stream()
.filter(message -> message.type() != ChatMessageType.SYSTEM)
.map(this::chatMessageToString)
.collect(joining("\n"));

final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT);
final Map<String, Object> requestParameters = getRequestParameters(prompt);
final String body = Json.toJson(requestParameters);
return body;
}

protected Map<String, Object> getRequestParameters(String prompt) {
final Map<String, Object> parameters = new HashMap<>(7);

parameters.put("prompt", prompt);
parameters.put("max_tokens_to_sample", getMaxTokens());
parameters.put("temperature", getTemperature());
parameters.put("top_k", topK);
parameters.put("top_p", getTopP());
parameters.put("stop_sequences", getStopSequences());
parameters.put("anthropic_version", anthropicVersion);

return parameters;
}





}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
public class BedrockChatModelIT {

@Test
@Disabled("To run this test, you must have provide your own access key, secret, region")
// @Disabled("To run this test, you must have provide your own access key, secret, region")
void testBedrockAnthropicChatModel() {

BedrockAnthropicChatModel bedrockChatModel = BedrockAnthropicChatModel
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package dev.langchain4j.model.bedrock;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.output.Response;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.regions.Region;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;

public class BedrockStreamingChatModelIT {
@Test
// @Disabled("To run this test, you must have provide your own access key, secret, region")
void testBedrockAnthropicStreamingChatModel() throws ExecutionException, InterruptedException, TimeoutException {
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel
.builder()
.temperature(0.50f)
.maxTokens(300)
.region(Region.US_EAST_1)
.maxRetries(1)
.build();

CompletableFuture<String> futureAnswer = new CompletableFuture<>();
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>();
bedrockChatModel.generate("What's the capital of Germany?", new StreamingResponseHandler<AiMessage>() {
private final StringBuilder answerBuilder = new StringBuilder();
@Override
public void onNext(String token) {
System.out.println("onNext: '" + token + "'");
answerBuilder.append(token);
}

@Override
public void onComplete(Response<AiMessage> response) {
System.out.println("onComplete: '" + response + "'");
futureAnswer.complete(answerBuilder.toString());
futureResponse.complete(response);

}

@Override
public void onError(Throwable throwable) {
System.out.println(throwable);
}

});
String answer = futureAnswer.get(30, SECONDS);
Response<AiMessage> response = futureResponse.get(30, SECONDS);

assertThat(answer).contains("Berlin");
}

}

0 comments on commit 21452f3

Please sign in to comment.