Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use blockwise_broadcast_reduce in reduction fusions. #1668

Merged
merged 54 commits into from
Oct 15, 2024

Conversation

manupak
Copy link
Contributor

@manupak manupak commented Oct 3, 2024

[TODO] : add more e2e tests and unit tests for passes.

Currently, we handle reductions that are being fused with gemm-like operations
by using atomic stores to the destination buffer. This can be cripplingly slow when
most of the output is being reduced as evidenced in layer_norm cases.

This PR adds the ability to blockwise_broadcast_reduce on the block sub-tiles of
of gemm output.
However, in-order to that we need to make sure the reduction dimension is uniformly
distributed across the blocks. This is achieved by :

  1. Firstly, this PR introduces a utility where for a given set of upper dimensions, it can
    traverse a transform stack and produce a list of sub-dimensions per each lower dimension
    where the upper reduction axes are mapped to.
  2. Then, this PR introduces ShuffleGemmForReductions pass, which will split and transpose
    the parallel dimension of the gemm such that reduction dimension is uniformly
    distributed across the blocks.
  3. Then at AlignTiling pass, we extract the block subtile when fusing in the rock.reduce operator.
    Then perform a blockwise_broadcast_reduce on the block subtile.
  4. Since we only want to write the partial reductions per block, we pad out broadcasted part of the subtile.
    (We rely on any block coordinate that goes to the padded region within the block will not be written out)
  5. Then we need to do Recombine the modified sub-tile coordinate maps with grid-only coordinates maps.
    a) Here, we drop all the upper-dimensions except g_block, m_block and n_block and obtain the grid-only transform map
    stack.
    b) In parallel, we re-use getLowerSubDimensions utility to figure out which sub-dimension gets mapped with the above
    grid-only dimensions.
    c) Then we extract of those sub-dimensions in a bottom up fashion and stitch it up with the said grid-only transform map
    stack.

I ll try to create some slides to explain all these.

In the cases I ve tests, this yields two orders of magnitude (~100x) gains over the pure atomics approach of doing reduction fusions.

manupak added 30 commits August 20, 2024 12:02
TODO : from creating a new view to represent
partial reduction in a global view.
a blockwise_reduction to reduce the number
of atomic_add s issued.
There is a fractional dim being create due to reduction
being applied on a partial dimensions
* add asserts to where fractional dims are created when removing
  upperDims
TODO : support cases where subdim is removed in pad
* Currently it has has code obtain all
  views leading to gemm.
  back to sub dimensions from upper to lower.
Next: output transposes
but I still see blocking non dividability.
me need re-think transposes
* abort if reduce to gemm is not invertible
* return failure for removeUPperDims divisbility cases
  so that blockwise reductions are not used.
* make th shuffle pass run on largest reduction
renamed pass
TODO : fix recombine logic
* added another iface for getLowerSubDIms
mlir/lib/Dialect/Rock/Transforms/AlignTiling.cpp Outdated Show resolved Hide resolved
LLVM_DEBUG(llvm::dbgs()
<< "readOperand = " << readOperand->get() << "\n");
// Test against the write operand to guard against [MemRead, MemWrite]
if (readOperand && readOperand != writerOperand &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't BufferDependencyAnalysis do this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or have this cached?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is heavily inspired by your code here :

// Test against the write operand to guard against [MemRead, MemWrite]
if (maybeRecursiveReadOperand &&
maybeRecursiveReadOperand != writeOperand &&
isa<MemoryEffects::Read>(effect.getEffect())) {
collectInputFusionWriteOperands(maybeRecursiveReadOperand, bufferDeps,
state);
}
}
}
.

I think its safer to have here as well.

mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Rock/utility/transformMapUtils.cpp Outdated Show resolved Hide resolved
in aligntiling -- and allow it fail rather than having
asserts if something is as not expected.
@manupak manupak changed the title [PROTOTYPE] Use blockwise_broadcast_reduce in reduction fusions. Use blockwise_broadcast_reduce in reduction fusions. Oct 10, 2024
@manupak manupak marked this pull request as ready for review October 10, 2024 14:15
@manupak manupak requested a review from dhernandez0 October 10, 2024 14:15
@manupak
Copy link
Contributor Author

manupak commented Oct 10, 2024

Thanks @krzysz00 for all the reviews on this big PR. I really appreciate it.
I have addressed all of them

I have not added the unit tests for passes yet and aling-tiling yet... which I will do next.
@dhernandez0 would you be able to take a look here ?

%11 = migraphx.mul %4, %4 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1>
%12 = migraphx.mul %11, %10 : <2x32x10x64x64xf32, 1310720x40960x4096x64x1>, <2x32x10x64x64xf32, 0x0x0x0x0> -> <2x32x10x64x64xf32, 1310720x40960x4096x64x1>
%13 = migraphx.reshape %12 {dims = [2, 32, 40960]} : <2x32x10x64x64xf32, 1310720x40960x4096x64x1> -> <2x32x40960xf32, 1310720x40960x1>
%14 = migraphx.reduce_sum %13 {axes = [2]} : <2x32x40960xf32, 1310720x40960x1> -> <2x32x1xf32, 32x1x1>
Copy link
Contributor

@dhernandez0 dhernandez0 Oct 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add tests where the axis of reduction is not 2. Also, tests where m/n is not reduced at all. This is to test the fixes introduced today.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. will add them.
I also need to fix/add unit test for align-tiling/rock-shuffle-gemm-for-reductions

@manupak
Copy link
Contributor Author

manupak commented Oct 14, 2024

@dhernandez0 I ve added the promised tests now!

@dhernandez0
Copy link
Contributor

dhernandez0 commented Oct 14, 2024

@dhernandez0 I ve added the promised tests now!

Thanks! I was wondering if we should add a test where the reduction occurs along the G axis only for completeness. I understand it would only involve atomics, but it might be good to include it just in case. I've already approved the PR, so whatever you decide.

@manupak manupak merged commit 1dea35b into ROCm:develop Oct 15, 2024
20 checks passed
dhernandez0 pushed a commit that referenced this pull request Oct 29, 2024
Currently, we handle reductions that are being fused with gemm-like operations
by using atomic stores to the destination buffer. This can be cripplingly slow when
most of the output is being reduced as evidenced in layer_norm cases.

This PR adds the ability to blockwise_broadcast_reduce on the block sub-tiles of
of gemm output.
However, in-order to that we need to make sure the reduction dimension is uniformly
distributed across the blocks. This is achieved by :

Firstly, this PR introduces a utility where for a given set of upper dimensions, it can
traverse a transform stack and produce a list of sub-dimensions per each lower dimension
where the upper reduction axes are mapped to.
Then, this PR introduces ShuffleGemmForReductions pass, which will split and transpose
the parallel dimension of the gemm such that reduction dimension is uniformly
distributed across the blocks.
Then at AlignTiling pass, we extract the block subtile when fusing in the rock.reduce operator.
Then perform a blockwise_broadcast_reduce on the block subtile.
Since we only want to write the partial reductions per block, we pad out broadcasted part of the subtile.
(We rely on any block coordinate that goes to the padded region within the block will not be written out)
Then we need to do Recombine the modified sub-tile coordinate maps with grid-only coordinates maps.
a) Here, we drop all the upper-dimensions except g_block, m_block and n_block and obtain the grid-only transform map
stack.
b) In parallel, we re-use getLowerSubDimensions utility to figure out which sub-dimension gets mapped with the above
grid-only dimensions.
c) Then we extract of those sub-dimensions in a bottom up fashion and stitch it up with the said grid-only transform map
stack.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants