Skip to content

Commit

Permalink
[Mosaic GPU] Relax TMEM stride constraints on dimensions of size 1
Browse files Browse the repository at this point in the history
Strides along those dimensions don't affect anything.

PiperOrigin-RevId: 723896657
  • Loading branch information
apaszke authored and Google-ML-Automation committed Feb 6, 2025
1 parent afad924 commit 1edfc11
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion jax/experimental/mosaic/gpu/launch_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,8 @@ def init_tma_desc(host_ptr):
if swizzle is None
else swizzle
)
# TODO(apaszke): Better verification (e.g. slice is non-zero)
# TODO(apaszke): We always know strides statically.
args = [
host_ptr,
base_ptr,
Expand Down Expand Up @@ -454,7 +456,14 @@ def async_copy(
f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
)
smem_strides, _ = smem_ref_ty.get_strides_and_offset()
if smem_strides != utils.get_contiguous_strides(smem_ref_ty.shape):
if any(
s != cs and d != 1 # Strides don't matter for dims of size 1.
for s, cs, d in zip(
smem_strides,
utils.get_contiguous_strides(smem_ref_ty.shape),
smem_ref_ty.shape,
)
):
raise ValueError(
"async_copy needs the SMEM reference to be contiguous, but got"
f" strides {smem_strides} for shape {smem_ref_ty.shape}"
Expand Down

0 comments on commit 1edfc11

Please sign in to comment.