forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b363150
commit 21452f3
Showing
6 changed files
with
308 additions
and
35 deletions.
There are no files selected for viewing
23 changes: 23 additions & 0 deletions
23
...drock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { | ||
@Getter | ||
public enum Types { | ||
AnthropicClaudeInstantV1("anthropic.claude-instant-v1"), | ||
AnthropicClaudeV1("anthropic.claude-v1"), | ||
AnthropicClaudeV2("anthropic.claude-v2"), | ||
AnthropicClaudeV2_1("anthropic.claude-v2:1"); | ||
|
||
private final String value; | ||
|
||
Types(String modelID) { | ||
this.value = modelID; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
...c/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ChatMessageType; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.internal.Json; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; | ||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; | ||
import software.amazon.awssdk.core.SdkBytes; | ||
import software.amazon.awssdk.regions.Region; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; | ||
|
||
import java.util.ArrayList; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
import static java.util.stream.Collectors.joining; | ||
|
||
/** | ||
* Bedrock chat model | ||
*/ | ||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { | ||
@Getter | ||
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); | ||
|
||
class StreamingResponse { | ||
public String completion; | ||
} | ||
|
||
|
||
@Override | ||
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) { | ||
List<ChatMessage> messages = new ArrayList<>(); | ||
messages.add(new UserMessage(userMessage)); | ||
|
||
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() | ||
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) | ||
.modelId(getModelId()) | ||
.contentType("application/json") | ||
.accept("application/json") | ||
.build(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() | ||
.onChunk(chunk -> { | ||
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); | ||
System.out.println("\n\nChunk:"); | ||
System.out.println(sr.completion); | ||
|
||
handler.onNext(sr.completion); | ||
}) | ||
.build(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() | ||
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) | ||
.onComplete(() -> { | ||
handler.onComplete(Response.from(new AiMessage())); | ||
System.out.println("\n\nComplete"); | ||
}) | ||
.onError(e -> System.out.println("\n\nError: " + e.getMessage())) | ||
.build(); | ||
asyncClient.invokeModelWithResponseStream(request, h).join(); | ||
|
||
} | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
} | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, List<ToolSpecification> toolSpecifications, StreamingResponseHandler<AiMessage> handler) { | ||
StreamingChatLanguageModel.super.generate(messages, toolSpecifications, handler); | ||
} | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, ToolSpecification toolSpecification, StreamingResponseHandler<AiMessage> handler) { | ||
StreamingChatLanguageModel.super.generate(messages, toolSpecification, handler); | ||
} | ||
|
||
/** | ||
* Initialize bedrock client | ||
* | ||
* @return bedrock client | ||
*/ | ||
private BedrockRuntimeAsyncClient initAsyncClient() { | ||
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() | ||
.region(region) | ||
.credentialsProvider(credentialsProvider) | ||
.build(); | ||
return client; | ||
} | ||
|
||
|
||
|
||
} |
111 changes: 111 additions & 0 deletions
111
.../src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ChatMessageType; | ||
import dev.langchain4j.internal.Json; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; | ||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; | ||
import software.amazon.awssdk.regions.Region; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static java.util.stream.Collectors.joining; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public class AbstractSharedBedrockChatModel { | ||
protected static final String HUMAN_PROMPT = "Human:"; | ||
protected static final String ASSISTANT_PROMPT = "Assistant:"; | ||
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; | ||
|
||
protected String getModelId() { | ||
return "anthropic.claude-v2"; | ||
} | ||
|
||
@Builder.Default | ||
protected final String humanPrompt = HUMAN_PROMPT; | ||
@Builder.Default | ||
protected final String assistantPrompt = ASSISTANT_PROMPT; | ||
@Builder.Default | ||
protected final Integer maxRetries = 5; | ||
@Builder.Default | ||
protected final Region region = Region.US_EAST_1; | ||
@Builder.Default | ||
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); | ||
@Builder.Default | ||
protected final int maxTokens = 300; | ||
@Builder.Default | ||
protected final float temperature = 1; | ||
@Builder.Default | ||
protected final float topP = 0.999f; | ||
@Builder.Default | ||
protected final String[] stopSequences = new String[]{}; | ||
@Builder.Default | ||
protected final int topK = 250; | ||
@Builder.Default | ||
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; | ||
|
||
|
||
/** | ||
* Convert chat message to string | ||
* | ||
* @param message chat message | ||
* @return string | ||
*/ | ||
protected String chatMessageToString(ChatMessage message) { | ||
switch (message.type()) { | ||
case SYSTEM: | ||
return message.text(); | ||
case USER: | ||
return humanPrompt + " " + message.text(); | ||
case AI: | ||
return assistantPrompt + " " + message.text(); | ||
case TOOL_EXECUTION_RESULT: | ||
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); | ||
} | ||
|
||
throw new IllegalArgumentException("Unknown message type: " + message.type()); | ||
} | ||
|
||
protected String convertMessagesToAwsBody(List<ChatMessage> messages) { | ||
final String context = messages.stream() | ||
.filter(message -> message.type() == ChatMessageType.SYSTEM) | ||
.map(ChatMessage::text) | ||
.collect(joining("\n")); | ||
|
||
final String userMessages = messages.stream() | ||
.filter(message -> message.type() != ChatMessageType.SYSTEM) | ||
.map(this::chatMessageToString) | ||
.collect(joining("\n")); | ||
|
||
final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); | ||
final Map<String, Object> requestParameters = getRequestParameters(prompt); | ||
final String body = Json.toJson(requestParameters); | ||
return body; | ||
} | ||
|
||
protected Map<String, Object> getRequestParameters(String prompt) { | ||
final Map<String, Object> parameters = new HashMap<>(7); | ||
|
||
parameters.put("prompt", prompt); | ||
parameters.put("max_tokens_to_sample", getMaxTokens()); | ||
parameters.put("temperature", getTemperature()); | ||
parameters.put("top_k", topK); | ||
parameters.put("top_p", getTopP()); | ||
parameters.put("stop_sequences", getStopSequences()); | ||
parameters.put("anthropic_version", anthropicVersion); | ||
|
||
return parameters; | ||
} | ||
|
||
|
||
|
||
|
||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
59 changes: 59 additions & 0 deletions
59
...in4j-bedrock/src/test/java/dev/langchain4j/model/bedrock/BedrockStreamingChatModelIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.output.Response; | ||
import org.junit.jupiter.api.Disabled; | ||
import org.junit.jupiter.api.Test; | ||
import software.amazon.awssdk.regions.Region; | ||
|
||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.TimeoutException; | ||
|
||
import static java.util.concurrent.TimeUnit.SECONDS; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
public class BedrockStreamingChatModelIT { | ||
@Test | ||
// @Disabled("To run this test, you must have provide your own access key, secret, region") | ||
void testBedrockAnthropicStreamingChatModel() throws ExecutionException, InterruptedException, TimeoutException { | ||
BedrockAnthropicStreamingChatModel bedrockChatModel = BedrockAnthropicStreamingChatModel | ||
.builder() | ||
.temperature(0.50f) | ||
.maxTokens(300) | ||
.region(Region.US_EAST_1) | ||
.maxRetries(1) | ||
.build(); | ||
|
||
CompletableFuture<String> futureAnswer = new CompletableFuture<>(); | ||
CompletableFuture<Response<AiMessage>> futureResponse = new CompletableFuture<>(); | ||
bedrockChatModel.generate("What's the capital of Germany?", new StreamingResponseHandler<AiMessage>() { | ||
private final StringBuilder answerBuilder = new StringBuilder(); | ||
@Override | ||
public void onNext(String token) { | ||
System.out.println("onNext: '" + token + "'"); | ||
answerBuilder.append(token); | ||
} | ||
|
||
@Override | ||
public void onComplete(Response<AiMessage> response) { | ||
System.out.println("onComplete: '" + response + "'"); | ||
futureAnswer.complete(answerBuilder.toString()); | ||
futureResponse.complete(response); | ||
|
||
} | ||
|
||
@Override | ||
public void onError(Throwable throwable) { | ||
System.out.println(throwable); | ||
} | ||
|
||
}); | ||
String answer = futureAnswer.get(30, SECONDS); | ||
Response<AiMessage> response = futureResponse.get(30, SECONDS); | ||
|
||
assertThat(answer).contains("Berlin"); | ||
} | ||
|
||
} |