diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/README.md b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/README.md new file mode 100644 index 0000000000..df1f8177c2 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/README.md @@ -0,0 +1,38 @@ +# NVIDIA Triton Inference Server on SageMaker - Hugging Face Sentence Transformers + +## Introduction + +[HuggingFace Sentence Transformers](https://huggingface.co/sentence-transformers) is a Machine Learning (ML) framework and set of pre-trained models to +extract embeddings from sentence, text, and image. The models in this group can also be used with the default methods exposed through the [Transformers](https://www.google.com/search?q=transofrmers+githbu&rlz=1C5GCEM_enES937ES938&oq=transofrmers+githbu&aqs=chrome..69i57.3022j0j7&sourceid=chrome&ie=UTF-8) library. + +[NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/) is a high-performance ML model server, which enables the deployment of ML models in an easy, scalable, and cost-effective way. It also exposes many easy-to-use optimization features to make the most of the underlying hardware, in particular NVIDIA GPU's. + +In this example, we walk through how you can: +* Create an Amazon SageMaker Studio image based on the official [NVIDIA PyTorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) image, which includes the necessary dependencies to optimize your model +* Optimize a pre-trained HuggingFace Sentence Transformers model with NVIDIA TensorRT to enable high-performance inference +* Create a Triton Model Ensemble, which will allow you to run in sequence a pre-processing step (input tokenization), model inference and post-processing, where sentence embeddings are computed from the raw token embeddings + +This example is meant to serve as a basis for use-cases in which you need to run your own code before and/or after your model, allowing you to optimize the bulk of the computation (the model) using tools such as TensorRT. + +Triton Model Ensamble + +#### ! Important: The example provided can be tested also by using Amazon SageMaker Notebook Instances + +### Prerequisites + +1. Required NVIDIA NGC Account. Follow the instruction https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#registering-activating-ngc-account + +## Step 1: Clone this repository + +## Step 2: Build Studio image + +In this example, we provide a [Dokerfile](./studio-image/image_tensorrt/Dockerfile) example to build a custom image for SageMaker Studio. + +To build the image, push it and make it available in your Amazon SageMaker Studio environment, edit [sagemaker-studio-config](./studio-image/studio-domain-config.json) by replacing `$DOMAIN_ID` with your Studio domain ID. + +We also provide automation scripts in order to [build and push](./studio-image/build_image.sh) your docker image to an ECR repository +and [create](./studio-image/create_studio_image.sh) or [update](./studio-image/update_studio_image.sh) an Amazon SageMaker Image. Please follow the instructions in the [README](./studio-image/README.md) for additional info on the usage of this script. + +## Step 3: Compile model, create an Amazon SageMaker Real-Time Endpoint with NVIDIA Triton Inference Server + +Clone this repository into your Amazon SageMaker Studio environment and execute the cells in the [notebook](./examples/triton_sentence_embeddings.ipynb) \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/bert-trt/config.pbtxt b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/bert-trt/config.pbtxt new file mode 100644 index 0000000000..3bf605dfc4 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/bert-trt/config.pbtxt @@ -0,0 +1,32 @@ +name: "bert-trt" +platform: "tensorrt_plan" +max_batch_size: 16 +input [ + { + name: "token_ids" + data_type: TYPE_INT32 + dims: [128] + }, + { + name: "attn_mask" + data_type: TYPE_INT32 + dims: [128] + } +] +output [ + { + name: "output" + data_type: TYPE_FP32 + dims: [128, 384] + }, + { + name: "854" + data_type: TYPE_FP32 + dims: [384] + } +] +instance_group [ + { + kind: KIND_GPU + } + ] \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/1/README.md b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/1/README.md new file mode 100644 index 0000000000..a8e639ff9d --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/1/README.md @@ -0,0 +1 @@ +Do not delete me! \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/config.pbtxt b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/config.pbtxt new file mode 100644 index 0000000000..aa36dd1ded --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/config.pbtxt @@ -0,0 +1,70 @@ +name: "ensemble" +platform: "ensemble" +max_batch_size: 16 +input [ + { + name: "INPUT0" + data_type: TYPE_STRING + dims: [ 1 ] + } +] +output [ + { + name: "finaloutput" + data_type: TYPE_FP32 + dims: [384] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocess" + model_version: -1 + input_map { + key: "INPUT0" + value: "INPUT0" + } + output_map { + key: "OUTPUT0" + value: "token_ids" + } + output_map { + key: "OUTPUT1" + value: "attn_mask" + } + }, + { + model_name: "bert-trt" + model_version: -1 + input_map { + key: "token_ids" + value: "token_ids" + } + input_map { + key: "attn_mask" + value: "attn_mask" + } + output_map { + key: "output" + value: "output" + } + }, + { + model_name: "postprocess" + model_version: -1 + input_map { + key: "TOKEN_EMBEDS_POST" + value: "output" + } + input_map { + key: "ATTENTION_POST" + value: "attn_mask" + } + output_map { + key: "SENT_EMBED" + value: "finaloutput" + } + + } + ] +} \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/1/model.py b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/1/model.py new file mode 100644 index 0000000000..5373e90dbb --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/1/model.py @@ -0,0 +1,78 @@ +import json +import logging +import numpy as np +import subprocess +import sys +import os + +import triton_python_backend_utils as pb_utils + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class TritonPythonModel: + """This model loops through different dtypes to make sure that + serialize_byte_tensor works correctly in the Python backend. + """ + + def __mean_pooling(self, token_embeddings, attention_mask): + logger.info("token_embeddings: {}".format(token_embeddings)) + logger.info("attention_mask: {}".format(attention_mask)) + + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + + def initialize(self, args): + self.model_dir = args['model_repository'] + subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt']) + global torch + import torch + + self.device_id = args['model_instance_device_id'] + self.model_config = model_config = json.loads(args['model_config']) + self.device = torch.device(f'cuda:{self.device_id}') if torch.cuda.is_available() else torch.device('cpu') + + output0_config = pb_utils.get_output_config_by_name( + model_config, "SENT_EMBED") + + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config["data_type"]) + + def execute(self, requests): + + responses = [] + + for request in requests: + tok_embeds = pb_utils.get_input_tensor_by_name(request, "TOKEN_EMBEDS_POST") + attn_mask = pb_utils.get_input_tensor_by_name(request, "ATTENTION_POST") + + tok_embeds = tok_embeds.as_numpy() + + logger.info("tok_embeds: {}".format(tok_embeds)) + logger.info("tok_embeds shape: {}".format(tok_embeds.shape)) + + tok_embeds = torch.tensor(tok_embeds,device=self.device) + + logger.info("tok_embeds_tensor: {}".format(tok_embeds)) + + attn_mask = attn_mask.as_numpy() + + logger.info("attn_mask: {}".format(attn_mask)) + logger.info("attn_mask shape: {}".format(attn_mask.shape)) + + attn_mask = torch.tensor(attn_mask,device=self.device) + + logger.info("attn_mask_tensor: {}".format(attn_mask)) + + sentence_embeddings = self.__mean_pooling(tok_embeds, attn_mask) + sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) + + out_0 = np.array(sentence_embeddings.cpu(),dtype=self.output0_dtype) + logger.info("out_0: {}".format(out_0)) + + out_tensor_0 = pb_utils.Tensor("SENT_EMBED", out_0) + logger.info("out_tensor_0: {}".format(out_tensor_0)) + + responses.append(pb_utils.InferenceResponse([out_tensor_0])) + + return responses \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/config.pbtxt b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/config.pbtxt new file mode 100644 index 0000000000..de573d924d --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/config.pbtxt @@ -0,0 +1,26 @@ +name: "postprocess" +backend: "python" +max_batch_size: 16 + +input [ + { + name: "TOKEN_EMBEDS_POST" + data_type: TYPE_FP32 + dims: [128, 384] + + }, + { + name: "ATTENTION_POST" + data_type: TYPE_INT32 + dims: [128] + } +] +output [ + { + name: "SENT_EMBED" + data_type: TYPE_FP32 + dims: [ 384 ] + } +] + +instance_group [{ kind: KIND_GPU }] \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/requirements.txt b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/requirements.txt new file mode 100644 index 0000000000..08ed5eeb4b --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/requirements.txt @@ -0,0 +1 @@ +torch \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/1/model.py b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/1/model.py new file mode 100644 index 0000000000..47b1f1befc --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/1/model.py @@ -0,0 +1,74 @@ +import json +import logging +import numpy as np +import subprocess +import sys + +import triton_python_backend_utils as pb_utils + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class TritonPythonModel: + """This model loops through different dtypes to make sure that + serialize_byte_tensor works correctly in the Python backend. + """ + + def initialize(self, args): + self.model_dir = args['model_repository'] + subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt']) + global transformers + import transformers + + self.tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + self.model_config = model_config = json.loads(args['model_config']) + + output0_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT0") + output1_config = pb_utils.get_output_config_by_name( + model_config, "OUTPUT1") + + self.output0_dtype = pb_utils.triton_string_to_numpy( + output0_config['data_type']) + self.output1_dtype = pb_utils.triton_string_to_numpy( + output0_config['data_type']) + + def execute(self, requests): + + file = open("logs.txt", "w") + + responses = [] + for request in requests: + logger.info("Request: {}".format(request)) + + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_0 = in_0.as_numpy() + + logger.info("in_0: {}".format(in_0)) + + tok_batch = [] + + for i in range(in_0.shape[0]): + decoded_object = in_0[i,0].decode() + + logger.info("decoded_object: {}".format(decoded_object)) + + tok_batch.append(decoded_object) + + logger.info("tok_batch: {}".format(tok_batch)) + + tok_sent = self.tokenizer(tok_batch, + padding='max_length', + max_length=128, + ) + + + logger.info("Tokens: {}".format(tok_sent)) + + out_0 = np.array(tok_sent['input_ids'],dtype=self.output0_dtype) + out_1 = np.array(tok_sent['attention_mask'],dtype=self.output1_dtype) + out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0) + out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1) + + responses.append(pb_utils.InferenceResponse([out_tensor_0,out_tensor_1])) + return responses diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/config.pbtxt b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/config.pbtxt new file mode 100644 index 0000000000..c3f70b03dd --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/config.pbtxt @@ -0,0 +1,28 @@ +name: "preprocess" +backend: "python" +max_batch_size: 16 + +input [ + { + name: "INPUT0" + data_type: TYPE_STRING + dims: [ 1 ] + + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_INT32 + dims: [ 128 ] + }, + + { + name: "OUTPUT1" + data_type: TYPE_INT32 + dims: [ 128 ] + } + +] + +instance_group [{ kind: KIND_CPU }] \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/requirements.txt b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/requirements.txt new file mode 100644 index 0000000000..747b7aa97a --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/requirements.txt @@ -0,0 +1 @@ +transformers \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/triton_sentence_embeddings.ipynb b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/triton_sentence_embeddings.ipynb new file mode 100644 index 0000000000..5fef5c4405 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/triton_sentence_embeddings.ipynb @@ -0,0 +1,1015 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ce5723a5", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Prerequisites\n", + "\n", + "Install the necessary Python modules to use and interact with [NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f5995424", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "! pip install torch==1.10.0 sagemaker transformers==4.9.1 tritonclient[all]" + ] + }, + { + "cell_type": "markdown", + "id": "4d9d12fa", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Part 1 - Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3aef44c4", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import argparse\n", + "import boto3\n", + "import copy\n", + "import datetime\n", + "import json\n", + "import numpy as np\n", + "import os\n", + "import pandas as pd\n", + "import pprint\n", + "import re\n", + "import sagemaker\n", + "import sys\n", + "import time\n", + "from time import gmtime, strftime\n", + "import tritonclient.http as http_client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe9a1636", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "session = boto3.Session()\n", + "role = sagemaker.get_execution_role()\n", + "\n", + "sm_client = session.client(\"sagemaker\")\n", + "sagemaker_session = sagemaker.Session(boto_session=session)\n", + "sm_runtime_client = boto3.client(\"sagemaker-runtime\")\n", + "\n", + "region = boto3.Session().region_name" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4996fe6", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "account_id_map = {\n", + " \"us-east-1\": \"785573368785\",\n", + " \"us-east-2\": \"007439368137\",\n", + " \"us-west-1\": \"710691900526\",\n", + " \"us-west-2\": \"301217895009\",\n", + " \"eu-west-1\": \"802834080501\",\n", + " \"eu-west-2\": \"205493899709\",\n", + " \"eu-west-3\": \"254080097072\",\n", + " \"eu-north-1\": \"601324751636\",\n", + " \"eu-south-1\": \"966458181534\",\n", + " \"eu-central-1\": \"746233611703\",\n", + " \"ap-east-1\": \"110948597952\",\n", + " \"ap-south-1\": \"763008648453\",\n", + " \"ap-northeast-1\": \"941853720454\",\n", + " \"ap-northeast-2\": \"151534178276\",\n", + " \"ap-southeast-1\": \"324986816169\",\n", + " \"ap-southeast-2\": \"355873309152\",\n", + " \"cn-northwest-1\": \"474822919863\",\n", + " \"cn-north-1\": \"472730292857\",\n", + " \"sa-east-1\": \"756306329178\",\n", + " \"ca-central-1\": \"464438896020\",\n", + " \"me-south-1\": \"836785723513\",\n", + " \"af-south-1\": \"774647643957\",\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "d31659f5", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "id": "14a0ba73", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Part 2 - Generate TensorRT Model\n", + "\n", + "In the following cells, we are using [HuggingFace Auto Classes](https://huggingface.co/docs/transformers/model_doc/auto) to load a pre-trained model from the [HuggingFace Model Hub](https://huggingface.co/models). We then convert the model to the ONNX format, and compile it using NVIDIA TensorRT - namely its command-line wrapper tool, `trtexec` -, using the scripts provided in the official AWS Sample for [SageMaker Triton](https://github.com/aws/amazon-sagemaker-examples/tree/main/sagemaker-triton).\n", + "\n", + "NVIDIA TensorRT is an SDK that facilitates high-performance machine learning inference. You can use it to create `engines` from models that have already been trained, \n", + "optimizing for a selected GPU architecture. Triton natively supports the TensorRT runtime, which enables you to easily deploy a TensorRT engine and pair it with the rich features that Triton provides.\n", + "\n", + "### Parameters:\n", + "\n", + "* `model_name`: Model identifier from the Hugging Face model hub library" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e3f1883", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "model_id = \"sentence-transformers/all-MiniLM-L6-v2\"" + ] + }, + { + "cell_type": "markdown", + "id": "436d115f", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Option 1 - TensorRT Model with Amazon SageMaker Studio" + ] + }, + { + "cell_type": "markdown", + "id": "629bfade-8f14-44b6-9be7-88a2fdb84ba9", + "metadata": {}, + "source": [ + "> **WARNING**: The next cell will only work if you have first created a custom Studio image, described in Step 2 of this repository's README. Change the `RUNNING_IN_STUDIO` to `True` if this is the case." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6df2bf7d-a595-46b1-8d1a-43ab946bc858", + "metadata": {}, + "outputs": [], + "source": [ + "RUNNING_IN_STUDIO = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c0eb376", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "if RUNNING_IN_STUDIO:\n", + " !/bin/bash ./workspace/generate_model_trt.sh $model_id && rm -rf ensemble_hf/bert-trt/1 && mkdir -p ensemble_hf/bert-trt/1 && cp ./model.plan ensemble_hf/bert-trt/1/model.plan && rm -rf ./model.plan ./conversion_bs16_dy.txt ./model.onnx" + ] + }, + { + "cell_type": "markdown", + "id": "b34cf8a7", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Option 2 - TensorRT Model with SageMaker Notebook Instances \n", + "\n", + "To make sure we use TensorRT version and dependencies that are compatible with the ones in our Triton container, we compile the model using the corresponding version of NVIDIA's PyTorch container image.\n", + "\n", + "If you take a look at the python files within the `workspace` folder, you will see that we are first convert the model into ONNX format, specifying dynamic axis indexes so that inputs with a different batch size and sequence length can be passed to the model. TensorRT will treat other input dimensions as fixed, and optimize for those.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f1e86f3", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "! docker run --gpus=all --rm -it -v `pwd`/workspace:/workspace nvcr.io/nvidia/pytorch:21.08-py3 /bin/bash generate_model_trt.sh $model_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd0bf459", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "! rm -rf ensemble_hf/bert-trt && mkdir -p ensemble_hf/bert-trt/1 && cp workspace/model.plan ensemble_hf/bert-trt/1/model.plan && rm -rf workspace/model.onnx workspace/core*" + ] + }, + { + "cell_type": "markdown", + "id": "fe095659", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Explore the output logs of the compilation process; at the very end, we get a section headlined \"=== Performance summary ===\" which gives us a series of metrics on the obtained engine's performance (latency, throughput, etc...). " + ] + }, + { + "cell_type": "markdown", + "id": "8455c0ca", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Part 3 - Run Local Triton Inference Server" + ] + }, + { + "cell_type": "markdown", + "id": "ed0cd1c9-c8ab-4f75-a368-a719dd165c04", + "metadata": {}, + "source": [ + "> **WARNING**: The cells under part 3 will only work if run within a SageMaker Notebook Instance!\n" + ] + }, + { + "cell_type": "markdown", + "id": "9c22df98", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "\n", + "\n", + "The following cells run the Triton Inference Server container in the background and load all the models within the folder `/ensemble_hf`. The docker won't fail if one or more of the model fails because of `--exit-on-error=false`, which is useful for iterative code and model repository building. Remove `-d` to see the logs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35cac085", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "!sudo docker system prune -f" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6965857", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "!docker run --gpus=all -d --shm-size=4G --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/ensemble_hf:/model_repository nvcr.io/nvidia/tritonserver:21.08-py3 tritonserver --model-repository=/model_repository --exit-on-error=false --strict-model-config=false\n", + "time.sleep(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9950b9c", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "CONTAINER_ID=!docker container ls -q\n", + "FIRST_CONTAINER_ID = CONTAINER_ID[0]" + ] + }, + { + "cell_type": "markdown", + "id": "3f903432-6e87-449c-84ea-4c3ed2fa445b", + "metadata": {}, + "source": [ + "Uncomment the next cell and run it to view the container logs and understand Triton model loading." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ca9f7dc", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# !docker logs $FIRST_CONTAINER_ID -f" + ] + }, + { + "cell_type": "markdown", + "id": "f5946837", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Test TensorRT model by invoking the local Triton Server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b5775f1", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Start a local Triton client\n", + "try:\n", + " triton_client = http_client.InferenceServerClient(url=\"localhost:8000\", verbose=True)\n", + "except Exception as e:\n", + " print(\"context creation failed: \" + str(e))\n", + " sys.exit()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36b46556", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Create inputs to send to Triton\n", + "model_name = \"ensemble\"\n", + "\n", + "text_inputs = [\"Sentence 1\", \"Sentence 2\"]\n", + "\n", + "# Text is passed to Trtion as BYTES\n", + "inputs = []\n", + "inputs.append(http_client.InferInput(\"INPUT0\", [len(text_inputs), 1], \"BYTES\"))\n", + "\n", + "# We need to structure batch inputs as such\n", + "batch_request = [[text_inputs[i]] for i in range(len(text_inputs))]\n", + "input0_real = np.array(batch_request, dtype=np.object_)\n", + "\n", + "inputs[0].set_data_from_numpy(input0_real, binary_data=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b47da0cb", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "outputs = []\n", + "\n", + "outputs.append(http_client.InferRequestedOutput(\"finaloutput\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2588e261", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a405f8ee", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "outputs_data = results.as_numpy(\"finaloutput\")\n", + "\n", + "for idx, output in enumerate(outputs_data):\n", + " print(text_inputs[idx])\n", + " print(output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d88a95c", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Use this to stop the container that was started in detached mode\n", + "!docker kill $FIRST_CONTAINER_ID" + ] + }, + { + "cell_type": "markdown", + "id": "dc75efff", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "id": "ef0e365f", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Part 4 - Deploy Triton to SageMaker Real-Time Endpoint" + ] + }, + { + "cell_type": "markdown", + "id": "706db9cb", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Deploy with SageMaker Triton container" + ] + }, + { + "cell_type": "markdown", + "id": "bdceed91-9fbc-4ea3-ab9e-0e2599ba7281", + "metadata": {}, + "source": [ + "First we get the URI for the Sagemaker Triton container image that matches the one we used for TensorRT model compilation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0dd85b6", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "if region not in account_id_map.keys():\n", + " raise (\"UNSUPPORTED REGION\")\n", + "\n", + "base = \"amazonaws.com.cn\" if region.startswith(\"cn-\") else \"amazonaws.com\"\n", + "\n", + "triton_image_uri = \"{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:21.08-py3\".format(\n", + " account_id=account_id_map[region], region=region, base=base\n", + ")\n", + "\n", + "triton_image_uri" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "081d1204", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "print(sagemaker_session.default_bucket())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0adf2f45", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "ensemble_prefix = \"mme_gpu_tests/ensemble-singlemodel\"\n", + "!tar -C ensemble_hf/ -czf ensemble-sentencetrans.tar.gz .\n", + "model_uri_tf = sagemaker_session.upload_data(\n", + " path=\"ensemble-sentencetrans.tar.gz\", key_prefix=ensemble_prefix\n", + ")\n", + "\n", + "print(\"S3 model uri: {}\".format(model_uri_tf))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da4eb8ef", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Important to define what which one of the models loaded by Triton is the default to be served by SM\n", + "# That is, SAGEMAKER_TRITON_DEFAULT_MODEL_NAME\n", + "container_model = {\n", + " \"Image\": triton_image_uri,\n", + " \"ModelDataUrl\": model_uri_tf,\n", + " \"Mode\": \"SingleModel\",\n", + " \"Environment\": {\"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME\": \"ensemble\"},\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "16b0ce98-b921-476d-8921-5e3ffb5cc7d4", + "metadata": {}, + "source": [ + "Register the model with Sagemaker." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e15ae6ac", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "sm_model_name = \"triton-sentence-ensemble\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", + "\n", + "create_model_response = sm_client.create_model(\n", + " ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container_model\n", + ")\n", + "\n", + "print(\"Model Arn: \" + create_model_response[\"ModelArn\"])" + ] + }, + { + "cell_type": "markdown", + "id": "695e71ab-c842-4001-8674-0d5c67ffd79f", + "metadata": {}, + "source": [ + "Create an endpoint configuration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6096778", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "endpoint_config_name = \"triton-sentence-ensemble\" + time.strftime(\n", + " \"%Y-%m-%d-%H-%M-%S\", time.gmtime()\n", + ")\n", + "\n", + "create_endpoint_config_response = sm_client.create_endpoint_config(\n", + " EndpointConfigName=endpoint_config_name,\n", + " ProductionVariants=[\n", + " {\n", + " \"InstanceType\": \"ml.g4dn.xlarge\",\n", + " \"InitialVariantWeight\": 1,\n", + " \"InitialInstanceCount\": 1,\n", + " \"ModelName\": sm_model_name,\n", + " \"VariantName\": \"AllTraffic\",\n", + " }\n", + " ],\n", + ")\n", + "\n", + "print(\"Endpoint Config Arn: \" + create_endpoint_config_response[\"EndpointConfigArn\"])" + ] + }, + { + "cell_type": "markdown", + "id": "c4893f46-dd98-4e60-9490-36026f128446", + "metadata": {}, + "source": [ + "Deploy the endpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9932ae0", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "endpoint_name = \"triton-sentence-ensemble\" + time.strftime(\"%Y-%m-%d-%H-%M-%S\", time.gmtime())\n", + "\n", + "create_endpoint_response = sm_client.create_endpoint(\n", + " EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name\n", + ")\n", + "\n", + "print(\"Endpoint Arn: \" + create_endpoint_response[\"EndpointArn\"])" + ] + }, + { + "cell_type": "markdown", + "id": "a3ec6a48-5fc5-4495-9b1d-10f94ef95277", + "metadata": {}, + "source": [ + "Wait for the endpoint to be up and running." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f257b3e7", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n", + "status = resp[\"EndpointStatus\"]\n", + "print(\"Status: \" + status)\n", + "\n", + "while status == \"Creating\":\n", + " time.sleep(60)\n", + " resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n", + " status = resp[\"EndpointStatus\"]\n", + " print(\"Status: \" + status)\n", + "\n", + "print(\"Arn: \" + resp[\"EndpointArn\"])\n", + "print(\"Status: \" + status)" + ] + }, + { + "cell_type": "markdown", + "id": "7aee2ad4", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "***" + ] + }, + { + "cell_type": "markdown", + "id": "59448028", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Test the SageMaker Triton Endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d55e6ea4", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "text_inputs = [\"Sentence 1\", \"Sentence 2\"]\n", + "\n", + "inputs = []\n", + "inputs.append(http_client.InferInput(\"INPUT0\", [len(text_inputs), 1], \"BYTES\"))\n", + "\n", + "batch_request = [[text_inputs[i]] for i in range(len(text_inputs))]\n", + "\n", + "input0_real = np.array(batch_request, dtype=np.object_)\n", + "\n", + "inputs[0].set_data_from_numpy(input0_real, binary_data=False)\n", + "\n", + "len(input0_real)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71acf686", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "outputs = []\n", + "\n", + "outputs.append(http_client.InferRequestedOutput(\"finaloutput\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8cc9a386", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "request_body, header_length = http_client.InferenceServerClient.generate_request_body(\n", + " inputs, outputs=outputs\n", + ")\n", + "\n", + "print(request_body)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2782361a", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "response = sm_runtime_client.invoke_endpoint(\n", + " EndpointName=endpoint_name,\n", + " ContentType=\"application/vnd.sagemaker-triton.binary+json;json-header-size={}\".format(\n", + " header_length\n", + " ),\n", + " Body=request_body,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65724cd1", + "metadata": { + "collapsed": false, + "jupyter": { + "outputs_hidden": false + }, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "## json.loads fails\n", + "# a = json.loads(response[\"Body\"].read().decode(\"utf8\"))\n", + "\n", + "header_length_prefix = \"application/vnd.sagemaker-triton.binary+json;json-header-size=\"\n", + "header_length_str = response[\"ContentType\"][len(header_length_prefix) :]\n", + "\n", + "# Read response body\n", + "result = http_client.InferenceServerClient.parse_response_body(\n", + " response[\"Body\"].read(), header_length=int(header_length_str)\n", + ")\n", + "\n", + "outputs_data = result.as_numpy(\"finaloutput\")\n", + "\n", + "for idx, output in enumerate(outputs_data):\n", + " print(text_inputs[idx])\n", + " print(output)" + ] + } + ], + "metadata": { + "instance_type": "ml.g4dn.xlarge", + "kernelspec": { + "display_name": "conda_pytorch_p38", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/workspace/generate_model_trt.sh b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/workspace/generate_model_trt.sh new file mode 100755 index 0000000000..8c68d40c60 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/workspace/generate_model_trt.sh @@ -0,0 +1,14 @@ +#!/bin/bash +MODEL_NAME=$1 +python -m pip install transformers==4.9.1 +python onnx_exporter.py --model $MODEL_NAME + +trtexec \ + --onnx=model.onnx \ + --saveEngine=model.plan \ + --minShapes=token_ids:1x128,attn_mask:1x128 \ + --optShapes=token_ids:16x128,attn_mask:16x128 \ + --maxShapes=token_ids:32x128,attn_mask:32x128 \ + --verbose \ + --workspace=14000 \ +| tee conversion.txt \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/workspace/onnx_exporter.py b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/workspace/onnx_exporter.py new file mode 100644 index 0000000000..b7e0907e52 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/examples/workspace/onnx_exporter.py @@ -0,0 +1,30 @@ +import torch +from transformers import AutoModel +import argparse +import os + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--save", default="model.onnx") + parser.add_argument("--model", required=True) + + args = parser.parse_args() + + model = AutoModel.from_pretrained(args.model, torchscript=True) + + bs = 1 + seq_len = 128 + dummy_inputs = (torch.randint(1000, (bs, seq_len),dtype=torch.int), torch.zeros(bs, seq_len, dtype=torch.int)) + + torch.onnx.export( + model, + dummy_inputs, + args.save, + export_params=True, + opset_version=10, + input_names=["token_ids", "attn_mask"], + output_names=["output"], + dynamic_axes={"token_ids": [0, 1], "attn_mask": [0, 1], "output": [0]}, + ) + + print("Saved {}".format(args.save)) diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/images/triton-ensemble.png b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/images/triton-ensemble.png new file mode 100644 index 0000000000..32f5c1a6aa Binary files /dev/null and b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/images/triton-ensemble.png differ diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/README.md b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/README.md new file mode 100644 index 0000000000..76083a1c26 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/README.md @@ -0,0 +1,67 @@ +## build_image.sh + +This script allows you to create a custom docker image and push on ECR + +Parameters: +* IMAGE_NAME: *Mandatory* - Name of the image you want to build +* REGISTRY_NAME: *Mandatory* - Name of the ECR repository you want to use for pushing the image +* IMAGE_TAG: *Mandatory* - Tag to apply to the ECR image +* DOCKER_FILE: *Mandatory* - Dockerfile to build +* PLATFORM: *Optional* - Target architecture chip where the image is executed +``` +./build_image.sh +``` + +Examples: + +``` +./build_image.sh image_tensorrt nvidia-tensorrt-21.08 latest Dockerfile linux/amd64 +``` + +## create_studio_image.sh + +This script allows you to create the Amazon SageMaker Studio Image + +Parameters: +* IMAGE_NAME: *Mandatory* - Name of the folder for the image +* REGISTRY_NAME: *Mandatory* - Name of the ECR repository where image is stored +* SM_IMAGE_NAME: *Mandatory* - Name of the image you want to create +* ROLE_ARN: *Mandatory* - Used to get ECR image information when and Image version is created + +``` +./create_studio_image.sh +``` + +Examples: + +``` +./create_studio_image.sh image_tensorrt nvidia-tensorrt-21.08 nvidia-tensorrt-21-08 arn:aws:iam:::role/mlops-sagemaker-execution-role +``` + +## update_studio_image.sh + +This script allows you to create the Amazon SageMaker Studio Image + +Parameters: +* IMAGE_NAME: *Mandatory* - Name of the folder for the image +* REGISTRY_NAME: *Mandatory* - Name of the ECR repository where image is stored +* SM_IMAGE_NAME: *Mandatory* - Name of the image you want to create +* ROLE_ARN: *Mandatory* - Used to get ECR image information when and Image version is created + +``` +./update_studio_image.sh +``` + +Examples: + +``` +./update_studio_image.sh image_tensorrt nvidia-tensorrt-21.08 nvidia-tensorrt-21-08 arn:aws:iam:::role/mlops-sagemaker-execution-role +``` + +## update_studio_domain.sh + +This script allows you to create the Amazon SageMaker Studio Image + +``` +./update_studio_domain.sh +``` \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/build_image.sh b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/build_image.sh new file mode 100755 index 0000000000..e532b7bc21 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/build_image.sh @@ -0,0 +1,48 @@ +#!/bin/sh + +# The name of our algorithm +repo=$1 +registry_name=$2 +image_tag=$3 +docker_file=$4 +platforms=$5 + +echo "[INFO]: registry_name=${registry_name}" +echo "[INFO]: image_tag=${image_tag}" +echo "[INFO]: docker_file=${docker_file}" +echo "[INFO]: platforms=${platforms}" + +cd $repo + +account=$(aws sts get-caller-identity --query Account --output text) + +# Get the region defined in the current configuration (default to us-west-2 if none defined) +region=$(aws configure get region) + +echo "[INFO]: Region ${region}" + +fullname="${account}.dkr.ecr.${region}.amazonaws.com/${registry_name}:${image_tag}" + +echo "[INFO]: Image name: ${fullname}" + +# If the repository doesn't exist in ECR, create it. + +aws ecr describe-repositories --repository-names "${registry_name}" > /dev/null 2>&1 + +aws ecr create-repository --repository-name "${registry_name}" > /dev/null + +## If you are extending Amazon SageMaker Images, you need to login to the account +# Get the login command from ECR and execute it directly +password=$(aws ecr --region ${region} get-login-password) + +docker login -u AWS -p ${password} "${account}.dkr.ecr.${region}.amazonaws.com" + +if [ -z ${platforms} ] +then + docker build -t ${fullname} -f ${docker_file} . +else + echo "Provided platform = ${platforms}" + docker build -t ${fullname} -f ${docker_file} . --platform=${platforms} +fi + +docker push ${fullname} \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/create_studio_image.sh b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/create_studio_image.sh new file mode 100755 index 0000000000..80b868b37e --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/create_studio_image.sh @@ -0,0 +1,54 @@ +#!/bin/sh + +IMAGE_NAME=$1 +REGISTRY_NAME=$2 +SM_IMAGE_NAME=$3 +ROLE_ARN=$4 + +if [ -z ${IMAGE_NAME} ] +then + echo "[INFO]: IMAGE_NAME not passed" + exit 1 +fi + +if [ -z ${REGISTRY_NAME} ] +then + echo "[INFO]: REGISTRY_NAME not passed" + exit 1 +fi + +if [ -z ${SM_IMAGE_NAME} ] +then + echo "[INFO]: SM_IMAGE_NAME not passed" + exit 1 +fi + +if [ -z ${ROLE_ARN} ] +then + echo "[INFO]: ROLE_ARN not passed" + exit 1 +fi + +echo "[INFO]: IMAGE_NAME=${IMAGE_NAME}" +echo "[INFO]: REGISTRY_NAME=${REGISTRY_NAME}" +echo "[INFO]: SM_IMAGE_NAME=${SM_IMAGE_NAME}" +echo "[INFO]: ROLE_ARN=${ROLE_ARN}" + +aws sagemaker create-image \ + --image-name ${SM_IMAGE_NAME} \ + --role-arn ${ROLE_ARN} \ + || exit 1 + +account=$(aws sts get-caller-identity --query Account --output text) +region=$(aws configure get region) + +aws sagemaker create-image-version \ + --image-name ${SM_IMAGE_NAME} \ + --base-image "${account}.dkr.ecr.${region}.amazonaws.com/${REGISTRY_NAME}:latest" \ + || exit 1 + +aws sagemaker delete-app-image-config --app-image-config-name ${SM_IMAGE_NAME}-config + +aws sagemaker describe-image-version --image-name ${SM_IMAGE_NAME} + +aws sagemaker create-app-image-config --cli-input-json file://${IMAGE_NAME}/app-image-config.json \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/image_tensorrt/Dockerfile b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/image_tensorrt/Dockerfile new file mode 100644 index 0000000000..eed1a58ccc --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/image_tensorrt/Dockerfile @@ -0,0 +1,5 @@ +FROM nvcr.io/nvidia/pytorch:21.08-py3 + +RUN pip install sagemaker transformers==4.9.1 tritonclient[all] + +RUN pip install ipykernel && python -m ipykernel install --sys-prefix \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/image_tensorrt/app-image-config.json b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/image_tensorrt/app-image-config.json new file mode 100644 index 0000000000..afe58441a9 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/image_tensorrt/app-image-config.json @@ -0,0 +1,16 @@ +{ + "AppImageConfigName": "nvidia-tensorrt-21-08-config", + "KernelGatewayImageConfig": { + "KernelSpecs": [ + { + "Name": "python3", + "DisplayName": "Python3" + } + ], + "FileSystemConfig": { + "MountPath": "/root", + "DefaultUid": 0, + "DefaultGid": 0 + } + } +} \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/studio-domain-config.json b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/studio-domain-config.json new file mode 100644 index 0000000000..73df5f3928 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/studio-domain-config.json @@ -0,0 +1,13 @@ +{ + "DomainId": "$DOMAIN_ID", + "DefaultUserSettings": { + "KernelGatewayAppSettings": { + "CustomImages": [ + { + "ImageName": "nvidia-tensorrt-21-08", + "AppImageConfigName": "nvidia-tensorrt-21-08-config" + } + ] + } + } +} \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/update_studio_domain.sh b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/update_studio_domain.sh new file mode 100755 index 0000000000..dde7dbe5a2 --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/update_studio_domain.sh @@ -0,0 +1,5 @@ +#!/bin/sh + +echo "studio-domain-config.json" + +aws sagemaker update-domain --cli-input-json file://studio-domain-config.json \ No newline at end of file diff --git a/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/update_studio_image.sh b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/update_studio_image.sh new file mode 100755 index 0000000000..36d860b07c --- /dev/null +++ b/inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/studio-image/update_studio_image.sh @@ -0,0 +1,49 @@ +#!/bin/sh + +IMAGE_NAME=$1 +REGISTRY_NAME=$2 +SM_IMAGE_NAME=$3 +ROLE_ARN=$4 + +if [ -z ${IMAGE_NAME} ] +then + echo "[INFO]: IMAGE_NAME not passed" + exit 1 +fi + +if [ -z ${REGISTRY_NAME} ] +then + echo "[INFO]: REGISTRY_NAME not passed" + exit 1 +fi + +if [ -z ${SM_IMAGE_NAME} ] +then + echo "[INFO]: SM_IMAGE_NAME not passed" + exit 1 +fi + +if [ -z ${ROLE_ARN} ] +then + echo "[INFO]: ROLE_ARN not passed" + exit 1 +fi + +echo "[INFO]: IMAGE_NAME=${IMAGE_NAME}" +echo "[INFO]: REGISTRY_NAME=${REGISTRY_NAME}" +echo "[INFO]: SM_IMAGE_NAME=${SM_IMAGE_NAME}" +echo "[INFO]: ROLE_ARN=${ROLE_ARN}" + +account=$(aws sts get-caller-identity --query Account --output text) +region=$(aws configure get region) + +aws sagemaker create-image-version \ + --image-name ${SM_IMAGE_NAME} \ + --base-image "${account}.dkr.ecr.${region}.amazonaws.com/${REGISTRY_NAME}:latest" \ + || exit 1 + +aws sagemaker delete-app-image-config --app-image-config-name ${SM_IMAGE_NAME}-config + +aws sagemaker describe-image-version --image-name ${SM_IMAGE_NAME} + +aws sagemaker create-app-image-config --cli-input-json file://${IMAGE_NAME}/app-image-config.json \ No newline at end of file