Skip to content

Commit

Permalink
aws#4725: Change model deployment to JumpStart
Browse files Browse the repository at this point in the history
  • Loading branch information
HubGab-Git committed Sep 29, 2024
1 parent faf8648 commit 0267902
Showing 1 changed file with 81 additions and 160 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
},
Expand All @@ -56,10 +55,7 @@
},
"outputs": [],
"source": [
"!pip install --upgrade sagemaker --quiet\n",
"!pip install ipywidgets==7.0.0 --quiet\n",
"!pip install langchain==0.0.148 --quiet\n",
"!pip install faiss-cpu --quiet"
"!pip install --upgrade sagemaker --quiet"
]
},
{
Expand All @@ -70,52 +66,11 @@
},
"outputs": [],
"source": [
"import time\n",
"import sagemaker, boto3, json\n",
"from sagemaker.session import Session\n",
"from sagemaker.model import Model\n",
"from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n",
"from sagemaker.predictor import Predictor\n",
"from sagemaker import Session\n",
"from sagemaker.utils import name_from_base\n",
"from typing import Any, Dict, List, Optional\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n",
"\n",
"sagemaker_session = Session()\n",
"aws_role = sagemaker_session.get_caller_identity_arn()\n",
"aws_region = boto3.Session().region_name\n",
"sess = sagemaker.Session()\n",
"model_version = \"1.*\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n",
" client = boto3.client(\"runtime.sagemaker\")\n",
" response = client.invoke_endpoint(\n",
" EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n",
" )\n",
" return response\n",
"\n",
"\n",
"def parse_response_model_flan_t5(query_response):\n",
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
" generated_text = model_predictions[\"generated_texts\"]\n",
" return generated_text\n",
"\n",
"from sagemaker.jumpstart.model import JumpStartModel\n",
"\n",
"def parse_response_multiple_texts_bloomz(query_response):\n",
" generated_text = []\n",
" model_predictions = json.loads(query_response[\"Body\"].read())\n",
" for x in model_predictions[0]:\n",
" generated_text.append(x[\"generated_text\"])\n",
" return generated_text"
"sagemaker_session = Session()"
]
},
{
Expand All @@ -135,30 +90,21 @@
"source": [
"_MODEL_CONFIG_ = {\n",
" \"huggingface-text2text-flan-t5-xxl\": {\n",
" \"instance type\": \"ml.g5.12xlarge\",\n",
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
" \"parse_function\": parse_response_model_flan_t5,\n",
" \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
" \"model_version\": \"2.*\",\n",
" \"instance type\": \"ml.g5.12xlarge\"\n",
" },\n",
" \"huggingface-textembedding-gpt-j-6b\": {\n",
" \"instance type\": \"ml.g5.24xlarge\",\n",
" \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n",
" },\n",
" # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n",
" # \"instance type\": \"ml.g5.12xlarge\",\n",
" # \"env\": {},\n",
" # \"parse_function\": parse_response_multiple_texts_bloomz,\n",
" # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n",
" \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
" \"model_version\": \"1.*\",\n",
" \"instance type\": \"ml.g5.24xlarge\"\n",
" }\n",
" # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n",
" # \"model_version\": \"3.*\",\n",
" # \"instance type\": \"ml.g5.12xlarge\"\n",
" # },\n",
" # \"huggingface-text2text-flan-ul2-bf16\": {\n",
" # \"instance type\": \"ml.g5.24xlarge\",\n",
" # \"env\": {\n",
" # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n",
" # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n",
" # },\n",
" # \"parse_function\": parse_response_model_flan_t5,\n",
" # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n",
" # },\n",
" # \"model_version\": \"2.*\",\n",
" # \"instance type\": \"ml.g5.24xlarge\"\n",
" # }\n",
"}"
]
},
Expand All @@ -168,41 +114,32 @@
"metadata": {},
"outputs": [],
"source": [
"newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n",
"\n",
"for model_id in _MODEL_CONFIG_:\n",
" endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n",
" inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n",
"\n",
" # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n",
" deploy_image_uri = image_uris.retrieve(\n",
" region=None,\n",
" framework=None, # automatically inferred from model_id\n",
" image_scope=\"inference\",\n",
" endpoint_name = name_from_base(f'jumpstart-example-raglc-{model_id}')\n",
" inference_instance_type = _MODEL_CONFIG_[model_id]['instance type']\n",
" model_version = _MODEL_CONFIG_[model_id]['model_version']\n",
"\n",
" print(f'Deploying {model_id}...')\n",
"\n",
" model = JumpStartModel(\n",
" model_id=model_id,\n",
" model_version=model_version,\n",
" instance_type=inference_instance_type,\n",
" )\n",
" # Retrieve the model uri.\n",
" model_uri = model_uris.retrieve(\n",
" model_id=model_id, model_version=model_version, model_scope=\"inference\"\n",
" )\n",
" model_inference = Model(\n",
" image_uri=deploy_image_uri,\n",
" model_data=model_uri,\n",
" role=aws_role,\n",
" predictor_cls=Predictor,\n",
" name=endpoint_name,\n",
" env=_MODEL_CONFIG_[model_id][\"env\"],\n",
" )\n",
" model_predictor_inference = model_inference.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" predictor_cls=Predictor,\n",
" endpoint_name=endpoint_name,\n",
" model_version=model_version\n",
" )\n",
" print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n",
" _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name"
"\n",
" try:\n",
" predictor = model.deploy(\n",
" initial_instance_count=1,\n",
" instance_type=inference_instance_type,\n",
" endpoint_name=name_from_base(\n",
" f\"jumpstart-example-raglc-{model_id}\"\n",
" )\n",
" )\n",
" print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n",
" _MODEL_CONFIG_[model_id]['predictor'] = predictor\n",
" except Exception as e:\n",
" print(f\"Error deploying {model_id}: {str(e)}\")\n",
"\n",
"print(\"Deployment process completed.\")"
]
},
{
Expand All @@ -229,26 +166,16 @@
"metadata": {},
"outputs": [],
"source": [
"payload = {\n",
" \"text_inputs\": question,\n",
" \"max_length\": 100,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 50,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": True,\n",
"}\n",
"\n",
"list_of_LLMs = list(_MODEL_CONFIG_.keys())\n",
"list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n",
"\n",
"list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n",
"\n",
"for model_id in list_of_LLMs:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")"
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
" response = predictor.predict({\n",
" \"inputs\": question\n",
" })\n",
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
" print(f\"{response[0]['generated_text']}\\n\")"
]
},
{
Expand Down Expand Up @@ -283,31 +210,15 @@
"metadata": {},
"outputs": [],
"source": [
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"temperature\": 1,\n",
"}\n",
"prompt = f'Answer based on context:\\n\\n{context}\\n\\n{question}'\n",
"\n",
"for model_id in list_of_LLMs:\n",
" endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n",
"\n",
" prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n",
"\n",
" text_input = prompt.replace(\"{context}\", context)\n",
" text_input = text_input.replace(\"{question}\", question)\n",
" payload = {\"text_inputs\": text_input, **parameters}\n",
"\n",
" query_response = query_endpoint_with_json_payload(\n",
" json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n",
" )\n",
" generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n",
" print(\n",
" f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n",
" )"
" predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n",
" response = predictor.predict({\n",
" \"inputs\": prompt\n",
" })\n",
" print(f\"For model: {model_id}, the generated output is:\\n\")\n",
" print(f\"{response[0]['generated_text']}\\n\")"
]
},
{
Expand Down Expand Up @@ -365,7 +276,11 @@
"outputs": [],
"source": [
"from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n",
"from langchain.embeddings import SagemakerEndpointEmbeddings\n",
"from typing import List\n",
"import boto3\n",
"\n",
"aws_region = boto3.Session().region_name\n",
"\n",
"class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):\n",
" def embed_documents(self, texts: List[str], chunk_size: int = 5) -> List[List[float]]:\n",
Expand Down Expand Up @@ -405,9 +320,12 @@
"\n",
"\n",
"content_handler = ContentHandler()\n",
"endpoint_name=_MODEL_CONFIG_[\n",
" \"huggingface-textembedding-all-MiniLM-L6-v2\"\n",
" ][\"predictor\"].endpoint_name\n",
"\n",
"embeddings = SagemakerEndpointEmbeddingsJumpStart(\n",
" endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n",
" endpoint_name=endpoint_name,\n",
" region_name=aws_region,\n",
" content_handler=content_handler,\n",
")"
Expand All @@ -428,33 +346,35 @@
"source": [
"from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n",
"\n",
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"temperature\": 1,\n",
"}\n",
"\n",
"\n",
"class ContentHandler(LLMContentHandler):\n",
" content_type = \"application/json\"\n",
" accepts = \"application/json\"\n",
"\n",
" def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n",
" input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n",
" input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n",
" return input_str.encode(\"utf-8\")\n",
"\n",
" def transform_output(self, output: bytes) -> str:\n",
" response_json = json.loads(output.read().decode(\"utf-8\"))\n",
" return response_json[\"generated_texts\"][0]\n",
" return response_json[0][\"generated_text\"]\n",
"\n",
"\n",
"content_handler = ContentHandler()\n",
"endpoint_name=_MODEL_CONFIG_[\n",
" \"huggingface-text2text-flan-t5-xxl\"\n",
" ][\"predictor\"].endpoint_name\n",
"\n",
"parameters = {\n",
" \"max_length\": 200,\n",
" \"num_return_sequences\": 1,\n",
" \"top_k\": 250,\n",
" \"top_p\": 0.95,\n",
" \"do_sample\": False,\n",
" \"temperature\": 1,\n",
"}\n",
"\n",
"sm_llm = SagemakerEndpoint(\n",
" endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n",
" endpoint_name=endpoint_name,\n",
" region_name=aws_region,\n",
" model_kwargs=parameters,\n",
" content_handler=content_handler,\n",
Expand Down Expand Up @@ -568,7 +488,8 @@
"from langchain.text_splitter import CharacterTextSplitter\n",
"from langchain import PromptTemplate\n",
"from langchain.chains.question_answering import load_qa_chain\n",
"from langchain.document_loaders.csv_loader import CSVLoader"
"from langchain.document_loaders.csv_loader import CSVLoader\n",
"import json"
]
},
{
Expand Down Expand Up @@ -1384,9 +1305,9 @@
],
"instance_type": "ml.t3.medium",
"kernelspec": {
"display_name": "Python 3 (Data Science 2.0)",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1398,7 +1319,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 0267902

Please sign in to comment.