Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Cannot perform inference if the ExportedProgram has weighted layers and custom ops. #2576

Closed
Tracked by #2262
peri044 opened this issue Jan 5, 2024 · 3 comments
Assignees
Labels
bug Something isn't working

Comments

@peri044
Copy link
Collaborator

peri044 commented Jan 5, 2024

Bug Description

  1. The graph has a conv node in pytorch and a TensorRT node. The conv node has weight and bias lifted as placeholders. Hence we are seeing this runtime error of mismatch in the number of inputs.

Error message:

_check_input_constraints_for_graph(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/utils.py", line 48, in _check_input_constraints_for_graph
    check(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/utils.py", line 40, in check
    raise RuntimeError(msg)
RuntimeError: Unexpected number of inputs (expected 3, got 1)
  1. If we unlift these parameters (i.e conv_weight and conv_bias are registered as get_attr nodes), there's a different error GraphModule does not contain attribute conv_weight
    Reason:
    This is because - syntax error occurs in _create_graph_module_for_export and hence the resulting gm does not have these attributes.

To Reproduce

Install the nightly version of Torch-TRT

pip install --pre torch-tensorrt  --extra-index-url https://download.pytorch.org/whl/nightly/cu121

Run the following script to reproduce the error

import torch
import torch_tensorrt
import unittest

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        conv = self.conv(x)
        relu = self.relu(conv)
        mul = relu * 0.5
        return mul

input = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
model = MyModule().eval().cuda()

compile_spec = {
        "inputs": [
            torch_tensorrt.Input(
                input.shape, dtype=torch.float, format=torch.contiguous_format
            )
        ],
        "ir": "dynamo",
        "min_block_size": 1,
        "torch_executed_ops": {"torch.ops.aten.convolution.default"},
    }

exp_program = torch_tensorrt.dynamo.trace(model, **compile_spec)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec)
trt_exp_program = torch_tensorrt.dynamo.export(trt_gm, [input], ir="exported_program")

torch.export.save(trt_exp_program, "/tmp/trt.ep")
deser_trt_exp_program = torch.export.load("/tmp/trt.ep")
outputs_pyt = model(input)
outputs_trt = trt_exp_program(input)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@peri044
Copy link
Collaborator Author

peri044 commented Jan 5, 2024

Related pytorch issue : pytorch/pytorch#116831

@peri044
Copy link
Collaborator Author

peri044 commented Jan 16, 2024

This is fixed by #2575

@peri044 peri044 closed this as completed Apr 16, 2024
@Hong753
Copy link

Hong753 commented Nov 29, 2024

Hello, I am having this same issue on torch_tensorrt 2.5.0+cu118, while compiling with a custom CUDA extension.

RuntimeError: Unexpected number of inputs (expected 9, got 7)

How am I supposed to fix this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants