Skip to content

Commit

Permalink
[Mosaic GPU] Add end-to-end lowering example for a pointwise kernel u…
Browse files Browse the repository at this point in the history
…sing the dialect and layout inference.

Also implement a lowering rule for `arith.AddFOp`.

PiperOrigin-RevId: 707131747
  • Loading branch information
bchetioui authored and Google-ML-Automation committed Dec 17, 2024
1 parent 473e2bf commit 36b12d5
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 2 deletions.
1 change: 1 addition & 0 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LaunchContext as LaunchContext,
MemRefTransform as MemRefTransform,
TMABarrier as TMABarrier,
ThreadSemantics as ThreadSemantics,
TileTransform as TileTransform,
TransposeTransform as TransposeTransform,
Union as Union,
Expand Down
24 changes: 24 additions & 0 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import ctypes
import dataclasses
import enum
import functools
import hashlib
import math
Expand All @@ -38,6 +39,15 @@
from jaxlib.mlir.dialects import nvvm
import numpy as np

from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401

if dialect is not None:
from . import dialect_lowering
from . import layout_inference
else:
dialect_lowering = None
layout_inference = None

from . import profiler
from . import utils

Expand Down Expand Up @@ -942,6 +952,13 @@ def _declare_runtime_functions():
)


class ThreadSemantics(enum.Enum):
"""Semantics for the kernel's instruction stream."""

Lane = enum.auto()
Warpgroup = enum.auto()


def as_gpu_kernel(
body,
grid: tuple[int, int, int],
Expand All @@ -953,6 +970,7 @@ def as_gpu_kernel(
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
kernel_name: str | None = None,
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
):
if isinstance(in_shape, list):
in_shape = tuple(in_shape)
Expand All @@ -966,6 +984,12 @@ def as_gpu_kernel(
)
)

if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None:
# Run Python lowering passes. The remaining passes will be run in C++ in
# jax/jaxlib/mosaic/gpu/custom_call.cc
layout_inference.infer_layout(module) # pytype: disable=attribute-error
dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error

expected_arg_treedef = jax.tree.structure(in_shape)
def _check_args(*args):
arg_treedef = jax.tree.structure(args)
Expand Down
13 changes: 13 additions & 0 deletions jax/experimental/mosaic/gpu/dialect_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,19 @@ def _vector_store_op_lowering_rule(
return []


@_register_lowering(arith.AddFOp)
def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]:

fragmented_array_lhs = _fragmented_array_from_ir(add.lhs)
fragmented_array_rhs = _fragmented_array_from_ir(add.rhs)

return [
_fragmented_array_to_ir(
fragmented_array_lhs + fragmented_array_rhs, add.result.type
)
]


def lower_mgpu_dialect(module: ir.Module):
module.context.append_dialect_registry(mlir_interpreter.upstream_dialects)
module.context.load_all_available_dialects()
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ cc_library(
":passes",
":target",
"//jaxlib/cuda:cuda_vendor",
"//jaxlib/mosaic/dialect/gpu:mosaic_gpu",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/container:flat_hash_map",
Expand Down
4 changes: 3 additions & 1 deletion jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ limitations under the License.
#include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
#include "mlir/include/mlir/Transforms/Passes.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"
#include "jaxlib/mosaic/gpu/launch_lowering.h"
#include "jaxlib/mosaic/gpu/passes.h"
#include "jaxlib/mosaic/gpu/target.h"
Expand Down Expand Up @@ -206,7 +207,8 @@ void InitContext(mlir::MLIRContext* context) {
mlir::math::MathDialect, mlir::memref::MemRefDialect,
mlir::scf::SCFDialect, mlir::vector::VectorDialect,
mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect,
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>();
mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect,
mosaic_gpu::MosaicGPUDialect>();
mlir::registerConvertNVVMToLLVMInterface(registry);
mlir::registerConvertComplexToLLVMInterface(registry);
mlir::registerConvertMemRefToLLVMInterface(registry);
Expand Down
52 changes: 51 additions & 1 deletion tests/mosaic/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from jax._src.lib.mlir.dialects import arith
from jax._src.lib.mlir.dialects import scf
from jax._src.lib.mlir.dialects import vector
from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member
from jax.experimental.mosaic.gpu import fragmented_array as fa
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -165,8 +166,11 @@ def setUp(self):
self.skipTest("Only works on GPU with capability >= sm90")
super().setUp()
self.prng = np.random.default_rng(1234)
self.context = mlir.make_ir_context()
if mgpu_dialect is not None:
mgpu_dialect.register_dialect(self.context)
self.enter_context(jtu.global_config_context(jax_traceback_filtering="off"))
self.enter_context(mlir.make_ir_context())
self.enter_context(self.context)
self.enter_context(ir.Location.unknown())


Expand Down Expand Up @@ -1854,5 +1858,51 @@ def get_reg(addr):
self.assertLessEqual(len(used_regs), expected_regs)


class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase):
"""Device tests with lowering from the MLIR dialect and layout inference."""

def setUp(self):
if mgpu_dialect is None:
raise self.skipTest("Test requires Mosaic GPU dialect")
super().setUp()

def test_pointwise_kernel(self):
def add(ctx, a, b, result, smem):
del ctx, smem
shape = ir.MemRefType(a.type).shape
elt_type = ir.MemRefType(a.type).element_type

zero_index = arith.constant(ir.IndexType.get(), 0)

# GMEM -> registers
ab_type = ir.VectorType.get(shape, elt_type)
a = vector.load(ab_type, a, [zero_index, zero_index])
b = vector.load(ab_type, b, [zero_index, zero_index])

# Computation
add = arith.addf(a, b)

# Registers -> GMEM
vector.store(add, result, [zero_index, zero_index])

dtype = jnp.bfloat16
shape = (128, 128)
jax_shape = jax.ShapeDtypeStruct(shape, dtype)
kernel = mgpu.as_gpu_kernel(
add,
grid=(1, 1, 1),
block=(128, 1, 1),
in_shape=(jax_shape, jax_shape),
out_shape=jax_shape,
smem_scratch_shape=[],
thread_semantics=mgpu.ThreadSemantics.Warpgroup,
)

x = self.prng.uniform(-1, 1, shape).astype(dtype)
y = self.prng.uniform(-1, 1, shape).astype(dtype)

self.assertArraysEqual(jax.jit(kernel)(x, y), x + y)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 36b12d5

Please sign in to comment.