Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added sentence transformers example with TensorRT and Triton Ensemble #3615

Merged
merged 7 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# NVIDIA Triton Inference Server on SageMaker - Hugging Face Sentence Transformers

## Introduction

[HuggingFace Sentence Transformers](https://huggingface.co/sentence-transformers) is a Machine Learning (ML) framework and set of pre-trained models to
extract embeddings from sentence, text, and image. The models in this group can also be used with the default methods exposed through the [Transformers](https://www.google.com/search?q=transofrmers+githbu&rlz=1C5GCEM_enES937ES938&oq=transofrmers+githbu&aqs=chrome..69i57.3022j0j7&sourceid=chrome&ie=UTF-8) library.

[NVIDIA Triton Inference Server](https://github.com/triton-inference-server/server/) is a high-performance ML model server, which enables the deployment of ML models in an easy, scalable, and cost-effective way. It also exposes many easy-to-use optimization features to make the most of the underlying hardware, in particular NVIDIA GPU's.

In this example, we walk through how you can:
* Create an Amazon SageMaker Studio image based on the official [NVIDIA PyTorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) image, which includes the necessary dependencies to optimize your model
* Optimize a pre-trained HuggingFace Sentence Transformers model with NVIDIA TensorRT to enable high-performance inference
* Create a Triton Model Ensemble, which will allow you to run in sequence a pre-processing step (input tokenization), model inference and post-processing, where sentence embeddings are computed from the raw token embeddings

This example is meant to serve as a basis for use-cases in which you need to run your own code before and/or after your model, allowing you to optimize the bulk of the computation (the model) using tools such as TensorRT.

<img src="images/triton-ensemble.png" alt="Triton Model Ensamble" />

#### ! Important: The example provided can be tested also by using Amazon SageMaker Notebook Instances

### Prerequisites

1. Required NVIDIA NGC Account. Follow the instruction https://docs.nvidia.com/ngc/ngc-catalog-user-guide/index.html#registering-activating-ngc-account

## Step 1: Clone this repository

## Step 2: Build Studio image

In this example, we provide a [Dokerfile](./studio-image/image_tensorrt/Dockerfile) example to build a custom image for SageMaker Studio.

To build the image, push it and make it available in your Amazon SageMaker Studio environment, edit [sagemaker-studio-config](./studio-image/studio-domain-config.json) by replacing `$DOMAIN_ID` with your Studio domain ID.

We also provide automation scripts in order to [build and push](./studio-image/build_image.sh) your docker image to an ECR repository
and [create](./studio-image/create_studio_image.sh) or [update](./studio-image/update_studio_image.sh) an Amazon SageMaker Image. Please follow the instructions in the [README](./studio-image/README.md) for additional info on the usage of this script.

## Step 3: Compile model, create an Amazon SageMaker Real-Time Endpoint with NVIDIA Triton Inference Server

Clone this repository into your Amazon SageMaker Studio environment and execute the cells in the [notebook](./examples/triton_sentence_embeddings.ipynb)
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: "bert-trt"
platform: "tensorrt_plan"
max_batch_size: 16
input [
{
name: "token_ids"
data_type: TYPE_INT32
dims: [128]
},
{
name: "attn_mask"
data_type: TYPE_INT32
dims: [128]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [128, 384]
},
{
name: "854"
data_type: TYPE_FP32
dims: [384]
}
]
instance_group [
{
kind: KIND_GPU
}
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Do not delete me!
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
name: "ensemble"
platform: "ensemble"
max_batch_size: 16
input [
{
name: "INPUT0"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [
{
name: "finaloutput"
data_type: TYPE_FP32
dims: [384]
}
]
ensemble_scheduling {
step [
{
model_name: "preprocess"
model_version: -1
input_map {
key: "INPUT0"
value: "INPUT0"
}
output_map {
key: "OUTPUT0"
value: "token_ids"
}
output_map {
key: "OUTPUT1"
value: "attn_mask"
}
},
{
model_name: "bert-trt"
model_version: -1
input_map {
key: "token_ids"
value: "token_ids"
}
input_map {
key: "attn_mask"
value: "attn_mask"
}
output_map {
key: "output"
value: "output"
}
},
{
model_name: "postprocess"
model_version: -1
input_map {
key: "TOKEN_EMBEDS_POST"
value: "output"
}
input_map {
key: "ATTENTION_POST"
value: "attn_mask"
}
output_map {
key: "SENT_EMBED"
value: "finaloutput"
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import json
import logging
import numpy as np
import subprocess
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Consider possible security implications associated with subprocess module. https://bandit.readthedocs.io/en/latest/blacklists/blacklist_imports.html#b404-import-subprocess

import sys
import os

import triton_python_backend_utils as pb_utils

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TritonPythonModel:
"""This model loops through different dtypes to make sure that
serialize_byte_tensor works correctly in the Python backend.
"""

def __mean_pooling(self, token_embeddings, attention_mask):
logger.info("token_embeddings: {}".format(token_embeddings))
logger.info("attention_mask: {}".format(attention_mask))

input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def initialize(self, args):
self.model_dir = args['model_repository']
subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt'])
global torch
import torch

self.device_id = args['model_instance_device_id']
self.model_config = model_config = json.loads(args['model_config'])
self.device = torch.device(f'cuda:{self.device_id}') if torch.cuda.is_available() else torch.device('cpu')

output0_config = pb_utils.get_output_config_by_name(
model_config, "SENT_EMBED")

self.output0_dtype = pb_utils.triton_string_to_numpy(
output0_config["data_type"])

def execute(self, requests):

responses = []

for request in requests:
tok_embeds = pb_utils.get_input_tensor_by_name(request, "TOKEN_EMBEDS_POST")
attn_mask = pb_utils.get_input_tensor_by_name(request, "ATTENTION_POST")

tok_embeds = tok_embeds.as_numpy()

logger.info("tok_embeds: {}".format(tok_embeds))
logger.info("tok_embeds shape: {}".format(tok_embeds.shape))

tok_embeds = torch.tensor(tok_embeds,device=self.device)

logger.info("tok_embeds_tensor: {}".format(tok_embeds))

attn_mask = attn_mask.as_numpy()

logger.info("attn_mask: {}".format(attn_mask))
logger.info("attn_mask shape: {}".format(attn_mask.shape))

attn_mask = torch.tensor(attn_mask,device=self.device)

logger.info("attn_mask_tensor: {}".format(attn_mask))

sentence_embeddings = self.__mean_pooling(tok_embeds, attn_mask)
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)

out_0 = np.array(sentence_embeddings.cpu(),dtype=self.output0_dtype)
logger.info("out_0: {}".format(out_0))

out_tensor_0 = pb_utils.Tensor("SENT_EMBED", out_0)
logger.info("out_tensor_0: {}".format(out_tensor_0))

responses.append(pb_utils.InferenceResponse([out_tensor_0]))

return responses
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: "postprocess"
backend: "python"
max_batch_size: 16

input [
{
name: "TOKEN_EMBEDS_POST"
data_type: TYPE_FP32
dims: [128, 384]

},
{
name: "ATTENTION_POST"
data_type: TYPE_INT32
dims: [128]
}
]
output [
{
name: "SENT_EMBED"
data_type: TYPE_FP32
dims: [ 384 ]
}
]

instance_group [{ kind: KIND_GPU }]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
torch
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import json
import logging
import numpy as np
import subprocess
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Consider possible security implications associated with subprocess module. https://bandit.readthedocs.io/en/latest/blacklists/blacklist_imports.html#b404-import-subprocess

import sys

import triton_python_backend_utils as pb_utils

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class TritonPythonModel:
"""This model loops through different dtypes to make sure that
serialize_byte_tensor works correctly in the Python backend.
"""

def initialize(self, args):
self.model_dir = args['model_repository']
subprocess.check_call([sys.executable, "-m", "pip", "install", '-r', f'{self.model_dir}/requirements.txt'])
global transformers
import transformers

self.tokenizer = transformers.AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
self.model_config = model_config = json.loads(args['model_config'])

output0_config = pb_utils.get_output_config_by_name(
model_config, "OUTPUT0")
output1_config = pb_utils.get_output_config_by_name(
model_config, "OUTPUT1")

self.output0_dtype = pb_utils.triton_string_to_numpy(
output0_config['data_type'])
self.output1_dtype = pb_utils.triton_string_to_numpy(
output0_config['data_type'])

def execute(self, requests):

file = open("logs.txt", "w")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji.

Problem
This line of code might contain a resource leak. Resource leaks can cause your system to slow down or crash.

Fix
Consider closing the following resource: file. The resource is allocated by call builtins.open. Execution paths that do not contain closure statements were detected. To prevent this resource leak, close file in a try-finally block or declare it using a with statement.

More info
View details about the with statement in the Python developer's guide (external link).


responses = []
for request in requests:
logger.info("Request: {}".format(request))

in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0")
in_0 = in_0.as_numpy()

logger.info("in_0: {}".format(in_0))

tok_batch = []

for i in range(in_0.shape[0]):
decoded_object = in_0[i,0].decode()

logger.info("decoded_object: {}".format(decoded_object))

tok_batch.append(decoded_object)

logger.info("tok_batch: {}".format(tok_batch))

tok_sent = self.tokenizer(tok_batch,
padding='max_length',
max_length=128,
)


logger.info("Tokens: {}".format(tok_sent))

out_0 = np.array(tok_sent['input_ids'],dtype=self.output0_dtype)
out_1 = np.array(tok_sent['attention_mask'],dtype=self.output1_dtype)
out_tensor_0 = pb_utils.Tensor("OUTPUT0", out_0)
out_tensor_1 = pb_utils.Tensor("OUTPUT1", out_1)

responses.append(pb_utils.InferenceResponse([out_tensor_0,out_tensor_1]))
return responses
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: "preprocess"
backend: "python"
max_batch_size: 16

input [
{
name: "INPUT0"
data_type: TYPE_STRING
dims: [ 1 ]

}
]
output [
{
name: "OUTPUT0"
data_type: TYPE_INT32
dims: [ 128 ]
},

{
name: "OUTPUT1"
data_type: TYPE_INT32
dims: [ 128 ]
}

]

instance_group [{ kind: KIND_CPU }]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers
Loading