Skip to content

Commit

Permalink
[Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with e…
Browse files Browse the repository at this point in the history
…xplicit handling of parities

This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
dimitar-asenov authored and Google-ML-Automation committed Jan 30, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 46512e6 commit 6214c25
Showing 3 changed files with 93 additions and 15 deletions.
38 changes: 35 additions & 3 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@
import operator
from typing import Sequence, Type, cast

import jax
from jax._src.interpreters import mlir as mlir_interpreter
from jax._src.lib import mosaic_gpu_dialect as mgpu
from jax._src.lib.mlir import ir
@@ -118,12 +119,18 @@ def _fragmented_array_from_ir(
)


# TODO(dasenov): Remove this when minimum jaxlib version >= 0.5.1.
# Jaxlib doesn't contain the latest Mosaic GPU dialect bindings.
WaitOp = mgpu.WaitOp if jax.version._version == jax.lib.__version__ else None
ArriveExpectTxOp = mgpu.ArriveExpectTxOp if jax.version._version == jax.lib.__version__ else None

def _register_lowering(
op: str | Type[ir.OpView]
op: str | Type[ir.OpView] | None
) -> Callable[[MlirLoweringRule], MlirLoweringRule]:
def wrapper(f):
op_name = op if isinstance(op, str) else op.OPERATION_NAME # pytype: disable=attribute-error
_lowerings[op_name] = f
if op is not None:
op_name = op if isinstance(op, str) else op.OPERATION_NAME # pytype: disable=attribute-error
_lowerings[op_name] = f
return f

return wrapper
@@ -309,6 +316,31 @@ def _mgpu_wgmma_op_lowering_rule(
return [_fragmented_array_to_ir(new_acc.value, wgmma_op.accumulator.type)]


@_register_lowering(ArriveExpectTxOp)
def _mgpu_arrive_expect_tx_op_lowering_rule(
ctx: LoweringContext, arrive_expect_tx_op: ArriveExpectTxOp
) -> Sequence[ir.Value]:

barrier = utils.BarrierRef.from_dialect_barrier_memref(arrive_expect_tx_op.barrier)
barrier.arrive_expect_tx(
arrive_expect_tx_op.expect_tx.value,
ctx.single_thread_per_warpgroup_predicate,
)

return []


@_register_lowering(WaitOp)
def _mgpu_wait_op_lowering_rule(
_: LoweringContext, wait_op: WaitOp
) -> Sequence[ir.Value]:

barrier = utils.BarrierRef.from_dialect_barrier_memref(wait_op.barrier)
barrier.wait_parity(wait_op.parity)

return []


def single_thread_predicates(module: ir.Module) -> tuple[ir.Value, ir.Value]:
"""Returns a single thread predicate per block and one per warpgroup."""
block_predicate = warpgroup_predicate = None
38 changes: 38 additions & 0 deletions jaxlib/mosaic/dialect/gpu/mosaic_gpu.td
Original file line number Diff line number Diff line change
@@ -72,6 +72,44 @@ def MosaicGPU_InitializeBarrierOp : Op<MosaicGPU_Dialect, "initialize_barrier",
}];
}

def MosaicGPU_ArriveExpectTxOp : Op<MosaicGPU_Dialect, "arrive_expect_tx", []> {
let summary = "Executes an arrive.expect_tx operation on the given barrier.";
let description = [{
A single thread in the warpgroup will execute an `arrive.expect_tx`
operation on the provided barrier with the provided `expect_tx`.
}];

let arguments = (ins
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
ConfinedAttr<I32Attr, [IntNonNegative]>:$expect_tx);

let assemblyFormat = [{
`barrier` `(` $barrier `:` type($barrier) `)`
$expect_tx
attr-dict
}];
}

def MosaicGPU_WaitOp : Op<MosaicGPU_Dialect, "wait", []> {
let summary = "Executes a wait operation on the given barrier.";
let description = [{
All threads in the warpgroup will block, waiting on the provided barrier
until:
- all pending threads have arrived on the barrier
- all expected byte transfers have been completed
- the barrier's parity matches the provided parity
}];

let arguments = (ins
MemRefRankOf<[MosaicGPU_Barrier], [0]>:$barrier,
I1:$parity
);
let assemblyFormat = [{
`barrier` `(` $barrier `:` type($barrier) `)`
`parity` `(` $parity `:` type($parity) `)`
attr-dict
}];
}

def MosaicGPU_WGStridedFragLayout : AttrDef<MosaicGPU_Dialect, "WGStridedFragLayout", []> {
let summary = "Annotates an array that can be collapsed to 1D and sharded across threads.";
32 changes: 20 additions & 12 deletions tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
@@ -2029,8 +2029,9 @@ def add(
smem: list[ir.Value],
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref = smem[:3]
tma_barrier = smem[3]
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
dialect_barrier = tma_barrier.as_dialect_barrier_memref()

memref_type = ir.MemRefType(a_gmem_ref.type)
shape = memref_type.shape
elt_type = memref_type.element_type
@@ -2040,13 +2041,15 @@ def add(
memref_bytes = utils.bytewidth(elt_type) # Also correct if rank == 0
for size in shape:
memref_bytes *= size
tma_barrier.arrive_expect_tx(2 * memref_bytes, single_thread_predicate())
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier, expect_tx=2 * memref_bytes
)

# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
barrier=tma_barrier.as_dialect_barrier_memref(),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
@@ -2057,7 +2060,7 @@ def add(
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
barrier=tma_barrier.as_dialect_barrier_memref(),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32],
slice_lengths=shape,
transforms=ir.ArrayAttr.get([]),
@@ -2066,7 +2069,9 @@ def add(
swizzle=swizzle,
)

tma_barrier.wait()
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)

zero_index = arith.constant(ir.IndexType.get(), 0)

@@ -2137,24 +2142,25 @@ def matmul(
):
del ctx
a_smem_ref, b_smem_ref, result_smem_ref, tma_barrier = smem
dialect_barrier = tma_barrier.as_dialect_barrier_memref()

shape_a = ir.MemRefType(a_gmem_ref.type).shape
shape_b = ir.MemRefType(b_gmem_ref.type).shape
ab_elt_type = ir.MemRefType(a_gmem_ref.type).element_type
memref_bytes_a = utils.bytewidth(ab_elt_type) * math.prod(shape_a)
memref_bytes_b = utils.bytewidth(ab_elt_type) * math.prod(shape_b)

tma_barrier.arrive_expect_tx(
memref_bytes_a + memref_bytes_b,
single_thread_predicate(),
mgpu_dialect.arrive_expect_tx(
barrier=dialect_barrier,
expect_tx=memref_bytes_a + memref_bytes_b,
)

zero_i32 = arith.constant(ir.IntegerType.get_signless(32), 0)
# GMEM -> SMEM
mgpu_dialect.async_load(
source=a_gmem_ref,
destination=a_smem_ref,
barrier=tma_barrier.as_dialect_barrier_memref(),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32, zero_i32, zero_i32],
slice_lengths=shape_a,
transforms=ir.ArrayAttr.get([]),
@@ -2165,7 +2171,7 @@ def matmul(
mgpu_dialect.async_load(
source=b_gmem_ref,
destination=b_smem_ref,
barrier=tma_barrier.as_dialect_barrier_memref(),
barrier=dialect_barrier,
indices=[zero_i32, zero_i32, zero_i32, zero_i32],
slice_lengths=shape_b,
transforms=ir.ArrayAttr.get([]),
@@ -2174,7 +2180,9 @@ def matmul(
swizzle=swizzle,
)

tma_barrier.wait()
parities = memref.load(tma_barrier.phases, [])
parity, _ = tma_barrier.update_parities(parities)
mgpu_dialect.wait(dialect_barrier, parity)

# Computation
shape_result = ir.MemRefType(result_gmem_ref.type).shape

0 comments on commit 6214c25

Please sign in to comment.