-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: replace torch.device("cuda") with torch.device("cuda:0") in devices initialization #3184
Conversation
wdyt @sjrl ? This is what I had in mind ☝️ |
haystack/modeling/utils.py
Outdated
@@ -96,6 +96,12 @@ def initialize_device_settings( | |||
n_gpu = 1 | |||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs | |||
torch.distributed.init_process_group(backend="nccl") | |||
|
|||
# HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be indexed cuda device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small typo add "an" to "to be an indexed ..."
@@ -96,6 +96,12 @@ def initialize_device_settings( | |||
n_gpu = 1 | |||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs | |||
torch.distributed.init_process_group(backend="nccl") | |||
|
|||
# HF transformers v4.21.2 pipeline object doesn't accept torch.device("cuda"), it has to be indexed cuda device | |||
# TODO eventually remove once the limitation is fixed in HF transformers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make this TODO an issue, so we can keep track in GitHub?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Just two small comments.
Corrected comment and opened #3185 |
Related Issues
Proposed Changes:
As version 4.21.2 in their pipelines, HF transformers only accept indexed Cuda devices. Therefore, we needed to replace instances of
torch.device("cuda")
withtorch.device("cuda:0")
in devices initialization util function.How did you test it?
A fix is trivial; I tried a small unit test in an interpreter but didn't add any additional unit tests
Notes for the reviewer
Think of a potential scenario where the proposed fix breaks.