-
Notifications
You must be signed in to change notification settings - Fork 6.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added sentence transformers example with TensorRT and Triton Ensemble #3615
Changes from 1 commit
fc2a765
1d3f798
071f233
94a2978
238bc09
806fca7
7578f71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# NVIDIA Triton Inference Server on SageMaker - Hugging Face Sentence Transformers | ||
|
||
## Introduction | ||
|
||
[HuggingFace Sentence Transformers](https://huggingface.co/sentence-transformers) is a Machine Learning (ML) framework and set of pre-trained models to | ||
extract embeddings from sentence, text, and image. The models in this group can also be used with the default methods exposed through the [Transformers](https://www.google.com/search?q=transofrmers+githbu&rlz=1C5GCEM_enES937ES938&oq=transofrmers+githbu&aqs=chrome..69i57.3022j0j7&sourceid=chrome&ie=UTF-8) library. | ||
|
||
[NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/) is a high-performance ML model server, which enables the deployment of ML models in an easy, scalable, and cost-effective way. It also exposes many easy-to-use optimization features to make the most of the underlying hardware, in particular NVIDIA GPU's. | ||
|
||
In this example, we walk through how you can: | ||
* Create an Amazon SageMaker Studio image based on the official [NVIDIA PyTorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) image, which includes the necessary dependencies to optimize your model | ||
* Optimize a pre-trained HuggingFace Sentence Transformers model with NVIDIA TensorRT to enable high-performance inference | ||
* Create a Triton Model Ensemble, which will allow you to run in sequence a pre-processing step (input tokenization), model inference and post-processing, where sentence embeddings are computed from the raw token embeddings | ||
|
||
This example is meant to serve as a basis for use-cases in which you need to run your own code before and/or after your model, allowing you to optimize the bulk of the computation (the model) using tools such as TensorRT. | ||
|
||
<img src="images/triton-ensemble.png" alt="Triton Model Ensamble" /> | ||
|
||
#### ! Important: The example provided can be tested also by using Amazon SageMaker Notebook Instances | ||
|
||
### Prerequisites | ||
|
||
1. Required NVIDIA NGC Account. Follow the instruction https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#registering-activating-ngc-account | ||
|
||
## Step 1: Clone this repository | ||
|
||
## Step 2: Build Studio image | ||
|
||
In this example, we provide a [Dokerfile](./studio-image/image_tensorrt/Dockerfile) example to build a custom image for SageMaker Studio. | ||
|
||
To build the image, push it and make it available in your Amazon SageMaker Studio environment, edit [sagemaker-studio-config](./studio-image/studio-domain-config.json) by replacing `$DOMAIN_ID` with your Studio domain ID. | ||
|
||
We also provide automation scripts in order to [build and push](./studio-image/build_image.sh) your docker image to an ECR repository | ||
and [create](./studio-image/create_studio_image.sh) or [update](./studio-image/update_studio_image.sh) an Amazon SageMaker Image. Please follow the instructions in the [README](./studio-image/README.md) for additional info on the usage of this script. | ||
|
||
## Step 3: Compile model, create an Amazon SageMaker Real-Time Endpoint with NVIDIA Triton Inference Server | ||
|
||
Clone this repository into your Amazon SageMaker Studio environment and execute the cells in the [notebook](./examples/triton_sentence_embeddings.ipynb) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
name: "bert-trt" | ||
platform: "tensorrt_plan" | ||
max_batch_size: 16 | ||
input [ | ||
{ | ||
name: "token_ids" | ||
data_type: TYPE_INT32 | ||
dims: [128] | ||
}, | ||
{ | ||
name: "attn_mask" | ||
data_type: TYPE_INT32 | ||
dims: [128] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "output" | ||
data_type: TYPE_FP32 | ||
dims: [128, 384] | ||
}, | ||
{ | ||
name: "854" | ||
data_type: TYPE_FP32 | ||
dims: [384] | ||
} | ||
] | ||
instance_group [ | ||
{ | ||
kind: KIND_GPU | ||
} | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Do not delete me! |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
name: "ensemble" | ||
platform: "ensemble" | ||
max_batch_size: 16 | ||
input [ | ||
{ | ||
name: "INPUT0" | ||
data_type: TYPE_STRING | ||
dims: [ 1 ] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "finaloutput" | ||
data_type: TYPE_FP32 | ||
dims: [384] | ||
} | ||
] | ||
ensemble_scheduling { | ||
step [ | ||
{ | ||
model_name: "preprocess" | ||
model_version: -1 | ||
input_map { | ||
key: "INPUT0" | ||
value: "INPUT0" | ||
} | ||
output_map { | ||
key: "OUTPUT0" | ||
value: "token_ids" | ||
} | ||
output_map { | ||
key: "OUTPUT1" | ||
value: "attn_mask" | ||
} | ||
}, | ||
{ | ||
model_name: "bert-trt" | ||
model_version: -1 | ||
input_map { | ||
key: "token_ids" | ||
value: "token_ids" | ||
} | ||
input_map { | ||
key: "attn_mask" | ||
value: "attn_mask" | ||
} | ||
output_map { | ||
key: "output" | ||
value: "output" | ||
} | ||
}, | ||
{ | ||
model_name: "postprocess" | ||
model_version: -1 | ||
input_map { | ||
key: "TOKEN_EMBEDS_POST" | ||
value: "output" | ||
} | ||
input_map { | ||
key: "ATTENTION_POST" | ||
value: "attn_mask" | ||
} | ||
output_map { | ||
key: "SENT_EMBED" | ||
value: "finaloutput" | ||
} | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import json | ||
import logging | ||
import numpy as np | ||
import subprocess | ||
import sys | ||
import os | ||
|
||
import triton_python_backend_utils as pb_utils | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class TritonPythonModel: | ||
"""This model loops through different dtypes to make sure that | ||
serialize_byte_tensor works correctly in the Python backend. | ||
""" | ||
|
||
def __mean_pooling(self, token_embeddings, attention_mask): | ||
logger.info("token_embeddings: {}".format(token_embeddings)) | ||
logger.info("attention_mask: {}".format(attention_mask)) | ||
|
||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
|
||
def initialize(self, args): | ||
self.model_dir = args['model_repository'] | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt']) | ||
global torch | ||
import torch | ||
|
||
self.device_id = args['model_instance_device_id'] | ||
self.model_config = model_config = json.loads(args['model_config']) | ||
self.device = torch.device(f'cuda:{self.device_id}') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
output0_config = pb_utils.get_output_config_by_name( | ||
model_config, "SENT_EMBED") | ||
|
||
self.output0_dtype = pb_utils.triton_string_to_numpy( | ||
output0_config["data_type"]) | ||
|
||
def execute(self, requests): | ||
|
||
responses = [] | ||
|
||
for request in requests: | ||
tok_embeds = pb_utils.get_input_tensor_by_name(request, "TOKEN_EMBEDS_POST") | ||
attn_mask = pb_utils.get_input_tensor_by_name(request, "ATTENTION_POST") | ||
|
||
tok_embeds = tok_embeds.as_numpy() | ||
|
||
logger.info("tok_embeds: {}".format(tok_embeds)) | ||
logger.info("tok_embeds shape: {}".format(tok_embeds.shape)) | ||
|
||
tok_embeds = torch.tensor(tok_embeds,device=self.device) | ||
|
||
logger.info("tok_embeds_tensor: {}".format(tok_embeds)) | ||
|
||
attn_mask = attn_mask.as_numpy() | ||
|
||
logger.info("attn_mask: {}".format(attn_mask)) | ||
logger.info("attn_mask shape: {}".format(attn_mask.shape)) | ||
|
||
attn_mask = torch.tensor(attn_mask,device=self.device) | ||
|
||
logger.info("attn_mask_tensor: {}".format(attn_mask)) | ||
|
||
sentence_embeddings = self.__mean_pooling(tok_embeds, attn_mask) | ||
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | ||
|
||
out_0 = np.array(sentence_embeddings.cpu(),dtype=self.output0_dtype) | ||
logger.info("out_0: {}".format(out_0)) | ||
|
||
out_tensor_0 = pb_utils.Tensor("SENT_EMBED", out_0) | ||
logger.info("out_tensor_0: {}".format(out_tensor_0)) | ||
|
||
responses.append(pb_utils.InferenceResponse([out_tensor_0])) | ||
|
||
return responses |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: "postprocess" | ||
backend: "python" | ||
max_batch_size: 16 | ||
|
||
input [ | ||
{ | ||
name: "TOKEN_EMBEDS_POST" | ||
data_type: TYPE_FP32 | ||
dims: [128, 384] | ||
}, | ||
{ | ||
name: "ATTENTION_POST" | ||
data_type: TYPE_INT32 | ||
dims: [128] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "SENT_EMBED" | ||
data_type: TYPE_FP32 | ||
dims: [ 384 ] | ||
} | ||
] | ||
|
||
instance_group [{ kind: KIND_GPU }] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
torch |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import json | ||
import logging | ||
import numpy as np | ||
import subprocess | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Consider possible security implications associated with subprocess module. https://bandit.readthedocs.io/en/latest/blacklists/blacklist_imports.html#b404-import-subprocess |
||
import sys | ||
|
||
import triton_python_backend_utils as pb_utils | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class TritonPythonModel: | ||
"""This model loops through different dtypes to make sure that | ||
serialize_byte_tensor works correctly in the Python backend. | ||
""" | ||
|
||
def initialize(self, args): | ||
self.model_dir = args['model_repository'] | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt']) | ||
global transformers | ||
import transformers | ||
|
||
self.tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | ||
self.model_config = model_config = json.loads(args['model_config']) | ||
|
||
output0_config = pb_utils.get_output_config_by_name( | ||
model_config, "OUTPUT0") | ||
output1_config = pb_utils.get_output_config_by_name( | ||
model_config, "OUTPUT1") | ||
|
||
self.output0_dtype = pb_utils.triton_string_to_numpy( | ||
output0_config['data_type']) | ||
self.output1_dtype = pb_utils.triton_string_to_numpy( | ||
output0_config['data_type']) | ||
|
||
def execute(self, requests): | ||
|
||
file = open("logs.txt", "w") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. Problem Fix More info |
||
|
||
responses = [] | ||
for request in requests: | ||
logger.info("Request: {}".format(request)) | ||
|
||
in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") | ||
in_0 = in_0.as_numpy() | ||
|
||
logger.info("in_0: {}".format(in_0)) | ||
|
||
tok_batch = [] | ||
|
||
for i in range(in_0.shape[0]): | ||
decoded_object = in_0[i,0].decode() | ||
|
||
logger.info("decoded_object: {}".format(decoded_object)) | ||
|
||
tok_batch.append(decoded_object) | ||
|
||
logger.info("tok_batch: {}".format(tok_batch)) | ||
|
||
tok_sent = self.tokenizer(tok_batch, | ||
padding='max_length', | ||
max_length=128, | ||
) | ||
|
||
|
||
logger.info("Tokens: {}".format(tok_sent)) | ||
|
||
out_0 = np.array(tok_sent['input_ids'],dtype=self.output0_dtype) | ||
out_1 = np.array(tok_sent['attention_mask'],dtype=self.output1_dtype) | ||
out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0) | ||
out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1) | ||
|
||
responses.append(pb_utils.InferenceResponse([out_tensor_0,out_tensor_1])) | ||
return responses |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: "preprocess" | ||
backend: "python" | ||
max_batch_size: 16 | ||
|
||
input [ | ||
{ | ||
name: "INPUT0" | ||
data_type: TYPE_STRING | ||
dims: [ 1 ] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "OUTPUT0" | ||
data_type: TYPE_INT32 | ||
dims: [ 128 ] | ||
}, | ||
{ | ||
name: "OUTPUT1" | ||
data_type: TYPE_INT32 | ||
dims: [ 128 ] | ||
} | ||
] | ||
|
||
instance_group [{ kind: KIND_CPU }] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.
Consider possible security implications associated with subprocess module. https://bandit.readthedocs.io/en/latest/blacklists/blacklist_imports.html#b404-import-subprocess