-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[QST] Why it won't OOB in tiled_copy pipeline #2018
Comments
Yes, those would go out of bounds because those coordinates are not valid for the shape as you noted. The confusion stems from conflating two individual partitioning patterns: one for the MMA and one for the CPY. If we write it like this Tensor tCrA = thr_mma.partition_fragment_A(sA);
Tensor tCrB = thr_mma.partition_fragment_B(sB);
Tensor tCrC = partition_fragment_C(tmma, Shape<bM, bN>{});
Tensor tAsA = smem_thr_cp_A.partition_S(sA);
Tensor tArA = smem_thr_cp_A.retile_D(tCrA);
Tensor tBsB = smem_thr_cp_B.partition_S(sB);
Tensor tBrB = smem_thr_cp_B.retile_D(tCrB); Then it's more clear that
could be very different shapes/orders of the same registers because they are fragments for different stages that have different partitioning. There happens to be two MMA instructions in the K-mode for cute::copy(smem_tiled_copy_A, tAsA, tArA); // CPY the entire sA tile with the tA partitioning
cute::copy(smem_tiled_copy_B, tBsB, tBrB); // CPY the entire sB tile with the tB partitioning
cute::gemm(tmma, tCrA, tCrB, tCrC); // MMA to consume entire sA/sB tile with the tC partitioning I also note that you're using 2 independent MMAs in the K-mode of the tile. Because those MMAs are distinct, their accumulators would have to be accumulated in the epilogue and a lot of care has to be taken in the mainloop. I believe that CUTLASS has not found a real use case for this... which makes sense since it necessarily decreases the compute intensity of the inner loops. In general, we don't recommend that and you should probably stick to TiledMMA tmma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{},
Layout<Shape<_2, _4>>{}, // 2x4 MMAs instead of 2x2x2
Tile<_32, _32, _16>{}); // Still 32x32x16 while still paying attention to the distinct partitioning patterns of the tensors like above. But then, to increase the compute intensity even more, I would try to get the LDSM to work down the MN-modes rather than across the K-modes, so perhaps something like this is even better TiledMMA tmma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{},
Layout<Shape<_2, _4>>{},
Tile<_64, _64, _8>{}); and now you're into layout engineering. |
@ccecka Huge thanks to your detailed explanations! I think I know much better than before. There are some more questions though. The config you gave can split the copies in K mode to 2 chunks:
but
|
In your upper case, you could have a size-2 pipeline. In your lower case, you could have a size-4 pipeline. Yes, the lower case would almost certainly be better. In general, the smaller the K-tile, the higher the compute intensity. |
Thanks for your guidance! |
What is your question?
A toy example (I am a newbee and there might be some "brainless" atom choice):
and stdout would give me this:
And then many examples will launch a pipeline iterating the K-mode like this:
The question is, the K-mode of
tSsA
andtSrA_copy_view
is1
, but that oftSrA
is2
. It seems a single copy from smem to register is sufficient for 2 gemms in this case, so isn't thattSsA(_, _, i + 1)
andtSrA_copy_view(_, _, i + 1)
will go out of bounds wheni == 0
?Hope anyone could guide me through this, thanks!
The text was updated successfully, but these errors were encountered: