Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include cublas error details when getting cublas handle fails #3695

Merged
merged 7 commits into from
Jun 13, 2023

Conversation

jli
Copy link
Contributor

@jli jli commented Jun 6, 2023

I've been getting hard-to-debug errors in some DeepSpeed runs. During initialization, one of the worker processes raises RuntimeError: Fail to create cublas handle. with no further details, which feels pretty mysterious.

This change includes details of the failure status by using https://docs.nvidia.com/cuda/cublas/#cublasgetstatusname and https://docs.nvidia.com/cuda/cublas/#cublasgetstatusstring


original error message (using deepspeed 0.9.2): RuntimeError: Fail to create cublas handle.
image

new error message with this change: RuntimeError: Failed to create cublas handle: CUBLAS_STATUS_NOT_INITIALIZED the library was not initialized
image

This is still not a great error message, but it has better search results (most results suggest that it's due to running out of GPU memory; bizarrely some people also report removing ~/.nv fixes it...).

@jli
Copy link
Contributor Author

jli commented Jun 6, 2023

@microsoft-github-policy-service agree

@loadams loadams changed the title include cublas error details when getting cublas handle fails Include cublas error details when getting cublas handle fails Jun 7, 2023
@jli
Copy link
Contributor Author

jli commented Jun 7, 2023

CI checks are failing because the CI environment is using a version of CUDA/cuBLAS that predates these functions. cublasGetStatusName and cublasGetStatusString were added in CUDA 11.4.2 (released in late 2021 I believe).

@loadams It seems like DeepSpeed doesn't specify any minimum version of CUDA, so I'm guessing you'd rather not include this change. If not, maybe we could instead just include the raw enum number, which the user could look up manually?

details

Example error:

/tmp/actions-runner/_work/DeepSpeed/DeepSpeed/unit-test-venv/lib/python3.8/site-packages/deepspeed/ops/csrc/includes/context.h(56): error: identifier "cublasGetStatusName" is undefined

According to the ds_report output, the CI checks use these versions:

  • nv-accelerate-v100: torch cuda version 11.7, nvcc 11.1
  • nv-torch19-p40: torch cuda version 11.1, nvcc 11.1
  • nv-torch19-v100: torch cuda version 11.1, nvcc 11.1

I'm not sure why nv-accelerate-v100 still failed with cuda 11.7, but maybe it's because it's using nvcc 11.1.

(the nv-inference failure seems unrelated? TypeError: can't assign a NoneType to a torch.cuda.HalfTensor)

@loadams
Copy link
Collaborator

loadams commented Jun 7, 2023

Thanks for the info @jli.

cc: @mrwyattii and @jeffra for FYI. At least until we got a min CUDA version, I think it would make sense to at least print the raw enum number to add more debug info for users.

@loadams loadams enabled auto-merge (squash) June 13, 2023 18:32
@loadams loadams merged commit 46bb08c into deepspeedai:master Jun 13, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants