diff --git a/lib/langchain/llm/aws_bedrock.rb b/lib/langchain/llm/aws_bedrock.rb index 9ba47f88d..0a87ebd0d 100644 --- a/lib/langchain/llm/aws_bedrock.rb +++ b/lib/langchain/llm/aws_bedrock.rb @@ -48,7 +48,7 @@ class AwsBedrock < Base attr_reader :client, :defaults - SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic cohere ai21].freeze + SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic ai21 cohere meta].freeze SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon].freeze @@ -209,6 +209,8 @@ def compose_parameters(params) compose_parameters_cohere params elsif completion_provider == :ai21 compose_parameters_ai21 params + elsif completion_provider == :meta + compose_parameters_meta params end end @@ -219,6 +221,8 @@ def parse_response(response) Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string)) elsif completion_provider == :ai21 Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true)) + elsif completion_provider == :meta + Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string)) end end @@ -282,6 +286,16 @@ def compose_parameters_ai21(params) } end + def compose_parameters_meta(params) + default_params = @defaults.merge(params) + + { + temperature: default_params[:temperature], + top_p: default_params[:top_p], + max_gen_len: default_params[:max_tokens_to_sample] + } + end + def response_from_chunks(chunks) raw_response = {} diff --git a/lib/langchain/llm/response/aws_bedrock_meta_response.rb b/lib/langchain/llm/response/aws_bedrock_meta_response.rb new file mode 100644 index 000000000..fd05ef91c --- /dev/null +++ b/lib/langchain/llm/response/aws_bedrock_meta_response.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +module Langchain::LLM + class AwsBedrockMetaResponse < BaseResponse + def completion + completions.first + end + + def completions + [raw_response.dig("generation")] + end + + def stop_reason + raw_response.dig("stop_reason") + end + + def prompt_tokens + raw_response.dig("prompt_token_count").to_i + end + + def completion_tokens + raw_response.dig("generation_token_count").to_i + end + + def total_tokens + prompt_tokens + completion_tokens + end + end +end diff --git a/spec/fixtures/llm/aws_bedrock_meta/complete.json b/spec/fixtures/llm/aws_bedrock_meta/complete.json new file mode 100644 index 000000000..e4c2fb8a4 --- /dev/null +++ b/spec/fixtures/llm/aws_bedrock_meta/complete.json @@ -0,0 +1,6 @@ +{ + "generation": "The sky has no definitive", + "prompt_token_count": 792, + "generation_token_count": 300, + "stop_reason": "length" +} diff --git a/spec/langchain/llm/response/aws_bedrock_meta_response_spec.rb b/spec/langchain/llm/response/aws_bedrock_meta_response_spec.rb new file mode 100644 index 000000000..f0f2e168d --- /dev/null +++ b/spec/langchain/llm/response/aws_bedrock_meta_response_spec.rb @@ -0,0 +1,15 @@ +# frozen_string_literal: true + +RSpec.describe Langchain::LLM::AwsBedrockMetaResponse do + let(:raw_chat_completions_response) { + JSON.parse File.read("spec/fixtures/llm/aws_bedrock_meta/complete.json") + } + + subject { described_class.new(raw_chat_completions_response) } + + describe "#complete" do + it "returns completion" do + expect(subject.completion).to eq("The sky has no definitive") + end + end +end