Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic batch results are strange. #3744

Closed
KyungHyunLim opened this issue Mar 27, 2024 · 13 comments
Closed

Dynamic batch results are strange. #3744

KyungHyunLim opened this issue Mar 27, 2024 · 13 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@KyungHyunLim
Copy link

KyungHyunLim commented Mar 27, 2024

Description

I have created an onnx model with the python sript below.
Then, I created the tensorrt model with the following command.

trtexec --onnx=model.onnx --saveEngine=model.plan --fp16 --minShapes=input_ids:1x128,attention_mask:1x128 --optShapes=input_ids:4x128,attention_mask:4x128 --maxShapes=input_ids:100x128,attention_mask:100x128 --shapes=input_ids:3x128,attention_mask:3x128

I moved both models to the triton inference server.

tensorrt example

In the case of the tensorrt model, the results are different when inferring one same sentence and when inferring three sentences.

querys = ["I love dog"]
embd = get_embeddings("trt_getEmbedding", querys)
print(embd)

querys = ["I love dog", "I love dog", "I love dog"]
embd = get_embeddings("trt_getEmbedding", querys)
print(embd[:768])
print(embd[768:768*2])
print(embd[768*2:])

==> output
image

onnx example

The onnx model gives the same results.

querys = ["I love dog"]
embd = get_embeddings("getEmbedding_onnx", querys)
print(embd)

querys = ["I love dog", "I love dog", "I love dog"]
embd = get_embeddings("getEmbedding_onnx", querys)
print(embd[:768])
print(embd[768:768*2])
print(embd[768*2:])

==> output
image

I need to use dynamic batch from 1x128 to 100x128.
Is there a problem with the tensorrt model conversion process?

Environment

nvcr.io/nvidia/tritonserver:23.04-py3
nvcr.io/nvidia/tensorrt:23.04-py3

TensorRT Version: 8.6.1
NVIDIA GPU: GPU: 4090
NVIDIA Driver Version: 530.41.03
CUDA Version: 12.1
CUDNN Version: 8.9.0

[03/22/2024-01:11:21] [I] [TRT] Input filename: model.onnx
[03/22/2024-01:11:21] [I] [TRT] ONNX IR version: 0.0.8
[03/22/2024-01:11:21] [I] [TRT] Opset version: 15
[03/22/2024-01:11:21] [I] [TRT] Producer name: pytorch
[03/22/2024-01:11:21] [I] [TRT] Producer version: 1.13.1

Relevant Files

make_onnx.py

from typing import List, Tuple, Optional, Union

import numpy as np
import torch

from transformers import RobertaPreTrainedModel
from transformers.models.roberta.modeling_roberta import RobertaEmbeddings, RobertaEncoder, RobertaPooler
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

class RobertaModel(RobertaPreTrainedModel):
    # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config)
        self.config = config

        self.embeddings = RobertaEmbeddings(config)
        self.encoder = RobertaEncoder(config)

        self.pooler = RobertaPooler(config) if add_pooling_layer else None

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embeddings.word_embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value

    def _prune_heads(self, heads_to_prune):
        """ 
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    # Copied from transformers.models.bert.modeling_bert.BertModel.forward
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.config.is_decoder:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
            input_shape = input_ids.size()
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        batch_size, seq_length = input_shape
        device = input_ids.device if input_ids is not None else inputs_embeds.device

        # past_key_values_length
        past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0

        if attention_mask is None:
            attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)

        if token_type_ids is None:
            if hasattr(self.embeddings, "token_type_ids"):
                buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
                buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
                token_type_ids = buffered_token_type_ids_expanded
            else:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)

        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.is_decoder and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_extended_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        embedding_output = self.embeddings(
            input_ids=input_ids,
            position_ids=position_ids,
            token_type_ids=token_type_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        encoder_outputs = self.encoder(
            embedding_output,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_extended_attention_mask,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = encoder_outputs[0]
        pooled_output = self.pooler(sequence_output) if self.pooler is not None else None

        return pooled_output

max_length = 128
bert_model = RobertaModel.from_pretrained('./5/', torchscript=True).to('cuda')
input_ids = torch.as_tensor(np.ones([1, max_length]), dtype=torch.int32).cuda()
attention_mask = torch.as_tensor(np.ones([1, max_length]), dtype=torch.int32).cuda()

# print(torch.onnx.__version__)

dynamic_axes = {
    'input_ids': {0 : 'batch_size'},
    'attention_mask': {0 : 'batch_size'},
    'outputs': {0 : 'batch_size'}
}
torch.onnx.export(
    bert_model,
    (input_ids, attention_mask),
    'model32.onnx',
    input_names=['input_ids', 'attention_mask'],
    output_names=['outputs'],
    dynamic_axes=dynamic_axes,
    opset_version=15)
@lix19937
Copy link

use polygraphy run model.onnx --trt --onnxrt --input-shapes to compare the output between trt and onnxruntime

@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Mar 31, 2024
@zerollzeng zerollzeng self-assigned this Mar 31, 2024
@zerollzeng
Copy link
Collaborator

Exactly, did you set input shapes before you feed another input shape for inference?

@KyungHyunLim
Copy link
Author

polygraphy run model.onnx --trt --onnxrt --input-shapes=input_ids:3x128 attention_mask:3x128
==>

root@0f8e2a20f78e:/workspace# polygraphy rumodel.onnx --trt --onnxrt --input-shapes=input_ids:3x128,attention_mask:3x128
[W] Input tensor: input_ids:3x128,attention_mask | For TensorRT profile, overriding dynamic shape: BoundedShape(['3', 'x', '1', '2', '8'], min=None, max=None) to: [1, 1, 1, 1, 1]
[I] RUNNING | Command: /usr/local/bin/polygraphy run model.onnx --trt --onnxrt --input-shapes=input_ids:3x128,attention_mask:3x128
[I] Will generate inference input data according to provided TensorMetadata: {input_ids:3x128,attention_mask [shape=('3', 'x', '1', '2', '8')]}
[I] trt-runner-N0-04/02/24-23:37:18     | Activating and starting inference
[W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[W]     Input tensor: input_ids (dtype=DataType.INT32, shape=(-1, 128)) | No shapes provided; Will use shape: [1, 128] for min/opt/max in profile.
[W]     This will cause the tensor to have a static shape. If this is incorrect, please set the range of shapes for this input tensor.
[W]     Input tensor: attention_mask (dtype=DataType.INT32, shape=(-1, 128)) | No shapes provided; Will use shape: [1, 128] for min/opt/max in profile.
[!]     Invalid inputs were provided to the optimization profile: {'input_ids:3x128,attention_mask'}
        Note: Inputs available in the TensorRT network are: {'input_ids', 'attention_mask'}
[E] FAILED | Runtime: 7.027s | Command: /usr/local/bin/polygraphy run model.onnx --trt --onnxrt --input-shapes=input_ids:3x128,attention_mask:3x128

An error occurs...

@KyungHyunLim
Copy link
Author

The shape is not set separately before inference.
This is my triton config file.

name: "bert"
platform: "tensorrt_plan" 
max_batch_size: 100

optimization { execution_accelerators {
  gpu_execution_accelerator : [ {
    name : "tensorrt"
    parameters { key: "precision_mode" value: "FP16" }
    parameters { key: "max_workspace_size_bytes" value: "1073741824" }
    }]
}}

instance_group [ { count: 2 }]
dynamic_batching { }


input [
    {
        name: "input_ids"
        data_type: TYPE_INT32 
        dims: [128]
    },
    {
        name: "attention_mask"
        data_type: TYPE_INT32 
        dims: [128]
    }
] 

output [
    {
        name: "outputs"
        data_type: TYPE_FP32 
        dims: [768]
    }
]

@KyungHyunLim
Copy link
Author

There is no problem with the polygraphy command in fp32, but a problem occurs in fp16.

polygraphy run model.onnx --trt --onnxr --fp16
==>

[I] onnxrt-runner-N0-04/04/24-00:29:09  | Completed 1 iteration(s) in 41 ms | Average inference time: 41 ms.
[I] Accuracy Comparison | trt-runner-N0-04/04/24-00:29:09 vs. onnxrt-runner-N0-04/04/24-00:29:09
[I]     Comparing Output: 'outputs' (dtype=float32, shape=(1, 768)) with 'outputs' (dtype=float32, shape=(1, 768))
[I]         Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-04/04/24-00:29:09: outputs | Stats: mean=0.019415, std-dev=0.51338, var=0.26356, median=0.018784, min=-0.96875 at (0, 316), max=0.95947 at (0, 358), avg-magnitude=0.44511
[I]             ---- Histogram ----
                Bin Range            |  Num Elems | Visualization
                (-0.969  , -0.776  ) |         47 | ##################
                (-0.776  , -0.583  ) |         71 | ############################
                (-0.583  , -0.39   ) |         84 | #################################
                (-0.39   , -0.197  ) |         99 | ########################################
                (-0.197  , -0.00464) |         72 | #############################
                (-0.00464, 0.188   ) |         79 | ###############################
                (0.188   , 0.381   ) |         92 | #####################################
                (0.381   , 0.574   ) |         84 | #################################
                (0.574   , 0.767   ) |         86 | ##################################
                (0.767   , 0.959   ) |         54 | #####################
[I]         onnxrt-runner-N0-04/04/24-00:29:09: outputs | Stats: mean=0.019295, std-dev=0.51356, var=0.26375, median=0.021457, min=-0.96847 at (0, 316), max=0.95893 at (0, 358), avg-magnitude=0.44535
[I]             ---- Histogram ----
                Bin Range            |  Num Elems | Visualization
                (-0.969  , -0.776  ) |         48 | ###################
                (-0.776  , -0.583  ) |         70 | ############################
                (-0.583  , -0.39   ) |         85 | ##################################
                (-0.39   , -0.197  ) |         99 | ########################################
                (-0.197  , -0.00464) |         71 | ############################
                (-0.00464, 0.188   ) |         81 | ################################
                (0.188   , 0.381   ) |         90 | ####################################
                (0.381   , 0.574   ) |         83 | #################################
                (0.574   , 0.767   ) |         87 | ###################################
                (0.767   , 0.959   ) |         54 | #####################
[I]         Error Metrics: outputs
[I]             Minimum Required Tolerance: elemwise error | [abs=0.010903] OR [rel=2.4414] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.0026119, std-dev=0.0022534, var=5.0778e-06, median=0.0019732, min=6.1989e-06 at (0, 323), max=0.010903 at (0, 591), avg-magnitude=0.0026119
[I]                 ---- Histogram ----
                    Bin Range          |  Num Elems | Visualization
                    (6.2e-06, 0.0011 ) |        251 | ########################################
                    (0.0011 , 0.00219) |        162 | #########################
                    (0.00219, 0.00328) |        118 | ##################
                    (0.00328, 0.00436) |         82 | #############
                    (0.00436, 0.00545) |         59 | #########
                    (0.00545, 0.00654) |         40 | ######
                    (0.00654, 0.00763) |         23 | ###
                    (0.00763, 0.00872) |         20 | ###
                    (0.00872, 0.00981) |          6 | 
                    (0.00981, 0.0109 ) |          7 | #
[I]             Relative Difference | Stats: mean=0.023913, std-dev=0.11717, var=0.013729, median=0.0045968, min=1.2172e-05 at (0, 323), max=2.4414 at (0, 9), avg-magnitude=0.023913
[I]                 ---- Histogram ----
                    Bin Range         |  Num Elems | Visualization
                    (1.22e-05, 0.244) |        758 | ########################################
                    (0.244   , 0.488) |          7 | 
                    (0.488   , 0.732) |          0 | 
                    (0.732   , 0.977) |          1 | 
                    (0.977   , 1.22 ) |          0 | 
                    (1.22    , 1.46 ) |          0 | 
                    (1.46    , 1.71 ) |          1 | 
                    (1.71    , 1.95 ) |          0 | 
                    (1.95    , 2.2  ) |          0 | 
                    (2.2     , 2.44 ) |          1 | 
[E]         FAILED | Output: 'outputs' | Difference exceeds tolerance (rel=1e-05, abs=1e-05)
[E]     FAILED | Mismatched outputs: ['outputs']
[E] Accuracy Summary | trt-runner-N0-04/04/24-00:29:09 vs. onnxrt-runner-N0-04/04/24-00:29:09 | Passed: 0/1 iterations | Pass Rate: 0.0%
[E] FAILED | Runtime: 32.517s | Command: /usr/local/bin/polygraphy run model.onnx --trt --onnxr --fp16

@zerollzeng
Copy link
Collaborator

[I] Error Metrics: outputs
[I] Minimum Required Tolerance: elemwise error | [abs=0.010903] OR [rel=2.4414] (requirements may be lower if both abs/rel tolerances are set)
[I] Absolute Difference | Stats: mean=0.0026119, std-dev=0.0022534, var=5.0778e-06, median=0.0019732, min=6.1989e-06 at (0, 323), max=0.010903 at (0, 591), avg-magnitude=0.0026119
[I] ---- Histogram ----
Bin Range | Num Elems | Visualization
(6.2e-06, 0.0011 ) | 251 | ########################################
(0.0011 , 0.00219) | 162 | #########################
(0.00219, 0.00328) | 118 | ##################
(0.00328, 0.00436) | 82 | #############
(0.00436, 0.00545) | 59 | #########
(0.00545, 0.00654) | 40 | ######
(0.00654, 0.00763) | 23 | ###
(0.00763, 0.00872) | 20 | ###
(0.00872, 0.00981) | 6 |
(0.00981, 0.0109 ) | 7 | #

The diff look good to me for FP16, did you observe significant accuracy loss on the real dataset?

@zerollzeng
Copy link
Collaborator

The failure here was due the the default polygraphy error tolerance is 1e-5.

@littleMatch03
Copy link

@KyungHyunLim, is this issue solved?

@KyungHyunLim
Copy link
Author

@littleMatch03 I'm still trying to figure it out.

@KyungHyunLim
Copy link
Author

@zerollzeng

There was no problem when learning.

When I gave the ["I love dog"] as input,
["I love dog", "I love dog", "I love dog"] as input, the embedding vector obtained is different.

The onnx model shows similar results.

For example,
Briefly showing the triton results,

  • onnx model
    ["i love dog"] => inference => [2.0, 3.0, 4.0]
    ["I love dog", "I love dog", "I love dog"] => inference => [ [1.99, 2.99, 4.01], [1.99, 2.99, 4.01], [1.99, 2.99, 4.01]]

  • trt model
    ["i love dog"] => inference => [1.98, 3.02, 3.97]
    ["I love dog", "I love dog", "I love dog"] => inference => [[0.1, 5.9, 4.8], [1.0, 6.0, 8.2], [0.3, 0.5, 2.7]]

There is no problem when typing 1 like this, but if the batch size exceeds 1, there is a problem.

@KyungHyunLim
Copy link
Author

KyungHyunLim commented Apr 18, 2024

This is the result of using polygraphy inspect model

[I] Loading bytes from /workspace/model.plan
[W] CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage and speed up TensorRT initialization. See "Lazy Loading" section of CUDA documentation https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#lazy-loading
[I] ==== TensorRT Engine ====
    Name: Unnamed Network 0 | Explicit Batch Engine
    
    ---- 2 Engine Input(s) ----
    {input_ids [dtype=int32, shape=(-1, 128)],
     attention_mask [dtype=int32, shape=(-1, 128)]}
    
    ---- 1 Engine Output(s) ----
    {outputs [dtype=float32, shape=(-1, 768)]}
    
    ---- Memory ----
    Device Memory: 236263936 bytes
    
    ---- 1 Profile(s) (3 Tensor(s) Each) ----
    - Profile: 0
        Tensor: input_ids               (Input), Index: 0 | Shapes: min=(1, 128), opt=(3, 128), max=(100, 128)
        Tensor: attention_mask          (Input), Index: 1 | Shapes: min=(1, 128), opt=(3, 128), max=(100, 128)
        Tensor: outputs                (Output), Index: 2 | Shape: (-1, 768)
    
    ---- 2 Layer(s) ----
[I] Loading model: /workspace/model.onnx
[I] ==== ONNX Model ====
    Name: torch_jit | ONNX Opset: 15
    
    ---- 2 Graph Input(s) ----
    {input_ids [dtype=int64, shape=('batch_size', 128)],
     attention_mask [dtype=int64, shape=('batch_size', 128)]}
    
    ---- 1 Graph Output(s) ----
    {outputs [dtype=float32, shape=('batch_size', 768)]}
    
    ---- 199 Initializer(s) ----
    
    ---- 1374 Node(s) ----

@zerollzeng
Copy link
Collaborator

There is no problem when typing 1 like this, but if the batch size exceeds 1, there is a problem.

Normally it's caused by not setting the correct dynamic shape profile or input shapes.

@ttyio
Copy link
Collaborator

ttyio commented Jul 2, 2024

closing since no activity for more than 3 weeks, pls reopen if you still have question, thanks all!

@ttyio ttyio closed this as completed Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

5 participants