-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Invalid CUDA stream handle with multiple sessions on multiple GPUs #18432
Comments
It turns out this bug was introduced somewhere between 1.14.0 and 1.14.1. I've updated the description above. |
Sure. Investigating |
Found the root cause, the failed inference does not set the proper cuda Device on the corresponding stream. Should have call cudaSetDevice() before the cuda stream creation properly. Will send the fix PR soon. Thanks for the reproduce script. |
### Description <!-- Describe your changes. --> Set cuda device before create cuda stream for IOBinding case ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> This is to fix the issue #18432 , which the inference will fail for IOBinding case when there are multiple cuda devices. The reason is that the cuda device is not set properly before the cuda stream is created
Fix is checked in. Please have a try in the latest build |
Thank you so much for the prompt fix! |
Describe the issue
In
onnxruntime-gpu>=1.14.1
, using multiple sessions loaded onto different GPUs in a very specific order with iobinding causes the CUDA Execution Provider to fail.Specifically, if you:
then your inference call will crash.
The exact failure mode seems to vary depending on the ONNX graph loaded. With a proprietary model, where the first executed node is a
Gather
node, I got the following error:In the MWE provided below, which uses this MNIST model, where the first executed node is a
Convolution
, I get the following error:I've confirmed this issue on two different virtual machines, both with 4×T4 GPUs with CUDA 11.8 and with 2×A100 GPUs with CUDA 12.2.
To reproduce
python mwe.py mnist.onnx minimal
orpython mwe.py mnist.onnx random
Running the
minimal
mode should produce output like:Running the
random
mode should produce output like:This shows that this bug is not triggered if you follow the following conditions:
Urgency
This is a regression from version
1.13
, where this code works without any issues, and is blocking upgrading our production systems to1.16.2
.The actual urgency is probably only medium, as we can work around the bug by always performing inference on each model right after loading it. On the other hand, we'd really prefer avoiding this hack, and will probably remain on
1.13
until this is fixed.Platform
Linux
OS Version
Ubuntu 20.04
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.14.1 and up
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 11.8 and 12.2
The text was updated successfully, but these errors were encountered: