Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalid CUDA stream handle with multiple sessions on multiple GPUs #18432

Closed
Numeri opened this issue Nov 14, 2023 · 6 comments
Closed

Invalid CUDA stream handle with multiple sessions on multiple GPUs #18432

Numeri opened this issue Nov 14, 2023 · 6 comments
Labels
ep:CUDA issues related to the CUDA execution provider

Comments

@Numeri
Copy link

Numeri commented Nov 14, 2023

Describe the issue

In onnxruntime-gpu>=1.14.1, using multiple sessions loaded onto different GPUs in a very specific order with iobinding causes the CUDA Execution Provider to fail.

Specifically, if you:

  1. Load three models – two on one GPU and one on another GPU
  2. Use a model on one GPU
  3. Use a model on the other GPU
  4. Use the model on the first GPU that you have not yet used, and use it with iobinding

then your inference call will crash.

The exact failure mode seems to vary depending on the ONNX graph loaded. With a proprietary model, where the first executed node is a Gather node, I got the following error:

  File "/home/username/exp/miniconda3/envs/envname/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 331, in run_with_iobinding
    self._sess.run_with_iobinding(iobinding._iobinding, run_options)
RuntimeError: Error in execution: Non-zero status code returned while running Gather node. Name:'/_embedding/token_embedding/Gather'
Status Message: CUDA error cudaErrorInvalidResourceHandle:invalid resource handle

In the MWE provided below, which uses this MNIST model, where the first executed node is a Convolution, I get the following error:

2023-11-14 15:12:52.405993900 [E:onnxruntime:Default, cuda_call.cc:116 CudaCall] CUDNN failure 8: CUDNN_STATUS_EXECUTION_FAILED ; GPU=3 ; hostname=hostname ; file=/onnxruntime_src/onnxruntime/core/providers/cuda/nn/conv.cc ; line=392 ; expr=cudnnConvolutionForward(cudnn_handle, &alpha, s_.x_tensor, s_.x_data, s_.w_desc, s_.w_data, s_.conv_desc, s_.algo, workspace.get(), s_.workspace_bytes, &beta, s_.y_tensor, s_.y_data); 
2023-11-14 15:12:52.406048249 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Conv node. Name:'Convolution28' Status Message: CUDNN failure 8: CUDNN_STATUS_EXECUTION_FAILED ; GPU=3 ; hostname=hostname; file=/onnxruntime_src/onnxruntime/core/providers/cuda/nn/conv.cc ; line=392 ; expr=cudnnConvolutionForward(cudnn_handle, &alpha, s_.x_tensor, s_.x_data, s_.w_desc, s_.w_data, s_.conv_desc, s_.algo, workspace.get(), s_.workspace_bytes, &beta, s_.y_tensor, s_.y_data);

I've confirmed this issue on two different virtual machines, both with 4×T4 GPUs with CUDA 11.8 and with 2×A100 GPUs with CUDA 12.2.

To reproduce

  1. Obtain a machine with multiple GPUs
  2. Download the MNIST model from the onnxruntime-inference-examples repo
  3. Run the following MWE with either python mwe.py mnist.onnx minimal or python mwe.py mnist.onnx random
from dataclasses import dataclass
import random
import sys

import numpy as np
import onnxruntime
import torchvision
import torchvision.transforms as transforms


@dataclass
class Model:
    session: onnxruntime.InferenceSession
    run_options: onnxruntime.RunOptions
    cuda_device: int


def load_model(model_filename: str, cuda_device: int) -> Model:
    sess_options = onnxruntime.SessionOptions()
    sess_options.log_severity_level = 3
    sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL

    providers = [
        (
            'CUDAExecutionProvider',
            {
               'device_id': cuda_device,
            }
        ),
        'CPUExecutionProvider',
    ]

    session = onnxruntime.InferenceSession(
        model_filename,
        providers=providers,
        sess_options=sess_options,
    )

    run_options = None

    return Model(session, run_options, cuda_device)


def get_test_data():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])

    testset = torchvision.datasets.MNIST(
        root='./data',
        train=False,
        download=True,
        transform=transform,
    )

    test_data = [
        (np.expand_dims(image.numpy(), 0), label)
        for image, label
        in testset
    ]

    return test_data


def run_inference(model: Model, use_iobinding: bool, image: np.ndarray):
    if use_iobinding:
        binding = model.session.io_binding()

        image = np.ascontiguousarray(image)
        image_on_gpu = onnxruntime.OrtValue.ortvalue_from_numpy(image, 'cuda', model.cuda_device)

        binding.bind_ortvalue_input('Input3', image_on_gpu)
        binding.bind_output(name='Plus214_Output_0', device_type='cuda', device_id=model.cuda_device)

        binding.synchronize_inputs()
        model.session.run_with_iobinding(binding)
        binding.synchronize_outputs()

        outputs = binding.get_outputs()
        logits = outputs[0].numpy()
    else:
        inputs = {
            'Input3': image,
        }
        outputs = ['Plus214_Output_0']

        outputs = model.session.run(outputs, inputs, run_options=model.run_options)
        logits = outputs[0]

    return np.argmax(logits)


def main():
    num_gpus = 2
    test_data = get_test_data()

    # Either run the MWE or a more extensive run showing edge cases better
    if sys.argv[2] == 'minimal':
        num_models_per_gpu = 2

        model_order = [0, 1, 2]
        iobinding_settings = [True] * 3
    else:
        num_models_per_gpu = 10

        model_order = [random.randint(0, num_models_per_gpu*num_gpus - 1) for _ in range(len(test_data))]
        iobinding_settings = [bool(random.randint(0, 1)) for _ in range(len(test_data))]

    # Load models
    models = []
    for _ in range(num_models_per_gpu):
        for cuda_device in range(num_gpus):
            models.append(load_model(sys.argv[1], cuda_device))

    # Do inference, triggering bug
    for model_index, use_iobinding, (sample, label) in zip(model_order, iobinding_settings, test_data):
        print(f'Model {model_index:>2d} on GPU {models[model_index].cuda_device} with {use_iobinding = }:\t', end='')
        try:
            run_inference(models[model_index], use_iobinding, sample)
            print('succeeded')
        except:
            print('failed')
            break

    return models, test_data


if __name__ == "__main__":
    main()

Running the minimal mode should produce output like:

Model  0 on GPU 0 with use_iobinding = True:    succeeded
Model  1 on GPU 1 with use_iobinding = True:    succeeded
Model  2 on GPU 0 with use_iobinding = True:    failed

Running the random mode should produce output like:

Model  3 on GPU 3 with use_iobinding = True:    succeeded
Model 18 on GPU 2 with use_iobinding = False:   succeeded
Model  0 on GPU 0 with use_iobinding = False:   succeeded
Model 11 on GPU 3 with use_iobinding = False:   succeeded
Model  3 on GPU 3 with use_iobinding = False:   succeeded
Model  8 on GPU 0 with use_iobinding = False:   succeeded
Model 18 on GPU 2 with use_iobinding = True:    succeeded
Model  1 on GPU 1 with use_iobinding = True:    succeeded
Model  9 on GPU 1 with use_iobinding = True:    succeeded
Model 17 on GPU 1 with use_iobinding = False:   succeeded
Model  3 on GPU 3 with use_iobinding = True:    succeeded
Model 10 on GPU 2 with use_iobinding = False:   succeeded
Model 18 on GPU 2 with use_iobinding = False:   succeeded
Model 15 on GPU 3 with use_iobinding = False:   succeeded
Model 18 on GPU 2 with use_iobinding = False:   succeeded
Model  5 on GPU 1 with use_iobinding = True:    failed

This shows that this bug is not triggered if you follow the following conditions:

  • You can use as many models as you want on a GPU, in any order you want, up until you use a model on another GPU.
  • After using a model on a new GPU, you can't use models on previous GPUs unless one of the following is true:
    • You are using a model that you had already used at least once
    • You are not using iobinding

Urgency

This is a regression from version 1.13, where this code works without any issues, and is blocking upgrading our production systems to 1.16.2.

The actual urgency is probably only medium, as we can work around the bug by always performing inference on each model right after loading it. On the other hand, we'd really prefer avoiding this hack, and will probably remain on 1.13 until this is fixed.

Platform

Linux

OS Version

Ubuntu 20.04

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.14.1 and up

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 11.8 and 12.2

@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Nov 14, 2023
@Numeri
Copy link
Author

Numeri commented Nov 15, 2023

It turns out this bug was introduced somewhere between 1.14.0 and 1.14.1. I've updated the description above.

@yufenglee
Copy link
Member

@jslhcl, @souptc , please help take a look. It looks like related to: #14719

@jslhcl
Copy link
Contributor

jslhcl commented Nov 15, 2023

Sure. Investigating

@jslhcl
Copy link
Contributor

jslhcl commented Nov 17, 2023

Found the root cause, the failed inference does not set the proper cuda Device on the corresponding stream. Should have call cudaSetDevice() before the cuda stream creation properly. Will send the fix PR soon. Thanks for the reproduce script.

jslhcl added a commit that referenced this issue Dec 5, 2023
### Description
<!-- Describe your changes. -->
Set cuda device before create cuda stream for IOBinding case


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This is to fix the issue #18432 , which the inference will fail for
IOBinding case when there are multiple cuda devices. The reason is that
the cuda device is not set properly before the cuda stream is created
@jslhcl
Copy link
Contributor

jslhcl commented Dec 5, 2023

Fix is checked in. Please have a try in the latest build

@jslhcl jslhcl closed this as completed Dec 5, 2023
@Numeri
Copy link
Author

Numeri commented Dec 6, 2023

Thank you so much for the prompt fix!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider
Projects
None yet
Development

No branches or pull requests

3 participants