Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Amazon Bedrock support #6226

Merged
merged 44 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7462525
Add Bedrock
viveksilimkhan1 Nov 3, 2023
e3efb9f
Update supported models for Bedrock
viveksilimkhan1 Nov 7, 2023
a3f34ed
Fix supports and add extract response in Bedrock
viveksilimkhan1 Nov 7, 2023
5295bc9
Merge branch 'main' into add-bedrock
tstadel Nov 7, 2023
f8a910c
fix errors imports
tstadel Nov 7, 2023
1cd86c7
improve and refactor supports
tstadel Nov 7, 2023
3a343a7
fix install
tstadel Nov 7, 2023
be0e211
fix mypy
tstadel Nov 7, 2023
5c57455
fix pylint
tstadel Nov 7, 2023
7715cd7
fix existing tests
tstadel Nov 7, 2023
ce3e7d6
Added Anthropic Bedrock
viveksilimkhan1 Nov 7, 2023
5e4e306
fix tests
tstadel Nov 7, 2023
f97f7cd
Merge branch 'add-bedrock' of github.com:viveksilimkhan1/haystack int…
tstadel Nov 7, 2023
2447324
fix sagemaker tests
tstadel Nov 8, 2023
35269ac
add default prompt handler, constructor and supports tests
tstadel Nov 8, 2023
54e9f42
more tests
tstadel Nov 8, 2023
00f1c55
invoke refactoring
tstadel Nov 8, 2023
dc24acc
refactor model_kwargs
tstadel Nov 8, 2023
68960f7
fix mypy
tstadel Nov 8, 2023
49f0ce9
lstrip responses
tstadel Nov 8, 2023
662b10c
Add streaming support
viveksilimkhan1 Nov 9, 2023
31fff2b
bump boto3 version
tstadel Nov 9, 2023
45f3153
Merge branch 'add-bedrock' of github.com:viveksilimkhan1/haystack int…
tstadel Nov 9, 2023
5150136
add class docstrings, better exception names
tstadel Nov 13, 2023
fb3a076
fix layer name
tstadel Nov 14, 2023
0b1bcac
add tests for anthropic and cohere model adapters
tstadel Nov 14, 2023
23459da
update cohere params
tstadel Nov 14, 2023
7260ad2
update ai21 args and add tests
tstadel Nov 14, 2023
c5eadd9
support cohere command light model
tstadel Nov 14, 2023
f5f5915
add tital tests
tstadel Nov 14, 2023
6782243
better class names
tstadel Nov 14, 2023
614ae87
support meta llama 2 model
tstadel Nov 14, 2023
8a18e31
fix streaming support
tstadel Nov 14, 2023
ed0f4d0
more future-proof model adapter selection
tstadel Nov 14, 2023
4697d40
fix import
tstadel Nov 14, 2023
693e8bc
fix mypy
tstadel Nov 14, 2023
ec6b995
fix pylint for preview
tstadel Nov 14, 2023
2ec9f81
add tests for streaming
tstadel Nov 14, 2023
d10b91d
add release notes
tstadel Nov 14, 2023
c516d1d
Apply suggestions from code review
tstadel Nov 15, 2023
4a30206
fix format
tstadel Nov 15, 2023
9846fd1
fix tests after msg changes
tstadel Nov 15, 2023
4d70ce5
fix streaming for cohere
tstadel Nov 15, 2023
1997281
Merge branch 'main' into add-bedrock
tstadel Nov 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions haystack/errors.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
"""Custom Errors for Haystack"""

from typing import Optional


tstadel marked this conversation as resolved.
Show resolved Hide resolved
class HaystackError(Exception):
"""
Any error generated by Haystack.
Expand Down Expand Up @@ -202,6 +197,12 @@ class HuggingFaceInferenceUnauthorizedError(HuggingFaceInferenceError):
"""Exception for issues that occur in the HuggingFace inference node due to unauthorized access"""


class AmazonBedrockConfigurationError(NodeError):
"""Exception raised when AmazonBedrock node is not configured correctly"""

def __init__(self, message: Optional[str] = None, send_message_in_event: bool = False):
super().__init__(message=message, send_message_in_event=send_message_in_event)

class SageMakerInferenceError(NodeError):
"""Exception for issues that occur in the SageMaker inference node"""

Expand Down
1 change: 1 addition & 0 deletions haystack/nodes/prompt/invocation_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from haystack.nodes.prompt.invocation_layer.cohere import CohereInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
from haystack.nodes.prompt.invocation_layer.hugging_face_inference import HFInferenceEndpointInvocationLayer
from haystack.nodes.prompt.invocation_layer.amazon_bedrock import AmazonBedrockBaseInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_meta import SageMakerMetaInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_infer import SageMakerHFInferenceInvocationLayer
from haystack.nodes.prompt.invocation_layer.sagemaker_hf_text_gen import SageMakerHFTextGenerationInvocationLayer
134 changes: 134 additions & 0 deletions haystack/nodes/prompt/invocation_layer/amazon_bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import json
import logging
from abc import abstractmethod, ABC
from typing import Optional, Dict, Union, List, Any


from haystack.errors import AmazonBedrockConfigurationError
from haystack.lazy_imports import LazyImport
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
from haystack.nodes.prompt.invocation_layer.handlers import DefaultPromptHandler

logger = logging.getLogger(__name__)


with LazyImport(message="Run 'pip install farm-haystack[aws]'") as boto3_import:
import boto3
from botocore.exceptions import ClientError, BotoCoreError


class AmazonBedrockBaseInvocationLayer(PromptModelInvocationLayer, ABC):
"""
Base class for Amazon Bedrock based invocation layers.
"""

def __init__(
self,
model_name_or_path: str,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
max_length: Optional[str] = 2048,
**kwargs,
tstadel marked this conversation as resolved.
Show resolved Hide resolved
):
super().__init__(model_name_or_path, **kwargs)
self.max_length = max_length

try:
session = AmazonBedrockBaseInvocationLayer.create_session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_profile_name=aws_profile_name,
)
self.client = session.client("bedrock-runtime")
except:
raise AmazonBedrockConfigurationError
tstadel marked this conversation as resolved.
Show resolved Hide resolved

def _ensure_token_limit(self, prompt: Union[str, List[Dict[str, str]]]) -> Union[str, List[Dict[str, str]]]:
# the prompt for this model will be of the type str
print(
"Tokenizer for the bedrock models are not available publicly. The tokens will get truncated automatically"
)
return prompt

@classmethod
def supports(cls, model_name_or_path, **kwargs):
if model_name_or_path in [
"amazon.titan-text-express-v1",
"amazon.titan-text-lite-v1",
"ai21.j2-ultra-v1",
"ai21.j2-mid-v1",
"cohere.command-text-v14",
]:
tstadel marked this conversation as resolved.
Show resolved Hide resolved
return model_name_or_path

@classmethod
def create_session(
self,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
aws_region_name: Optional[str] = None,
aws_profile_name: Optional[str] = None,
**kwargs,
):
"""
Creates an AWS Session with the given parameters.

:param aws_access_key_id: AWS access key ID.
:param aws_secret_access_key: AWS secret access key.
:param aws_session_token: AWS session token.
:param aws_region_name: AWS region name.
:param aws_profile_name: AWS profile name.
:raise NoCredentialsError: If the AWS credentials are not provided or invalid.
:return: The created AWS Session.
"""
boto3_import.check()
return boto3.Session(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
region_name=aws_region_name,
profile_name=aws_profile_name,
)

def _prepare_invoke(self, prompt, **kwargs):
del kwargs["top_k"]
if self.model_name_or_path in ["amazon.titan-text-express-v1", "amazon.titan-text-lite-v1"]:
kwargs["stopSequences"] = kwargs["stop_words"] or []
tstadel marked this conversation as resolved.
Show resolved Hide resolved
kwargs["topP"] = 1 if "topP" not in kwargs else kwargs["topP"]
kwargs["temperature"] = 0.3 if "temperature" not in kwargs else kwargs["temperature"]
tstadel marked this conversation as resolved.
Show resolved Hide resolved
kwargs["maxTokenCount"] = self.max_length
del kwargs["stop_words"]
body = json.dumps({"inputText": prompt, "textGenerationConfig": {**kwargs}})
if self.model_name_or_path in ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]:
kwargs["topP"] = 1 if "topP" not in kwargs else kwargs["topP"]
kwargs["temperature"] = 0.3 if "temperature" not in kwargs else kwargs["temperature"]
kwargs["maxTokens"] = self.max_length
del kwargs["stop_words"]
body = json.dumps({"prompt": prompt, **kwargs})
tstadel marked this conversation as resolved.
Show resolved Hide resolved
if self.model_name_or_path in ["cohere.command-text-v14"]:
kwargs["temperature"] = 0.3 if "temperature" not in kwargs else kwargs["temperature"]
kwargs["max_tokens"] = self.max_length
del kwargs["stop_words"]
body = json.dumps({"prompt": prompt, **kwargs})
tstadel marked this conversation as resolved.
Show resolved Hide resolved
return body

def invoke(self, *args, **kwargs):
client = self.client
prompt = kwargs.get("prompt")
body = self._prepare_invoke(**kwargs)
r = client.invoke_model(
body=body, modelId=self.model_name_or_path, accept="application/json", contentType="application/json"
tstadel marked this conversation as resolved.
Show resolved Hide resolved
)
tstadel marked this conversation as resolved.
Show resolved Hide resolved
if self.model_name_or_path in ["amazon.titan-text-express-v1", "amazon.titan-text-lite-v1"]:
responses = json.loads(r["body"].read().decode())["results"][0]["outputText"]
if self.model_name_or_path in ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]:
responses = json.loads(r["body"].read().decode())["completions"][0]["data"]["text"]
if self.model_name_or_path in ["cohere.command-text-v14"]:
responses = json.loads(r["body"].read().decode())["generations"][0]["text"]
return [responses]
tstadel marked this conversation as resolved.
Show resolved Hide resolved