-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Pipeliner] Multi-buffer TMA descriptors #5290
base: main
Are you sure you want to change the base?
[Pipeliner] Multi-buffer TMA descriptors #5290
Conversation
fddb007
to
076a5af
Compare
52bf1d8
to
bb3cd77
Compare
076a5af
to
23ae3af
Compare
bb3cd77
to
8cbbc22
Compare
23ae3af
to
0d7337c
Compare
8cbbc22
to
7bc334d
Compare
<git-pr-chain> [NFC] Remove unused forOp argument from `setStageCluster` #### [PR chain](https://github.com/jlebar/git-pr-chain) 1. 👉 #5288 👈 **YOU ARE HERE** 1. #5289 1. #5290 </git-pr-chain>
…rs (#5289) <git-pr-chain> [TESTING] Add golden sample test for pipelining matmul with descriptors #### [PR chain](https://github.com/jlebar/git-pr-chain) 1. #5288 1. 👉 #5289 👈 **YOU ARE HERE** 1. #5290⚠️ ⚠️ Please **do not click the green "merge" button** unless you know what you're doing. This PR is part of a chain of PRs, and clicking the merge button will not merge it into master.⚠️ ⚠️ </git-pr-chain>
git-pr-chain: pipeliner_multi_buffer_tma_descriptors_59c9
7bc334d
to
cb50eb3
Compare
cb50eb3
to
8243c50
Compare
constexpr inline int TMA_ALIGN = 128; | ||
|
||
template <typename BuilderT> | ||
mlir::LogicalResult createTMADesc(mlir::Value tmaPtr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is simply factored out from TMALowering.cpp
with minimal changes
@@ -31,6 +31,7 @@ void init_triton_passes_common(py::module &&m) { | |||
ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); | |||
ADD_PASS_WRAPPER_0("add_cse", createCSEPass); | |||
ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); | |||
ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is useful for debugging to print the IR at a specific point in the pipeline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are going to move the pipeliner to be a "uber pass" so this won't help those cases anymore unfortunately
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's also useful to generate the IR for golden sample tests. Just write a python script and add the prints before the passes we care about.
@@ -426,6 +427,9 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # | |||
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |||
|
|||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |||
# Create an opaque value to prevent the descriptor creation from being | |||
# hoisted out of the loop | |||
zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One cool thing about tensor descriptors being IR values, they can be hoisted by LICM now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, this is quite nice to have value based tensor descriptor.
Why do we want to block LICM in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good to me. Added couple questions.
Would be good if @pawelszczerbuk could take a look at the pipelining as well.
@@ -31,6 +31,7 @@ void init_triton_passes_common(py::module &&m) { | |||
ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass); | |||
ADD_PASS_WRAPPER_0("add_cse", createCSEPass); | |||
ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass); | |||
ADD_PASS_WRAPPER_0("print_ir", createPrintIRPass); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are going to move the pipeliner to be a "uber pass" so this won't help those cases anymore unfortunately
@@ -426,6 +427,9 @@ def matmul_kernel_device_tma_persistent(workspace_ptr, # | |||
num_pid_in_group = GROUP_SIZE_M * num_pid_n | |||
|
|||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |||
# Create an opaque value to prevent the descriptor creation from being | |||
# hoisted out of the loop | |||
zero = tl.inline_asm_elementwise("mov.b32 $0, 0;", "=r", [], dtype=tl.int32, is_pure=True, pack=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, this is quite nice to have value based tensor descriptor.
Why do we want to block LICM in this case?
// TODO peter: walk to loop yield to find the init value if this is a | ||
// loop-carried value. That would save us from allocating another buffer | ||
// just for the init value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we allocate an extra buffer for init value in this case? Wouldn't the buffer be allocated only at the place where we create a descriptor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initial loop value will be created outside the loop, most likely by a call to tt.make_tensor_descriptor
. Since this is only replacing descriptor creation inside the loop, it will be missed and lowered by the fallback lowering instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah interesting, yes the way the original loop is written is a bit weird as the first tt. make_tensor_descriptor
is outside the loop. I wonder if we should write it differently to make it more friendly to the pipeliner even if that means having a separate if
block originally. But that's a detail probably not worth looking at now.
Commits in this PR
PR chain