Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Question on customize epilogue reduction #1301

Closed
zejia-lin opened this issue Jan 11, 2024 · 7 comments
Closed

[QST] Question on customize epilogue reduction #1301

zejia-lin opened this issue Jan 11, 2024 · 7 comments
Labels
question Question

Comments

@zejia-lin
Copy link

What is your question?

Hello, I found that many epilogues are element-wise. I wondered if it could be customized to sum up a 2*2 tile instead of an element-wise operation. That is, for D = AB + C, where A is a (m*2, k) matrix, B is a (k, n*2) matrix, and C, D is (m, n) matrix . While AB produces a (m*2, n*2) matrix, is it possible to sum up every 2*2 tile of the output matrix and produce a (m, n) matrix?

Many thanks for any advice.

image
@hwu36
Copy link
Collaborator

hwu36 commented Jan 11, 2024

which hardware do you use? what is the data type? do you want to use tensor cores?

@zejia-lin
Copy link
Author

zejia-lin commented Jan 12, 2024

Thanks!

I am using A100, both cutlass 2.x and 3.x is suitable for me. The data type of A and B are int8 with int32 accumulation, C and D are int32. I do want to use tensor core.

My custom kernel basically similar to pooling, which takes 2*2 elements and returns 1 element, but it has more complex operation internally. I found in this issue #188 said cutlass has no pooling at March 2, 2021. I was wondering if there is such functionality now.

Specifically, if there is any interface, I could easily implement it:

  1. operate on the accumulation fragment after performing GEMM and before writing to global memory. Possibly the epilogue stage, I guess.
  2. change the load and store pattern of C and D. Because A, B are 2m-by-2n-by-k matrices, and produce a 2m-by-2n matrix. My kernel works on the 2m-by-2n matrix and produces a m-by-n matrix, which is the dimension of C and D.

I do concern about memory consumption, so I don't want to store the 2m-by-2n matrix in global memory and launch another kernel to perform this operation.

@hwu36
Copy link
Collaborator

hwu36 commented Jan 13, 2024

in 2.x, you can get row coordinate from row_offset + thread_start_row_ in https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L398

every threads own several fragments. every fragment owns kElementsPerAccess consecutive data in the same row. you can first do 1x2 reduction here. then do more reduction with different threads in the next row.

you can first dump row coordinates and check the mapping between thread id and row coordinate. all the mapping information you need is actually in ThreadMap (https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L71)

don't forget to change the memory pointer at last (

memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset
) for the new coordinates.

@zejia-lin
Copy link
Author

Thank you for the detailed reply. I'll try it later.

Copy link

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

@mnicely
Copy link
Collaborator

mnicely commented Feb 22, 2024

@zejia-lin have you resolved your issue?

@zejia-lin
Copy link
Author

I am sorry for the late response. I found I was not able to resolve it under reasonable efforts. I am closing this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question
Projects
None yet
Development

No branches or pull requests

3 participants