Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Encountered bug when using Torch-TensorRT with torchscript model Conformer Transducer #2197

Open
kzelias opened this issue Aug 14, 2023 · 16 comments
Assignees
Labels
bug Something isn't working

Comments

@kzelias
Copy link

kzelias commented Aug 14, 2023

Bug Description

I get an error when converting a conformer transducer enecoder to tensorrt. (asr task)

To Reproduce

requirenments.txt

CODE:

import nemo.collections.asr as nemo_asr
import torch
import torch_tensorrt as torchtrt


nemo_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_large")
nemo_model.freeze()
nemo_model.export(output="temp_rnnt.ts", check_trace=True)


with torchtrt.logging.debug():
    variant = "encoder-temp_rnnt.ts"
    precisions = [torch.float, torch.half]
    batch_size = 1

    model = torch.jit.load(variant)

    inputs = [
            torchtrt.Input(shape=[batch_size, 80, 8269]), # 8269 from mel spectr for 1min wav with resample
            torchtrt.Input(shape=[1]),
        ]

    for precision in precisions:
        compile_settings = {
            "inputs": inputs, 
            "enabled_precisions": {precision},
            "workspace_size": 2000000000,
            "truncate_long_and_double": True,
        }
        print(f"Generating Torchscript-TensorRT module for batchsize {batch_size} precision {precision}")
        trt_ts_module = torchtrt.compile(model, **compile_settings)
        torch.jit.save(trt_ts_module, f"{variant.replace('.ts','')}_bs{batch_size}_{precision}.torch-tensorrt")

CONSOLE:

Generating Torchscript-TensorRT module for batchsize 1 precision torch.float32
WARNING: [Torch-TensorRT] - Data types for input tensors have been modified by inserting aten::to operations which cast INT64 inputs to INT32. To disable this, please recompile using INT32 inputs
WARNING: [Torch-TensorRT] - Truncating intermediate graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating intermediate graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating intermediate graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating intermediate graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Truncating intermediate graph input type from at::kLong to at::kInt
WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - Trying to record the value lengths1.1 with the ITensor (Unnamed Layer* 13) [Unary]_output again.
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - Truncating aten::to output type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Trying to record the value padding_length.1 with the ITensor (Unnamed Layer* 26) [Identity]_output again.
WARNING: [Torch-TensorRT] - Truncating aten::to output type from at::kLong to at::kInt
WARNING: [Torch-TensorRT] - Trying to record the value 28 with the ITensor (Unnamed Layer* 26) [Identity]_output again.
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - Unable to process input type of at::kLong, truncate type to at::kInt in scalar_to_tensor_util 
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Int64 to Int32
WARNING: [Torch-TensorRT] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
WARNING: [Torch-TensorRT TorchScript Conversion Context] - CUDA lazy loading is not enabled. Enabling it can significantly reduce device memory usage. See `CUDA_MODULE_LOADING` in https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
WARNING: [Torch-TensorRT
] - Truncating weight (constant in the graph) from Float64 to Float32
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: [graphShapeAnalyzer.cpp::analyzeShapes::1872] Error Code 4: Miscellaneous (IElementWiseLayer %103 : Tensor = aten::add(%matrix_ac.1, %matrix_bd0.1, %124) # /usr/local/lib/python3.8/dist-packages/nemo/collections/asr/parts/submodules/multi_head_attention.py:243:0: broadcast dimensions must be conformable)
Segmentation fault (core dumped)

Expected behavior

I'm expecting a tensorrt file on the output

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 1.4.0
  • PyTorch Version (e.g. 1.0): 2.0.1+cu118
  • CPU Architecture: AMD EPYC 7763 64-Core Processor
  • OS (e.g., Linux): Ubuntu 20.04.5 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Python version: 3.8.10
  • CUDA version: release 11.8, V11.8.89
  • GPU models and configuration: NVIDIA A100 80GB
  • image: nvcr.io/nvidia/tensorrt:22.12-py3

Additional context

I want to export from torch script to tenorrt encoder and decoder conformer transducer models

@kzelias kzelias added the bug Something isn't working label Aug 14, 2023
@kzelias
Copy link
Author

kzelias commented Aug 15, 2023

Full debug log:
debug_tensorrt_conformer_transducer.log

@kzelias
Copy link
Author

kzelias commented Aug 15, 2023

updated the code for full reproduction^

P.S. no one conformer model from NeMo is compiled to tensorrt
I tried stt_en_conformer_transducer_large, stt_en_conformer_transducer_small

@narendasan
Copy link
Collaborator

@gs-olive can you take a look at this nemo model?

@gs-olive
Copy link
Collaborator

I am able to reproduce this error in the TorchScript path on the latest main and with NeMo toolkit 1.20.0. It seems to stem from tensor addition operators which are not broadcastable (looks like [1, 4, 6204, 6204] and [1, 4, 6204, 4135] are being added). I'm not yet sure what is causing this mismatch, but it seems to be either a converter or lowering pass.

@gs-olive
Copy link
Collaborator

gs-olive commented Aug 16, 2023

When tracing this model with the torch_compile IR option on main, we encounter the errors from #2183 and #2227, for which a fix is in-progress. I will post an update on this model once that fix is ready. In the meantime, I am looking further into the TorchScript broadcasting issue.

@kzelias
Copy link
Author

kzelias commented Aug 16, 2023

Thanks! Do I understand correctly that if I try to compile this model from native pytorch to tensorrt it might work? Or is the problem in the Conformer architecture itself?

@gs-olive
Copy link
Collaborator

gs-olive commented Aug 16, 2023

The issue does not seem to be with the Conformer architecture itself, since inference in plain PyTorch is working, and it is scripting to TorchScript successfully. There is a possibility that PyTorch --> ONNX --> TensorRT might work, yes. I have verified that with #2228 and #2234, we are able to compile this model with torchtrt.compile(model, ir="torch_compile", ...). I am still investigating the TorchScript path for the model.

@gs-olive
Copy link
Collaborator

gs-olive commented Aug 18, 2023

Regarding the TorchScript path, the bug occurs on this line, where the shape of matrix_ac and matrix_bd disagree. Specifically, the issue is that this line attempts to drop extra elements in matrix_bd to match that of matrix_ac, but matrix_bd already has fewer elements than matrix_ac, so the truncation has no effect. This behavior is not reflected in Torch, however, so the issue is likely not with the Nemo code or Torch.

@gs-olive
Copy link
Collaborator

gs-olive commented Sep 1, 2023

After further investigation on this issue, we may be able to compile this model via the ir="dynamo" path, which also allows model saving and loading. Currently, we will need #2195, and possibly #2249 to fully compile, save, and load this model. Additionally, I used a wrapper class to ensure the inputs are a list of tensors instead of named arguments, as below:

class ModelWrapper(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.nemo_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_small")
        self.nemo_model.freeze()
        self.nemo_model.eval().cuda()

    def forward(self, x, y):
        return self.nemo_model(processed_signal=x,
                               processed_signal_length=y)

I will follow up on this issue again as these PRs and improvements are merged.

@gs-olive
Copy link
Collaborator

Hello - the referenced PRs have been merged, and the model building/serialization is now functional for this model in the Dynamo path! The script I used to compile, serialize, and reload the model can be found below:

Code Sample
import nemo.collections.asr as nemo_asr
import torch
import torch_tensorrt as torchtrt

batch_size = 1
inputs = [
        torchtrt.Input(shape=[batch_size, 80, 8269]),
        torchtrt.Input(shape=[batch_size]),
    ]

torch_inputs = [torch.rand([batch_size, 80, 8269]).cuda(),
                torch.rand([batch_size]).cuda()]

class Wrapper(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.nemo_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(model_name="stt_en_conformer_transducer_small")
        self.nemo_model.freeze()
        self.nemo_model.eval().cuda()

    def forward(self, x, y):
        return self.nemo_model(processed_signal=x,
                               processed_signal_length=y)

# Trace the model through to FX
nemo_model = Wrapper().eval().cuda()
fx_graphmodule = torch.fx.experimental.proxy_tensor.make_fx(nemo_model)(*torch_inputs)

compile_settings = {
    "inputs": inputs,
    "enabled_precisions": {torch.float, torch.half},
    "truncate_long_and_double": True,
    "min_block_size": 85,
}

# Compile TRT-optimized model
trt_fx_module = torchtrt.compile(fx_graphmodule, ir="dynamo", **compile_settings)
trt_out = trt_fx_module(*torch_inputs)

# Trace through the output model with TorchScript
# Serialize and save the resultant graph
trt_script_model = torch.jit.trace(trt_fx_module, torch_inputs)
torch.jit.save(trt_script_model, "trt_model.ts")

# Reload model from save and perform inference
reloaded_model = torch.jit.load("trt_model.ts").cuda()
trt_reloaded_out = reloaded_model(*torch_inputs)

Not all of the operators in the graph have converters currently, and I believe there are roughly 16 TRT engines generated as a result. If having full graph support for this model is important, please either let me know or file new issues, so each of the missing operators can be implemented. Additionally, please let me know if the compilation is functional on your machine!

@kzelias
Copy link
Author

kzelias commented Feb 21, 2024

@gs-olive hello!
I tried to use your code, but I got an error
torch._export.verifier.SpecViolationError: Node.meta reshape_default is missing val field.
log: conf_trt_log.log
debug log: conf_trt_debug_log.log

image: nvcr.io/nvidia/tensorrt:23.07-py3
python: 3.10.6
reqs:

  • tensorrt @ file:///TensorRT-8.6.1.6/python/tensorrt-8.6.1-cp310-none-linux_x86_64.whl#sha256=2684b4772cb16088184266728a0668f5dac14e66f088c4ccff2096ccb222d74c
  • torch_tensorrt==2.2.0
  • torch==2.2.0
  • torchaudio==2.2.0
  • torchvision==0.17.0
  • nemo_toolkit==1.22.0

@gs-olive
Copy link
Collaborator

Thanks for the follow-up. Based on the logs it seems that compilation succeeded but model serialization did not. As suggested by @peri044 - could you add output_format="torchscript" to the compile_settings dictionary in the sample script and try it?

@kzelias
Copy link
Author

kzelias commented Feb 27, 2024

It works, thanks!
But the Nemo model weighs 56Mb, after conversion the weight of the trt becomes ~800Mb for input tensor [1, 80, 6000]. It's ok?
-rw-r--r-- 1 root root 796M Feb 27 13:44 trt_model.ts

trt_debug.log

(for tensor [1,80,8269] model weighs = 1.4Gb)

@gs-olive
Copy link
Collaborator

Hello - based on the logs, I believe the large model size is due to segmentation, since there seem to be some operators which we don't currently have converters for in this model. Could you specify debug=True in the compilation_settings and capture the model output? We can then see which operators are not being converted.

@kzelias
Copy link
Author

kzelias commented Feb 27, 2024

conf_trt_ts_debug.log Is that enough?

If you export this model to torchscript or onnx, it decomposes into two files: encoder.ts and decoder.ts
Maybe that's why there are some problems in converting to TRT?
Does it make sense to try to convert two ts files separately to TRT?

@gs-olive
Copy link
Collaborator

gs-olive commented Feb 27, 2024

Yes, this is very helpful thank you - it looks like we are missing the torch.ops.aten.glu.default operator here, which is causing some of the segmentation. It is possible that the encoder/decoder separation is contributing, but I also think the converter support is important to reduce the number of TRT engines generated. I have filed a converter request here: #2663, for this operator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants