Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] stages parameters in TileDescription #1184

Closed
jeromeku opened this issue Nov 14, 2023 · 2 comments
Closed

[QST] stages parameters in TileDescription #1184

jeromeku opened this issue Nov 14, 2023 · 2 comments
Labels
question Question

Comments

@jeromeku
Copy link
Contributor

What is your question?
How is the stages parameter in TileDescription of the python cutlass library determined for non-Hopper architectures? What exactly is the pipeline, and how the "maximum number of stages for an operation on a given architecture" computed?

More generally, where does this fit into the GEMM hierarchy as described here and here?

Thanks!

@jackkosaian
Copy link
Contributor

Stages here is referring to the number of buffers in shared memory available for copying data from global memory to shared memory. If you had one stage, you would need to synchronously load the operand and consume it (i.e., do math on it). With two stages, you could consume the operand loaded in buffer 0 while buffer 1 is being filled. This concept extends out for larger stage counts as well.

From the description above, you can see how a "pipeline" is formed: we load an operand for some operation that must be completed in the future while we are computing another operation.

For SM80+ architectures, the maximum stage count possible is primarily limited by the amount of available shared memory per SM. This can be roughly calculated as the total number of tiles of A and B that can fit into shared memory.

tile_size_bytes = (tile_m * tile_k * sizeof(element_A)) + (tile_n * tile_k * sizeof(element_B))
max_stage_count = shared_mem_per_sm / tile_size_bytes

This calculation is imperfect because it does not account for shared memory used in other components of the kernel (e.g., the epilogue). However, it is a sufficient calculation for CUTLASS 2.x kernels (typically, CC <= 89) as these these kernels use epilogues that do not require isolation of shared memory between the mainloop and epilogue.

You can follow how this is calculated in the Python interface here. Note that stage counts exceeding 2 are only available for SM80+ in CUTLASS.

@jeromeku
Copy link
Contributor Author

Hi @jackkosaian:

Many thanks for the clear explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Question
Projects
None yet
Development

No branches or pull requests

3 participants