Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid doing bias-add by setting the bias value as the outs operand #836

Open
MaheshRavishankar opened this issue May 6, 2022 · 11 comments
Assignees

Comments

@MaheshRavishankar
Copy link
Contributor

Looking at the IR generated from Torch-MLIR within IREE, after some fusion, I see these kind of patterns

 %114 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_18 : tensor<384xf32>) outs(%49 : tensor<1x128x384xf32>) {
  ^bb0(%arg1: f32, %arg2: f32):
    linalg.yield %arg1 : f32
  } -> tensor<1x128x384xf32>
  %115 = linalg.batch_matmul ins(%113, %cst_182 : tensor<1x128x384xf32>, tensor<1x384x384xf32>) outs(%114 : tensor<1x128x384xf32>) -> tensor<1x128x384xf32>

If I am reading this correctly, this is a batch_matmul followed by bias add computation that is written as a broadcast of the bias into the output shape of the batch_matmul followed by the batch_matmul. Not sure this is the best way to represent the computation, it definitely trips up fusion at Linalg level. A better representation would be

%fill =  linalg.fill ins(%cst_zero : f32) outs (%114 : tensor<1x128x384xf32>) -> tensor<1x128x384xf32>
%115 = linalg.batch_matmul ins(%113, %cst_182 : tensor<1x128x384xf32>, tensor<1x384x384xf32>) outs(%fill : tensor<1x128x384xf32>) -> tensor<1x128x384xf32>
%cst = linalg.generic {
    indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>],
    iterator_types = ["parallel", "parallel", "parallel"] {
  ^bb0(%b0 : f32, %b1 : f32, %b2 : f32):
    %0 = arith.addf %b0, %b1 : f32
    linalg.yield : f32
  } -> tensor<1x128x384xf32>

At a very preliminary level this representation avoids the explicit broadcast of %cst_182 (and FWIW the Tensorflow MLIR lowering of op + bias-add is done this way). I tend to think of this as a more canonical representation of the computation here.

@silvasean
Copy link
Contributor

Thanks Mahesh, somehow the former IR naively felt more minimal at the time, but thanks for the feedback! This is why we co-design :)

@Shukla-Gaurav
Copy link
Collaborator

Shukla-Gaurav commented May 13, 2022

This looks like a part of AtenLinearOp lowering to linalg. I can modify the same lowering(linalg conversion pass) to separate out the bias addition, but moving the lowering to decomposition pass(AtenMatMul + AtenAdd) seems a better and clean approach. What do you think @silvasean @MaheshRavishankar ?

@silvasean
Copy link
Contributor

Seems fine to decompose it.

@Shukla-Gaurav
Copy link
Collaborator

There is an open patch for the decomposition: #862
CI fails for this PR because the aten.matmul op does not handle higher dimensional cases. A specific test case of (3D,2D) input fails here. I am trying to handle the cases for aten.matmul where at least one matrix is 3D.
Possible approaches in my mind:
a) lower to linalg.batchmatmul: 1. broadcast(batch dimensions) the less rank matrix, 2. collapse the batch dimensions 3. matrix multiply by linalg.batchmatmul 4. expand the batch dimensions
Although this approach seems to be efficient, the 4th step will create issues for dynamic dimensions AFAIK
b) lower to linalg.generic : This seems to be an inefficient approach.

What do you think @silvasean @MaheshRavishankar ?

@powderluv
Copy link
Collaborator

@ThomasRaoux

@silvasean
Copy link
Contributor

a) is the approach I imagined.

@erman-gurses erman-gurses self-assigned this May 27, 2022
@Shukla-Gaurav
Copy link
Collaborator

Another related issue: #879

@MaheshRavishankar
Copy link
Contributor Author

I am not sure I follow (a) fully. If the case is 3D LHS, 2D RHS, I would expect it to be lowered as

  1. Broadcast the 2D RHS to 3D RHS
  2. Use batch-matmul

I dont understand the "collapse batch dimension" and the "expand the batch dimensions" part of (a) above.

This has a down-side though. The broadcast from the 2D RHS to 3D RHS will have to be materialized in memory. Thats both a computation and a memory cost. It would be interesting to see if just using a linalg.generic works here. In the final state of things a linalg.matmul/linalg.batch_matmul and linalg.generic that express the same computation should end up being handled the same way, but we might not be there yet.

@silvasean
Copy link
Contributor

silvasean commented Jun 6, 2022

I dont understand the "collapse batch dimension" and the "expand the batch dimensions" part of (a) above.

torch's matmul allows arbitrary leading batch dimensions and combinations of broadcasting. e.g. [42,2,1,4,5,6] x [1,3,4,6,7]. So in general we need to resolve all of that down to a single leading batch dimension for linalg.batch_matmul.

It would be nice if linalg could be improved so that these broadcasts aren't materialized in memory.

@silvasean
Copy link
Contributor

@Shukla-Gaurav @erman-gurses anyone working on this?

The fix for this should roll in the change that @makslevental made here: #919

@Shukla-Gaurav
Copy link
Collaborator

@silvasean I am working on this, will take care of #919 also. thanks!

qedawkins pushed a commit to nod-ai/torch-mlir that referenced this issue Oct 3, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants