Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Why it won't OOB in tiled_copy pipeline #2018

Closed
ZhZhang711 opened this issue Dec 31, 2024 · 4 comments
Closed

[QST] Why it won't OOB in tiled_copy pipeline #2018

ZhZhang711 opened this issue Dec 31, 2024 · 4 comments

Comments

@ZhZhang711
Copy link

What is your question?
A toy example (I am a newbee and there might be some "brainless" atom choice):

  using ELM = cutlass::half_t;
  using bM = decltype(Int<128>{});
  using bN = decltype(Int<128>{});
  using bK = decltype(Int<16>{});
  TiledMMA tmma =
      make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, Layout<Shape<_2, _2, _2>>{},
                     Tile<_32, _32, _16>{});
  auto thr_mma = tmma.get_thread_slice(0);

  auto sA = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bM, bK>>{});  // Let's assume A is somehow copied to this sA
  auto sB = make_tensor(make_smem_ptr((ELM *)(0)), Layout<Shape<bN, bK>>{});  // Let's assume the same as well

  Tensor tSrA = thr_mma.partition_fragment_A(sA);
  Tensor tSrB = thr_mma.partition_fragment_B(sB);
  Tensor acc = partition_fragment_C(tmma, Shape<bM, bN>{});

  auto cp_atom = Copy_Atom<SM75_U32x4_LDSM_N, ELM>{};
  auto smem_tiled_cp_A = make_tiled_copy_A(cp_atom, tmma);
  auto smem_thr_cp_A = smem_tiled_cp_A.get_thread_slice(0);
  Tensor tSsA = smem_thr_cp_A.partition_S(sA);
  auto smem_tiled_cp_B = make_tiled_copy_B(cp_atom, tmma);
  auto smem_thr_cp_B = smem_tiled_cp_B.get_thread_slice(0);
  Tensor tSsB = smem_thr_cp_B.partition_S(sB);

  Tensor tSrA_copy_view = smem_thr_cp_A.retile_D(tSrA);
  Tensor tSrB_copy_view = smem_thr_cp_A.retile_D(tSrB);

  printf("\n");
  cute::print(layout<>(tSrA));
  printf("\n");
  cute::print(layout<>(tSsA));
  printf("\n");
  cute::print(layout<>(tSrA_copy_view));
  printf("\n");

and stdout would give me this:

(_4,_8,_2):(_1,_4,_32)
(((_2,_4),_2),_4,_1):(((_1,_128),_1024),_32,_0)
((_8,_2),_4,_1):((_1,_32),_8,_0)

And then many examples will launch a pipeline iterating the K-mode like this:

  cute::copy(smem_tiled_cp_A, tSsA(_, _, _0{}), tSrA_copy_view(_, _, _0{}));
  cute::copy(smem_tiled_cp_B, tSsB(_, _, _0{}), tSrB_copy_view(_, _, _0{}));

  for (int i = 0; i < size<2>(tSrA); ++i) {
    if (i < size<2>(tSrA) - 1) {  // prefetch
      cute::copy(smem_tiled_copy_A, tSsA(_, _, i + 1), tSrA_copy_view(_, _, i + 1));
      cute::copy(smem_tiled_copy_B, tSsB(_, _, i + 1), tSrB_copy_view(_, _, i + 1));
    }
    cute::gemm(tmma, tSrA(_, _, i), tSrB(_, _, i), acc);
  }

The question is, the K-mode of tSsA and tSrA_copy_view is 1, but that of tSrA is 2. It seems a single copy from smem to register is sufficient for 2 gemms in this case, so isn't that tSsA(_, _, i + 1) and tSrA_copy_view(_, _, i + 1) will go out of bounds when i == 0?
Hope anyone could guide me through this, thanks!

@ZhZhang711 ZhZhang711 changed the title [QST] Why it won't OOB in tiled_copy [QST] Why it won't OOB in tiled_copy pipeline Dec 31, 2024
@ccecka
Copy link

ccecka commented Jan 1, 2025

Yes, those would go out of bounds because those coordinates are not valid for the shape as you noted.

The confusion stems from conflating two individual partitioning patterns: one for the MMA and one for the CPY. If we write it like this

  Tensor tCrA = thr_mma.partition_fragment_A(sA);
  Tensor tCrB = thr_mma.partition_fragment_B(sB);
  Tensor tCrC = partition_fragment_C(tmma, Shape<bM, bN>{});

  Tensor tAsA = smem_thr_cp_A.partition_S(sA);
  Tensor tArA = smem_thr_cp_A.retile_D(tCrA);

  Tensor tBsB = smem_thr_cp_B.partition_S(sB);
  Tensor tBrB = smem_thr_cp_B.retile_D(tCrB);

Then it's more clear that

  • tCrA is the "registers for A rA partitioned across threads for the mma to compute C tC", and
  • tArA is the "registers for A rA partitioned across threads for the ldsm to copy A tA".

could be very different shapes/orders of the same registers because they are fragments for different stages that have different partitioning.

There happens to be two MMA instructions in the K-mode for tCrA and there happens to be one CPY instruction in the K-mode for tArA. Because of that, there doesn't seem to be any prefetching opportunity here and we can't loop over the MMA-instructions or the CPY-instructions in the inner loops.

cute::copy(smem_tiled_copy_A, tAsA, tArA);   // CPY the entire sA tile with the tA partitioning
cute::copy(smem_tiled_copy_B, tBsB, tBrB);   // CPY the entire sB tile with the tB partitioning
cute::gemm(tmma, tCrA, tCrB, tCrC);          // MMA to consume entire sA/sB tile with the tC partitioning

I also note that you're using 2 independent MMAs in the K-mode of the tile. Because those MMAs are distinct, their accumulators would have to be accumulated in the epilogue and a lot of care has to be taken in the mainloop. I believe that CUTLASS has not found a real use case for this... which makes sense since it necessarily decreases the compute intensity of the inner loops. In general, we don't recommend that and you should probably stick to

  TiledMMA tmma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, 
                                 Layout<Shape<_2, _4>>{},   // 2x4 MMAs instead of 2x2x2
                                 Tile<_32, _32, _16>{});    // Still 32x32x16

while still paying attention to the distinct partitioning patterns of the tensors like above.

But then, to increase the compute intensity even more, I would try to get the LDSM to work down the MN-modes rather than across the K-modes, so perhaps something like this is even better

  TiledMMA tmma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, 
                                 Layout<Shape<_2, _4>>{},
                                 Tile<_64, _64, _8>{});

and now you're into layout engineering.

@ZhZhang711
Copy link
Author

@ccecka Huge thanks to your detailed explanations! I think I know much better than before. There are some more questions though. The config you gave can split the copies in K mode to 2 chunks:

  TiledMMA tmma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, 
                                 Layout<Shape<_2, _4>>{},
                                 Tile<_64, _64, _8>{});
  // layouts:
  // tCrA   : (_4,_8,_4):(_1,_4,_32)
  // tAsA   : (((_2,_4),(_2,_2)),_2,_2):(((_1,_128),(_32,_512)),_64,_1024)
  // tCrA_cp: ((_8,(_2,_2)),_2,_2):((_1,(_8,_32)),_16,_64)

but MMA_K and CPY_K still don't match (4 vs 2), which means we can only prefetch once in this "4 gemms 2 copies" situation, right?
So I am wondering if I half down the K mode of permutations once again to _4 like below, would it be faster? As we can prefetch 3 times now, making up a CPY -> MMA x4 pipeline.

  TiledMMA tmma = make_tiled_mma(SM70_8x8x4_F32F16F16F32_NT{}, 
                                 Layout<Shape<_2, _4>>{},
                                 Tile<_64, _64, _4>{});
  // layouts:
  // tCrA   : (_4,_8,_4):(_1,_4,_32)
  // tAsA   : (((_2,_4),_2),_2,_4):(((_1,_128),_32),_64,_512)
  // tCrA_cp: ((_8,_2),_2,_4):((_1,_8),_16,_32)

@ccecka
Copy link

ccecka commented Jan 2, 2025

In your upper case, you could have a size-2 pipeline. In your lower case, you could have a size-4 pipeline. Yes, the lower case would almost certainly be better. In general, the smaller the K-tile, the higher the compute intensity.

@ZhZhang711
Copy link
Author

Thanks for your guidance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants