Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue to build deberta-v3-base due to missing validUnaryType && datatype on TensorRT 8.6 #3587

Closed
VibhuJawa opened this issue Jan 10, 2024 · 6 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@VibhuJawa
Copy link

VibhuJawa commented Jan 10, 2024

Description

I am running into conversion issue while trying to convert deberta-v3-base into a TensorRT engine. We run into

 UNSUPPORTED_NODE: Assertion failed: validUnaryType && "This version of TensorRT does not support the given operator with the given input data type."

More Trace is present next to MRE.

Environment

TensorRT Version:

8.6

NVIDIA GPU:

V100

NVIDIA Driver Version:
525.105.17

CUDA Version:
12.0

CUDNN Version:

Operating System:

Python Version (if applicable): 3.10

Steps To Reproduce

from transformers import DebertaV2ForSequenceClassification
import tensorrt as trt
import torch

def remove_uint8_cast(graph):
    nodes = [node for node in graph.nodes if node.op == 'Cast' and node.attrs["to"] == TensorProto.UINT8]

    for node in nodes:
        input_node = node.i()
        input_node.outputs = node.outputs
        node.outputs.clear()

    return graph


deberta_model = DebertaV2ForSequenceClassification.from_pretrained("microsoft/deberta-v3-base").cuda()
deberta_model.eval()
vocab_size = deberta_model.config.vocab_size
batch_size = 32
seq_len = 12
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), dtype=torch.long).to('cuda')
attention_mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.long).to('cuda')
input_names = ['input_ids', 'attention_mask']
output_names = ['output']
dynamic_axes={'input_ids'   : {0 : 'batch_size', 1: 'seq_len'},
              'attention_mask'   : {0 : 'batch_size', 1: 'seq_len'},
              'output' : {0 : 'batch_size'}}

torch.onnx.export(deberta_model,
                  (input_ids, attention_mask),
                  "model.onnx",
                  export_params=True,
                  opset_version=13,
                  do_constant_folding=True,
                  input_names = input_names,
                  output_names = output_names,
                  dynamic_axes = dynamic_axes
                 )





TRT_LOGGER = trt.Logger(trt.Logger.INFO)
TRT_BUILDER = trt.Builder(TRT_LOGGER)
network = TRT_BUILDER.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
onnx_parser = trt.OnnxParser(network, TRT_LOGGER)
parse_success = onnx_parser.parse_from_file("model.onnx")

for idx in range(onnx_parser.num_errors):
    print(onnx_parser.get_error(idx))
Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:554](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:554): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(mid - 1).type_as(relative_pos),
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:558](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:558): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.ceil(torch.log(abs_pos [/](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/) mid) [/](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/) torch.log(torch.tensor((max_position - 1) [/](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/) mid)) * (mid - 1)) + mid
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:792): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(pos_key_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:804): UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  scale = torch.sqrt(torch.tensor(pos_query_layer.size(-1), dtype=torch.float) * scale_factor)
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:805](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:805): TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if key_layer.size(-2) != query_layer.size(-2):
[/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:112](https://vscode-remote+ssh-002dremote-002b10-002e33-002e227-002e161.vscode-resource.vscode-cdn.net/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:112): TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
[01/10/2024-08:55:53] [TRT] [I] The logger passed into createInferBuilder differs from one already provided for an existing builder, runtime, or refitter. Uses of the global logger, returned by nvinfer1::getLogger(), will return the existing value.
[01/10/2024-08:55:53] [TRT] [I] [MemUsageChange] Init CUDA: CPU +0, GPU +0, now: CPU 3739, GPU 3915 (MiB)
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 738637429
[01/10/2024-08:55:54] [TRT] [I] ----------------------------------------------------------------
[01/10/2024-08:55:54] [TRT] [I] Input filename:   model.onnx
[01/10/2024-08:55:54] [TRT] [I] ONNX IR version:  0.0.7
[01/10/2024-08:55:54] [TRT] [I] Opset version:    13
[01/10/2024-08:55:54] [TRT] [I] Producer name:    pytorch
[01/10/2024-08:55:54] [TRT] [I] Producer version: 2.1.0
[01/10/2024-08:55:54] [TRT] [I] Domain:           
[01/10/2024-08:55:54] [TRT] [I] Model version:    0
[01/10/2024-08:55:54] [TRT] [I] Doc string:       
[01/10/2024-08:55:54] [TRT] [I] ----------------------------------------------------------------
[01/10/2024-08:55:55] [TRT] [W] onnx2trt_utils.cpp:374: Your ONNX model has been generated with INT64 weights, while TensorRT does not natively support INT64. Attempting to cast down to INT32.
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:771: While parsing node number 66 [Sign -> "/deberta/encoder/Sign_output_0"]:
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:772: --- Begin node ---
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:773: input: "/deberta/encoder/Sub_output_0"
output: "/deberta/encoder/Sign_output_0"
name: "/deberta/encoder/Sign"
op_type: "Sign"

[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:774: --- End node ---
[01/10/2024-08:55:55] [TRT] [E] ModelImporter.cpp:777: ERROR: onnx2trt_utils.cpp:1779 In function unaryHelper:
[8] Assertion failed: validUnaryType && "This version of TensorRT does not support the given operator with the given input data type."
In node 66 (unaryHelper): UNSUPPORTED_NODE: Assertion failed: validUnaryType && "This version of TensorRT does not support the given operator with the given input data type."
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:604] Reading dangerously large protocol message.  If the message turns out to be larger than 2147483647 bytes, parsing will be halted for security reasons.  To increase the limit (or to disable these warnings), see CodedInputStream::SetTotalBytesLimit() in google/protobuf/io/coded_stream.h.
[libprotobuf WARNING google/protobuf/io/coded_stream.cc:81] The total number of bytes read was 738637429

Also tried following #3124 but to no avail.

# graph = gs.import_onnx(onnx.load("model.onnx"))
# graph = remove_uint8_cast(graph)

# graph.cleanup().toposort()
# onnx.save_model(gs.export_onnx(graph), "model_updated.onnx")

# model = fold_constants(onnx.load("model_updated.onnx"))
# onnx.save(model, "model_updated_folded.onnx")
@zerollzeng zerollzeng self-assigned this Jan 11, 2024
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Jan 11, 2024
@zerollzeng
Copy link
Collaborator

Looks like a known limitation, could you please try latest TRT 9.2? Thanks!

@VibhuJawa
Copy link
Author

VibhuJawa commented Jan 11, 2024

Looks like a known limitation, could you please try latest TRT 9.2? Thanks!

Can you link me to on how do I get access to latest TRT 9.2 , please?

I dont see nightly wheels.

@VibhuJawa
Copy link
Author

VibhuJawa commented Jan 11, 2024

I tested it with tensorrt==9.0.1.post12.dev4 and i can create the engine now.

I want to support dynamic batch and sequence sizes, I am running into below warnings which based on my understanding will mean that we will fail there. Can you suggest how to get that working ?

You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:554: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.tensor(mid - 1).type_as(relative_pos),
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:558: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  torch.ceil(torch.log(abs_pos / mid) / torch.log(torch.tensor((max_position - 1) / mid)) * (mid - 1)) + mid
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:717: UserWarning: 
.................
  if key_layer.size(-2) != query_layer.size(-2):
/datasets/vjawa/miniconda3/envs/TensorRT_Transformers/lib/python3.10/site-packages/transformers/models/deberta_v2/modeling_deberta_v2.py:112: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))

@zerollzeng
Copy link
Collaborator

it's a warning from transformers(pytorch), I think it means use dynamic shape may caused unexpected behavior when export onnx(e.g. some tensor still become constant) during export onnx. You can ask for help in the transformer repo.

@ttyio
Copy link
Collaborator

ttyio commented Feb 6, 2024

Closing since no activity for more than 3 weeks, thanks all!

@ttyio ttyio closed this as completed Feb 6, 2024
@copasseron
Copy link

is there any release or tag of TensorRT 8.6.1 that resolves this issue ?

Because I could built it with TRT 9.2.0, but I want to deploy this model on nvidia triton inference server.

However, triton tensorRT backend does not yet support newer version of tensorRT than 8.6.1, and the tensorRT runtime version should be the same than the one used to build the engine.

What would be my best solution for this problem ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants