You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
What is your question?
How is the stages parameter in TileDescription of the python cutlass library determined for non-Hopper architectures? What exactly is the pipeline, and how the "maximum number of stages for an operation on a given architecture" computed?
More generally, where does this fit into the GEMM hierarchy as described here and here?
Thanks!
The text was updated successfully, but these errors were encountered:
Stages here is referring to the number of buffers in shared memory available for copying data from global memory to shared memory. If you had one stage, you would need to synchronously load the operand and consume it (i.e., do math on it). With two stages, you could consume the operand loaded in buffer 0 while buffer 1 is being filled. This concept extends out for larger stage counts as well.
From the description above, you can see how a "pipeline" is formed: we load an operand for some operation that must be completed in the future while we are computing another operation.
For SM80+ architectures, the maximum stage count possible is primarily limited by the amount of available shared memory per SM. This can be roughly calculated as the total number of tiles of A and B that can fit into shared memory.
This calculation is imperfect because it does not account for shared memory used in other components of the kernel (e.g., the epilogue). However, it is a sufficient calculation for CUTLASS 2.x kernels (typically, CC <= 89) as these these kernels use epilogues that do not require isolation of shared memory between the mainloop and epilogue.
You can follow how this is calculated in the Python interface here. Note that stage counts exceeding 2 are only available for SM80+ in CUTLASS.
What is your question?
How is the
stages
parameter inTileDescription
of the pythoncutlass
library determined for non-Hopper architectures? What exactly is thepipeline
, and how the "maximum number of stages for an operation on a given architecture" computed?More generally, where does this fit into the GEMM hierarchy as described here and here?
Thanks!
The text was updated successfully, but these errors were encountered: