-
Notifications
You must be signed in to change notification settings - Fork 519
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TorchToLinalg] Support lowering MaxPool3dWithIndices #3652
[TorchToLinalg] Support lowering MaxPool3dWithIndices #3652
Conversation
@vivekkhandelwal1 Can you take a look if have time? Thanks! |
@lingzhiz1998 Please fix the CI failure. |
You should be able to specify version specific xfails for fx_importer config. There are a few examples of this in https://github.com/llvm/torch-mlir/blob/main/projects/pt1/e2e_testing/xfail_sets.py#L2865-L2893 |
7f6ab44
to
dfd15a5
Compare
The fx importer test failure is not related to pytorch version. The root cause is that fx importer translate maxpool3d to maxpool3dwithindices and return the first result. These tests which expected fail will pass after this PR. |
done. |
Ah, I see. I saw that the xpasses were only happening on nightly so I just assumed. This is great, thanks for the pr! I'll give a review here shortly. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a refactoring and generalization of the current implementation, this is a nice improvement. However, I'm not really a fan of the original indices computation (anytime a tensor.extract
op can be avoided in a linalg.generic, it probably should be), and I can't help but wonder if it would be more efficient to have a separate "max pooling with indices" computation that computes both the max and its corresponding index at the same time. I'm totally fine with this getting merged now, but considering that the newer fx importer route will compute the indices regardless of whether they are used, it might be important for performance to try and streamline this computation in the near future. I'm not sure about this, but if the indices tensor isn't used by anything, will the corresponding generic op eventually get removed from the IR? If not (or maybe in general), it would be good to add a canonicalizer for the "max pool with indices" ops which rewrites them as ordinary max pooling ops if the consuming op list for the indices result is empty.
dfd15a5
to
534ef82
Compare
@zjgarvey I agree with what you mentioned above and am willing to solve these problems. But I have other plans this week, may start next week. |
@lingzhiz1998 Would you like me to merge this PR in the meantime? |
Yes. |
As discussed in #3652, we should replace maxpool3dwithindices with maxpool3d if indices have no user.
Support torch.MaxPool3dWithIndices lowering to linalg backend.