-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added sentence transformers example with TensorRT and Triton Ensemble (…
…#3615) * Added sentence transformers example with TensorRT and Triton Ensemble * Notebook changes to pass CI build * Grammar fixes and installing torch for CI build * Installing torch to pass CI build Co-authored-by: atqy <[email protected]>
- Loading branch information
1 parent
f1fe550
commit 4c834f0
Showing
22 changed files
with
1,665 additions
and
0 deletions.
There are no files selected for viewing
38 changes: 38 additions & 0 deletions
38
inference/nlp/realtime/huggingface/sentence-transformers-triton-ensemble/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# NVIDIA Triton Inference Server on SageMaker - Hugging Face Sentence Transformers | ||
|
||
## Introduction | ||
|
||
[HuggingFace Sentence Transformers](https://huggingface.co/sentence-transformers) is a Machine Learning (ML) framework and set of pre-trained models to | ||
extract embeddings from sentence, text, and image. The models in this group can also be used with the default methods exposed through the [Transformers](https://www.google.com/search?q=transofrmers+githbu&rlz=1C5GCEM_enES937ES938&oq=transofrmers+githbu&aqs=chrome..69i57.3022j0j7&sourceid=chrome&ie=UTF-8) library. | ||
|
||
[NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/) is a high-performance ML model server, which enables the deployment of ML models in an easy, scalable, and cost-effective way. It also exposes many easy-to-use optimization features to make the most of the underlying hardware, in particular NVIDIA GPU's. | ||
|
||
In this example, we walk through how you can: | ||
* Create an Amazon SageMaker Studio image based on the official [NVIDIA PyTorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) image, which includes the necessary dependencies to optimize your model | ||
* Optimize a pre-trained HuggingFace Sentence Transformers model with NVIDIA TensorRT to enable high-performance inference | ||
* Create a Triton Model Ensemble, which will allow you to run in sequence a pre-processing step (input tokenization), model inference and post-processing, where sentence embeddings are computed from the raw token embeddings | ||
|
||
This example is meant to serve as a basis for use-cases in which you need to run your own code before and/or after your model, allowing you to optimize the bulk of the computation (the model) using tools such as TensorRT. | ||
|
||
<img src="images/triton-ensemble.png" alt="Triton Model Ensamble" /> | ||
|
||
#### ! Important: The example provided can be tested also by using Amazon SageMaker Notebook Instances | ||
|
||
### Prerequisites | ||
|
||
1. Required NVIDIA NGC Account. Follow the instruction https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#registering-activating-ngc-account | ||
|
||
## Step 1: Clone this repository | ||
|
||
## Step 2: Build Studio image | ||
|
||
In this example, we provide a [Dokerfile](./studio-image/image_tensorrt/Dockerfile) example to build a custom image for SageMaker Studio. | ||
|
||
To build the image, push it and make it available in your Amazon SageMaker Studio environment, edit [sagemaker-studio-config](./studio-image/studio-domain-config.json) by replacing `$DOMAIN_ID` with your Studio domain ID. | ||
|
||
We also provide automation scripts in order to [build and push](./studio-image/build_image.sh) your docker image to an ECR repository | ||
and [create](./studio-image/create_studio_image.sh) or [update](./studio-image/update_studio_image.sh) an Amazon SageMaker Image. Please follow the instructions in the [README](./studio-image/README.md) for additional info on the usage of this script. | ||
|
||
## Step 3: Compile model, create an Amazon SageMaker Real-Time Endpoint with NVIDIA Triton Inference Server | ||
|
||
Clone this repository into your Amazon SageMaker Studio environment and execute the cells in the [notebook](./examples/triton_sentence_embeddings.ipynb) |
32 changes: 32 additions & 0 deletions
32
...gingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/bert-trt/config.pbtxt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
name: "bert-trt" | ||
platform: "tensorrt_plan" | ||
max_batch_size: 16 | ||
input [ | ||
{ | ||
name: "token_ids" | ||
data_type: TYPE_INT32 | ||
dims: [128] | ||
}, | ||
{ | ||
name: "attn_mask" | ||
data_type: TYPE_INT32 | ||
dims: [128] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "output" | ||
data_type: TYPE_FP32 | ||
dims: [128, 384] | ||
}, | ||
{ | ||
name: "854" | ||
data_type: TYPE_FP32 | ||
dims: [384] | ||
} | ||
] | ||
instance_group [ | ||
{ | ||
kind: KIND_GPU | ||
} | ||
] |
1 change: 1 addition & 0 deletions
1
...sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/1/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Do not delete me! |
70 changes: 70 additions & 0 deletions
70
...gingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/ensemble/config.pbtxt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
name: "ensemble" | ||
platform: "ensemble" | ||
max_batch_size: 16 | ||
input [ | ||
{ | ||
name: "INPUT0" | ||
data_type: TYPE_STRING | ||
dims: [ 1 ] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "finaloutput" | ||
data_type: TYPE_FP32 | ||
dims: [384] | ||
} | ||
] | ||
ensemble_scheduling { | ||
step [ | ||
{ | ||
model_name: "preprocess" | ||
model_version: -1 | ||
input_map { | ||
key: "INPUT0" | ||
value: "INPUT0" | ||
} | ||
output_map { | ||
key: "OUTPUT0" | ||
value: "token_ids" | ||
} | ||
output_map { | ||
key: "OUTPUT1" | ||
value: "attn_mask" | ||
} | ||
}, | ||
{ | ||
model_name: "bert-trt" | ||
model_version: -1 | ||
input_map { | ||
key: "token_ids" | ||
value: "token_ids" | ||
} | ||
input_map { | ||
key: "attn_mask" | ||
value: "attn_mask" | ||
} | ||
output_map { | ||
key: "output" | ||
value: "output" | ||
} | ||
}, | ||
{ | ||
model_name: "postprocess" | ||
model_version: -1 | ||
input_map { | ||
key: "TOKEN_EMBEDS_POST" | ||
value: "output" | ||
} | ||
input_map { | ||
key: "ATTENTION_POST" | ||
value: "attn_mask" | ||
} | ||
output_map { | ||
key: "SENT_EMBED" | ||
value: "finaloutput" | ||
} | ||
} | ||
] | ||
} |
78 changes: 78 additions & 0 deletions
78
...ingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/1/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import json | ||
import logging | ||
import numpy as np | ||
import subprocess | ||
import sys | ||
import os | ||
|
||
import triton_python_backend_utils as pb_utils | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class TritonPythonModel: | ||
"""This model loops through different dtypes to make sure that | ||
serialize_byte_tensor works correctly in the Python backend. | ||
""" | ||
|
||
def __mean_pooling(self, token_embeddings, attention_mask): | ||
logger.info("token_embeddings: {}".format(token_embeddings)) | ||
logger.info("attention_mask: {}".format(attention_mask)) | ||
|
||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | ||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | ||
|
||
def initialize(self, args): | ||
self.model_dir = args['model_repository'] | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt']) | ||
global torch | ||
import torch | ||
|
||
self.device_id = args['model_instance_device_id'] | ||
self.model_config = model_config = json.loads(args['model_config']) | ||
self.device = torch.device(f'cuda:{self.device_id}') if torch.cuda.is_available() else torch.device('cpu') | ||
|
||
output0_config = pb_utils.get_output_config_by_name( | ||
model_config, "SENT_EMBED") | ||
|
||
self.output0_dtype = pb_utils.triton_string_to_numpy( | ||
output0_config["data_type"]) | ||
|
||
def execute(self, requests): | ||
|
||
responses = [] | ||
|
||
for request in requests: | ||
tok_embeds = pb_utils.get_input_tensor_by_name(request, "TOKEN_EMBEDS_POST") | ||
attn_mask = pb_utils.get_input_tensor_by_name(request, "ATTENTION_POST") | ||
|
||
tok_embeds = tok_embeds.as_numpy() | ||
|
||
logger.info("tok_embeds: {}".format(tok_embeds)) | ||
logger.info("tok_embeds shape: {}".format(tok_embeds.shape)) | ||
|
||
tok_embeds = torch.tensor(tok_embeds,device=self.device) | ||
|
||
logger.info("tok_embeds_tensor: {}".format(tok_embeds)) | ||
|
||
attn_mask = attn_mask.as_numpy() | ||
|
||
logger.info("attn_mask: {}".format(attn_mask)) | ||
logger.info("attn_mask shape: {}".format(attn_mask.shape)) | ||
|
||
attn_mask = torch.tensor(attn_mask,device=self.device) | ||
|
||
logger.info("attn_mask_tensor: {}".format(attn_mask)) | ||
|
||
sentence_embeddings = self.__mean_pooling(tok_embeds, attn_mask) | ||
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | ||
|
||
out_0 = np.array(sentence_embeddings.cpu(),dtype=self.output0_dtype) | ||
logger.info("out_0: {}".format(out_0)) | ||
|
||
out_tensor_0 = pb_utils.Tensor("SENT_EMBED", out_0) | ||
logger.info("out_tensor_0: {}".format(out_tensor_0)) | ||
|
||
responses.append(pb_utils.InferenceResponse([out_tensor_0])) | ||
|
||
return responses |
26 changes: 26 additions & 0 deletions
26
...gface/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/config.pbtxt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: "postprocess" | ||
backend: "python" | ||
max_batch_size: 16 | ||
|
||
input [ | ||
{ | ||
name: "TOKEN_EMBEDS_POST" | ||
data_type: TYPE_FP32 | ||
dims: [128, 384] | ||
}, | ||
{ | ||
name: "ATTENTION_POST" | ||
data_type: TYPE_INT32 | ||
dims: [128] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "SENT_EMBED" | ||
data_type: TYPE_FP32 | ||
dims: [ 384 ] | ||
} | ||
] | ||
|
||
instance_group [{ kind: KIND_GPU }] |
1 change: 1 addition & 0 deletions
1
...e/sentence-transformers-triton-ensemble/examples/ensemble_hf/postprocess/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
torch |
74 changes: 74 additions & 0 deletions
74
...gingface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/1/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import json | ||
import logging | ||
import numpy as np | ||
import subprocess | ||
import sys | ||
|
||
import triton_python_backend_utils as pb_utils | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
class TritonPythonModel: | ||
"""This model loops through different dtypes to make sure that | ||
serialize_byte_tensor works correctly in the Python backend. | ||
""" | ||
|
||
def initialize(self, args): | ||
self.model_dir = args['model_repository'] | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt']) | ||
global transformers | ||
import transformers | ||
|
||
self.tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') | ||
self.model_config = model_config = json.loads(args['model_config']) | ||
|
||
output0_config = pb_utils.get_output_config_by_name( | ||
model_config, "OUTPUT0") | ||
output1_config = pb_utils.get_output_config_by_name( | ||
model_config, "OUTPUT1") | ||
|
||
self.output0_dtype = pb_utils.triton_string_to_numpy( | ||
output0_config['data_type']) | ||
self.output1_dtype = pb_utils.triton_string_to_numpy( | ||
output0_config['data_type']) | ||
|
||
def execute(self, requests): | ||
|
||
file = open("logs.txt", "w") | ||
|
||
responses = [] | ||
for request in requests: | ||
logger.info("Request: {}".format(request)) | ||
|
||
in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") | ||
in_0 = in_0.as_numpy() | ||
|
||
logger.info("in_0: {}".format(in_0)) | ||
|
||
tok_batch = [] | ||
|
||
for i in range(in_0.shape[0]): | ||
decoded_object = in_0[i,0].decode() | ||
|
||
logger.info("decoded_object: {}".format(decoded_object)) | ||
|
||
tok_batch.append(decoded_object) | ||
|
||
logger.info("tok_batch: {}".format(tok_batch)) | ||
|
||
tok_sent = self.tokenizer(tok_batch, | ||
padding='max_length', | ||
max_length=128, | ||
) | ||
|
||
|
||
logger.info("Tokens: {}".format(tok_sent)) | ||
|
||
out_0 = np.array(tok_sent['input_ids'],dtype=self.output0_dtype) | ||
out_1 = np.array(tok_sent['attention_mask'],dtype=self.output1_dtype) | ||
out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0) | ||
out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1) | ||
|
||
responses.append(pb_utils.InferenceResponse([out_tensor_0,out_tensor_1])) | ||
return responses |
28 changes: 28 additions & 0 deletions
28
...ngface/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/config.pbtxt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: "preprocess" | ||
backend: "python" | ||
max_batch_size: 16 | ||
|
||
input [ | ||
{ | ||
name: "INPUT0" | ||
data_type: TYPE_STRING | ||
dims: [ 1 ] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "OUTPUT0" | ||
data_type: TYPE_INT32 | ||
dims: [ 128 ] | ||
}, | ||
{ | ||
name: "OUTPUT1" | ||
data_type: TYPE_INT32 | ||
dims: [ 128 ] | ||
} | ||
] | ||
|
||
instance_group [{ kind: KIND_CPU }] |
1 change: 1 addition & 0 deletions
1
...ce/sentence-transformers-triton-ensemble/examples/ensemble_hf/preprocess/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
transformers |
Oops, something went wrong.