-
Notifications
You must be signed in to change notification settings - Fork 488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Missed optimization: a[mask] = b
-> a = torch.where(mask, b, a)
#4248
Comments
I think there are a couple issue in the
in the example above, |
Thanks for the reply. I guess pytorch doesn't provide any way for you to intercept calls at the Python Is this the sort of thing that should be fixed by the planned dynamic shapes support in #3884 ? |
I don't think dynamic shape will help out of the box here since there is a materialization of the index tensor to query the
I think we can make that op supports dynamic shape then we should be good. @ezyang in case you know which op/code does that. |
In fact, dynamo would help you here, as it wouldn't immediately force dynamic shape calculation when you do the mask, and then later you can figure out the mask is only used in a setter context and optimize it into a where calculation. However, I don't think we do this optimization yet. |
sg, I guess we can revisit this issue later. |
🐛 Bug
I'm trying to get an existing model running under pytorch/XLA that uses the construct
a[mask] = b
frequently, which seems to be a bottleneck. I'm guessing that this is (as far as I can imagine, unnecessarily) becomes something like "create a slice with a dynamic shape, then write to it" and triggers recompilation. I'm running against my CPU currently.To Reproduce
I tried to demonstrate the issue with a microbenchmark:
Output (times only):
Grabbing some snippets from the two metrics output, the first call reports:
and the seconds reports:
Expected behavior
Ideally I think that the
getitem
andwhere
variants should have roughly equal performance.Environment
torch_xla/version.py
shows:The text was updated successfully, but these errors were encountered: