Skip to content

Commit

Permalink
[mgpu] FragentedArray.foreach() can now optionally return a new array
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700002715
  • Loading branch information
cperivol authored and Google-ML-Automation committed Nov 27, 2024
1 parent 03b6945 commit 75deb01
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 4 deletions.
22 changes: 18 additions & 4 deletions jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,15 +1243,29 @@ def select(self, on_true, on_false):
lambda t, p, f: arith.select(p, t, f), self, on_false,
)

def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]):
def foreach(
self,
fn: Callable[[ir.Value, tuple[ir.Value, ...]], ir.Value | None],
*,
create_array=False,
is_signed=None,
):
"""Call a function for each value and index."""
index = ir.IndexType.get()
for idx, reg in zip(self.layout.thread_idxs(self.shape), self.registers.flat, strict=True):
assert len(idx) == len(self.shape), (idx, self.shape)
new_regs = None
if create_array:
new_regs = np.full_like(self.registers, llvm.mlir_undef(self.registers.flat[0].type))
for mlir_idx, reg_idx in zip(self.layout.thread_idxs(self.shape), np.ndindex(self.registers.shape), strict=True):
reg = self.registers[reg_idx]
assert len(mlir_idx) == len(self.shape), (mlir_idx, self.shape)
[elems] = ir.VectorType(reg.type).shape
for i in range(elems):
i = c(i, index)
fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i)))
val = fn(vector.extractelement(reg, position=i), (*mlir_idx[:-1], arith.addi(mlir_idx[-1], i)))
if create_array:
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)

return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)

def store_untiled(self, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
Expand Down
33 changes: 33 additions & 0 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,6 +1361,39 @@ def kernel(ctx, dst, _):
rhs = rhs = 0 if rhs_is_literal else iota + 1
np.testing.assert_array_equal(result, op(iota, rhs))

def test_foreach(self):
dtype = jnp.int32
swizzle = 128
tile = 64, swizzle // jnp.dtype(dtype).itemsize
shape = 128, 192
tiled_shape = mgpu.tile_shape(shape, tile)
mlir_dtype = utils.dtype_to_ir_type(dtype)
cst = 9999
def causal(val, idx):
row, col = idx
mask = arith.cmpi(arith.CmpIPredicate.uge, row, col)
return arith.select(mask, val, c(cst, mlir_dtype))

tiling = mgpu.TileTransform(tile)
def kernel(ctx, dst, smem):
x = iota_tensor(shape[0], shape[1], dtype)
x.foreach(causal, create_array=True, is_signed=False).store_untiled(smem)
mgpu.commit_shared()
ctx.async_copy(src_ref=smem, dst_ref=dst)
ctx.await_async_copy(0)

iota = np.arange(np.prod(shape), dtype=dtype).reshape(*shape)
result = mgpu.as_gpu_kernel(
kernel,
(1, 1, 1),
(128, 1, 1),
(),
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
jax.ShapeDtypeStruct(shape=shape, dtype=dtype),
)()
expected = jnp.tril(iota) + jnp.triu(jnp.ones(shape), k=1) * cst
np.testing.assert_array_equal(result, expected)

@parameterized.product(
op=[operator.and_, operator.or_, operator.xor],
dtype=[jnp.uint32],
Expand Down

0 comments on commit 75deb01

Please sign in to comment.