-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[VectorDistribution] Add support for multi-subgroup attention #18188
Conversation
80e1d84
to
4cb027b
Compare
dc29782
to
8073687
Compare
14bad9d
to
21dddcf
Compare
8073687
to
e312fa7
Compare
e312fa7
to
648d7ca
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work, overall looks great just a quick Q/NIT, but not blocking.
@@ -297,6 +297,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { | |||
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); | |||
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0); | |||
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); | |||
s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: Any thoughts on making this a standardized/registered attribute in LinalgExt dialect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an anti-pattern. We shouldnt rely on such attributes, so i dont want to "codify" them. lets land this for now but plan to unwind this in the medium term. Could you add a note here that we shouldnt be relying on such attributes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way you can avoid maybe decomposing attention operation until vector distribution and handle the layout distribution for attention directly?
Today, IIUC we'd need to decompose earlier than vector distribution because attention decomposes into non-trivial ops such as matmuls, shuffles/reductions (to a lesser extend reads, and broadcasts) which requires layout analysis and vector distribution to ensure the thread-distributed shapes play nice with each other. |
Probably not… if we do that, we would effectively be writing microkernels for attention hardcoded for each intrinsic type at thread level. Which is fine… but not sure if we want to do that. one thing we could do is do subgroup distribution at attention op level and do thread distribution after decomposition. This would require a major rerwite of vector distribution, splitting it up into subgroup and thread level distribution. Im also not sure if we can properly split things up also. id rather land this patch, and invest effort in teaching TileAndFuse to do attention instead of rerwiting VectorDistribution. |
6166c83
to
89a46e9
Compare
@@ -297,6 +297,7 @@ OnlineAttentionOp::decomposeOperation(OpBuilder &b) { | |||
Value sZero = b.create<arith::ConstantOp>(loc, b.getZeroAttr(elementType)); | |||
Value s = b.create<linalg::FillOp>(loc, sZero, emptyS).getResult(0); | |||
s = computeMatmul(b, loc, getQueryMap(), getKeyMap(), sMap, query, key, s); | |||
s.getDefiningOp()->setAttr("attention_qk_matmul", b.getUnitAttr()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an anti-pattern. We shouldnt rely on such attributes, so i dont want to "codify" them. lets land this for now but plan to unwind this in the medium term. Could you add a note here that we shouldnt be relying on such attributes.
Well, we could also just decompose within the pass as a "pre-processing". Then the attribute becomes an internal detail of the pass.
Ok, I stamped it, but please add TODO/warnings as to this being unstable. |
89a46e9
to
0ff52b3
Compare
There are some tests that exceed shared memory, so i'm going to wait for #18415 to land before i land this. |
7b2e0cd
to
eb7d6ea
Compare
[VectorDistribution] Add support for multi-subgroup attention No tech debt Add TODO comment Add configuration heuristics for attention address comments address more comments Update tests
eb7d6ea
to
0b58ffc
Compare
…rg#18188) This patch adds support for distributing attention to multiple subgroups. Currently, we distinguish the two matmuls in attention by setting a discardable attribute on the matmuls (set during decomposition) used as a hint to layout anchoring, on what to do when it encounters these matmuls. (Note that even if these hints were dropped, it would only lead to a drop in performance, because the layout anchoring doesn't know its attention anymore). The correct way to handle these matmuls would be to start putting mma_schedule as an operation specific lowering config and teach decomposition to propagate this lowering to the two matmuls after decomposition. This is blocked by work on TileAndDistributeToWorkgroups supporting consumer fusion, and needs some heavy lifting.
This patch adds support for distributing attention to multiple subgroups.
Currently, we distinguish the two matmuls in attention by setting a discardable attribute on the matmuls (set during decomposition) used as a hint to layout anchoring, on what to do when it encounters these matmuls. (Note that even if these hints were dropped, it would only lead to a drop in performance, because the layout anchoring doesn't know its attention anymore). The correct way to handle these matmuls would be to start putting mma_schedule as an operation specific lowering config and teach decomposition to propagate this lowering to the two matmuls after decomposition. This is blocked by work on TileAndDistributeToWorkgroups supporting consumer fusion, and needs some heavy lifting.