Skip to content

Commit

Permalink
[mgpu] FragentedArray.foreach() can now optionally return a new array
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700002715
  • Loading branch information
cperivol authored and Google-ML-Automation committed Nov 25, 2024
1 parent aa05dc0 commit 5eda2ba
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
25 changes: 21 additions & 4 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,15 +1223,32 @@ def select(self, on_true, on_false):
lambda t, p, f: arith.select(p, t, f), self, on_false,
)

def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
def foreach(
self,
fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None],
*,
create_array=False,
is_signed=None,
):
"""Call a function for each value and index."""
index = ir.IndexType.get()
for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True):
assert len(idx) == len(self.shape), (idx, self.shape)
new_regs = np.empty_like(self.registers)
for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True):
reg = self.registers[reg_idx]
assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape)
[elems] = ir.VectorType(reg.type).shape
vec = llvm.mlir_undef(reg.type)
for i in range(elems):
i = c(i, index)
fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i)))
val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i)))
if create_array:
vec = vector.insertelement(val, vec, position=i)

if create_array:
new_regs[reg_idx] = vec

if create_array:
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)

def store_untiled(self, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
Expand Down
39 changes: 39 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,45 @@ def kernel(ctx, dst, _):
rhs = rhs = 0 if rhs_is_literal else iota + 1
np.testing.assert_array_equal(result, op(iota, rhs))

def test_foreach(self):
dtype = jnp.int32
swizzle = 128
tile = 64, swizzle // jnp.dtype(dtype).itemsize
shape = 128, 192
tiled_shape = mgpu.tile_shape(shape, tile)
mlir_dtype = utils.dtype_to_ir_type(dtype)
cst = 9999
def causal(val, idx):
row, col = idx
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
return arith.select(mask, val, c(cst, mlir_dtype))

tiling = mgpu.TileTransform(tile)
def kernel(ctx, src, dst, scratch):
smem, barrier = scratch
ctx.async_copy(src_ref=src, dst_ref=smem, gmem_transform=tiling, swizzle=128, barrier=barrier)
barrier.wait()
x = mgpu.FragmentedArray.load_tiled(smem, 128, is_signed=True)
x.foreach(causal, create_array=True, is_signed=False).store_tiled(smem, 128)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem, dst_ref=dst, gmem_transform=tiling, swizzle=128)
ctx.await_async_copy(0)

iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
result = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
iota,
iota,
(
jax.ShapeDtypeStruct(dtype=dtype, shape=tiled_shape),
mgpu.TMABarrier(),
),
)(iota)
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
np.testing.assert_array_equal(result, expected)

@parameterized.product(
op=[operator.and_, operator.or_, operator.xor],
dtype=[jnp.uint32],
Expand Down

0 comments on commit 5eda2ba

Please sign in to comment.