forked from langchain4j/langchain4j
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add streaming API for Bedrock Anthropics
- Loading branch information
1 parent
b363150
commit 6d20783
Showing
6 changed files
with
292 additions
and
64 deletions.
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
...drock/src/main/java/dev/langchain4j/model/bedrock/BedrockAnthropicStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package dev.langchain4j.model.bedrock; | ||
|
||
import dev.langchain4j.model.bedrock.internal.AbstractBedrockStreamingChatModel; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public class BedrockAnthropicStreamingChatModel extends AbstractBedrockStreamingChatModel { | ||
@Builder.Default | ||
private final BedrockAnthropicChatModel.Types model = BedrockAnthropicChatModel.Types.AnthropicClaudeV2; | ||
|
||
@Override | ||
protected String getModelId() { | ||
return model.getValue(); | ||
} | ||
|
||
@Getter | ||
/** | ||
* Bedrock Anthropic model ids | ||
*/ | ||
public enum Types { | ||
AnthropicClaudeV2("anthropic.claude-v2"), | ||
AnthropicClaudeV2_1("anthropic.claude-v2:1"); | ||
|
||
private final String value; | ||
|
||
Types(String modelID) { | ||
this.value = modelID; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
...c/main/java/dev/langchain4j/model/bedrock/internal/AbstractBedrockStreamingChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.agent.tool.ToolSpecification; | ||
import dev.langchain4j.data.message.AiMessage; | ||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.UserMessage; | ||
import dev.langchain4j.internal.Json; | ||
import dev.langchain4j.model.StreamingResponseHandler; | ||
import dev.langchain4j.model.chat.StreamingChatLanguageModel; | ||
import dev.langchain4j.model.output.Response; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.core.SdkBytes; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamRequest; | ||
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelWithResponseStreamResponseHandler; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.concurrent.atomic.AtomicReference; | ||
|
||
/** | ||
* Bedrock Streaming chat model | ||
*/ | ||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractBedrockStreamingChatModel extends AbstractSharedBedrockChatModel implements StreamingChatLanguageModel { | ||
@Getter | ||
private final BedrockRuntimeAsyncClient asyncClient = initAsyncClient(); | ||
|
||
class StreamingResponse { | ||
public String completion; | ||
} | ||
|
||
@Override | ||
public void generate(String userMessage, StreamingResponseHandler<AiMessage> handler) { | ||
List<ChatMessage> messages = new ArrayList<>(); | ||
messages.add(new UserMessage(userMessage)); | ||
generate(messages, handler); | ||
} | ||
|
||
@Override | ||
public void generate(List<ChatMessage> messages, StreamingResponseHandler<AiMessage> handler) { | ||
InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() | ||
.body(SdkBytes.fromUtf8String(convertMessagesToAwsBody(messages))) | ||
.modelId(getModelId()) | ||
.contentType("application/json") | ||
.accept("application/json") | ||
.build(); | ||
|
||
AtomicReference<String> finalCompletion = new AtomicReference<>(""); | ||
|
||
InvokeModelWithResponseStreamResponseHandler.Visitor visitor = InvokeModelWithResponseStreamResponseHandler.Visitor.builder() | ||
.onChunk(chunk -> { | ||
StreamingResponse sr = Json.fromJson(chunk.bytes().asUtf8String(), StreamingResponse.class); | ||
finalCompletion.set(finalCompletion.get() + sr.completion); | ||
handler.onNext(sr.completion); | ||
}) | ||
.build(); | ||
|
||
InvokeModelWithResponseStreamResponseHandler h = InvokeModelWithResponseStreamResponseHandler.builder() | ||
.onEventStream(stream -> stream.subscribe(event -> event.accept(visitor))) | ||
.onComplete(() -> { | ||
handler.onComplete(Response.from(new AiMessage(finalCompletion.get()))); | ||
}) | ||
.onError(handler::onError) | ||
.build(); | ||
asyncClient.invokeModelWithResponseStream(request, h).join(); | ||
|
||
} | ||
|
||
/** | ||
* Initialize async bedrock client | ||
* | ||
* @return async bedrock client | ||
*/ | ||
private BedrockRuntimeAsyncClient initAsyncClient() { | ||
BedrockRuntimeAsyncClient client = BedrockRuntimeAsyncClient.builder() | ||
.region(region) | ||
.credentialsProvider(credentialsProvider) | ||
.build(); | ||
return client; | ||
} | ||
|
||
|
||
|
||
} |
110 changes: 110 additions & 0 deletions
110
.../src/main/java/dev/langchain4j/model/bedrock/internal/AbstractSharedBedrockChatModel.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
package dev.langchain4j.model.bedrock.internal; | ||
|
||
import dev.langchain4j.data.message.ChatMessage; | ||
import dev.langchain4j.data.message.ChatMessageType; | ||
import dev.langchain4j.internal.Json; | ||
import lombok.Builder; | ||
import lombok.Getter; | ||
import lombok.experimental.SuperBuilder; | ||
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; | ||
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; | ||
import software.amazon.awssdk.regions.Region; | ||
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static java.util.stream.Collectors.joining; | ||
|
||
@Getter | ||
@SuperBuilder | ||
public abstract class AbstractSharedBedrockChatModel { | ||
protected static final String HUMAN_PROMPT = "Human:"; | ||
protected static final String ASSISTANT_PROMPT = "Assistant:"; | ||
protected static final String DEFAULT_ANTHROPIC_VERSION = "bedrock-2023-05-31"; | ||
|
||
@Builder.Default | ||
protected final String humanPrompt = HUMAN_PROMPT; | ||
@Builder.Default | ||
protected final String assistantPrompt = ASSISTANT_PROMPT; | ||
@Builder.Default | ||
protected final Integer maxRetries = 5; | ||
@Builder.Default | ||
protected final Region region = Region.US_EAST_1; | ||
@Builder.Default | ||
protected final AwsCredentialsProvider credentialsProvider = DefaultCredentialsProvider.builder().build(); | ||
@Builder.Default | ||
protected final int maxTokens = 300; | ||
@Builder.Default | ||
protected final float temperature = 1; | ||
@Builder.Default | ||
protected final float topP = 0.999f; | ||
@Builder.Default | ||
protected final String[] stopSequences = new String[]{}; | ||
@Builder.Default | ||
protected final int topK = 250; | ||
@Builder.Default | ||
protected final String anthropicVersion = DEFAULT_ANTHROPIC_VERSION; | ||
|
||
|
||
/** | ||
* Convert chat message to string | ||
* | ||
* @param message chat message | ||
* @return string | ||
*/ | ||
protected String chatMessageToString(ChatMessage message) { | ||
switch (message.type()) { | ||
case SYSTEM: | ||
return message.text(); | ||
case USER: | ||
return humanPrompt + " " + message.text(); | ||
case AI: | ||
return assistantPrompt + " " + message.text(); | ||
case TOOL_EXECUTION_RESULT: | ||
throw new IllegalArgumentException("Tool execution results are not supported for Bedrock models"); | ||
} | ||
|
||
throw new IllegalArgumentException("Unknown message type: " + message.type()); | ||
} | ||
|
||
protected String convertMessagesToAwsBody(List<ChatMessage> messages) { | ||
final String context = messages.stream() | ||
.filter(message -> message.type() == ChatMessageType.SYSTEM) | ||
.map(ChatMessage::text) | ||
.collect(joining("\n")); | ||
|
||
final String userMessages = messages.stream() | ||
.filter(message -> message.type() != ChatMessageType.SYSTEM) | ||
.map(this::chatMessageToString) | ||
.collect(joining("\n")); | ||
|
||
final String prompt = String.format("%s\n\n%s\n%s", context, userMessages, ASSISTANT_PROMPT); | ||
final Map<String, Object> requestParameters = getRequestParameters(prompt); | ||
final String body = Json.toJson(requestParameters); | ||
return body; | ||
} | ||
|
||
protected Map<String, Object> getRequestParameters(String prompt) { | ||
final Map<String, Object> parameters = new HashMap<>(7); | ||
|
||
parameters.put("prompt", prompt); | ||
parameters.put("max_tokens_to_sample", getMaxTokens()); | ||
parameters.put("temperature", getTemperature()); | ||
parameters.put("top_k", topK); | ||
parameters.put("top_p", getTopP()); | ||
parameters.put("stop_sequences", getStopSequences()); | ||
parameters.put("anthropic_version", anthropicVersion); | ||
|
||
return parameters; | ||
} | ||
|
||
/** | ||
* Get model id | ||
* | ||
* @return model id | ||
*/ | ||
protected abstract String getModelId(); | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.