-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature #34883
Option to set 'non_blocking' for to(device) in BatchEncoding and BatchFeature #34883
Conversation
… improvements. Defaults to 'false', thus no behavioral changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @daniel-bogdoll, thanks for adding this! It looks great to me. Do you think it might be worth extending the same option to BatchFeature to ensure consistent capabilities?
Thanks @qubvel, sure thing! Which tests would I need to run to make sure modifications in the to() function of BatchFeature get tested? Just to make sure, I assume you refer to
|
Yes, I refer to this one, but not sure it's properly tested anywhere, I was able to find only |
Maybe we can do it as simple as non_blocking = kwargs.get("non_blocking", False)
...
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device, non_blocking=non_blocking)
... |
That's how I would have tried it as well. But what about this block?
Here device is derived from |
I don't think so, maybe at some moment, it is worth refactoring this method for more explicit args and kwargs. For now, we can add a note in docstring that |
@qubvel Done! Thanks for the super-fast replies, was a pleasure! Tests fail now, though: For the first one, as you stated here (#34826 (comment)), it does not seem to be related.
As the second one is a timeout issue, it also seems unrelated:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updates! Looks great, just a small suggestion
Co-authored-by: Pavel Iakubovskii <[email protected]>
@ArthurZucker or @LysandreJik please review when you have bandwidth |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah sound super good!
@@ -799,12 +799,13 @@ def as_tensor(value, dtype=None): | |||
|
|||
return self | |||
|
|||
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding": | |||
def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need *
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@qubvel suggested this to enforce it as a keyword argument for future backwards compatability. All arguments after the * are forced to be passed as keyword arguments: #34883 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, only device
can be passed as a positional argument with *
introduced. This way, we will prevent anyone from using batch_feature.to("cuda", True)
instead of batch_feature.to("cuda", non_blocking=True)
. This would be useful in case we introduce more positional arguments in the future or need to change order, for example, with adding dtype
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining, good decision @qubvel ! 🤗
Option to set 'non_blocking' for to(device) operation in BatchEncoding for performance improvements. Defaults to 'false', thus no behavioral changes.
What does this PR do?
This minor PR adds the non_blocking option to the to() function.
Previous: def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
New: def to(self, device: Union[str, "torch.device"], non_blocking: bool = False) -> "BatchEncoding":
Since non_blocking defaults to 'False', this PR does not introduce behavioral changes.
I realized, when utilizing Zero Shot Object Detection models, that it was not possible to set this option, leading to sub-optimal performance during inference.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?