diff --git a/docs/docs/integrations/llms/sagemaker.ipynb b/docs/docs/integrations/llms/sagemaker.ipynb index 067aeaaa5600d..32659dd2aaa7c 100644 --- a/docs/docs/integrations/llms/sagemaker.ipynb +++ b/docs/docs/integrations/llms/sagemaker.ipynb @@ -82,6 +82,15 @@ "]" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example to initialize with external boto3 session\n", + "\n", + "### for cross account scenarios" + ] + }, { "cell_type": "code", "execution_count": null, @@ -92,7 +101,77 @@ "source": [ "from typing import Dict\n", "\n", - "from langchain.prompts import PromptTemplate\nfrom langchain.llms import SagemakerEndpoint\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.llms import SagemakerEndpoint\n", + "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", + "from langchain.chains.question_answering import load_qa_chain\n", + "import json\n", + "import boto3\n", + "\n", + "query = \"\"\"How long was Elizabeth hospitalized?\n", + "\"\"\"\n", + "\n", + "prompt_template = \"\"\"Use the following pieces of context to answer the question at the end.\n", + "\n", + "{context}\n", + "\n", + "Question: {question}\n", + "Answer:\"\"\"\n", + "PROMPT = PromptTemplate(\n", + " template=prompt_template, input_variables=[\"context\", \"question\"]\n", + ")\n", + "\n", + "roleARN = 'arn:aws:iam::123456789:role/cross-account-role'\n", + "sts_client = boto3.client('sts')\n", + "response = sts_client.assume_role(RoleArn=roleARN, \n", + " RoleSessionName='CrossAccountSession')\n", + "\n", + "client = boto3.client(\n", + " \"sagemaker-runtime\",\n", + " region_name=\"us-west-2\", \n", + " aws_access_key_id=response['Credentials']['AccessKeyId'],\n", + " aws_secret_access_key=response['Credentials']['SecretAccessKey'],\n", + " aws_session_token = response['Credentials']['SessionToken']\n", + ")\n", + "\n", + "class ContentHandler(LLMContentHandler):\n", + " content_type = \"application/json\"\n", + " accepts = \"application/json\"\n", + "\n", + " def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:\n", + " input_str = json.dumps({prompt: prompt, **model_kwargs})\n", + " return input_str.encode(\"utf-8\")\n", + "\n", + " def transform_output(self, output: bytes) -> str:\n", + " response_json = json.loads(output.read().decode(\"utf-8\"))\n", + " return response_json[0][\"generated_text\"]\n", + "\n", + "\n", + "content_handler = ContentHandler()\n", + "\n", + "chain = load_qa_chain(\n", + " llm=SagemakerEndpoint(\n", + " endpoint_name=\"endpoint-name\",\n", + " client=client,\n", + " model_kwargs={\"temperature\": 1e-10},\n", + " content_handler=content_handler,\n", + " ),\n", + " prompt=PROMPT,\n", + ")\n", + "\n", + "chain({\"input_documents\": docs, \"question\": query}, return_only_outputs=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict\n", + "\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.llms import SagemakerEndpoint\n", "from langchain.llms.sagemaker_endpoint import LLMContentHandler\n", "from langchain.chains.question_answering import load_qa_chain\n", "import json\n", diff --git a/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb b/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb index ec80112e1019d..98d423890db56 100644 --- a/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb +++ b/docs/docs/integrations/text_embedding/sagemaker-endpoint.ipynb @@ -43,7 +43,7 @@ "from langchain.embeddings import SagemakerEndpointEmbeddings\n", "from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n", "import json\n", - "\n", + "import boto3\n", "\n", "class ContentHandler(EmbeddingsContentHandler):\n", " content_type = \"application/json\"\n", @@ -87,7 +87,18 @@ " endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\",\n", " region_name=\"us-east-1\",\n", " content_handler=content_handler,\n", - ")" + ")\n", + "\n", + "\n", + "# client = boto3.client(\n", + "# \"sagemaker-runtime\",\n", + "# region_name=\"us-west-2\" \n", + "# )\n", + "# embeddings = SagemakerEndpointEmbeddings(\n", + "# endpoint_name=\"huggingface-pytorch-inference-2023-03-21-16-14-03-834\", \n", + "# client=client\n", + "# content_handler=content_handler,\n", + "# )" ] }, { diff --git a/libs/langchain/langchain/embeddings/sagemaker_endpoint.py b/libs/langchain/langchain/embeddings/sagemaker_endpoint.py index 6bfd29f2c227a..0e724624ae89d 100644 --- a/libs/langchain/langchain/embeddings/sagemaker_endpoint.py +++ b/libs/langchain/langchain/embeddings/sagemaker_endpoint.py @@ -46,8 +46,18 @@ class SagemakerEndpointEmbeddings(BaseModel, Embeddings): region_name=region_name, credentials_profile_name=credentials_profile_name ) + + #Use with boto3 client + client = boto3.client( + "sagemaker-runtime", + region_name=region_name + ) + se = SagemakerEndpointEmbeddings( + endpoint_name=endpoint_name, + client=client + ) """ - client: Any #: :meta private: + client: Any = None endpoint_name: str = "" """The name of the endpoint from the deployed Sagemaker model. @@ -106,6 +116,10 @@ class Config: @root_validator() def validate_environment(cls, values: Dict) -> Dict: + """Dont do anything if client provided externally""" + if values.get("client") is not None: + return values + """Validate that AWS credentials to and python package exists in environment.""" try: import boto3