Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BACKEND] Linear Layout with stmatrix part 2: support stmatrix for local_alloc ops #4763

Merged
merged 22 commits into from
Oct 1, 2024

Conversation

Jokeren
Copy link
Contributor

@Jokeren Jokeren commented Sep 19, 2024

This PR enables the use of stmatrix for local_alloc ops through linear layout and removes the legacy code from the TargetInfo class.

@Jokeren Jokeren changed the title [DRAFT][BACKEND] Linear Layout with stmatrix part 2: support stmatrix for local_alloc ops [BACKEND] Linear Layout with stmatrix part 2: support stmatrix for local_alloc ops Sep 25, 2024
@Jokeren Jokeren marked this pull request as ready for review September 25, 2024 02:13
Copy link
Collaborator

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • In general I don't really understand the new code, could use help by way of additional comments.
  • Have we run performance tests?
  • Perhaps we can involve someone else from the Triton team so they can start learning this stuff?

// In the swizzled layout, the leading dimension (i.e., column dimension) is
// strided by swizzleByteSize. For example, in a matrix of size 128x128 with a
// data type of f16, stored in shared memory using 128B-swizzle mode, the offset
// of the element at index (1, 0) will be 72 due to the stride. Without
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a 128x128 matrix with no swizzling, I would have thought that the element at index (1,0) would be at offset 128. How do we get 64 and 72?

(I also don't see how this thing has to do with swizzling. As described is it just the stride of dimension 1, measured in bytes?)

Copy link
Contributor Author

@Jokeren Jokeren Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a 128x128 matrix with no swizzling, I would have thought that the element at index (1,0) would be at offset 128. How do we get 64 and 72?

My bad... I'll just remove comment about the offset without swizzling.

Copy link
Contributor Author

@Jokeren Jokeren Sep 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

72 should be corrected with 128 + 2 bytes/element * 8 swizzled coordinate = 144 if we use the number of bytes as the unit of offset.

@Jokeren
Copy link
Contributor Author

Jokeren commented Sep 26, 2024

Have we run performance tests?

Yeah, no regression found.

Perhaps we can involve someone else from the Triton team so they can start learning this stuff?

I'll discuss with Phil and Thomas next week regarding this issue. Sorry about consistently bothering you on reviewing these PRs.

@Jokeren
Copy link
Contributor Author

Jokeren commented Sep 26, 2024

In general I don't really understand the new code, could use help by way of additional comments.

Might be better if you could point out me which parts you're not clear about?

@jlebar
Copy link
Collaborator

jlebar commented Sep 27, 2024

Sorry about consistently bothering you on reviewing these PRs.

No problem, it would just be good to reduce the bus factor here.

Might be better if you could point out me which parts you're not clear about?

Just the ones already pointed out.

@Jokeren
Copy link
Contributor Author

Jokeren commented Sep 27, 2024

Just the ones already pointed out.

OK, let me revisit the implementation tomorrow and think about a better explanation

@Jokeren
Copy link
Contributor Author

Jokeren commented Sep 29, 2024

Hi @jlebar , comments have been updated

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jlebar
Copy link
Collaborator

jlebar commented Sep 30, 2024

Sorry, it looks like I had review comments that never made it in. Feel free to ignore if you want, since you already have an LGTM from the Triton team.

// data type of f16, stored in shared memory using 128B swizzle mode, the offset
// of the element at index (1, 0) will be 128B + 2B * 8 (vector_width) = 144B
// due to the stride. However, if we apply swizzling without a leading offset,
// the offset would be 2B * 128 (num_columns) + 2B * 8 (vector_width) = 272B.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I almost get it, but maybe if you wrote it as index (x,y) instead of (1,0) then it would be clear? Or is the formula too gross?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comments have been refactored significantly. Should have information you want now?

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp Outdated Show resolved Hide resolved
return true;
}

} // anonymous namespace
std::optional<LinearLayout> chooseStMatrixLayoutLeadingOffset(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably this function should have some unit tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I will add a test when investigating peter's issue. Seems like there're still some problems.
#4727

auto vals = unpackLLVector(loc, val, rewriter);
SmallVector<Value> inputs;
// Pack the input into 2xf16
Type packedTy = vec_ty(vals[0].getType(), 2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is vals[0].getType() the same as elemTy?

return true;
auto vals = unpackLLVector(loc, val, rewriter);
SmallVector<Value> inputs;
// Pack the input into 2xf16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to assert that elemTy is f16 (or is a 16-bit scalar value or something?)

for (int i = 0; i < 4; i++) {
Value input = undef(packedTy);
for (int j = 0; j < 2; j++) {
input = insert_element(packedTy, input, vals[i * 2 + j], i32_val(j));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to assert something about the size of vals? Otherwise this could silently read off the end of the array?

@Jokeren
Copy link
Contributor Author

Jokeren commented Sep 30, 2024

I did some searching and I'm not finding where this was copy-pasted from (not sure which old code you're referring to), but yeah, in order to review for correctness I think I need to understand it.

Oh, to be clear, it's not a direct copy and paste.

It's based on the following lines:

Copy link
Collaborator

@jlebar jlebar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is much more helpful, thank you!!

@Jokeren
Copy link
Contributor Author

Jokeren commented Oct 1, 2024

Assertions have been added for the data type and number of elements. Merge this PR into main now.

@Jokeren Jokeren merged commit 49266aa into main Oct 1, 2024
7 checks passed
@Jokeren Jokeren deleted the keren/local-alloc branch October 1, 2024 00:53
SamGinzburg pushed a commit to SamGinzburg/triton that referenced this pull request Oct 1, 2024
…ocal_alloc` ops (triton-lang#4763)

This PR enables the use of `stmatrix` for `local_alloc` ops through
linear layout and removes the legacy code from the `TargetInfo` class.
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…ocal_alloc` ops (triton-lang#4763)

This PR enables the use of `stmatrix` for `local_alloc` ops through
linear layout and removes the legacy code from the `TargetInfo` class.
bertmaher pushed a commit to bertmaher/triton that referenced this pull request Dec 10, 2024
…ocal_alloc` ops (triton-lang#4763)

This PR enables the use of `stmatrix` for `local_alloc` ops through
linear layout and removes the legacy code from the `TargetInfo` class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants