-
Notifications
You must be signed in to change notification settings - Fork 41
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use
blockwise_broadcast_reduce
in reduction fusions. (#1668)
Currently, we handle reductions that are being fused with gemm-like operations by using atomic stores to the destination buffer. This can be cripplingly slow when most of the output is being reduced as evidenced in layer_norm cases. This PR adds the ability to blockwise_broadcast_reduce on the block sub-tiles of of gemm output. However, in-order to that we need to make sure the reduction dimension is uniformly distributed across the blocks. This is achieved by : Firstly, this PR introduces a utility where for a given set of upper dimensions, it can traverse a transform stack and produce a list of sub-dimensions per each lower dimension where the upper reduction axes are mapped to. Then, this PR introduces ShuffleGemmForReductions pass, which will split and transpose the parallel dimension of the gemm such that reduction dimension is uniformly distributed across the blocks. Then at AlignTiling pass, we extract the block subtile when fusing in the rock.reduce operator. Then perform a blockwise_broadcast_reduce on the block subtile. Since we only want to write the partial reductions per block, we pad out broadcasted part of the subtile. (We rely on any block coordinate that goes to the padded region within the block will not be written out) Then we need to do Recombine the modified sub-tile coordinate maps with grid-only coordinates maps. a) Here, we drop all the upper-dimensions except g_block, m_block and n_block and obtain the grid-only transform map stack. b) In parallel, we re-use getLowerSubDimensions utility to figure out which sub-dimension gets mapped with the above grid-only dimensions. c) Then we extract of those sub-dimensions in a bottom up fashion and stitch it up with the said grid-only transform map stack.
- Loading branch information
1 parent
c2c5c91
commit 608638c
Showing
24 changed files
with
1,757 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.