Skip to content

Commit

Permalink
Add Meta models support for AWS Bedrock LLM (#764)
Browse files Browse the repository at this point in the history
* Update aws_bedrock.rb

* Revert "Update aws_bedrock.rb"

This reverts commit 34ac330.

* Add Aws Bedrock Meta support

* Update complete.json

* Fix bedrock renaming of response

* Fix specs
  • Loading branch information
Fodoj authored Sep 12, 2024
1 parent feaeb73 commit 8b94656
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 1 deletion.
16 changes: 15 additions & 1 deletion lib/langchain/llm/aws_bedrock.rb
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class AwsBedrock < Base

attr_reader :client, :defaults

SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic cohere ai21].freeze
SUPPORTED_COMPLETION_PROVIDERS = %i[anthropic ai21 cohere meta].freeze
SUPPORTED_CHAT_COMPLETION_PROVIDERS = %i[anthropic].freeze
SUPPORTED_EMBEDDING_PROVIDERS = %i[amazon].freeze

Expand Down Expand Up @@ -209,6 +209,8 @@ def compose_parameters(params)
compose_parameters_cohere params
elsif completion_provider == :ai21
compose_parameters_ai21 params
elsif completion_provider == :meta
compose_parameters_meta params
end
end

Expand All @@ -219,6 +221,8 @@ def parse_response(response)
Langchain::LLM::CohereResponse.new(JSON.parse(response.body.string))
elsif completion_provider == :ai21
Langchain::LLM::AI21Response.new(JSON.parse(response.body.string, symbolize_names: true))
elsif completion_provider == :meta
Langchain::LLM::AwsBedrockMetaResponse.new(JSON.parse(response.body.string))
end
end

Expand Down Expand Up @@ -282,6 +286,16 @@ def compose_parameters_ai21(params)
}
end

def compose_parameters_meta(params)
default_params = @defaults.merge(params)

{
temperature: default_params[:temperature],
top_p: default_params[:top_p],
max_gen_len: default_params[:max_tokens_to_sample]
}
end

def response_from_chunks(chunks)
raw_response = {}

Expand Down
29 changes: 29 additions & 0 deletions lib/langchain/llm/response/aws_bedrock_meta_response.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# frozen_string_literal: true

module Langchain::LLM
class AwsBedrockMetaResponse < BaseResponse
def completion
completions.first
end

def completions
[raw_response.dig("generation")]
end

def stop_reason
raw_response.dig("stop_reason")
end

def prompt_tokens
raw_response.dig("prompt_token_count").to_i
end

def completion_tokens
raw_response.dig("generation_token_count").to_i
end

def total_tokens
prompt_tokens + completion_tokens
end
end
end
6 changes: 6 additions & 0 deletions spec/fixtures/llm/aws_bedrock_meta/complete.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"generation": "The sky has no definitive",
"prompt_token_count": 792,
"generation_token_count": 300,
"stop_reason": "length"
}
15 changes: 15 additions & 0 deletions spec/langchain/llm/response/aws_bedrock_meta_response_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# frozen_string_literal: true

RSpec.describe Langchain::LLM::AwsBedrockMetaResponse do
let(:raw_chat_completions_response) {
JSON.parse File.read("spec/fixtures/llm/aws_bedrock_meta/complete.json")
}

subject { described_class.new(raw_chat_completions_response) }

describe "#complete" do
it "returns completion" do
expect(subject.completion).to eq("The sky has no definitive")
end
end
end

0 comments on commit 8b94656

Please sign in to comment.