-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SymIntify roi_align #7448
Merged
Merged
SymIntify roi_align #7448
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
b01039a
SymIntify roi_align
ezyang 50e0487
add metas
ezyang fd68bd7
fixup
ezyang 8f428fc
Merge branch 'main' into symint-roi-align
pmeier 041e437
Fix import order
ezyang 62cc94c
Merge branch 'symint-roi-align' of github.com:ezyang/vision into symi…
ezyang 0d0058a
Fix lint
ezyang ca8194c
lintfix
ezyang 5c753d4
Merge branch 'main' into symint-roi-align
NicolasHug File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
import torch.library | ||
|
||
# Ensure that torch.ops.torchvision is visible | ||
import torchvision.extension # noqa: F401 | ||
|
||
from torch._prims_common import check | ||
|
||
_meta_lib = torch.library.Library("torchvision", "IMPL", "Meta") | ||
|
||
vision = torch.ops.torchvision | ||
|
||
|
||
def register_meta(op): | ||
def wrapper(fn): | ||
_meta_lib.impl(op, fn) | ||
return fn | ||
|
||
return wrapper | ||
|
||
|
||
@register_meta(vision.roi_align.default) | ||
def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): | ||
check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") | ||
check( | ||
input.dtype == rois.dtype, | ||
lambda: ( | ||
"Expected tensor for input to have the same type as tensor for rois; " | ||
f"but type {input.dtype} does not equal {rois.dtype}" | ||
), | ||
) | ||
num_rois = rois.size(0) | ||
_, channels, height, width = input.size() | ||
return input.new_empty((num_rois, channels, pooled_height, pooled_width)) | ||
|
||
|
||
@register_meta(vision._roi_align_backward.default) | ||
def meta_roi_align_backward( | ||
grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio, aligned | ||
): | ||
check( | ||
grad.dtype == rois.dtype, | ||
lambda: ( | ||
"Expected tensor for grad to have the same type as tensor for rois; " | ||
f"but type {grad.dtype} does not equal {rois.dtype}" | ||
), | ||
) | ||
return grad.new_empty((batch_size, channels, height, width)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're now only registering the
SymInt
signature, do we still need to keep the pre-existingroi_align
definitions/declarations that are usingint64_t
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In core, we had to keep the old signatures because the C++ API is public API, and the SymInt signature is not exactly interchangeable with the int signature (as it can affect what implicit conversions are specified). If your old signatures are not public API, we can remove them too, but I'm guessing they are public-ish? In any case, this is the most conservative change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm open to other strategies, as we will have to do this for every function we SymInt'ify which is going to be a bit of a pain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks.
Yeah... They're in the "we want them to be private but we don't know who's using them in the wild" category.
Let's keep them in for now and perhaps reconsider if this becomes too much of a mess.