-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid doing bias-add by setting the bias value as the outs
operand
#836
Comments
Thanks Mahesh, somehow the former IR naively felt more minimal at the time, but thanks for the feedback! This is why we co-design :) |
This looks like a part of |
Seems fine to decompose it. |
There is an open patch for the decomposition: #862 What do you think @silvasean @MaheshRavishankar ? |
a) is the approach I imagined. |
Another related issue: #879 |
I am not sure I follow (a) fully. If the case is 3D LHS, 2D RHS, I would expect it to be lowered as
I dont understand the "collapse batch dimension" and the "expand the batch dimensions" part of (a) above. This has a down-side though. The broadcast from the 2D RHS to 3D RHS will have to be materialized in memory. Thats both a computation and a memory cost. It would be interesting to see if just using a |
torch's matmul allows arbitrary leading batch dimensions and combinations of broadcasting. e.g. It would be nice if linalg could be improved so that these broadcasts aren't materialized in memory. |
@Shukla-Gaurav @erman-gurses anyone working on this? The fix for this should roll in the change that @makslevental made here: #919 |
@silvasean I am working on this, will take care of #919 also. thanks! |
Signed-off-by: Tung D. Le <[email protected]>
Looking at the IR generated from Torch-MLIR within IREE, after some fusion, I see these kind of patterns
If I am reading this correctly, this is a
batch_matmul
followed by bias add computation that is written as a broadcast of the bias into the output shape of thebatch_matmul
followed by thebatch_matmul
. Not sure this is the best way to represent the computation, it definitely trips up fusion at Linalg level. A better representation would beAt a very preliminary level this representation avoids the explicit broadcast of
%cst_182
(and FWIW the Tensorflow MLIR lowering of op + bias-add is done this way). I tend to think of this as a more canonical representation of the computation here.The text was updated successfully, but these errors were encountered: