diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 79f23ab2d7d8..9c2617c04971 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -22,6 +22,7 @@ LaunchContext as LaunchContext, MemRefTransform as MemRefTransform, TMABarrier as TMABarrier, + ThreadSemantics as ThreadSemantics, TileTransform as TileTransform, TransposeTransform as TransposeTransform, Union as Union, diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index b03c3a5b54fc..b802eb436672 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -17,6 +17,7 @@ import contextlib import ctypes import dataclasses +import enum import functools import hashlib import math @@ -38,6 +39,15 @@ from jaxlib.mlir.dialects import nvvm import numpy as np +from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 + +if dialect is not None: + from . import dialect_lowering + from . import layout_inference +else: + dialect_lowering = None + layout_inference = None + from . import profiler from . import utils @@ -942,6 +952,13 @@ def _declare_runtime_functions(): ) +class ThreadSemantics(enum.Enum): + """Semantics for the kernel's instruction stream.""" + + Lane = enum.auto() + Warpgroup = enum.auto() + + def as_gpu_kernel( body, grid: tuple[int, int, int], @@ -953,6 +970,7 @@ def as_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, + thread_semantics: ThreadSemantics = ThreadSemantics.Lane, ): if isinstance(in_shape, list): in_shape = tuple(in_shape) @@ -966,6 +984,12 @@ def as_gpu_kernel( ) ) + if thread_semantics == ThreadSemantics.Warpgroup and dialect is not None: + # Run Python lowering passes. The remaining passes will be run in C++ in + # jax/jaxlib/mosaic/gpu/custom_call.cc + layout_inference.infer_layout(module) # pytype: disable=attribute-error + dialect_lowering.lower_mgpu_dialect(module) # pytype: disable=attribute-error + expected_arg_treedef = jax.tree.structure(in_shape) def _check_args(*args): arg_treedef = jax.tree.structure(args) diff --git a/jax/experimental/mosaic/gpu/dialect_lowering.py b/jax/experimental/mosaic/gpu/dialect_lowering.py index 99eb85079553..56d5aac39c86 100644 --- a/jax/experimental/mosaic/gpu/dialect_lowering.py +++ b/jax/experimental/mosaic/gpu/dialect_lowering.py @@ -237,6 +237,19 @@ def _vector_store_op_lowering_rule( return [] +@_register_lowering(arith.AddFOp) +def _arith_addf_op_lowering_rule(add: arith.AddFOp) -> Sequence[ir.Value]: + + fragmented_array_lhs = _fragmented_array_from_ir(add.lhs) + fragmented_array_rhs = _fragmented_array_from_ir(add.rhs) + + return [ + _fragmented_array_to_ir( + fragmented_array_lhs + fragmented_array_rhs, add.result.type + ) + ] + + def lower_mgpu_dialect(module: ir.Module): module.context.append_dialect_registry(mlir_interpreter.upstream_dialects) module.context.load_all_available_dialects() diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 2139a266622a..f626b40d6194 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -119,6 +119,7 @@ cc_library( ":passes", ":target", "//jaxlib/cuda:cuda_vendor", + "//jaxlib/mosaic/dialect/gpu:mosaic_gpu", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 54792b3097f7..383f68ddd3d1 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -83,6 +83,7 @@ limitations under the License. #include "mlir/include/mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/include/mlir/Transforms/Passes.h" #include "jaxlib/gpu/vendor.h" +#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" #include "jaxlib/mosaic/gpu/passes.h" #include "jaxlib/mosaic/gpu/target.h" @@ -206,7 +207,8 @@ void InitContext(mlir::MLIRContext* context) { mlir::math::MathDialect, mlir::memref::MemRefDialect, mlir::scf::SCFDialect, mlir::vector::VectorDialect, mlir::gpu::GPUDialect, mlir::nvgpu::NVGPUDialect, - mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect>(); + mlir::NVVM::NVVMDialect, mlir::LLVM::LLVMDialect, + mosaic_gpu::MosaicGPUDialect>(); mlir::registerConvertNVVMToLLVMInterface(registry); mlir::registerConvertComplexToLLVMInterface(registry); mlir::registerConvertMemRefToLLVMInterface(registry); diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index 466188b6bdcd..8ffeee728f61 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -31,6 +31,7 @@ from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector +from jax.experimental.mosaic.gpu import dialect as mgpu_dialect # pylint: disable=g-importing-member from jax.experimental.mosaic.gpu import fragmented_array as fa import jax.numpy as jnp import numpy as np @@ -165,8 +166,11 @@ def setUp(self): self.skipTest("Only works on GPU with capability >= sm90") super().setUp() self.prng = np.random.default_rng(1234) + self.context = mlir.make_ir_context() + if mgpu_dialect is not None: + mgpu_dialect.register_dialect(self.context) self.enter_context(jtu.global_config_context(jax_traceback_filtering="off")) - self.enter_context(mlir.make_ir_context()) + self.enter_context(self.context) self.enter_context(ir.Location.unknown()) @@ -1854,5 +1858,51 @@ def get_reg(addr): self.assertLessEqual(len(used_regs), expected_regs) +class MosaicGpuDialectTest(TestCase, jtu.JaxTestCase): + """Device tests with lowering from the MLIR dialect and layout inference.""" + + def setUp(self): + if mgpu_dialect is None: + raise self.skipTest("Test requires Mosaic GPU dialect") + super().setUp() + + def test_pointwise_kernel(self): + def add(ctx, a, b, result, smem): + del ctx, smem + shape = ir.MemRefType(a.type).shape + elt_type = ir.MemRefType(a.type).element_type + + zero_index = arith.constant(ir.IndexType.get(), 0) + + # GMEM -> registers + ab_type = ir.VectorType.get(shape, elt_type) + a = vector.load(ab_type, a, [zero_index, zero_index]) + b = vector.load(ab_type, b, [zero_index, zero_index]) + + # Computation + add = arith.addf(a, b) + + # Registers -> GMEM + vector.store(add, result, [zero_index, zero_index]) + + dtype = jnp.bfloat16 + shape = (128, 128) + jax_shape = jax.ShapeDtypeStruct(shape, dtype) + kernel = mgpu.as_gpu_kernel( + add, + grid=(1, 1, 1), + block=(128, 1, 1), + in_shape=(jax_shape, jax_shape), + out_shape=jax_shape, + smem_scratch_shape=[], + thread_semantics=mgpu.ThreadSemantics.Warpgroup, + ) + + x = self.prng.uniform(-1, 1, shape).astype(dtype) + y = self.prng.uniform(-1, 1, shape).astype(dtype) + + self.assertArraysEqual(jax.jit(kernel)(x, y), x + y) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())