Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusedConv Cuda EP invalid argument error. #12321

Open
kiennguyen94 opened this issue Jul 26, 2022 · 5 comments
Open

FusedConv Cuda EP invalid argument error. #12321

kiennguyen94 opened this issue Jul 26, 2022 · 5 comments
Assignees

Comments

@kiennguyen94
Copy link
Contributor

kiennguyen94 commented Jul 26, 2022

Describe the bug
When running models with conv layers with optimization, ORT throws the following error

2022-07-26 09:49:04.435452395 [E:onnxruntime:Default, cuda_call.cc:118 CudaCall] CUDNN failure 3: CUDNN_STATUS_BAD_PARAM ; GPU=1 ; hostname=tiger09.som.ma ; expr=cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, &alpha, Base::s_.y_tensor, Base::s_.y_data);
2022-07-26 09:49:04.435491739 [E:onnxruntime:, sequential_executor.cc:368 Execute] Non-zero status code returned while running FusedConv node. Name:'Conv_125' Status Message: CUDNN error executing cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, &alpha, Base::s_.y_tensor, Base::s_.y_data)
Traceback (most recent call last):
  File "final_repo.py", line 57, in <module>
    output = model.run(None, {"audio_signal": au_sig, "length": length}, run_options=run_opt)
  File "/n/w1-knguyen/conda3/install/envs/py3_cuda/lib/python3.7/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 200, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Non-zero status code returned while running FusedConv node. Name:'Conv_125' Status Message: CUDNN error executing cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data, &alpha, Base::s_.y_tensor, Base::s_.y_data)

Urgency
None

System information

  • OS Platform and Distribution: Centos 7
  • ONNX Runtime installed from: source
  • ONNX Runtime version: ddb45e9, also observed with 1.12
  • Python version: 3.7.13
  • Visual Studio version (if applicable): None
  • GCC/Compiler version (if compiling from source): 7.5.0
  • CUDA/cuDNN version: 11.3/8.2.1
  • GPU model and memory: 1080 TI - 12GB

To Reproduce

Expected behavior

  • Expect no error.

Screenshots
None

Additional context

  • The code fails at
    CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), &alpha, Base::s_.z_tensor, Base::s_.z_data,
    with the error code CUDNN_STATUS_BAD_PARAM, which means either dimension of Z and Y don't match, or incompatible datatype. (per Cudnn docs)
  • It turns out that when Z is the output of some previous OP, it can have missing dimension, (eg Z is [1, 1024, 16] whereas Y is [1, 1024, 16, 1])
  • So the question is: do we want to support automatic dimension matching in the cuda/FusedConv? If so, I think simply adding a dimension check if len(Z.shape) == len(Y.shape) - 1: extend Z.shape by 1, then set ORT_RETURN_IF_ERROR(s_.z_tensor.Set(new_z_dim, CudnnTensor::GetDataType<CudaT>())); right around here
    ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType<CudaT>()));
  • Or we can treat this as an export error, so the exporter must apply a Reshape on Z before FusedConv. But given that cpu/FusedConv works fine, I don't think we want this.
  • I'm happy to contribute, just need some guidance.
@RandySheriffH
Copy link
Contributor

@kiennguyen94: thanks for reporting this - we are on a similar issue #11548.
And yes the mismatch between Z and Y is the culprit.

@RandySheriffH RandySheriffH self-assigned this Jul 26, 2022
@RandySheriffH
Copy link
Contributor

@kiennguyen94 : BTW, for your case mind share us a model and some sample script to double confirm the fix working?

@kiennguyen94
Copy link
Contributor Author

kiennguyen94 commented Jul 28, 2022

@RandySheriffH Sorry for the late reply, here's the repro https://github.com/kiennguyen94/ort_load_repro.
Just need to extract the model tar tar xvzf ./citrinet.tgz then run python ./final_repo.py.

I have the fix in this temp draft PR #12366. This PR fixes this particular repro, but I'm not sure if it introduces unwanted behavior.

@RandySheriffH
Copy link
Contributor

@kiennguyen94, thanks! For this issue your fix #12366 should work, plan to bring it in along with other fixes into https://github.com/microsoft/onnxruntime/tree/FuseConvShapeMismatch.

@xfbs
Copy link

xfbs commented Feb 27, 2024

For anyone else who runs into this issue: we were able to circumvent this by by turning off optimizations.

sess_opt = ort.SessionOptions()
sess_opt.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
return ort.InferenceSession(onnx_path, providers=providers, sess_options=sess_opt)

The theory being that some of the optimizations lead to a graph which cannot be executed using CUDA, while the ONNX file itself is fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants