diff --git a/jax/experimental/mosaic/gpu/launch_context.py b/jax/experimental/mosaic/gpu/launch_context.py index 7744bf80c0e2..6333af1387e9 100644 --- a/jax/experimental/mosaic/gpu/launch_context.py +++ b/jax/experimental/mosaic/gpu/launch_context.py @@ -317,6 +317,8 @@ def init_tma_desc(host_ptr): if swizzle is None else swizzle ) + # TODO(apaszke): Better verification (e.g. slice is non-zero) + # TODO(apaszke): We always know strides statically. args = [ host_ptr, base_ptr, @@ -454,7 +456,14 @@ def async_copy( f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" ) smem_strides, _ = smem_ref_ty.get_strides_and_offset() - if smem_strides != utils.get_contiguous_strides(smem_ref_ty.shape): + if any( + s != cs and d != 1 # Strides don't matter for dims of size 1. + for s, cs, d in zip( + smem_strides, + utils.get_contiguous_strides(smem_ref_ty.shape), + smem_ref_ty.shape, + ) + ): raise ValueError( "async_copy needs the SMEM reference to be contiguous, but got" f" strides {smem_strides} for shape {smem_ref_ty.shape}"