-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Linear Layout with stmatrix part 2: support stmatrix for local_alloc
ops
#4763
Conversation
local_alloc
opslocal_alloc
ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- In general I don't really understand the new code, could use help by way of additional comments.
- Have we run performance tests?
- Perhaps we can involve someone else from the Triton team so they can start learning this stuff?
// In the swizzled layout, the leading dimension (i.e., column dimension) is | ||
// strided by swizzleByteSize. For example, in a matrix of size 128x128 with a | ||
// data type of f16, stored in shared memory using 128B-swizzle mode, the offset | ||
// of the element at index (1, 0) will be 72 due to the stride. Without |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have a 128x128 matrix with no swizzling, I would have thought that the element at index (1,0) would be at offset 128. How do we get 64 and 72?
(I also don't see how this thing has to do with swizzling. As described is it just the stride of dimension 1, measured in bytes?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have a 128x128 matrix with no swizzling, I would have thought that the element at index (1,0) would be at offset 128. How do we get 64 and 72?
My bad... I'll just remove comment about the offset without swizzling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
72 should be corrected with 128 + 2 bytes/element * 8 swizzled coordinate = 144 if we use the number of bytes as the unit of offset.
Yeah, no regression found.
I'll discuss with Phil and Thomas next week regarding this issue. Sorry about consistently bothering you on reviewing these PRs. |
Might be better if you could point out me which parts you're not clear about? |
No problem, it would just be good to reduce the bus factor here.
Just the ones already pointed out. |
OK, let me revisit the implementation tomorrow and think about a better explanation |
Hi @jlebar , comments have been updated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Sorry, it looks like I had review comments that never made it in. Feel free to ignore if you want, since you already have an LGTM from the Triton team. |
// data type of f16, stored in shared memory using 128B swizzle mode, the offset | ||
// of the element at index (1, 0) will be 128B + 2B * 8 (vector_width) = 144B | ||
// due to the stride. However, if we apply swizzling without a leading offset, | ||
// the offset would be 2B * 128 (num_columns) + 2B * 8 (vector_width) = 272B. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I almost get it, but maybe if you wrote it as index (x,y)
instead of (1,0)
then it would be clear? Or is the formula too gross?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comments have been refactored significantly. Should have information you want now?
return true; | ||
} | ||
|
||
} // anonymous namespace | ||
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably this function should have some unit tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I will add a test when investigating peter's issue. Seems like there're still some problems.
#4727
auto vals = unpackLLVector(loc, val, rewriter); | ||
SmallVector<Value> inputs; | ||
// Pack the input into 2xf16 | ||
Type packedTy = vec_ty(vals[0].getType(), 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is vals[0].getType() the same as elemTy?
return true; | ||
auto vals = unpackLLVector(loc, val, rewriter); | ||
SmallVector<Value> inputs; | ||
// Pack the input into 2xf16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to assert that elemTy is f16 (or is a 16-bit scalar value or something?)
for (int i = 0; i < 4; i++) { | ||
Value input = undef(packedTy); | ||
for (int j = 0; j < 2; j++) { | ||
input = insert_element(packedTy, input, vals[i * 2 + j], i32_val(j)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to assert something about the size of vals? Otherwise this could silently read off the end of the array?
Oh, to be clear, it's not a direct copy and paste. It's based on the following lines:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is much more helpful, thank you!!
Assertions have been added for the data type and number of elements. Merge this PR into main now. |
…ocal_alloc` ops (triton-lang#4763) This PR enables the use of `stmatrix` for `local_alloc` ops through linear layout and removes the legacy code from the `TargetInfo` class.
…ocal_alloc` ops (triton-lang#4763) This PR enables the use of `stmatrix` for `local_alloc` ops through linear layout and removes the legacy code from the `TargetInfo` class.
…ocal_alloc` ops (triton-lang#4763) This PR enables the use of `stmatrix` for `local_alloc` ops through linear layout and removes the legacy code from the `TargetInfo` class.
This PR enables the use of
stmatrix
forlocal_alloc
ops through linear layout and removes the legacy code from theTargetInfo
class.