Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Error when weigh streaming and cuda graphs is used #3308

Open
keehyuna opened this issue Dec 2, 2024 · 0 comments
Open

🐛 [Bug] Error when weigh streaming and cuda graphs is used #3308

keehyuna opened this issue Dec 2, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@keehyuna
Copy link
Collaborator

keehyuna commented Dec 2, 2024

Bug Description

"RuntimeError: CUDA error: invalid argument" if cuda graphs is enabled and weight streaming budget has changed.
It seems cuda graphs need to record when weight streaming budget is changed

To Reproduce

model = SampleModel().eval().cuda()
input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()]
fx_graph = torch.fx.symbolic_trace(model)

optimized_model = torchtrt.compile(
fx_graph,
inputs=input,
ir="dynamo",
min_block_size=1,
cache_built_engines=False,
reuse_cached_engines=False,
use_python_runtime=True,
use_explicit_typing=True,
enable_weight_streaming=True,
)

torchtrt.runtime.set_cudagraphs_mode(True)

Weight streaming context keeps current device budget size

with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx:
new_budget = int(weight_streaming_ctx.total_device_budget * 0.2)
weight_streaming_ctx.device_budget = new_budget
optimized_model(*input)

new_budget = int(weight_streaming_ctx.total_device_budget * 0.4)
weight_streaming_ctx.device_budget = new_budget
optimized_model(*input)

Expected behavior

no cuda runtime error

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

@keehyuna keehyuna added the bug Something isn't working label Dec 2, 2024
@keehyuna keehyuna self-assigned this Dec 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant